diff --git a/cmd/wcp/main.go b/cmd/wcp/main.go index edf129c..9a6a67b 100644 --- a/cmd/wcp/main.go +++ b/cmd/wcp/main.go @@ -3,6 +3,7 @@ package main import ( "context" "fmt" + "log" "os" "path/filepath" "strings" @@ -19,116 +20,85 @@ const ( maxSize = 1024 * 1024 * 15 ) -type keyPair struct { - send []byte - receive []byte -} - -func exchangeKeys(session *xconn.Session) (*keyPair, error) { - publicKey, privateKey, err := berncrypt.CreateX25519KeyPair() +func uploadFile(session *xconn.Session, keys *wampshell.KeyPair, localPath, remotePath string) error { + fileInfo, err := os.Stat(localPath) if err != nil { - return nil, err + return fmt.Errorf("failed to stat local file: %w", err) } - - response := session.Call("wampshell.key.exchange").Arg(publicKey).Do() - if response.Err != nil { - return nil, response.Err + if fileInfo.Size() > maxSize { + return fmt.Errorf("file too large: %d bytes (max %d bytes)", fileInfo.Size(), maxSize) } - publicKeyPeer, err := response.Args.Bytes(0) + data, err := os.ReadFile(localPath) if err != nil { - return nil, err + return fmt.Errorf("failed to read local file: %w", err) } - sharedSecret, err := berncrypt.PerformKeyExchange(privateKey, publicKeyPeer) + ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305(data, keys.Send) if err != nil { - return nil, err + return err } - receiveKey, err := berncrypt.DeriveKeyHKDF(sharedSecret, []byte("backendToFrontend")) - if err != nil { - return nil, err + encryptedPayload := append(nonce, ciphertext...) + callResponse := session.Call("wampshell.shell.upload").Args(remotePath, encryptedPayload).Do() + if callResponse.Err != nil { + return fmt.Errorf("file upload error: %w", callResponse.Err) } - sendKey, err := berncrypt.DeriveKeyHKDF(sharedSecret, []byte("frontendToBackend")) + encResp, err := callResponse.Args.Bytes(0) if err != nil { - return nil, err + return fmt.Errorf("parsing response failed: %w", err) } - return &keyPair{ - send: sendKey, - receive: receiveKey, - }, nil -} -func uploadFile(session *xconn.Session, keys *keyPair, localFile, remoteFile string) error { - file, err := os.Stat(localFile) - if err != nil { - return fmt.Errorf("failed to stat local file: %w", err) - } - if file.Size() > maxSize { - return fmt.Errorf("file too large: %d bytes (max %d bytes)", file.Size(), maxSize) - } - data, err := os.ReadFile(localFile) + plainResp, err := berncrypt.DecryptChaCha20Poly1305(encResp[12:], encResp[:12], keys.Receive) if err != nil { - return fmt.Errorf("failed to read local file: %w", err) + return fmt.Errorf("response decryption failed: %w", err) } - ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305(data, keys.send) - if err != nil { - return err - } - encryptedPayload := append(nonce, ciphertext...) - cmdResponse := session.Call("wampshell.shell.upload").Args(remoteFile, encryptedPayload).Do() - if cmdResponse.Err != nil { - return fmt.Errorf("file upload error: %w", cmdResponse.Err) - } - encResp, err := cmdResponse.Args.Bytes(0) - if err != nil { - return fmt.Errorf("output parsing error: %w", err) - } - resp, err := berncrypt.DecryptChaCha20Poly1305(encResp[12:], encResp[:12], keys.receive) - if err != nil { - return err - } - fmt.Printf("Server response: %s\n", string(resp)) + + log.Printf("Upload response: %s", string(plainResp)) return nil } -func downloadFile(session *xconn.Session, keys *keyPair, remoteFile, localFile string) error { - cmdResponse := session.Call("wampshell.shell.download").Arg(remoteFile).Do() - if cmdResponse.Err != nil { - return fmt.Errorf("file download error: %w", cmdResponse.Err) +func downloadFile(session *xconn.Session, keys *wampshell.KeyPair, remotePath, localPath string) error { + callResponse := session.Call("wampshell.shell.download").Arg(remotePath).Do() + if callResponse.Err != nil { + return fmt.Errorf("file download error: %w", callResponse.Err) } - encResp, err := cmdResponse.Args.Bytes(0) + + encResp, err := callResponse.Args.Bytes(0) if err != nil { - return fmt.Errorf("output parsing error: %w", err) + return fmt.Errorf("failed to parse response: %w", err) } - data, err := berncrypt.DecryptChaCha20Poly1305(encResp[12:], encResp[:12], keys.receive) + + data, err := berncrypt.DecryptChaCha20Poly1305(encResp[12:], encResp[:12], keys.Receive) if err != nil { return err } - if err := os.WriteFile(localFile, data, 0600); err != nil { + + if err := os.WriteFile(localPath, data, 0600); err != nil { return fmt.Errorf("failed to save file: %w", err) } - fmt.Printf("Downloaded %s -> %s (%d bytes)\n", remoteFile, localFile, len(data)) + + log.Printf("Downloaded %s → %s (%d bytes)", remotePath, localPath, len(data)) return nil } -func splitRemote(s string) (user, host, port, path string, err error) { +func parseRemoteTarget(target string) (user, host, port, path string, err error) { port = "8022" - if strings.Contains(s, "@") { - parts := strings.SplitN(s, "@", 2) - user, s = parts[0], parts[1] + if strings.Contains(target, "@") { + parts := strings.SplitN(target, "@", 2) + user, target = parts[0], parts[1] } - parts := strings.SplitN(s, ":", 3) + parts := strings.SplitN(target, ":", 3) switch len(parts) { case 2: host, path = parts[0], parts[1] case 3: host, port, path = parts[0], parts[1], parts[2] default: - err = fmt.Errorf("invalid target: %s", s) + err = fmt.Errorf("invalid target: %s", target) } if user == "" { @@ -148,44 +118,37 @@ func main() { var opts Options parser := flags.NewParser(&opts, flags.Default) - _, err := parser.Parse() - if err != nil { - os.Exit(1) + if _, err := parser.Parse(); err != nil { + log.Fatal(err) } src := opts.Args.Source dst := opts.Args.Target var mode string - var localFile, remoteFile string + var localPath, remotePath string var user, host, port string if strings.Contains(src, ":") && !strings.Contains(dst, ":") { mode = "download" - user, host, port, remoteFile, _ = splitRemote(src) - localFile = dst + user, host, port, remotePath, _ = parseRemoteTarget(src) + localPath = dst } else if !strings.Contains(src, ":") && strings.Contains(dst, ":") { mode = "upload" - localFile = src - user, host, port, remoteFile, _ = splitRemote(dst) + localPath = src + user, host, port, remotePath, _ = parseRemoteTarget(dst) } else { - fmt.Println("Invalid usage: one of source/target must be remote (user@host:path)") - os.Exit(1) + log.Fatal("Invalid usage: one of source/target must be remote (user@host:path)") } privateKey, err := wampshell.ReadPrivateKeyFromFile() if err != nil { - fmt.Printf("Error reading private key: %v\n", err) - os.Exit(1) + log.Fatalf("Reading private key failed: %v", err) } - authExtra := map[string]any{} - authExtra["user"] = user - - authenticator, err := auth.NewCryptoSignAuthenticator("", privateKey, authExtra) + authenticator, err := auth.NewCryptoSignAuthenticator("", privateKey, map[string]any{"user": user}) if err != nil { - fmt.Printf("Error creating crypto sign authenticator: %v\n", err) - os.Exit(1) + log.Fatalf("Creating authenticator failed: %v", err) } client := xconn.Client{ @@ -196,32 +159,30 @@ func main() { url := fmt.Sprintf("rs://%s:%s", host, port) session, err := client.Connect(context.Background(), url, "wampshell") if err != nil { - panic(err) + log.Fatalf("Connection failed: %v", err) } - keys, err := exchangeKeys(session) + keys, err := wampshell.ExchangeKeys(session) if err != nil { - panic(err) + log.Fatalf("Key exchange failed: %v", err) } switch mode { case "upload": - if strings.HasSuffix(remoteFile, "/") { - remoteFile += filepath.Base(localFile) - } else if remoteFile == "" { - remoteFile = filepath.Base(localFile) + if strings.HasSuffix(remotePath, "/") { + remotePath += filepath.Base(localPath) + } else if remotePath == "" { + remotePath = filepath.Base(localPath) } - if err := uploadFile(session, keys, localFile, remoteFile); err != nil { - fmt.Printf("Upload failed: %v\n", err) - os.Exit(1) + if err := uploadFile(session, keys, localPath, remotePath); err != nil { + log.Fatalf("Upload failed: %v", err) } case "download": - if fi, err := os.Stat(localFile); err == nil && fi.IsDir() { - localFile = filepath.Join(localFile, filepath.Base(remoteFile)) + if fi, err := os.Stat(localPath); err == nil && fi.IsDir() { + localPath = filepath.Join(localPath, filepath.Base(remotePath)) } - if err := downloadFile(session, keys, remoteFile, localFile); err != nil { - fmt.Printf("Download failed: %v\n", err) - os.Exit(1) + if err := downloadFile(session, keys, remotePath, localPath); err != nil { + log.Fatalf("Download failed: %v", err) } } } diff --git a/cmd/wsh/main.go b/cmd/wsh/main.go index f124500..4afca96 100644 --- a/cmd/wsh/main.go +++ b/cmd/wsh/main.go @@ -26,49 +26,9 @@ const ( topicAnswererOnCandidate = "wampshell.webrtc.answerer.on_candidate" ) -type keyPair struct { - send []byte - receive []byte -} - -func exchangeKeys(session *xconn.Session) (*keyPair, error) { - publicKey, privateKey, err := berncrypt.CreateX25519KeyPair() - if err != nil { - return nil, err - } - - response := session.Call("wampshell.key.exchange").Arg(publicKey).Do() - if response.Err != nil { - return nil, response.Err - } +func startInteractiveShell(session *xconn.Session, keys *wampshell.KeyPair) error { + const nonceSize = 12 - publicKeyPeer, err := response.Args.Bytes(0) - if err != nil { - return nil, err - } - - sharedSecret, err := berncrypt.PerformKeyExchange(privateKey, publicKeyPeer) - if err != nil { - return nil, err - } - - receiveKey, err := berncrypt.DeriveKeyHKDF(sharedSecret, []byte("backendToFrontend")) - if err != nil { - return nil, err - } - - sendKey, err := berncrypt.DeriveKeyHKDF(sharedSecret, []byte("frontendToBackend")) - if err != nil { - return nil, err - } - - return &keyPair{ - send: sendKey, - receive: receiveKey, - }, nil -} - -func startInteractiveShell(session *xconn.Session, keys *keyPair) error { fd := int(os.Stdin.Fd()) oldState, err := term.MakeRaw(fd) if err != nil { @@ -78,84 +38,88 @@ func startInteractiveShell(session *xconn.Session, keys *keyPair) error { firstProgress := true + readAndEncrypt := func() (*xconn.Progress, error) { + buf := make([]byte, 1024) + n, err := os.Stdin.Read(buf) + if err != nil { + return nil, fmt.Errorf("read error: %w", err) + } + + ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305(buf[:n], keys.Send) + if err != nil { + return nil, fmt.Errorf("encryption error: %w", err) + } + payload := append(nonce, ciphertext...) + return xconn.NewProgress(payload), nil + } + + decryptAndWrite := func(encData []byte) error { + if len(encData) < nonceSize { + return fmt.Errorf("invalid payload from server: too short") + } + plain, err := berncrypt.DecryptChaCha20Poly1305(encData[nonceSize:], encData[:nonceSize], keys.Receive) + if err != nil { + return fmt.Errorf("decryption error: %w", err) + } + _, err = os.Stdout.Write(plain) + return err + } + call := session.Call(procedureInteractive). ProgressSender(func(ctx context.Context) *xconn.Progress { if firstProgress { firstProgress = false return xconn.NewProgress() } - - buf := make([]byte, 1024) - n, err := os.Stdin.Read(buf) + progress, err := readAndEncrypt() if err != nil { + fmt.Fprintln(os.Stderr, err) return xconn.NewFinalProgress() } - - ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305(buf[:n], keys.send) - if err != nil { - fmt.Printf("encryption error: %s", err) - os.Exit(1) - } - payload := append(nonce, ciphertext...) - - return xconn.NewProgress(payload) + return progress }). ProgressReceiver(func(result *xconn.InvocationResult) { if len(result.Args) > 0 { - encData := result.Args[0].([]byte) - - if len(encData) < 12 { - fmt.Fprintln(os.Stderr, "invalid payload from server") - os.Exit(1) - } - - plain, err := berncrypt.DecryptChaCha20Poly1305(encData[12:], encData[:12], keys.receive) - if err != nil { - _ = fmt.Errorf("decryption error: %w", err) + if err := decryptAndWrite(result.Args[0].([]byte)); err != nil { + fmt.Fprintln(os.Stderr, err) } - - os.Stdout.Write(plain) } else { - err = term.Restore(fd, oldState) - if err != nil { - return - } + _ = term.Restore(fd, oldState) os.Exit(0) } }).Do() if call.Err != nil { - log.Fatalf("Shell error: %s", call.Err) + return fmt.Errorf("shell error: %w", call.Err) } return nil } -func runCommand(session *xconn.Session, keys *keyPair, args []string) error { +func runCommand(session *xconn.Session, keys *wampshell.KeyPair, args []string) error { b := []byte(strings.Join(args, " ")) - ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305(b, keys.send) + ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305(b, keys.Send) if err != nil { return fmt.Errorf("encryption error: %w", err) } payload := append(nonce, ciphertext...) - cmdResponse := session.Call(procedureExec).Args(payload).Do() - if cmdResponse.Err != nil { - return fmt.Errorf("command execution error: %w", cmdResponse.Err) + callResponse := session.Call(procedureExec).Args(payload).Do() + if callResponse.Err != nil { + return fmt.Errorf("command execution failed: %w", callResponse.Err) } - output, err := cmdResponse.Args.Bytes(0) + encryptedOutput, err := callResponse.Args.Bytes(0) if err != nil { - fmt.Printf("Output parsing error: %v", err) - os.Exit(1) + return fmt.Errorf("output parsing error: %w", err) } - plain, err := berncrypt.DecryptChaCha20Poly1305(output[12:], output[:12], keys.receive) + plainOutput, err := berncrypt.DecryptChaCha20Poly1305(encryptedOutput[12:], encryptedOutput[:12], keys.Receive) if err != nil { - return fmt.Errorf("decryption error: %w", err) + return fmt.Errorf("decryption failed: %w", err) } - fmt.Print(string(plain)) + fmt.Print(string(plainOutput)) return nil } @@ -174,7 +138,7 @@ func main() { _, err := parser.Parse() if err != nil { - os.Exit(1) + log.Fatalln(err) } target := opts.Args.Target @@ -187,8 +151,7 @@ func main() { } else { user := os.Getenv("USER") if user == "" { - fmt.Println("Error: user not provided and $USER not set") - os.Exit(1) + log.Fatalln("Error: user not provided and $USER not set") } host = target } @@ -202,8 +165,7 @@ func main() { privateKey, err := wampshell.ReadPrivateKeyFromFile() if err != nil { - fmt.Printf("Error reading private key: %v\n", err) - os.Exit(1) + log.Fatalf("Error reading private key: %s", err) } authenticator, err := auth.NewCryptoSignAuthenticator("", privateKey, nil) @@ -240,9 +202,9 @@ func main() { } } - keys, err := exchangeKeys(session) + keys, err := wampshell.ExchangeKeys(session) if err != nil { - panic(err) + log.Fatalf("Failed to exchange keys: %v", err) } if opts.Interactive || len(args) == 0 { diff --git a/cmd/wshd/main.go b/cmd/wshd/main.go index d398046..9937e04 100644 --- a/cmd/wshd/main.go +++ b/cmd/wshd/main.go @@ -94,9 +94,7 @@ func (p *interactiveShellSession) handleShell(e *wampshell.EncryptionManager) fu return func(_ context.Context, inv *xconn.Invocation) *xconn.InvocationResult { caller := inv.Caller() - e.Lock() - key, ok := e.Keys()[inv.Caller()] - e.Unlock() + key, ok := e.Key(inv.Caller()) if !ok { return xconn.NewInvocationError("wamp.error.unavailable", "unavailable") } @@ -173,21 +171,17 @@ func runCommand(cmd string, args ...string) ([]byte, error) { func handleRunCommand(e *wampshell.EncryptionManager) func(_ context.Context, inv *xconn.Invocation) *xconn.InvocationResult { return func(_ context.Context, inv *xconn.Invocation) *xconn.InvocationResult { - - payload, err := inv.ArgBytes(0) + encryptedPayload, err := inv.ArgBytes(0) if err != nil { return xconn.NewInvocationError("wamp.error.invalid_argument", err.Error()) } - e.Lock() - key, ok := e.Keys()[inv.Caller()] - e.Unlock() - + key, ok := e.Key(inv.Caller()) if !ok { return xconn.NewInvocationError("wamp.error.unavailable", "unavailable") } - decryptedPayload, err := berncrypt.DecryptChaCha20Poly1305(payload[12:], payload[:12], key.Receive) + decryptedPayload, err := berncrypt.DecryptChaCha20Poly1305(encryptedPayload[12:], encryptedPayload[:12], key.Receive) if err != nil { return xconn.NewInvocationError("wamp.error.internal_error", err.Error()) } @@ -203,17 +197,13 @@ func handleRunCommand(e *wampshell.EncryptionManager) func(_ context.Context, return xconn.NewInvocationError("wamp.error.internal_error", err.Error()) } - ciphertext1, nonce1, err1 := berncrypt.EncryptChaCha20Poly1305(output, key.Send) - if err1 != nil { - log.Printf("Encryption failed in runCommand: %v", err1) - return xconn.NewInvocationError("wamp.error.internal_error", err1.Error()) + ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305(output, key.Send) + if err != nil { + log.Printf("Encryption failed in runCommand: %v", err) + return xconn.NewInvocationError("wamp.error.internal_error", err.Error()) } - payload1 := make([]byte, len(nonce1)+len(ciphertext1)) - copy(payload1, nonce1) - copy(payload1[len(nonce1):], ciphertext1) - - return xconn.NewInvocationResult(payload1) + return xconn.NewInvocationResult(append(nonce, ciphertext...)) } } @@ -237,9 +227,7 @@ func handleFileUpload(e *wampshell.EncryptionManager) func(_ context.Context, fmt.Sprintf("file content must be []byte, got %s", err.Error())) } - e.Lock() - key, ok := e.Keys()[inv.Caller()] - e.Unlock() + key, ok := e.Key(inv.Caller()) if !ok { return xconn.NewInvocationError("wamp.error.unavailable", "no encryption key for caller") } @@ -273,9 +261,7 @@ func handleFileDownload(e *wampshell.EncryptionManager) func(_ context.Context, return xconn.NewInvocationError("wamp.error.invalid_argument", err.Error()) } - e.Lock() - key, ok := e.Keys()[inv.Caller()] - e.Unlock() + key, ok := e.Key(inv.Caller()) if !ok { return xconn.NewInvocationError("wamp.error.unavailable", "no encryption key for caller") } @@ -294,15 +280,6 @@ func handleFileDownload(e *wampshell.EncryptionManager) func(_ context.Context, } } -func registerProcedure(session *xconn.Session, procedure string, handler xconn.InvocationHandler) error { - response := session.Register(procedure, handler).Do() - if response.Err != nil { - return fmt.Errorf("failed to register procedure %q: %w", procedure, response.Err) - } - log.Printf("Procedure registered: %s", procedure) - return nil -} - func addRealm(router *xconn.Router, realm string) { if router.HasRealm(realm) { return @@ -332,6 +309,8 @@ func main() { authenticator := wampshell.NewAuthenticator(keyStore) router := xconn.NewRouter() + addRealm(router, defaultRealm) + for realm := range authenticator.Realms() { addRealm(router, realm) } @@ -394,9 +373,11 @@ func main() { } for _, proc := range procedures { - if err = registerProcedure(session, proc.name, proc.handler); err != nil { - log.Fatal(err) + registerResponse := session.Register(proc.name, proc.handler).Do() + if registerResponse.Err != nil { + log.Fatalln(registerResponse.Err) } + log.Printf("Procedure registered: %s", proc.name) } log.Printf("listening on rs://%s", address) diff --git a/encryption.go b/encryption.go index 4b34706..b108deb 100644 --- a/encryption.go +++ b/encryption.go @@ -115,3 +115,10 @@ func (e *EncryptionManager) TestEcho(_ context.Context, invocation *xconn.Invoca func (e *EncryptionManager) Keys() map[uint64]*KeyPair { return e.keys } + +func (e *EncryptionManager) Key(sessionID uint64) (*KeyPair, bool) { + e.Lock() + defer e.Unlock() + key, ok := e.keys[sessionID] + return key, ok +} diff --git a/helpers.go b/helpers.go index 1418ebc..5298141 100644 --- a/helpers.go +++ b/helpers.go @@ -7,6 +7,7 @@ import ( "path/filepath" "strings" + berncrypt "github.com/xconnio/berncrypt/go" "github.com/xconnio/wampproto-capnproto/go" "github.com/xconnio/xconn-go" ) @@ -55,3 +56,39 @@ func RunningInSnap() bool { err := cmd.Run() return err == nil } + +func ExchangeKeys(session *xconn.Session) (*KeyPair, error) { + publicKey, privateKey, err := berncrypt.CreateX25519KeyPair() + if err != nil { + return nil, err + } + + response := session.Call("wampshell.key.exchange").Arg(publicKey).Do() + if response.Err != nil { + return nil, response.Err + } + + publicKeyPeer, err := response.Args.Bytes(0) + if err != nil { + return nil, err + } + + sharedSecret, err := berncrypt.PerformKeyExchange(privateKey, publicKeyPeer) + if err != nil { + return nil, err + } + + receiveKey, err := berncrypt.DeriveKeyHKDF(sharedSecret, []byte("backendToFrontend")) + if err != nil { + return nil, err + } + + sendKey, err := berncrypt.DeriveKeyHKDF(sharedSecret, []byte("frontendToBackend")) + if err != nil { + return nil, err + } + return &KeyPair{ + Send: sendKey, + Receive: receiveKey, + }, nil +}