Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 63 additions & 102 deletions cmd/wcp/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"context"
"fmt"
"log"
"os"
"path/filepath"
"strings"
Expand All @@ -19,116 +20,85 @@ const (
maxSize = 1024 * 1024 * 15
)

type keyPair struct {
send []byte
receive []byte
}

func exchangeKeys(session *xconn.Session) (*keyPair, error) {
publicKey, privateKey, err := berncrypt.CreateX25519KeyPair()
func uploadFile(session *xconn.Session, keys *wampshell.KeyPair, localPath, remotePath string) error {
fileInfo, err := os.Stat(localPath)
if err != nil {
return nil, err
return fmt.Errorf("failed to stat local file: %w", err)
}

response := session.Call("wampshell.key.exchange").Arg(publicKey).Do()
if response.Err != nil {
return nil, response.Err
if fileInfo.Size() > maxSize {
return fmt.Errorf("file too large: %d bytes (max %d bytes)", fileInfo.Size(), maxSize)
}

publicKeyPeer, err := response.Args.Bytes(0)
data, err := os.ReadFile(localPath)
if err != nil {
return nil, err
return fmt.Errorf("failed to read local file: %w", err)
}

sharedSecret, err := berncrypt.PerformKeyExchange(privateKey, publicKeyPeer)
ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305(data, keys.Send)
if err != nil {
return nil, err
return err
}

receiveKey, err := berncrypt.DeriveKeyHKDF(sharedSecret, []byte("backendToFrontend"))
if err != nil {
return nil, err
encryptedPayload := append(nonce, ciphertext...)
callResponse := session.Call("wampshell.shell.upload").Args(remotePath, encryptedPayload).Do()
if callResponse.Err != nil {
return fmt.Errorf("file upload error: %w", callResponse.Err)
}

sendKey, err := berncrypt.DeriveKeyHKDF(sharedSecret, []byte("frontendToBackend"))
encResp, err := callResponse.Args.Bytes(0)
if err != nil {
return nil, err
return fmt.Errorf("parsing response failed: %w", err)
}
return &keyPair{
send: sendKey,
receive: receiveKey,
}, nil
}

func uploadFile(session *xconn.Session, keys *keyPair, localFile, remoteFile string) error {
file, err := os.Stat(localFile)
if err != nil {
return fmt.Errorf("failed to stat local file: %w", err)
}
if file.Size() > maxSize {
return fmt.Errorf("file too large: %d bytes (max %d bytes)", file.Size(), maxSize)
}
data, err := os.ReadFile(localFile)
plainResp, err := berncrypt.DecryptChaCha20Poly1305(encResp[12:], encResp[:12], keys.Receive)
if err != nil {
return fmt.Errorf("failed to read local file: %w", err)
return fmt.Errorf("response decryption failed: %w", err)
}
ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305(data, keys.send)
if err != nil {
return err
}
encryptedPayload := append(nonce, ciphertext...)
cmdResponse := session.Call("wampshell.shell.upload").Args(remoteFile, encryptedPayload).Do()
if cmdResponse.Err != nil {
return fmt.Errorf("file upload error: %w", cmdResponse.Err)
}
encResp, err := cmdResponse.Args.Bytes(0)
if err != nil {
return fmt.Errorf("output parsing error: %w", err)
}
resp, err := berncrypt.DecryptChaCha20Poly1305(encResp[12:], encResp[:12], keys.receive)
if err != nil {
return err
}
fmt.Printf("Server response: %s\n", string(resp))

log.Printf("Upload response: %s", string(plainResp))
return nil
}

func downloadFile(session *xconn.Session, keys *keyPair, remoteFile, localFile string) error {
cmdResponse := session.Call("wampshell.shell.download").Arg(remoteFile).Do()
if cmdResponse.Err != nil {
return fmt.Errorf("file download error: %w", cmdResponse.Err)
func downloadFile(session *xconn.Session, keys *wampshell.KeyPair, remotePath, localPath string) error {
callResponse := session.Call("wampshell.shell.download").Arg(remotePath).Do()
if callResponse.Err != nil {
return fmt.Errorf("file download error: %w", callResponse.Err)
}
encResp, err := cmdResponse.Args.Bytes(0)

encResp, err := callResponse.Args.Bytes(0)
if err != nil {
return fmt.Errorf("output parsing error: %w", err)
return fmt.Errorf("failed to parse response: %w", err)
}
data, err := berncrypt.DecryptChaCha20Poly1305(encResp[12:], encResp[:12], keys.receive)

data, err := berncrypt.DecryptChaCha20Poly1305(encResp[12:], encResp[:12], keys.Receive)
if err != nil {
return err
}
if err := os.WriteFile(localFile, data, 0600); err != nil {

if err := os.WriteFile(localPath, data, 0600); err != nil {
return fmt.Errorf("failed to save file: %w", err)
}
fmt.Printf("Downloaded %s -> %s (%d bytes)\n", remoteFile, localFile, len(data))

log.Printf("Downloaded %s → %s (%d bytes)", remotePath, localPath, len(data))
return nil
}

func splitRemote(s string) (user, host, port, path string, err error) {
func parseRemoteTarget(target string) (user, host, port, path string, err error) {
port = "8022"

if strings.Contains(s, "@") {
parts := strings.SplitN(s, "@", 2)
user, s = parts[0], parts[1]
if strings.Contains(target, "@") {
parts := strings.SplitN(target, "@", 2)
user, target = parts[0], parts[1]
}

parts := strings.SplitN(s, ":", 3)
parts := strings.SplitN(target, ":", 3)
switch len(parts) {
case 2:
host, path = parts[0], parts[1]
case 3:
host, port, path = parts[0], parts[1], parts[2]
default:
err = fmt.Errorf("invalid target: %s", s)
err = fmt.Errorf("invalid target: %s", target)
}

if user == "" {
Expand All @@ -148,44 +118,37 @@ func main() {
var opts Options
parser := flags.NewParser(&opts, flags.Default)

_, err := parser.Parse()
if err != nil {
os.Exit(1)
if _, err := parser.Parse(); err != nil {
log.Fatal(err)
}

src := opts.Args.Source
dst := opts.Args.Target

var mode string
var localFile, remoteFile string
var localPath, remotePath string
var user, host, port string

if strings.Contains(src, ":") && !strings.Contains(dst, ":") {
mode = "download"
user, host, port, remoteFile, _ = splitRemote(src)
localFile = dst
user, host, port, remotePath, _ = parseRemoteTarget(src)
localPath = dst
} else if !strings.Contains(src, ":") && strings.Contains(dst, ":") {
mode = "upload"
localFile = src
user, host, port, remoteFile, _ = splitRemote(dst)
localPath = src
user, host, port, remotePath, _ = parseRemoteTarget(dst)
} else {
fmt.Println("Invalid usage: one of source/target must be remote (user@host:path)")
os.Exit(1)
log.Fatal("Invalid usage: one of source/target must be remote (user@host:path)")
}

privateKey, err := wampshell.ReadPrivateKeyFromFile()
if err != nil {
fmt.Printf("Error reading private key: %v\n", err)
os.Exit(1)
log.Fatalf("Reading private key failed: %v", err)
}

authExtra := map[string]any{}
authExtra["user"] = user

authenticator, err := auth.NewCryptoSignAuthenticator("", privateKey, authExtra)
authenticator, err := auth.NewCryptoSignAuthenticator("", privateKey, map[string]any{"user": user})
if err != nil {
fmt.Printf("Error creating crypto sign authenticator: %v\n", err)
os.Exit(1)
log.Fatalf("Creating authenticator failed: %v", err)
}

client := xconn.Client{
Expand All @@ -196,32 +159,30 @@ func main() {
url := fmt.Sprintf("rs://%s:%s", host, port)
session, err := client.Connect(context.Background(), url, "wampshell")
if err != nil {
panic(err)
log.Fatalf("Connection failed: %v", err)
}

keys, err := exchangeKeys(session)
keys, err := wampshell.ExchangeKeys(session)
if err != nil {
panic(err)
log.Fatalf("Key exchange failed: %v", err)
}

switch mode {
case "upload":
if strings.HasSuffix(remoteFile, "/") {
remoteFile += filepath.Base(localFile)
} else if remoteFile == "" {
remoteFile = filepath.Base(localFile)
if strings.HasSuffix(remotePath, "/") {
remotePath += filepath.Base(localPath)
} else if remotePath == "" {
remotePath = filepath.Base(localPath)
}
if err := uploadFile(session, keys, localFile, remoteFile); err != nil {
fmt.Printf("Upload failed: %v\n", err)
os.Exit(1)
if err := uploadFile(session, keys, localPath, remotePath); err != nil {
log.Fatalf("Upload failed: %v", err)
}
case "download":
if fi, err := os.Stat(localFile); err == nil && fi.IsDir() {
localFile = filepath.Join(localFile, filepath.Base(remoteFile))
if fi, err := os.Stat(localPath); err == nil && fi.IsDir() {
localPath = filepath.Join(localPath, filepath.Base(remotePath))
}
if err := downloadFile(session, keys, remoteFile, localFile); err != nil {
fmt.Printf("Download failed: %v\n", err)
os.Exit(1)
if err := downloadFile(session, keys, remotePath, localPath); err != nil {
log.Fatalf("Download failed: %v", err)
}
}
}
Loading
Loading