From 3105c549c7303ea122128fb651d28ed9c48b8e62 Mon Sep 17 00:00:00 2001 From: asimfarooq5 Date: Wed, 24 Sep 2025 14:42:05 +0500 Subject: [PATCH 1/5] fix: always add default realm --- cmd/wshd/main.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cmd/wshd/main.go b/cmd/wshd/main.go index d398046..2658788 100644 --- a/cmd/wshd/main.go +++ b/cmd/wshd/main.go @@ -332,6 +332,8 @@ func main() { authenticator := wampshell.NewAuthenticator(keyStore) router := xconn.NewRouter() + addRealm(router, defaultRealm) + for realm := range authenticator.Realms() { addRealm(router, realm) } From 93e15f39f6a5fcb911c6de941582480970109e85 Mon Sep 17 00:00:00 2001 From: asimfarooq5 Date: Wed, 24 Sep 2025 14:45:44 +0500 Subject: [PATCH 2/5] move exchangeKeys in helpers for reusability --- cmd/wcp/main.go | 53 ++++++---------------------------------------- cmd/wsh/main.go | 56 +++++++------------------------------------------ helpers.go | 37 ++++++++++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 96 deletions(-) diff --git a/cmd/wcp/main.go b/cmd/wcp/main.go index edf129c..ef63e8c 100644 --- a/cmd/wcp/main.go +++ b/cmd/wcp/main.go @@ -19,48 +19,7 @@ const ( maxSize = 1024 * 1024 * 15 ) -type keyPair struct { - send []byte - receive []byte -} - -func exchangeKeys(session *xconn.Session) (*keyPair, error) { - publicKey, privateKey, err := berncrypt.CreateX25519KeyPair() - if err != nil { - return nil, err - } - - response := session.Call("wampshell.key.exchange").Arg(publicKey).Do() - if response.Err != nil { - return nil, response.Err - } - - publicKeyPeer, err := response.Args.Bytes(0) - if err != nil { - return nil, err - } - - sharedSecret, err := berncrypt.PerformKeyExchange(privateKey, publicKeyPeer) - if err != nil { - return nil, err - } - - receiveKey, err := berncrypt.DeriveKeyHKDF(sharedSecret, []byte("backendToFrontend")) - if err != nil { - return nil, err - } - - sendKey, err := berncrypt.DeriveKeyHKDF(sharedSecret, []byte("frontendToBackend")) - if err != nil { - return nil, err - } - return &keyPair{ - send: sendKey, - receive: receiveKey, - }, nil -} - -func uploadFile(session *xconn.Session, keys *keyPair, localFile, remoteFile string) error { +func uploadFile(session *xconn.Session, keys *wampshell.KeyPair, localFile, remoteFile string) error { file, err := os.Stat(localFile) if err != nil { return fmt.Errorf("failed to stat local file: %w", err) @@ -72,7 +31,7 @@ func uploadFile(session *xconn.Session, keys *keyPair, localFile, remoteFile str if err != nil { return fmt.Errorf("failed to read local file: %w", err) } - ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305(data, keys.send) + ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305(data, keys.Send) if err != nil { return err } @@ -85,7 +44,7 @@ func uploadFile(session *xconn.Session, keys *keyPair, localFile, remoteFile str if err != nil { return fmt.Errorf("output parsing error: %w", err) } - resp, err := berncrypt.DecryptChaCha20Poly1305(encResp[12:], encResp[:12], keys.receive) + resp, err := berncrypt.DecryptChaCha20Poly1305(encResp[12:], encResp[:12], keys.Receive) if err != nil { return err } @@ -93,7 +52,7 @@ func uploadFile(session *xconn.Session, keys *keyPair, localFile, remoteFile str return nil } -func downloadFile(session *xconn.Session, keys *keyPair, remoteFile, localFile string) error { +func downloadFile(session *xconn.Session, keys *wampshell.KeyPair, remoteFile, localFile string) error { cmdResponse := session.Call("wampshell.shell.download").Arg(remoteFile).Do() if cmdResponse.Err != nil { return fmt.Errorf("file download error: %w", cmdResponse.Err) @@ -102,7 +61,7 @@ func downloadFile(session *xconn.Session, keys *keyPair, remoteFile, localFile s if err != nil { return fmt.Errorf("output parsing error: %w", err) } - data, err := berncrypt.DecryptChaCha20Poly1305(encResp[12:], encResp[:12], keys.receive) + data, err := berncrypt.DecryptChaCha20Poly1305(encResp[12:], encResp[:12], keys.Receive) if err != nil { return err } @@ -199,7 +158,7 @@ func main() { panic(err) } - keys, err := exchangeKeys(session) + keys, err := wampshell.ExchangeKeys(session) if err != nil { panic(err) } diff --git a/cmd/wsh/main.go b/cmd/wsh/main.go index f124500..8637c35 100644 --- a/cmd/wsh/main.go +++ b/cmd/wsh/main.go @@ -26,49 +26,7 @@ const ( topicAnswererOnCandidate = "wampshell.webrtc.answerer.on_candidate" ) -type keyPair struct { - send []byte - receive []byte -} - -func exchangeKeys(session *xconn.Session) (*keyPair, error) { - publicKey, privateKey, err := berncrypt.CreateX25519KeyPair() - if err != nil { - return nil, err - } - - response := session.Call("wampshell.key.exchange").Arg(publicKey).Do() - if response.Err != nil { - return nil, response.Err - } - - publicKeyPeer, err := response.Args.Bytes(0) - if err != nil { - return nil, err - } - - sharedSecret, err := berncrypt.PerformKeyExchange(privateKey, publicKeyPeer) - if err != nil { - return nil, err - } - - receiveKey, err := berncrypt.DeriveKeyHKDF(sharedSecret, []byte("backendToFrontend")) - if err != nil { - return nil, err - } - - sendKey, err := berncrypt.DeriveKeyHKDF(sharedSecret, []byte("frontendToBackend")) - if err != nil { - return nil, err - } - - return &keyPair{ - send: sendKey, - receive: receiveKey, - }, nil -} - -func startInteractiveShell(session *xconn.Session, keys *keyPair) error { +func startInteractiveShell(session *xconn.Session, keys *wampshell.KeyPair) error { fd := int(os.Stdin.Fd()) oldState, err := term.MakeRaw(fd) if err != nil { @@ -91,7 +49,7 @@ func startInteractiveShell(session *xconn.Session, keys *keyPair) error { return xconn.NewFinalProgress() } - ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305(buf[:n], keys.send) + ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305(buf[:n], keys.Send) if err != nil { fmt.Printf("encryption error: %s", err) os.Exit(1) @@ -109,7 +67,7 @@ func startInteractiveShell(session *xconn.Session, keys *keyPair) error { os.Exit(1) } - plain, err := berncrypt.DecryptChaCha20Poly1305(encData[12:], encData[:12], keys.receive) + plain, err := berncrypt.DecryptChaCha20Poly1305(encData[12:], encData[:12], keys.Receive) if err != nil { _ = fmt.Errorf("decryption error: %w", err) } @@ -130,10 +88,10 @@ func startInteractiveShell(session *xconn.Session, keys *keyPair) error { return nil } -func runCommand(session *xconn.Session, keys *keyPair, args []string) error { +func runCommand(session *xconn.Session, keys *wampshell.KeyPair, args []string) error { b := []byte(strings.Join(args, " ")) - ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305(b, keys.send) + ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305(b, keys.Send) if err != nil { return fmt.Errorf("encryption error: %w", err) } @@ -151,7 +109,7 @@ func runCommand(session *xconn.Session, keys *keyPair, args []string) error { os.Exit(1) } - plain, err := berncrypt.DecryptChaCha20Poly1305(output[12:], output[:12], keys.receive) + plain, err := berncrypt.DecryptChaCha20Poly1305(output[12:], output[:12], keys.Receive) if err != nil { return fmt.Errorf("decryption error: %w", err) } @@ -240,7 +198,7 @@ func main() { } } - keys, err := exchangeKeys(session) + keys, err := wampshell.ExchangeKeys(session) if err != nil { panic(err) } diff --git a/helpers.go b/helpers.go index 1418ebc..5298141 100644 --- a/helpers.go +++ b/helpers.go @@ -7,6 +7,7 @@ import ( "path/filepath" "strings" + berncrypt "github.com/xconnio/berncrypt/go" "github.com/xconnio/wampproto-capnproto/go" "github.com/xconnio/xconn-go" ) @@ -55,3 +56,39 @@ func RunningInSnap() bool { err := cmd.Run() return err == nil } + +func ExchangeKeys(session *xconn.Session) (*KeyPair, error) { + publicKey, privateKey, err := berncrypt.CreateX25519KeyPair() + if err != nil { + return nil, err + } + + response := session.Call("wampshell.key.exchange").Arg(publicKey).Do() + if response.Err != nil { + return nil, response.Err + } + + publicKeyPeer, err := response.Args.Bytes(0) + if err != nil { + return nil, err + } + + sharedSecret, err := berncrypt.PerformKeyExchange(privateKey, publicKeyPeer) + if err != nil { + return nil, err + } + + receiveKey, err := berncrypt.DeriveKeyHKDF(sharedSecret, []byte("backendToFrontend")) + if err != nil { + return nil, err + } + + sendKey, err := berncrypt.DeriveKeyHKDF(sharedSecret, []byte("frontendToBackend")) + if err != nil { + return nil, err + } + return &KeyPair{ + Send: sendKey, + Receive: receiveKey, + }, nil +} From 388cb9813d23df9d8bcbf5206a04986ff716a2d4 Mon Sep 17 00:00:00 2001 From: asimfarooq5 Date: Thu, 25 Sep 2025 13:57:37 +0500 Subject: [PATCH 3/5] refactor(wcp): improve naming and error handling --- cmd/wcp/main.go | 120 ++++++++++++++++++++++++------------------------ 1 file changed, 61 insertions(+), 59 deletions(-) diff --git a/cmd/wcp/main.go b/cmd/wcp/main.go index ef63e8c..9a6a67b 100644 --- a/cmd/wcp/main.go +++ b/cmd/wcp/main.go @@ -3,6 +3,7 @@ package main import ( "context" "fmt" + "log" "os" "path/filepath" "strings" @@ -19,75 +20,85 @@ const ( maxSize = 1024 * 1024 * 15 ) -func uploadFile(session *xconn.Session, keys *wampshell.KeyPair, localFile, remoteFile string) error { - file, err := os.Stat(localFile) +func uploadFile(session *xconn.Session, keys *wampshell.KeyPair, localPath, remotePath string) error { + fileInfo, err := os.Stat(localPath) if err != nil { return fmt.Errorf("failed to stat local file: %w", err) } - if file.Size() > maxSize { - return fmt.Errorf("file too large: %d bytes (max %d bytes)", file.Size(), maxSize) + if fileInfo.Size() > maxSize { + return fmt.Errorf("file too large: %d bytes (max %d bytes)", fileInfo.Size(), maxSize) } - data, err := os.ReadFile(localFile) + + data, err := os.ReadFile(localPath) if err != nil { return fmt.Errorf("failed to read local file: %w", err) } + ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305(data, keys.Send) if err != nil { return err } + encryptedPayload := append(nonce, ciphertext...) - cmdResponse := session.Call("wampshell.shell.upload").Args(remoteFile, encryptedPayload).Do() - if cmdResponse.Err != nil { - return fmt.Errorf("file upload error: %w", cmdResponse.Err) + callResponse := session.Call("wampshell.shell.upload").Args(remotePath, encryptedPayload).Do() + if callResponse.Err != nil { + return fmt.Errorf("file upload error: %w", callResponse.Err) } - encResp, err := cmdResponse.Args.Bytes(0) + + encResp, err := callResponse.Args.Bytes(0) if err != nil { - return fmt.Errorf("output parsing error: %w", err) + return fmt.Errorf("parsing response failed: %w", err) } - resp, err := berncrypt.DecryptChaCha20Poly1305(encResp[12:], encResp[:12], keys.Receive) + + plainResp, err := berncrypt.DecryptChaCha20Poly1305(encResp[12:], encResp[:12], keys.Receive) if err != nil { - return err + return fmt.Errorf("response decryption failed: %w", err) } - fmt.Printf("Server response: %s\n", string(resp)) + + log.Printf("Upload response: %s", string(plainResp)) return nil } -func downloadFile(session *xconn.Session, keys *wampshell.KeyPair, remoteFile, localFile string) error { - cmdResponse := session.Call("wampshell.shell.download").Arg(remoteFile).Do() - if cmdResponse.Err != nil { - return fmt.Errorf("file download error: %w", cmdResponse.Err) +func downloadFile(session *xconn.Session, keys *wampshell.KeyPair, remotePath, localPath string) error { + callResponse := session.Call("wampshell.shell.download").Arg(remotePath).Do() + if callResponse.Err != nil { + return fmt.Errorf("file download error: %w", callResponse.Err) } - encResp, err := cmdResponse.Args.Bytes(0) + + encResp, err := callResponse.Args.Bytes(0) if err != nil { - return fmt.Errorf("output parsing error: %w", err) + return fmt.Errorf("failed to parse response: %w", err) } + data, err := berncrypt.DecryptChaCha20Poly1305(encResp[12:], encResp[:12], keys.Receive) if err != nil { return err } - if err := os.WriteFile(localFile, data, 0600); err != nil { + + if err := os.WriteFile(localPath, data, 0600); err != nil { return fmt.Errorf("failed to save file: %w", err) } - fmt.Printf("Downloaded %s -> %s (%d bytes)\n", remoteFile, localFile, len(data)) + + log.Printf("Downloaded %s → %s (%d bytes)", remotePath, localPath, len(data)) return nil } -func splitRemote(s string) (user, host, port, path string, err error) { +func parseRemoteTarget(target string) (user, host, port, path string, err error) { port = "8022" - if strings.Contains(s, "@") { - parts := strings.SplitN(s, "@", 2) - user, s = parts[0], parts[1] + if strings.Contains(target, "@") { + parts := strings.SplitN(target, "@", 2) + user, target = parts[0], parts[1] } - parts := strings.SplitN(s, ":", 3) + parts := strings.SplitN(target, ":", 3) switch len(parts) { case 2: host, path = parts[0], parts[1] case 3: host, port, path = parts[0], parts[1], parts[2] default: - err = fmt.Errorf("invalid target: %s", s) + err = fmt.Errorf("invalid target: %s", target) } if user == "" { @@ -107,44 +118,37 @@ func main() { var opts Options parser := flags.NewParser(&opts, flags.Default) - _, err := parser.Parse() - if err != nil { - os.Exit(1) + if _, err := parser.Parse(); err != nil { + log.Fatal(err) } src := opts.Args.Source dst := opts.Args.Target var mode string - var localFile, remoteFile string + var localPath, remotePath string var user, host, port string if strings.Contains(src, ":") && !strings.Contains(dst, ":") { mode = "download" - user, host, port, remoteFile, _ = splitRemote(src) - localFile = dst + user, host, port, remotePath, _ = parseRemoteTarget(src) + localPath = dst } else if !strings.Contains(src, ":") && strings.Contains(dst, ":") { mode = "upload" - localFile = src - user, host, port, remoteFile, _ = splitRemote(dst) + localPath = src + user, host, port, remotePath, _ = parseRemoteTarget(dst) } else { - fmt.Println("Invalid usage: one of source/target must be remote (user@host:path)") - os.Exit(1) + log.Fatal("Invalid usage: one of source/target must be remote (user@host:path)") } privateKey, err := wampshell.ReadPrivateKeyFromFile() if err != nil { - fmt.Printf("Error reading private key: %v\n", err) - os.Exit(1) + log.Fatalf("Reading private key failed: %v", err) } - authExtra := map[string]any{} - authExtra["user"] = user - - authenticator, err := auth.NewCryptoSignAuthenticator("", privateKey, authExtra) + authenticator, err := auth.NewCryptoSignAuthenticator("", privateKey, map[string]any{"user": user}) if err != nil { - fmt.Printf("Error creating crypto sign authenticator: %v\n", err) - os.Exit(1) + log.Fatalf("Creating authenticator failed: %v", err) } client := xconn.Client{ @@ -155,32 +159,30 @@ func main() { url := fmt.Sprintf("rs://%s:%s", host, port) session, err := client.Connect(context.Background(), url, "wampshell") if err != nil { - panic(err) + log.Fatalf("Connection failed: %v", err) } keys, err := wampshell.ExchangeKeys(session) if err != nil { - panic(err) + log.Fatalf("Key exchange failed: %v", err) } switch mode { case "upload": - if strings.HasSuffix(remoteFile, "/") { - remoteFile += filepath.Base(localFile) - } else if remoteFile == "" { - remoteFile = filepath.Base(localFile) + if strings.HasSuffix(remotePath, "/") { + remotePath += filepath.Base(localPath) + } else if remotePath == "" { + remotePath = filepath.Base(localPath) } - if err := uploadFile(session, keys, localFile, remoteFile); err != nil { - fmt.Printf("Upload failed: %v\n", err) - os.Exit(1) + if err := uploadFile(session, keys, localPath, remotePath); err != nil { + log.Fatalf("Upload failed: %v", err) } case "download": - if fi, err := os.Stat(localFile); err == nil && fi.IsDir() { - localFile = filepath.Join(localFile, filepath.Base(remoteFile)) + if fi, err := os.Stat(localPath); err == nil && fi.IsDir() { + localPath = filepath.Join(localPath, filepath.Base(remotePath)) } - if err := downloadFile(session, keys, remoteFile, localFile); err != nil { - fmt.Printf("Download failed: %v\n", err) - os.Exit(1) + if err := downloadFile(session, keys, remotePath, localPath); err != nil { + log.Fatalf("Download failed: %v", err) } } } From 0753ac7bff94a0a42b3d7ed8b1bf147f4fcc3c80 Mon Sep 17 00:00:00 2001 From: asimfarooq5 Date: Thu, 25 Sep 2025 14:03:26 +0500 Subject: [PATCH 4/5] refactor(wsh): improve naming and error handling --- cmd/wsh/main.go | 92 ++++++++++++++++++++++++++----------------------- 1 file changed, 48 insertions(+), 44 deletions(-) diff --git a/cmd/wsh/main.go b/cmd/wsh/main.go index 8637c35..28babe4 100644 --- a/cmd/wsh/main.go +++ b/cmd/wsh/main.go @@ -27,6 +27,8 @@ const ( ) func startInteractiveShell(session *xconn.Session, keys *wampshell.KeyPair) error { + const nonceSize = 12 + fd := int(os.Stdin.Fd()) oldState, err := term.MakeRaw(fd) if err != nil { @@ -36,54 +38,59 @@ func startInteractiveShell(session *xconn.Session, keys *wampshell.KeyPair) erro firstProgress := true + readAndEncrypt := func() (*xconn.Progress, error) { + buf := make([]byte, 1024) + n, err := os.Stdin.Read(buf) + if err != nil { + return xconn.NewFinalProgress(), nil + } + + ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305(buf[:n], keys.Send) + if err != nil { + return nil, fmt.Errorf("encryption error: %w", err) + } + payload := append(nonce, ciphertext...) + return xconn.NewProgress(payload), nil + } + + decryptAndWrite := func(encData []byte) error { + if len(encData) < nonceSize { + return fmt.Errorf("invalid payload from server: too short") + } + plain, err := berncrypt.DecryptChaCha20Poly1305(encData[nonceSize:], encData[:nonceSize], keys.Receive) + if err != nil { + return fmt.Errorf("decryption error: %w", err) + } + _, err = os.Stdout.Write(plain) + return err + } + call := session.Call(procedureInteractive). ProgressSender(func(ctx context.Context) *xconn.Progress { if firstProgress { firstProgress = false return xconn.NewProgress() } - - buf := make([]byte, 1024) - n, err := os.Stdin.Read(buf) + progress, err := readAndEncrypt() if err != nil { + fmt.Fprintln(os.Stderr, err) return xconn.NewFinalProgress() } - - ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305(buf[:n], keys.Send) - if err != nil { - fmt.Printf("encryption error: %s", err) - os.Exit(1) - } - payload := append(nonce, ciphertext...) - - return xconn.NewProgress(payload) + return progress }). ProgressReceiver(func(result *xconn.InvocationResult) { if len(result.Args) > 0 { - encData := result.Args[0].([]byte) - - if len(encData) < 12 { - fmt.Fprintln(os.Stderr, "invalid payload from server") - os.Exit(1) - } - - plain, err := berncrypt.DecryptChaCha20Poly1305(encData[12:], encData[:12], keys.Receive) - if err != nil { - _ = fmt.Errorf("decryption error: %w", err) + if err := decryptAndWrite(result.Args[0].([]byte)); err != nil { + fmt.Fprintln(os.Stderr, err) } - - os.Stdout.Write(plain) } else { - err = term.Restore(fd, oldState) - if err != nil { - return - } + _ = term.Restore(fd, oldState) os.Exit(0) } }).Do() if call.Err != nil { - log.Fatalf("Shell error: %s", call.Err) + return fmt.Errorf("shell error: %w", call.Err) } return nil } @@ -98,22 +105,21 @@ func runCommand(session *xconn.Session, keys *wampshell.KeyPair, args []string) payload := append(nonce, ciphertext...) - cmdResponse := session.Call(procedureExec).Args(payload).Do() - if cmdResponse.Err != nil { - return fmt.Errorf("command execution error: %w", cmdResponse.Err) + callResponse := session.Call(procedureExec).Args(payload).Do() + if callResponse.Err != nil { + return fmt.Errorf("command execution failed: %w", callResponse.Err) } - output, err := cmdResponse.Args.Bytes(0) + encryptedOutput, err := callResponse.Args.Bytes(0) if err != nil { - fmt.Printf("Output parsing error: %v", err) - os.Exit(1) + return fmt.Errorf("output parsing error: %w", err) } - plain, err := berncrypt.DecryptChaCha20Poly1305(output[12:], output[:12], keys.Receive) + plainOutput, err := berncrypt.DecryptChaCha20Poly1305(encryptedOutput[12:], encryptedOutput[:12], keys.Receive) if err != nil { - return fmt.Errorf("decryption error: %w", err) + return fmt.Errorf("decryption failed: %w", err) } - fmt.Print(string(plain)) + fmt.Print(string(plainOutput)) return nil } @@ -132,7 +138,7 @@ func main() { _, err := parser.Parse() if err != nil { - os.Exit(1) + log.Fatalln(err) } target := opts.Args.Target @@ -145,8 +151,7 @@ func main() { } else { user := os.Getenv("USER") if user == "" { - fmt.Println("Error: user not provided and $USER not set") - os.Exit(1) + log.Fatalln("Error: user not provided and $USER not set") } host = target } @@ -160,8 +165,7 @@ func main() { privateKey, err := wampshell.ReadPrivateKeyFromFile() if err != nil { - fmt.Printf("Error reading private key: %v\n", err) - os.Exit(1) + log.Fatalf("Error reading private key: %s", err) } authenticator, err := auth.NewCryptoSignAuthenticator("", privateKey, nil) @@ -200,7 +204,7 @@ func main() { keys, err := wampshell.ExchangeKeys(session) if err != nil { - panic(err) + log.Fatalf("Failed to exchange keys: %v", err) } if opts.Interactive || len(args) == 0 { From f3ea0b0154e28d9a71e143fcd60afd6457519c7d Mon Sep 17 00:00:00 2001 From: asimfarooq5 Date: Thu, 25 Sep 2025 16:32:58 +0500 Subject: [PATCH 5/5] refactor(wshd): improve naming and error handling --- cmd/wsh/main.go | 2 +- cmd/wshd/main.go | 51 ++++++++++++++---------------------------------- encryption.go | 7 +++++++ 3 files changed, 23 insertions(+), 37 deletions(-) diff --git a/cmd/wsh/main.go b/cmd/wsh/main.go index 28babe4..4afca96 100644 --- a/cmd/wsh/main.go +++ b/cmd/wsh/main.go @@ -42,7 +42,7 @@ func startInteractiveShell(session *xconn.Session, keys *wampshell.KeyPair) erro buf := make([]byte, 1024) n, err := os.Stdin.Read(buf) if err != nil { - return xconn.NewFinalProgress(), nil + return nil, fmt.Errorf("read error: %w", err) } ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305(buf[:n], keys.Send) diff --git a/cmd/wshd/main.go b/cmd/wshd/main.go index 2658788..9937e04 100644 --- a/cmd/wshd/main.go +++ b/cmd/wshd/main.go @@ -94,9 +94,7 @@ func (p *interactiveShellSession) handleShell(e *wampshell.EncryptionManager) fu return func(_ context.Context, inv *xconn.Invocation) *xconn.InvocationResult { caller := inv.Caller() - e.Lock() - key, ok := e.Keys()[inv.Caller()] - e.Unlock() + key, ok := e.Key(inv.Caller()) if !ok { return xconn.NewInvocationError("wamp.error.unavailable", "unavailable") } @@ -173,21 +171,17 @@ func runCommand(cmd string, args ...string) ([]byte, error) { func handleRunCommand(e *wampshell.EncryptionManager) func(_ context.Context, inv *xconn.Invocation) *xconn.InvocationResult { return func(_ context.Context, inv *xconn.Invocation) *xconn.InvocationResult { - - payload, err := inv.ArgBytes(0) + encryptedPayload, err := inv.ArgBytes(0) if err != nil { return xconn.NewInvocationError("wamp.error.invalid_argument", err.Error()) } - e.Lock() - key, ok := e.Keys()[inv.Caller()] - e.Unlock() - + key, ok := e.Key(inv.Caller()) if !ok { return xconn.NewInvocationError("wamp.error.unavailable", "unavailable") } - decryptedPayload, err := berncrypt.DecryptChaCha20Poly1305(payload[12:], payload[:12], key.Receive) + decryptedPayload, err := berncrypt.DecryptChaCha20Poly1305(encryptedPayload[12:], encryptedPayload[:12], key.Receive) if err != nil { return xconn.NewInvocationError("wamp.error.internal_error", err.Error()) } @@ -203,17 +197,13 @@ func handleRunCommand(e *wampshell.EncryptionManager) func(_ context.Context, return xconn.NewInvocationError("wamp.error.internal_error", err.Error()) } - ciphertext1, nonce1, err1 := berncrypt.EncryptChaCha20Poly1305(output, key.Send) - if err1 != nil { - log.Printf("Encryption failed in runCommand: %v", err1) - return xconn.NewInvocationError("wamp.error.internal_error", err1.Error()) + ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305(output, key.Send) + if err != nil { + log.Printf("Encryption failed in runCommand: %v", err) + return xconn.NewInvocationError("wamp.error.internal_error", err.Error()) } - payload1 := make([]byte, len(nonce1)+len(ciphertext1)) - copy(payload1, nonce1) - copy(payload1[len(nonce1):], ciphertext1) - - return xconn.NewInvocationResult(payload1) + return xconn.NewInvocationResult(append(nonce, ciphertext...)) } } @@ -237,9 +227,7 @@ func handleFileUpload(e *wampshell.EncryptionManager) func(_ context.Context, fmt.Sprintf("file content must be []byte, got %s", err.Error())) } - e.Lock() - key, ok := e.Keys()[inv.Caller()] - e.Unlock() + key, ok := e.Key(inv.Caller()) if !ok { return xconn.NewInvocationError("wamp.error.unavailable", "no encryption key for caller") } @@ -273,9 +261,7 @@ func handleFileDownload(e *wampshell.EncryptionManager) func(_ context.Context, return xconn.NewInvocationError("wamp.error.invalid_argument", err.Error()) } - e.Lock() - key, ok := e.Keys()[inv.Caller()] - e.Unlock() + key, ok := e.Key(inv.Caller()) if !ok { return xconn.NewInvocationError("wamp.error.unavailable", "no encryption key for caller") } @@ -294,15 +280,6 @@ func handleFileDownload(e *wampshell.EncryptionManager) func(_ context.Context, } } -func registerProcedure(session *xconn.Session, procedure string, handler xconn.InvocationHandler) error { - response := session.Register(procedure, handler).Do() - if response.Err != nil { - return fmt.Errorf("failed to register procedure %q: %w", procedure, response.Err) - } - log.Printf("Procedure registered: %s", procedure) - return nil -} - func addRealm(router *xconn.Router, realm string) { if router.HasRealm(realm) { return @@ -396,9 +373,11 @@ func main() { } for _, proc := range procedures { - if err = registerProcedure(session, proc.name, proc.handler); err != nil { - log.Fatal(err) + registerResponse := session.Register(proc.name, proc.handler).Do() + if registerResponse.Err != nil { + log.Fatalln(registerResponse.Err) } + log.Printf("Procedure registered: %s", proc.name) } log.Printf("listening on rs://%s", address) diff --git a/encryption.go b/encryption.go index 4b34706..b108deb 100644 --- a/encryption.go +++ b/encryption.go @@ -115,3 +115,10 @@ func (e *EncryptionManager) TestEcho(_ context.Context, invocation *xconn.Invoca func (e *EncryptionManager) Keys() map[uint64]*KeyPair { return e.keys } + +func (e *EncryptionManager) Key(sessionID uint64) (*KeyPair, bool) { + e.Lock() + defer e.Unlock() + key, ok := e.keys[sessionID] + return key, ok +}