From 6e638f3c1d2e96c209cc386de16a2a16b4018082 Mon Sep 17 00:00:00 2001 From: Tom Barlow <60068+tombee@users.noreply.github.com> Date: Thu, 8 Jan 2026 13:27:49 +0000 Subject: [PATCH] Add tool result streaming for long-running operations Implements streaming support for tool execution, enabling real-time output visibility for long-running operations like shell commands. - Add StreamingTool interface with ExecuteStream method returning ToolChunk channel - Add EventToolOutput event type for streaming output events - Implement sensitive data redaction (AWS keys, tokens, passwords, URLs) - Add shell tool streaming with line-buffered output and 4KB binary fallback - Update SDK tool adapter to support streaming - Add agent integration with ToolOutputChunks field in step context - Add comprehensive unit and integration tests --- pkg/agent/agent.go | 100 +++- pkg/agent/agent_test.go | 190 ++++++ pkg/tools/builtin/shell.go | 341 +++++++++++ pkg/tools/builtin/shell_integration_test.go | 578 ++++++++++++++++++ pkg/tools/builtin/shell_test.go | 613 ++++++++++++++++++++ pkg/tools/redact.go | 92 +++ pkg/tools/redact_test.go | 456 +++++++++++++++ pkg/tools/registry.go | 215 ++++++- pkg/tools/registry_test.go | 506 ++++++++++++++++ sdk/events.go | 12 + sdk/tool.go | 132 +++++ 11 files changed, 3224 insertions(+), 11 deletions(-) create mode 100644 pkg/tools/builtin/shell_integration_test.go create mode 100644 pkg/tools/redact.go create mode 100644 pkg/tools/redact_test.go diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index c944ebb4..dec55cb1 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -171,6 +171,37 @@ type ToolExecution struct { // DurationMs is the duration in milliseconds (for spec compliance) DurationMs int + + // OutputChunks contains streaming output chunks from the tool execution + OutputChunks []ToolOutputChunk +} + +// ToolOutputChunk represents a streaming output chunk from a tool execution. +type ToolOutputChunk struct { + // ToolCallID links to the tool call + ToolCallID string + + // ToolName is the name of the tool + ToolName string + + // Stream identifies the output stream ("stdout", "stderr", or "") + Stream string + + // Data is the chunk content + Data string + + // IsFinal indicates this is the last chunk + IsFinal bool + + // Metadata contains optional metadata + Metadata map[string]interface{} +} + +// StepContext contains contextual information available during agent execution steps. +// This context accumulates data across iterations and is available for reasoning. +type StepContext struct { + // ToolOutputChunks contains all streaming output chunks from tool executions + ToolOutputChunks []ToolOutputChunk } // NewAgent creates a new agent. @@ -216,6 +247,11 @@ func (a *Agent) Run(ctx context.Context, systemPrompt string, userPrompt string) ToolExecutions: []ToolExecution{}, } + // Initialize step context + stepContext := &StepContext{ + ToolOutputChunks: []ToolOutputChunk{}, + } + // Initialize conversation with system and user messages messages := []Message{ {Role: "system", Content: systemPrompt}, @@ -274,7 +310,7 @@ func (a *Agent) Run(ctx context.Context, systemPrompt string, userPrompt string) // Execute tool calls if any if len(response.ToolCalls) > 0 { for _, toolCall := range response.ToolCalls { - execution := a.executeTool(ctx, toolCall) + execution := a.executeTool(ctx, toolCall, stepContext) result.ToolExecutions = append(result.ToolExecutions, execution) // Add tool result to conversation @@ -313,11 +349,12 @@ func (a *Agent) Run(ctx context.Context, systemPrompt string, userPrompt string) return result, fmt.Errorf("max iterations reached") } -// executeTool executes a single tool call. -func (a *Agent) executeTool(ctx context.Context, toolCall ToolCall) ToolExecution { +// executeTool executes a single tool call using streaming execution. +func (a *Agent) executeTool(ctx context.Context, toolCall ToolCall, stepContext *StepContext) ToolExecution { startTime := time.Now() execution := ToolExecution{ - ToolName: toolCall.Name, + ToolName: toolCall.Name, + OutputChunks: []ToolOutputChunk{}, } // Parse arguments @@ -340,15 +377,62 @@ func (a *Agent) executeTool(ctx context.Context, toolCall ToolCall) ToolExecutio execution.Inputs = inputs - // Execute tool - outputs, err := a.registry.Execute(ctx, toolCall.Name, inputs) + // Execute tool with streaming support + chunks, err := a.registry.ExecuteStream(ctx, toolCall.Name, inputs, toolCall.ID) + if err != nil { + execution.Success = false + execution.Status = "error" + execution.Error = err.Error() + execution.Duration = time.Since(startTime) + execution.DurationMs = int(execution.Duration.Milliseconds()) + return execution + } + + // Process streaming chunks + var outputs map[string]interface{} + var execError error + + for chunk := range chunks { + // Create output chunk for this execution + outputChunk := ToolOutputChunk{ + ToolCallID: toolCall.ID, + ToolName: toolCall.Name, + Stream: chunk.Stream, + Data: chunk.Data, + IsFinal: chunk.IsFinal, + Metadata: chunk.Metadata, + } + + // Store chunk in execution and step context + execution.OutputChunks = append(execution.OutputChunks, outputChunk) + stepContext.ToolOutputChunks = append(stepContext.ToolOutputChunks, outputChunk) + + // Emit event via callback if configured + if a.eventCallback != nil { + a.eventCallback("tool.output", map[string]interface{}{ + "tool_call_id": toolCall.ID, + "tool_name": toolCall.Name, + "stream": chunk.Stream, + "data": chunk.Data, + "is_final": chunk.IsFinal, + "metadata": chunk.Metadata, + }) + } + + // Extract final result + if chunk.IsFinal { + outputs = chunk.Result + execError = chunk.Error + } + } + execution.Duration = time.Since(startTime) execution.DurationMs = int(execution.Duration.Milliseconds()) - if err != nil { + if execError != nil { execution.Success = false execution.Status = "error" - execution.Error = err.Error() + execution.Error = execError.Error() return execution } diff --git a/pkg/agent/agent_test.go b/pkg/agent/agent_test.go index f0c05da2..090b6d55 100644 --- a/pkg/agent/agent_test.go +++ b/pkg/agent/agent_test.go @@ -560,3 +560,193 @@ func (m *mockStreamingLLMProvider) Stream(ctx context.Context, messages []Messag close(ch) return ch, nil } + +// mockStreamingTool implements StreamingTool for testing +type mockStreamingTool struct { + name string + chunks []tools.ToolChunk +} + +func (m *mockStreamingTool) Name() string { + return m.name +} + +func (m *mockStreamingTool) Description() string { + return "A mock streaming tool" +} + +func (m *mockStreamingTool) Schema() *tools.Schema { + return &tools.Schema{ + Inputs: &tools.ParameterSchema{ + Type: "object", + }, + Outputs: &tools.ParameterSchema{ + Type: "object", + }, + } +} + +func (m *mockStreamingTool) Execute(ctx context.Context, inputs map[string]interface{}) (map[string]interface{}, error) { + // For non-streaming execution, collect all chunks and return the final result + ch, err := m.ExecuteStream(ctx, inputs) + if err != nil { + return nil, err + } + + var result map[string]interface{} + var execError error + for chunk := range ch { + if chunk.IsFinal { + result = chunk.Result + execError = chunk.Error + } + } + + if execError != nil { + return nil, execError + } + return result, nil +} + +func (m *mockStreamingTool) ExecuteStream(ctx context.Context, inputs map[string]interface{}) (<-chan tools.ToolChunk, error) { + ch := make(chan tools.ToolChunk, len(m.chunks)) + go func() { + defer close(ch) + for _, chunk := range m.chunks { + ch <- chunk + } + }() + return ch, nil +} + +func TestAgent_ToolStreamingExecution(t *testing.T) { + // Create a streaming tool that emits chunks + streamingTool := &mockStreamingTool{ + name: "streaming-tool", + chunks: []tools.ToolChunk{ + { + Data: "Line 1\n", + Stream: "stdout", + }, + { + Data: "Line 2\n", + Stream: "stdout", + }, + { + Data: "Error message\n", + Stream: "stderr", + }, + { + IsFinal: true, + Result: map[string]interface{}{ + "exit_code": 0, + "duration": 100, + }, + }, + }, + } + + registry := tools.NewRegistry() + if err := registry.Register(streamingTool); err != nil { + t.Fatalf("Failed to register tool: %v", err) + } + + llm := &mockLLMProvider{ + responses: []Response{ + { + Content: "Using streaming tool", + FinishReason: "tool_calls", + ToolCalls: []ToolCall{ + { + ID: "call-1", + Name: "streaming-tool", + Arguments: map[string]interface{}{}, + }, + }, + Usage: TokenUsage{TotalTokens: 10}, + }, + { + Content: "Completed with streaming output", + FinishReason: "stop", + Usage: TokenUsage{TotalTokens: 10}, + }, + }, + } + + // Track events emitted via callback + var capturedEvents []map[string]interface{} + agent := NewAgent(llm, registry).WithEventCallback(func(eventType string, data interface{}) { + if eventType == "tool.output" { + if eventData, ok := data.(map[string]interface{}); ok { + capturedEvents = append(capturedEvents, eventData) + } + } + }) + + ctx := context.Background() + result, err := agent.Run(ctx, "System", "Task") + if err != nil { + t.Fatalf("Run() error = %v", err) + } + + // Verify tool execution + if len(result.ToolExecutions) != 1 { + t.Fatalf("ToolExecutions count = %d, want 1", len(result.ToolExecutions)) + } + + execution := result.ToolExecutions[0] + + // Verify output chunks are captured in execution + if len(execution.OutputChunks) != 4 { + t.Errorf("OutputChunks count = %d, want 4", len(execution.OutputChunks)) + } + + // Verify chunk content + if execution.OutputChunks[0].Data != "Line 1\n" { + t.Errorf("Chunk 0 data = %q, want %q", execution.OutputChunks[0].Data, "Line 1\n") + } + if execution.OutputChunks[0].Stream != "stdout" { + t.Errorf("Chunk 0 stream = %q, want %q", execution.OutputChunks[0].Stream, "stdout") + } + + if execution.OutputChunks[2].Stream != "stderr" { + t.Errorf("Chunk 2 stream = %q, want %q", execution.OutputChunks[2].Stream, "stderr") + } + + // Verify final chunk + if !execution.OutputChunks[3].IsFinal { + t.Error("Last chunk should have IsFinal=true") + } + + // Verify events were emitted + if len(capturedEvents) != 4 { + t.Errorf("Captured %d events, want 4", len(capturedEvents)) + } + + // Verify event structure + if len(capturedEvents) > 0 { + firstEvent := capturedEvents[0] + if firstEvent["tool_name"] != "streaming-tool" { + t.Errorf("Event tool_name = %q, want %q", firstEvent["tool_name"], "streaming-tool") + } + if firstEvent["tool_call_id"] != "call-1" { + t.Errorf("Event tool_call_id = %q, want %q", firstEvent["tool_call_id"], "call-1") + } + if firstEvent["data"] != "Line 1\n" { + t.Errorf("Event data = %q, want %q", firstEvent["data"], "Line 1\n") + } + } + + // Verify execution succeeded + if !execution.Success { + t.Error("Tool execution should have succeeded") + } + + // Verify final result + if execution.Outputs == nil { + t.Error("Execution outputs should not be nil") + } + if exitCode, ok := execution.Outputs["exit_code"].(int); !ok || exitCode != 0 { + t.Errorf("Exit code = %v, want 0", execution.Outputs["exit_code"]) + } +} diff --git a/pkg/tools/builtin/shell.go b/pkg/tools/builtin/shell.go index 384e559d..e4dc2c7d 100644 --- a/pkg/tools/builtin/shell.go +++ b/pkg/tools/builtin/shell.go @@ -1,13 +1,16 @@ package builtin import ( + "bytes" "context" "fmt" + "io" "log/slog" "os" "os/exec" "path/filepath" "strings" + "sync" "syscall" "time" @@ -30,6 +33,9 @@ type ShellTool struct { // securityConfig provides enhanced security controls securityConfig *security.ShellSecurityConfig + + // redactor redacts sensitive data from output + redactor *tools.Redactor } // NewShellTool creates a new shell tool with default settings. @@ -39,6 +45,7 @@ func NewShellTool() *ShellTool { workingDir: "", // Current directory allowedCommands: []string{}, // No restrictions by default securityConfig: security.DefaultShellSecurityConfig(), // Secure defaults + redactor: tools.NewRedactor(), // Default redaction patterns } } @@ -285,6 +292,340 @@ func (t *ShellTool) Execute(ctx context.Context, inputs map[string]interface{}) }, nil } +// ExecuteStream runs the shell command and streams output line-by-line. +// It implements the StreamingTool interface for real-time output visibility. +func (t *ShellTool) ExecuteStream(ctx context.Context, inputs map[string]interface{}) (<-chan tools.ToolChunk, error) { + // Extract and validate command and args + command, ok := inputs["command"].(string) + if !ok { + return nil, &errors.ValidationError{ + Field: "command", + Message: "command must be a string", + Suggestion: "Provide the command as a string", + } + } + + // Extract arguments (optional) + var args []string + if argsRaw, ok := inputs["args"]; ok { + argsSlice, ok := argsRaw.([]interface{}) + if !ok { + return nil, &errors.ValidationError{ + Field: "args", + Message: "args must be an array", + Suggestion: "Provide arguments as an array of strings", + } + } + args = make([]string, len(argsSlice)) + for i, arg := range argsSlice { + argStr, ok := arg.(string) + if !ok { + return nil, &errors.ValidationError{ + Field: fmt.Sprintf("args[%d]", i), + Message: "all args must be strings", + Suggestion: "Ensure all arguments are strings", + } + } + args[i] = argStr + } + } + + // Validate command + if err := t.validateCommand(command); err != nil { + return nil, fmt.Errorf("command validation failed: %w", err) + } + + // Validate with security config + if t.securityConfig != nil { + if err := t.securityConfig.ValidateCommand(command, args); err != nil { + return nil, fmt.Errorf("security validation failed for command %s: %w", command, err) + } + } + + // Create bounded channel for streaming chunks + chunks := make(chan tools.ToolChunk, 256) + + // Start streaming execution in a goroutine + go func() { + defer close(chunks) + + // Track execution time + startTime := time.Now() + + // Set timeout + execCtx, cancel := context.WithTimeout(ctx, t.timeout) + defer cancel() + + // Create command + cmd := exec.CommandContext(execCtx, command, args...) + if t.workingDir != "" { + cmd.Dir = t.workingDir + } + + // Sanitize environment if configured + if t.securityConfig != nil && t.securityConfig.SanitizeEnv { + cmd.Env = t.securityConfig.SanitizeEnvironment(os.Environ()) + } + + // Create pipes for stdout and stderr + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + chunks <- tools.ToolChunk{ + IsFinal: true, + Error: fmt.Errorf("failed to create stdout pipe: %w", err), + } + return + } + + stderrPipe, err := cmd.StderrPipe() + if err != nil { + chunks <- tools.ToolChunk{ + IsFinal: true, + Error: fmt.Errorf("failed to create stderr pipe: %w", err), + } + return + } + + // Start command + if err := cmd.Start(); err != nil { + duration := time.Since(startTime).Milliseconds() + chunks <- tools.ToolChunk{ + IsFinal: true, + Result: map[string]interface{}{ + "success": false, + "stdout": "", + "stderr": "", + "exit_code": -1, + "status": "error", + "duration": duration, + }, + } + return + } + + // Use WaitGroup to coordinate stdout/stderr goroutines + var wg sync.WaitGroup + wg.Add(2) + + // Track total output size across both streams + var totalSize int64 + var sizeMutex sync.Mutex + var truncated bool + + // Stream stdout with enhanced streamPipe helper + go t.streamPipe(execCtx, chunks, stdoutPipe, "stdout", &wg, &totalSize, &sizeMutex, &truncated) + + // Stream stderr with enhanced streamPipe helper + go t.streamPipe(execCtx, chunks, stderrPipe, "stderr", &wg, &totalSize, &sizeMutex, &truncated) + + // Wait for command to complete + cmdErr := cmd.Wait() + + // Wait for all output to be streamed + wg.Wait() + + // Calculate duration + duration := time.Since(startTime).Milliseconds() + + // Determine status and exit code + var exitCode int + var status string + + if ctx.Err() == context.DeadlineExceeded || execCtx.Err() == context.DeadlineExceeded { + // Timeout occurred + status = "timeout" + exitCode = -1 + + // Try to kill the process + if cmd.Process != nil { + // Send SIGTERM first + cmd.Process.Signal(syscall.SIGTERM) + + // Wait 2 seconds for graceful shutdown + time.Sleep(2 * time.Second) + + // If still running, SIGKILL + if cmd.ProcessState == nil || !cmd.ProcessState.Exited() { + cmd.Process.Kill() + } + } + } else if cmdErr != nil { + // Command failed + status = "completed" + if exitErr, ok := cmdErr.(*exec.ExitError); ok { + exitCode = exitErr.ExitCode() + } else { + exitCode = -1 + } + } else { + // Command succeeded + status = "completed" + exitCode = 0 + } + + // Send final chunk with result + finalChunk := tools.ToolChunk{ + IsFinal: true, + Result: map[string]interface{}{ + "success": exitCode == 0, + "exit_code": exitCode, + "status": status, + "duration": duration, + }, + } + + // Add truncation metadata and error if output was truncated + if truncated { + finalChunk.Metadata = map[string]interface{}{ + "truncated": true, + } + finalChunk.Error = fmt.Errorf("output truncated: exceeded size limit of %d bytes", t.securityConfig.MaxOutputSize) + } + + chunks <- finalChunk + }() + + return chunks, nil +} + +// emitChunkWithSizeCheck checks size limits before emitting a chunk. +// Returns true if the chunk was emitted, false if truncated. +func (t *ShellTool) emitChunkWithSizeCheck(chunks chan<- tools.ToolChunk, data, stream string, totalSize *int64, sizeMutex *sync.Mutex, truncated *bool) bool { + // Check if output limit is configured + if t.securityConfig == nil || t.securityConfig.MaxOutputSize <= 0 { + // No limit configured, emit normally + chunks <- tools.ToolChunk{ + Data: data, + Stream: stream, + } + return true + } + + // Check size under lock + sizeMutex.Lock() + defer sizeMutex.Unlock() + + // Check if already truncated + if *truncated { + return false + } + + dataSize := int64(len(data)) + newTotal := *totalSize + dataSize + + // Check if this would exceed the limit + if newTotal > t.securityConfig.MaxOutputSize { + // Mark as truncated + *truncated = true + return false + } + + // Update size and emit chunk + *totalSize = newTotal + chunks <- tools.ToolChunk{ + Data: data, + Stream: stream, + } + return true +} + +// streamPipe reads from a pipe and emits chunks with line-buffering and redaction. +// It handles edge cases including: +// - Line-buffered output (emits on \n) +// - 4KB fallback for binary data without newlines +// - Panic recovery with error reporting +// - Context cancellation for cleanup +// - Redaction of sensitive data before emission +// - Output size limit enforcement with truncation +func (t *ShellTool) streamPipe(ctx context.Context, chunks chan<- tools.ToolChunk, pipe io.ReadCloser, stream string, wg *sync.WaitGroup, totalSize *int64, sizeMutex *sync.Mutex, truncated *bool) { + defer wg.Done() + + // Recover from panics and report as error + defer func() { + if r := recover(); r != nil { + chunks <- tools.ToolChunk{ + Stream: stream, + Data: "", + Error: fmt.Errorf("panic in %s stream: %v", stream, r), + } + } + }() + + // Ensure pipe is closed on exit + defer pipe.Close() + + const maxChunkSize = 4 * 1024 // 4KB fallback for binary data + buf := make([]byte, maxChunkSize) + var pending []byte + + for { + // Check for context cancellation + select { + case <-ctx.Done(): + // Context cancelled, emit any pending data and exit + if len(pending) > 0 { + data := string(pending) + redacted := t.redactor.Redact(data) + t.emitChunkWithSizeCheck(chunks, redacted, stream, totalSize, sizeMutex, truncated) + } + return + default: + } + + // Read from pipe + n, err := pipe.Read(buf) + if n > 0 { + pending = append(pending, buf[:n]...) + + // Process complete lines (line-buffered output) + for { + idx := bytes.IndexByte(pending, '\n') + if idx == -1 { + break + } + + // Emit line with redaction and size check + line := string(pending[:idx]) + redacted := t.redactor.Redact(line) + if !t.emitChunkWithSizeCheck(chunks, redacted, stream, totalSize, sizeMutex, truncated) { + // Size limit exceeded, stop processing + return + } + pending = pending[idx+1:] + } + + // If buffer exceeds 4KB without newline (binary data), emit it + if len(pending) >= maxChunkSize { + data := string(pending) + redacted := t.redactor.Redact(data) + if !t.emitChunkWithSizeCheck(chunks, redacted, stream, totalSize, sizeMutex, truncated) { + // Size limit exceeded, stop processing + return + } + pending = pending[:0] + } + } + + if err != nil { + // Report non-EOF errors + if err != io.EOF { + chunks <- tools.ToolChunk{ + Stream: stream, + Error: fmt.Errorf("error reading %s: %w", stream, err), + } + } + + // Emit any remaining data (partial line at EOF) + if len(pending) > 0 { + data := string(pending) + redacted := t.redactor.Redact(data) + t.emitChunkWithSizeCheck(chunks, redacted, stream, totalSize, sizeMutex, truncated) + } + return + } + } +} + // validateCommand checks if a command is allowed. func (t *ShellTool) validateCommand(command string) error { // Empty list means all commands allowed (preserve existing behavior) diff --git a/pkg/tools/builtin/shell_integration_test.go b/pkg/tools/builtin/shell_integration_test.go new file mode 100644 index 00000000..7a7ea67f --- /dev/null +++ b/pkg/tools/builtin/shell_integration_test.go @@ -0,0 +1,578 @@ +package builtin + +import ( + "context" + "runtime" + "strings" + "testing" + "time" + + "github.com/tombee/conductor/pkg/security" +) + +// TestShellTool_Integration_MultiSecondStreaming tests that commands producing output +// over multiple seconds stream chunks in real-time rather than buffering everything. +func TestShellTool_Integration_MultiSecondStreaming(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping multi-second streaming test on Windows") + } + + secConfig := security.DefaultShellSecurityConfig() + secConfig.AllowShellExpand = true + + tool := NewShellTool().WithSecurityConfig(secConfig).WithTimeout(10 * time.Second) + ctx := context.Background() + + // Script that outputs one line per second for 3 seconds + chunks, err := tool.ExecuteStream(ctx, map[string]interface{}{ + "command": "sh", + "args": []interface{}{"-c", ` + echo "Line 1" + sleep 1 + echo "Line 2" + sleep 1 + echo "Line 3" + `}, + }) + if err != nil { + t.Fatalf("ExecuteStream() error = %v", err) + } + + // Track when chunks arrive + type chunkEvent struct { + timestamp time.Time + data string + isFinal bool + } + var events []chunkEvent + startTime := time.Now() + + for chunk := range chunks { + events = append(events, chunkEvent{ + timestamp: time.Now(), + data: chunk.Data, + isFinal: chunk.IsFinal, + }) + } + + // Verify we got chunks during execution, not all at the end + if len(events) < 4 { // At least 3 data chunks + 1 final chunk + t.Errorf("Expected at least 4 events (3 lines + final), got %d", len(events)) + } + + // Verify real-time streaming: chunks should arrive with delays between them + // Check that the time between first and last data chunk is at least 1.5 seconds + // (accounting for 2 sleep commands of 1 second each) + if len(events) >= 3 { + firstDataIdx := 0 + lastDataIdx := 0 + for i, e := range events { + if !e.isFinal && e.data != "" { + if firstDataIdx == 0 { + firstDataIdx = i + } + lastDataIdx = i + } + } + + if firstDataIdx != lastDataIdx { + duration := events[lastDataIdx].timestamp.Sub(events[firstDataIdx].timestamp) + if duration < 800*time.Millisecond { + t.Errorf("Expected chunks to arrive over at least 0.8s (real-time streaming), got %v", duration) + } + } + } + + // Verify all expected output was received + var outputLines []string + for _, e := range events { + if !e.isFinal && e.data != "" { + outputLines = append(outputLines, e.data) + } + } + + output := strings.Join(outputLines, "\n") + for i := 1; i <= 3; i++ { + expected := "Line " + string(rune('0'+i)) + if !strings.Contains(output, expected) { + t.Errorf("Output missing %q, got: %s", expected, output) + } + } + + // Verify total execution time was reasonable (at least 2 seconds for the sleeps) + totalDuration := time.Since(startTime) + if totalDuration < 2*time.Second { + t.Errorf("Total execution time %v too short for a 2-second command", totalDuration) + } +} + +// TestShellTool_Integration_FirstChunkLatency tests that the first chunk is emitted +// within 100ms of output being available (best-effort). +func TestShellTool_Integration_FirstChunkLatency(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping latency test on Windows") + } + + tool := NewShellTool() + ctx := context.Background() + + // Command that produces output immediately + startTime := time.Now() + + chunks, err := tool.ExecuteStream(ctx, map[string]interface{}{ + "command": "echo", + "args": []interface{}{"immediate output"}, + }) + if err != nil { + t.Fatalf("ExecuteStream() error = %v", err) + } + + // Measure time to first chunk + var firstChunkTime time.Time + var gotFirstChunk bool + + for chunk := range chunks { + if !gotFirstChunk && !chunk.IsFinal && chunk.Data != "" { + firstChunkTime = time.Now() + gotFirstChunk = true + } + } + + if !gotFirstChunk { + t.Fatal("No data chunk received") + } + + latency := firstChunkTime.Sub(startTime) + + // Best-effort target: first chunk within 100ms + // We allow up to 200ms for CI environments which may be slower + if latency > 200*time.Millisecond { + t.Logf("WARNING: First chunk latency %v exceeds 100ms target (CI tolerance: 200ms)", latency) + } + + t.Logf("First chunk latency: %v", latency) +} + +// TestShellTool_Integration_MemoryBounded tests that memory usage remains bounded +// regardless of total output size, verifying that streaming doesn't buffer everything. +func TestShellTool_Integration_MemoryBounded(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping memory test on Windows") + } + + secConfig := security.DefaultShellSecurityConfig() + secConfig.AllowShellExpand = true + secConfig.MaxOutputSize = 0 // Disable output limits for this test + + tool := NewShellTool().WithSecurityConfig(secConfig).WithTimeout(30 * time.Second) + ctx := context.Background() + + // Record memory before starting + runtime.GC() + var memBefore runtime.MemStats + runtime.ReadMemStats(&memBefore) + + // Command that produces ~10MB of output over time + // We generate 1000 lines of 10KB each = ~10MB total + chunks, err := tool.ExecuteStream(ctx, map[string]interface{}{ + "command": "sh", + "args": []interface{}{"-c", ` + for i in $(seq 1 1000); do + printf '%10240s\n' | tr ' ' 'x' + done + `}, + }) + if err != nil { + t.Fatalf("ExecuteStream() error = %v", err) + } + + // Process chunks but don't accumulate all output (simulating real streaming consumer) + chunkCount := 0 + + for chunk := range chunks { + if !chunk.IsFinal { + chunkCount++ + _ = len(chunk.Data) // Process chunk but don't store + } + + // Periodically check memory during streaming + if chunkCount%100 == 0 { + runtime.GC() + var memDuring runtime.MemStats + runtime.ReadMemStats(&memDuring) + + // Memory growth should be bounded (less than 5MB during processing) + // The actual output is 10MB but streaming should never buffer it all + // Note: memory can decrease due to GC, so only check if it increased + if memDuring.Alloc > memBefore.Alloc { + memGrowth := memDuring.Alloc - memBefore.Alloc + if memGrowth > 5*1024*1024 { + t.Errorf("Memory growth %d bytes exceeds 5MB during streaming", memGrowth) + } + } + } + } + + if chunkCount == 0 { + t.Fatal("No chunks received") + } + + // Verify we got a reasonable number of chunks (streaming, not single buffer dump) + if chunkCount < 100 { + t.Errorf("Expected many chunks for 10MB output, got only %d", chunkCount) + } + + t.Logf("Received %d chunks for ~10MB output (avg chunk size: ~%d bytes)", + chunkCount, 10*1024*1024/chunkCount) +} + +// TestShellTool_Integration_RedactionWorks tests that sensitive patterns in output +// are properly redacted during streaming. +func TestShellTool_Integration_RedactionWorks(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping redaction integration test on Windows") + } + + secConfig := security.DefaultShellSecurityConfig() + secConfig.AllowShellExpand = true + + tool := NewShellTool().WithSecurityConfig(secConfig) + ctx := context.Background() + + // Command that outputs multiple sensitive values over time + sensitiveScript := ` + echo "Starting deployment..." + sleep 0.1 + echo "AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE" + sleep 0.1 + echo "AWS_SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" + sleep 0.1 + echo "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0" + sleep 0.1 + echo "DATABASE_URL=postgresql://user:secretpass123@localhost:5432/mydb" + sleep 0.1 + echo "API_KEY=test_apikey_abcdef1234567890ghijklmnop" + sleep 0.1 + echo "Deployment complete!" + ` + + chunks, err := tool.ExecuteStream(ctx, map[string]interface{}{ + "command": "sh", + "args": []interface{}{"-c", sensitiveScript}, + }) + if err != nil { + t.Fatalf("ExecuteStream() error = %v", err) + } + + // Collect all output + var allOutput []string + for chunk := range chunks { + if !chunk.IsFinal && chunk.Stream == "stdout" { + allOutput = append(allOutput, chunk.Data) + } + } + + output := strings.Join(allOutput, "\n") + + // Verify redaction occurred + tests := []struct { + name string + shouldNotContain string + shouldContain string + }{ + { + name: "AWS access key redacted", + shouldNotContain: "AKIAIOSFODNN7EXAMPLE", + shouldContain: "[REDACTED]", + }, + { + name: "AWS secret key redacted", + shouldNotContain: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + shouldContain: "[REDACTED]", + }, + { + name: "Bearer token redacted", + shouldNotContain: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9", + shouldContain: "[REDACTED]", + }, + { + name: "Password in URL redacted", + shouldNotContain: "secretpass123", + shouldContain: "[REDACTED]", + }, + { + name: "API key redacted", + shouldNotContain: "test_apikey_abcdef1234567890ghijklmnop", + shouldContain: "[REDACTED]", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if strings.Contains(output, tt.shouldNotContain) { + t.Errorf("Output contains sensitive data %q, it should be redacted", tt.shouldNotContain) + } + if !strings.Contains(output, tt.shouldContain) { + t.Errorf("Output should contain %q indicating redaction occurred", tt.shouldContain) + } + }) + } + + // Verify non-sensitive content is preserved + if !strings.Contains(output, "Starting deployment") { + t.Error("Non-sensitive content should be preserved") + } + if !strings.Contains(output, "Deployment complete") { + t.Error("Non-sensitive content should be preserved") + } +} + +// TestShellTool_Integration_TruncationWorks tests that output exceeding size limits +// is properly truncated with an appropriate indicator. +func TestShellTool_Integration_TruncationWorks(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping truncation integration test on Windows") + } + + secConfig := security.DefaultShellSecurityConfig() + secConfig.AllowShellExpand = true + secConfig.MaxOutputSize = 1024 // 1KB limit + + tool := NewShellTool().WithSecurityConfig(secConfig).WithTimeout(10 * time.Second) + ctx := context.Background() + + // Command that produces much more than 1KB of output + chunks, err := tool.ExecuteStream(ctx, map[string]interface{}{ + "command": "sh", + "args": []interface{}{"-c", ` + for i in $(seq 1 100); do + echo "This is line $i with some extra padding to make it longer" + done + `}, + }) + if err != nil { + t.Fatalf("ExecuteStream() error = %v", err) + } + + // Collect chunks + var dataChunks []string + var finalChunk map[string]interface{} + var finalMetadata map[string]interface{} + var finalError error + + for chunk := range chunks { + if chunk.IsFinal { + finalChunk = chunk.Result + finalMetadata = chunk.Metadata + finalError = chunk.Error + } else if chunk.Stream == "stdout" { + dataChunks = append(dataChunks, chunk.Data) + } + } + + // Verify final chunk received + if finalChunk == nil { + t.Fatal("No final chunk received") + } + + // Verify truncation metadata + if finalMetadata == nil { + t.Fatal("Final chunk should have metadata indicating truncation") + } + + truncated, ok := finalMetadata["truncated"].(bool) + if !ok || !truncated { + t.Errorf("Final chunk metadata should indicate truncation, got: %v", finalMetadata) + } + + // Verify truncation error + if finalError == nil { + t.Error("Final chunk should have error indicating output was truncated") + } else { + errMsg := finalError.Error() + if !strings.Contains(errMsg, "truncated") && !strings.Contains(errMsg, "exceeded") { + t.Errorf("Error should mention truncation, got: %v", errMsg) + } + } + + // Verify total output doesn't exceed limit + totalSize := 0 + for _, chunk := range dataChunks { + totalSize += len(chunk) + } + + if totalSize > int(secConfig.MaxOutputSize) { + t.Errorf("Total output %d bytes exceeds limit %d bytes", totalSize, secConfig.MaxOutputSize) + } + + // Verify we got some output (not empty due to truncation) + if len(dataChunks) == 0 { + t.Error("Should receive some data chunks before truncation") + } + + t.Logf("Received %d bytes of output (limit: %d bytes), truncation correctly enforced", + totalSize, secConfig.MaxOutputSize) +} + +// TestShellTool_Integration_NonStreamingCompatibility tests that non-streaming tools +// called via ExecuteStream produce identical output to Execute(). +func TestShellTool_Integration_NonStreamingCompatibility(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping compatibility test on Windows") + } + + secConfig := security.DefaultShellSecurityConfig() + secConfig.AllowShellExpand = true + + tool := NewShellTool().WithSecurityConfig(secConfig) + ctx := context.Background() + + testScript := "echo 'Line 1'; echo 'Line 2' >&2; exit 42" + + // Execute via non-streaming Execute() + execResult, execErr := tool.Execute(ctx, map[string]interface{}{ + "command": "sh", + "args": []interface{}{"-c", testScript}, + }) + + if execErr != nil { + t.Fatalf("Execute() error = %v", execErr) + } + + // Execute via streaming ExecuteStream() + chunks, streamErr := tool.ExecuteStream(ctx, map[string]interface{}{ + "command": "sh", + "args": []interface{}{"-c", testScript}, + }) + + if streamErr != nil { + t.Fatalf("ExecuteStream() error = %v", streamErr) + } + + // Collect streaming output + var stdoutChunks []string + var stderrChunks []string + var streamResult map[string]interface{} + + for chunk := range chunks { + if chunk.IsFinal { + streamResult = chunk.Result + } else { + switch chunk.Stream { + case "stdout": + stdoutChunks = append(stdoutChunks, chunk.Data) + case "stderr": + stderrChunks = append(stderrChunks, chunk.Data) + } + } + } + + if streamResult == nil { + t.Fatal("No final result from ExecuteStream") + } + + // Compare exit codes + execExitCode := execResult["exit_code"].(int) + streamExitCode := streamResult["exit_code"].(int) + + if execExitCode != streamExitCode { + t.Errorf("Exit codes differ: Execute=%d, ExecuteStream=%d", execExitCode, streamExitCode) + } + + if execExitCode != 42 { + t.Errorf("Expected exit code 42, got %d", execExitCode) + } + + // Compare success flags + execSuccess := execResult["success"].(bool) + streamSuccess := streamResult["success"].(bool) + + if execSuccess != streamSuccess { + t.Errorf("Success flags differ: Execute=%v, ExecuteStream=%v", execSuccess, streamSuccess) + } + + // Compare stdout (allowing for potential whitespace differences in streaming) + execStdout := strings.TrimSpace(execResult["stdout"].(string)) + streamStdout := strings.TrimSpace(strings.Join(stdoutChunks, "\n")) + + if execStdout != streamStdout { + t.Errorf("Stdout differs:\nExecute: %q\nExecuteStream: %q", execStdout, streamStdout) + } + + // Compare stderr + execStderr := strings.TrimSpace(execResult["stderr"].(string)) + streamStderr := strings.TrimSpace(strings.Join(stderrChunks, "\n")) + + if execStderr != streamStderr { + t.Errorf("Stderr differs:\nExecute: %q\nExecuteStream: %q", execStderr, streamStderr) + } + + t.Log("Execute() and ExecuteStream() produce compatible results") +} + +// TestShellTool_Integration_ConcurrentStreamingExecutions tests that multiple +// concurrent streaming tool executions don't interfere with each other. +func TestShellTool_Integration_ConcurrentStreamingExecutions(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping concurrent streaming test on Windows") + } + + // Use security config that allows shell metacharacters for this test + secConfig := security.DefaultShellSecurityConfig() + secConfig.BlockedMetachars = []string{} // Allow all metachars for concurrent test + + tool := NewShellTool().WithTimeout(5 * time.Second).WithSecurityConfig(secConfig) + ctx := context.Background() + + // Run 5 concurrent streaming executions with different outputs + concurrency := 5 + results := make(chan error, concurrency) + + for i := 0; i < concurrency; i++ { + workerID := i + go func() { + chunks, err := tool.ExecuteStream(ctx, map[string]interface{}{ + "command": "sh", + "args": []interface{}{"-c", "echo worker-" + string(rune('0'+workerID)) + "; sleep 0.1"}, + }) + + if err != nil { + results <- err + return + } + + // Verify we get expected output for this worker + var output []string + for chunk := range chunks { + if !chunk.IsFinal && chunk.Stream == "stdout" { + output = append(output, chunk.Data) + } + } + + combined := strings.Join(output, "") + expectedMarker := "worker-" + string(rune('0'+workerID)) + if !strings.Contains(combined, expectedMarker) { + results <- &testError{msg: "Worker " + string(rune('0'+workerID)) + " output missing expected marker"} + return + } + + results <- nil + }() + } + + // Collect results + for i := 0; i < concurrency; i++ { + if err := <-results; err != nil { + t.Errorf("Concurrent execution %d failed: %v", i, err) + } + } +} + +// testError is a simple error type for test results +type testError struct { + msg string +} + +func (e *testError) Error() string { + return e.msg +} diff --git a/pkg/tools/builtin/shell_test.go b/pkg/tools/builtin/shell_test.go index effebe78..6fab345a 100644 --- a/pkg/tools/builtin/shell_test.go +++ b/pkg/tools/builtin/shell_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/tombee/conductor/pkg/security" + "github.com/tombee/conductor/pkg/tools" ) func TestShellTool_Name(t *testing.T) { @@ -565,3 +566,615 @@ func TestValidateCommand_CaseSensitivity(t *testing.T) { t.Error("validateCommand('Git') should not match 'git' on Unix (case-sensitive)") } } + +// TestShellTool_ExecuteStream_BasicStreaming tests basic streaming functionality +func TestShellTool_ExecuteStream_BasicStreaming(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping streaming test on Windows") + } + + tool := NewShellTool() + ctx := context.Background() + + chunks, err := tool.ExecuteStream(ctx, map[string]interface{}{ + "command": "echo", + "args": []interface{}{"hello"}, + }) + if err != nil { + t.Fatalf("ExecuteStream() error = %v", err) + } + + var receivedChunks []string + var finalChunk *struct { + result map[string]interface{} + found bool + } + finalChunk = &struct { + result map[string]interface{} + found bool + }{} + + for chunk := range chunks { + if chunk.IsFinal { + finalChunk.found = true + finalChunk.result = chunk.Result + } else { + receivedChunks = append(receivedChunks, chunk.Data) + if chunk.Stream != "stdout" && chunk.Stream != "stderr" { + t.Errorf("Chunk stream = %s, expected stdout or stderr", chunk.Stream) + } + } + } + + if !finalChunk.found { + t.Fatal("No final chunk received") + } + + if finalChunk.result == nil { + t.Fatal("Final chunk result is nil") + } + + // Verify final chunk contains required fields + if _, ok := finalChunk.result["exit_code"]; !ok { + t.Error("Final chunk result missing exit_code") + } + + if _, ok := finalChunk.result["duration"]; !ok { + t.Error("Final chunk result missing duration") + } + + if _, ok := finalChunk.result["status"]; !ok { + t.Error("Final chunk result missing status") + } + + success, ok := finalChunk.result["success"].(bool) + if !ok { + t.Fatal("success field is not a boolean") + } + + if !success { + t.Errorf("Command should have succeeded: %v", finalChunk.result) + } +} + +// TestShellTool_ExecuteStream_StdoutStderr tests separate stdout/stderr streaming +func TestShellTool_ExecuteStream_StdoutStderr(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping stdout/stderr test on Windows") + } + + secConfig := security.DefaultShellSecurityConfig() + secConfig.AllowShellExpand = true + + tool := NewShellTool().WithSecurityConfig(secConfig) + ctx := context.Background() + + chunks, err := tool.ExecuteStream(ctx, map[string]interface{}{ + "command": "sh", + "args": []interface{}{"-c", "echo stdout-line; echo stderr-line >&2"}, + }) + if err != nil { + t.Fatalf("ExecuteStream() error = %v", err) + } + + var stdoutChunks []string + var stderrChunks []string + + for chunk := range chunks { + if chunk.IsFinal { + continue + } + + switch chunk.Stream { + case "stdout": + stdoutChunks = append(stdoutChunks, chunk.Data) + case "stderr": + stderrChunks = append(stderrChunks, chunk.Data) + } + } + + if len(stdoutChunks) == 0 { + t.Error("No stdout chunks received") + } + + if len(stderrChunks) == 0 { + t.Error("No stderr chunks received") + } + + // Verify content + stdoutText := strings.Join(stdoutChunks, "\n") + if !strings.Contains(stdoutText, "stdout-line") { + t.Errorf("stdout does not contain expected text, got: %s", stdoutText) + } + + stderrText := strings.Join(stderrChunks, "\n") + if !strings.Contains(stderrText, "stderr-line") { + t.Errorf("stderr does not contain expected text, got: %s", stderrText) + } +} + +// TestShellTool_ExecuteStream_ExitCode tests that exit code is captured in final chunk +func TestShellTool_ExecuteStream_ExitCode(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping exit code test on Windows") + } + + tool := NewShellTool() + ctx := context.Background() + + chunks, err := tool.ExecuteStream(ctx, map[string]interface{}{ + "command": "sh", + "args": []interface{}{"-c", "exit 42"}, + }) + if err != nil { + t.Fatalf("ExecuteStream() error = %v", err) + } + + var finalResult map[string]interface{} + for chunk := range chunks { + if chunk.IsFinal { + finalResult = chunk.Result + } + } + + if finalResult == nil { + t.Fatal("No final result received") + } + + exitCode, ok := finalResult["exit_code"].(int) + if !ok { + t.Fatalf("exit_code is not an int: %T", finalResult["exit_code"]) + } + + if exitCode != 42 { + t.Errorf("exit_code = %d, want 42", exitCode) + } + + success, ok := finalResult["success"].(bool) + if !ok { + t.Fatal("success field is not a boolean") + } + + if success { + t.Error("Command with non-zero exit code should not succeed") + } +} + +// TestShellTool_ExecuteStream_Duration tests that duration is included in final chunk +func TestShellTool_ExecuteStream_Duration(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping duration test on Windows") + } + + tool := NewShellTool() + ctx := context.Background() + + chunks, err := tool.ExecuteStream(ctx, map[string]interface{}{ + "command": "echo", + "args": []interface{}{"test"}, + }) + if err != nil { + t.Fatalf("ExecuteStream() error = %v", err) + } + + var finalResult map[string]interface{} + for chunk := range chunks { + if chunk.IsFinal { + finalResult = chunk.Result + } + } + + if finalResult == nil { + t.Fatal("No final result received") + } + + duration, ok := finalResult["duration"].(int64) + if !ok { + t.Fatalf("duration is not an int64: %T", finalResult["duration"]) + } + + if duration < 0 { + t.Errorf("duration = %d, should be non-negative", duration) + } +} + +// TestShellTool_ExecuteStream_InvalidInputs tests error handling for invalid inputs +func TestShellTool_ExecuteStream_InvalidInputs(t *testing.T) { + tool := NewShellTool() + ctx := context.Background() + + tests := []struct { + name string + inputs map[string]interface{} + }{ + { + name: "missing command", + inputs: map[string]interface{}{}, + }, + { + name: "invalid command type", + inputs: map[string]interface{}{ + "command": 123, + }, + }, + { + name: "invalid args type", + inputs: map[string]interface{}{ + "command": "echo", + "args": "not an array", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := tool.ExecuteStream(ctx, tt.inputs) + if err == nil { + t.Error("ExecuteStream() should fail with invalid inputs") + } + }) + } +} + +// TestShellTool_ExecuteStream_MultipleLines tests streaming of multiple output lines +func TestShellTool_ExecuteStream_MultipleLines(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping multi-line test on Windows") + } + + secConfig := security.DefaultShellSecurityConfig() + secConfig.AllowShellExpand = true + + tool := NewShellTool().WithSecurityConfig(secConfig) + ctx := context.Background() + + chunks, err := tool.ExecuteStream(ctx, map[string]interface{}{ + "command": "sh", + "args": []interface{}{"-c", "echo line1; echo line2; echo line3"}, + }) + if err != nil { + t.Fatalf("ExecuteStream() error = %v", err) + } + + var outputLines []string + for chunk := range chunks { + if !chunk.IsFinal && chunk.Stream == "stdout" { + outputLines = append(outputLines, chunk.Data) + } + } + + if len(outputLines) < 3 { + t.Errorf("Expected at least 3 output lines, got %d", len(outputLines)) + } + + // Verify we got the expected lines + output := strings.Join(outputLines, "\n") + for _, expected := range []string{"line1", "line2", "line3"} { + if !strings.Contains(output, expected) { + t.Errorf("Output does not contain %s, got: %s", expected, output) + } + } +} + +// TestShellTool_ExecuteStream_ChannelClosed tests that channel is properly closed +func TestShellTool_ExecuteStream_ChannelClosed(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping channel test on Windows") + } + + tool := NewShellTool() + ctx := context.Background() + + chunks, err := tool.ExecuteStream(ctx, map[string]interface{}{ + "command": "echo", + "args": []interface{}{"test"}, + }) + if err != nil { + t.Fatalf("ExecuteStream() error = %v", err) + } + + // Drain the channel + chunkCount := 0 + for range chunks { + chunkCount++ + } + + if chunkCount == 0 { + t.Error("Expected at least one chunk (final chunk)") + } + + // Try to read from closed channel - should return immediately with zero value + chunk, ok := <-chunks + if ok { + t.Errorf("Channel should be closed, but received chunk: %+v", chunk) + } +} + +// TestShellTool_ExecuteStream_BinaryFallback tests 4KB fallback for binary data without newlines +func TestShellTool_ExecuteStream_BinaryFallback(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping binary fallback test on Windows") + } + + secConfig := security.DefaultShellSecurityConfig() + secConfig.AllowShellExpand = true + + tool := NewShellTool().WithSecurityConfig(secConfig) + ctx := context.Background() + + // Generate 5KB of data without newlines to test 4KB fallback + chunks, err := tool.ExecuteStream(ctx, map[string]interface{}{ + "command": "sh", + "args": []interface{}{"-c", "printf '%5120s' | tr ' ' 'x'"}, + }) + if err != nil { + t.Fatalf("ExecuteStream() error = %v", err) + } + + var dataChunks []string + for chunk := range chunks { + if !chunk.IsFinal && chunk.Stream == "stdout" { + dataChunks = append(dataChunks, chunk.Data) + } + } + + // Should have at least 2 chunks: one 4KB chunk and one for remaining data + if len(dataChunks) < 2 { + t.Errorf("Expected at least 2 chunks for 5KB binary data, got %d", len(dataChunks)) + } + + // First chunk should be around 4KB (4096 bytes) + if len(dataChunks[0]) < 4000 { + t.Errorf("First chunk size = %d, expected around 4KB", len(dataChunks[0])) + } +} + +// TestShellTool_ExecuteStream_ContextCancellation tests cleanup on context cancellation +func TestShellTool_ExecuteStream_ContextCancellation(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping context cancellation test on Windows") + } + + tool := NewShellTool() + ctx, cancel := context.WithCancel(context.Background()) + + // Start a long-running command + chunks, err := tool.ExecuteStream(ctx, map[string]interface{}{ + "command": "sleep", + "args": []interface{}{"10"}, + }) + if err != nil { + t.Fatalf("ExecuteStream() error = %v", err) + } + + // Cancel context after a short delay + time.AfterFunc(100*time.Millisecond, cancel) + + // Drain the channel + var receivedFinal bool + for chunk := range chunks { + if chunk.IsFinal { + receivedFinal = true + } + } + + // Should receive final chunk even after cancellation + if !receivedFinal { + t.Error("Should receive final chunk after context cancellation") + } +} + +// TestShellTool_ExecuteStream_Redaction tests sensitive data redaction in output +func TestShellTool_ExecuteStream_Redaction(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping redaction test on Windows") + } + + secConfig := security.DefaultShellSecurityConfig() + secConfig.AllowShellExpand = true + + tool := NewShellTool().WithSecurityConfig(secConfig) + ctx := context.Background() + + tests := []struct { + name string + output string + contains string + notContains string + }{ + { + name: "AWS access key", + output: "AWS Key: AKIAIOSFODNN7EXAMPLE", + contains: "[REDACTED]", + notContains: "AKIAIOSFODNN7EXAMPLE", + }, + { + name: "Bearer token", + output: "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9", + contains: "[REDACTED]", + notContains: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9", + }, + { + name: "Password in URL", + output: "Database: postgresql://user:secretpass123@localhost/db", + contains: "[REDACTED]", + notContains: "secretpass123", + }, + { + name: "API key", + output: "API_KEY=sk_live_1234567890abcdefghij", + contains: "[REDACTED]", + notContains: "sk_live_1234567890abcdefghij", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chunks, err := tool.ExecuteStream(ctx, map[string]interface{}{ + "command": "echo", + "args": []interface{}{tt.output}, + }) + if err != nil { + t.Fatalf("ExecuteStream() error = %v", err) + } + + var outputData []string + for chunk := range chunks { + if !chunk.IsFinal && chunk.Stream == "stdout" { + outputData = append(outputData, chunk.Data) + } + } + + output := strings.Join(outputData, "\n") + + if !strings.Contains(output, tt.contains) { + t.Errorf("Output should contain %q, got: %s", tt.contains, output) + } + + if strings.Contains(output, tt.notContains) { + t.Errorf("Output should NOT contain sensitive data %q, got: %s", tt.notContains, output) + } + }) + } +} + +// TestShellTool_ExecuteStream_PartialLineAtEOF tests handling of partial lines at EOF +func TestShellTool_ExecuteStream_PartialLineAtEOF(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping partial line test on Windows") + } + + secConfig := security.DefaultShellSecurityConfig() + secConfig.AllowShellExpand = true + + tool := NewShellTool().WithSecurityConfig(secConfig) + ctx := context.Background() + + // Use printf without newline to test partial line handling + chunks, err := tool.ExecuteStream(ctx, map[string]interface{}{ + "command": "printf", + "args": []interface{}{"no-newline-here"}, + }) + if err != nil { + t.Fatalf("ExecuteStream() error = %v", err) + } + + var outputData []string + for chunk := range chunks { + if !chunk.IsFinal && chunk.Stream == "stdout" { + outputData = append(outputData, chunk.Data) + } + } + + // Should receive the partial line + if len(outputData) == 0 { + t.Error("Should receive partial line at EOF") + } + + output := strings.Join(outputData, "") + if !strings.Contains(output, "no-newline-here") { + t.Errorf("Output should contain partial line, got: %s", output) + } +} + +// TestShellTool_ExecuteStream_LineBuffering tests that lines are emitted immediately on newline +func TestShellTool_ExecuteStream_LineBuffering(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping line buffering test on Windows") + } + + secConfig := security.DefaultShellSecurityConfig() + secConfig.AllowShellExpand = true + + tool := NewShellTool().WithSecurityConfig(secConfig) + ctx := context.Background() + + // Output multiple lines with delays to test line buffering + chunks, err := tool.ExecuteStream(ctx, map[string]interface{}{ + "command": "sh", + "args": []interface{}{"-c", "echo first; echo second; echo third"}, + }) + if err != nil { + t.Fatalf("ExecuteStream() error = %v", err) + } + + var lines []string + for chunk := range chunks { + if !chunk.IsFinal && chunk.Stream == "stdout" { + lines = append(lines, chunk.Data) + } + } + + // Should receive each line as a separate chunk + if len(lines) < 3 { + t.Errorf("Expected at least 3 line chunks, got %d", len(lines)) + } + + // Verify each line contains expected content (may have multiple words per chunk) + output := strings.Join(lines, "\n") + for _, expected := range []string{"first", "second", "third"} { + if !strings.Contains(output, expected) { + t.Errorf("Output should contain %q, got: %s", expected, output) + } + } +} + +// TestShellTool_ExecuteStream_SizeLimit tests that output is truncated when size limit is exceeded +func TestShellTool_ExecuteStream_SizeLimit(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping size limit test on Windows") + } + + // Configure a small output size limit (100 bytes) + secConfig := security.DefaultShellSecurityConfig() + secConfig.MaxOutputSize = 100 + secConfig.AllowShellExpand = true + + tool := NewShellTool().WithSecurityConfig(secConfig) + ctx := context.Background() + + // Generate output larger than the limit (500+ bytes) + chunks, err := tool.ExecuteStream(ctx, map[string]interface{}{ + "command": "sh", + "args": []interface{}{"-c", "for i in 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15; do echo 'This is a line of output that should be truncated'; done"}, + }) + if err != nil { + t.Fatalf("ExecuteStream() error = %v", err) + } + + var totalSize int + var finalChunk *tools.ToolChunk + for chunk := range chunks { + if chunk.IsFinal { + finalChunk = &chunk + } else if chunk.Stream == "stdout" { + totalSize += len(chunk.Data) + } + } + + // Verify output was truncated + if totalSize > int(secConfig.MaxOutputSize) { + t.Errorf("Total output size %d exceeds limit %d", totalSize, secConfig.MaxOutputSize) + } + + // Verify final chunk has truncation metadata + if finalChunk == nil { + t.Fatal("No final chunk received") + } + + if finalChunk.Metadata == nil { + t.Fatal("Final chunk should have metadata when truncated") + } + + truncated, ok := finalChunk.Metadata["truncated"].(bool) + if !ok || !truncated { + t.Error("Final chunk should have truncated=true in metadata") + } + + // Verify error message indicates truncation + if finalChunk.Error == nil { + t.Error("Final chunk should have error when truncated") + } else if !strings.Contains(finalChunk.Error.Error(), "truncated") { + t.Errorf("Error should mention truncation, got: %v", finalChunk.Error) + } +} diff --git a/pkg/tools/redact.go b/pkg/tools/redact.go new file mode 100644 index 00000000..87b54cab --- /dev/null +++ b/pkg/tools/redact.go @@ -0,0 +1,92 @@ +// Package tools provides utilities for tool execution and output processing. +package tools + +import ( + "regexp" + "sync" +) + +// Redactor detects and redacts sensitive data patterns in strings. +// It uses compiled regex patterns to identify common sensitive values like API keys, +// tokens, passwords, and cloud provider credentials. +type Redactor struct { + patterns []*redactionPattern + mu sync.RWMutex // Protects patterns for thread-safe usage +} + +// redactionPattern represents a compiled pattern with its replacement string. +type redactionPattern struct { + regex *regexp.Regexp + replacement string +} + +// NewRedactor creates a new redactor with default patterns for common sensitive data. +// Patterns include: +// - AWS access keys (AKIA...) +// - AWS secret keys in configuration +// - Bearer tokens +// - API keys in various formats +// - Passwords in URLs +// - Database connection strings with credentials +func NewRedactor() *Redactor { + r := &Redactor{ + patterns: make([]*redactionPattern, 0), + } + + // AWS Access Key IDs (start with AKIA and are 20 characters) + r.addPattern(`AKIA[A-Z0-9]{16}`, "[REDACTED]") + + // AWS Secret Access Keys in configuration contexts + // Matches patterns like: aws_secret_access_key = "base64string" + // or secret_key: "base64string" (40 character base64) + r.addPattern(`(?i)(aws[_-]?secret[_-]?access[_-]?key|secret[_-]?key|aws[_-]?secret)\s*[=:]\s*['\"]?([A-Za-z0-9/+=]{40})['\"]?`, "$1=[REDACTED]") + + // Bearer tokens in Authorization headers + // Only match if followed by token-like characters (at least 10 chars) + r.addPattern(`(?i)Bearer\s+([a-zA-Z0-9_\-\.]{10,})`, "Bearer [REDACTED]") + + // API keys in various formats + // Matches: api_key=xxx, apiKey: xxx, api-key="xxx", etc. + r.addPattern(`(?i)(api[_-]?key|apikey)\s*[=:]\s*['\"]?([a-zA-Z0-9_\-]{20,})['\"]?`, "$1=[REDACTED]") + + // Generic token patterns + r.addPattern(`(?i)(token|access[_-]?token|auth[_-]?token)\s*[=:]\s*['\"]?([a-zA-Z0-9_\-\.]{20,})['\"]?`, "$1=[REDACTED]") + + // Passwords in URLs (://user:password@host) + // Note: @ characters in passwords should be URL-encoded as %40 + r.addPattern(`://([^:@\s]+):([^@\s]+)@`, "://$1:[REDACTED]@") + + // Database connection strings with passwords + // Handle both quoted and unquoted values, with minimum length of 3 + r.addPattern(`(?i)(password|pwd|pass)\s*=\s*'([^']{3,})'`, "$1=[REDACTED]") + r.addPattern(`(?i)(password|pwd|pass)\s*=\s*"([^"]{3,})"`, "$1=[REDACTED]") + r.addPattern(`(?i)(password|pwd|pass)\s*=\s*([^;'\"\s]{3,})`, "$1=[REDACTED]") + + // Generic secret patterns (for environment variables or config) + r.addPattern(`(?i)(secret|private[_-]?key)\s*[=:]\s*['\"]?([a-zA-Z0-9_\-/+=]{20,})['\"]?`, "$1=[REDACTED]") + + return r +} + +// addPattern compiles and adds a new redaction pattern. +func (r *Redactor) addPattern(pattern, replacement string) { + regex := regexp.MustCompile(pattern) + r.patterns = append(r.patterns, &redactionPattern{ + regex: regex, + replacement: replacement, + }) +} + +// Redact scans the input string and replaces all matches of sensitive patterns with [REDACTED]. +// It applies all patterns in sequence and returns the redacted string. +// This method is thread-safe and can be called concurrently. +func (r *Redactor) Redact(s string) string { + r.mu.RLock() + defer r.mu.RUnlock() + + result := s + for _, p := range r.patterns { + result = p.regex.ReplaceAllString(result, p.replacement) + } + return result +} diff --git a/pkg/tools/redact_test.go b/pkg/tools/redact_test.go new file mode 100644 index 00000000..4fd55ec8 --- /dev/null +++ b/pkg/tools/redact_test.go @@ -0,0 +1,456 @@ +package tools + +import ( + "strings" + "testing" +) + +func TestRedactor_AWSAccessKeys(t *testing.T) { + r := NewRedactor() + + tests := []struct { + name string + input string + expected string + }{ + { + name: "AWS access key in plain text", + input: "AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE", + expected: "AWS_ACCESS_KEY_ID=[REDACTED]", + }, + { + name: "AWS access key in log output", + input: "Using credentials AKIAIOSFODNN7EXAMPLE for deployment", + expected: "Using credentials [REDACTED] for deployment", + }, + { + name: "Multiple AWS access keys", + input: "Key1: AKIAIOSFODNN7EXAMPLE, Key2: AKIAJ7EXAMPLE1234567", + expected: "Key1: [REDACTED], Key2: [REDACTED]", + }, + { + name: "Not an AWS key (too short)", + input: "AKIASHORT is not a key", + expected: "AKIASHORT is not a key", + }, + { + name: "Not an AWS key (wrong prefix)", + input: "BKIAIOSFODNN7EXAMPLE is not valid", + expected: "BKIAIOSFODNN7EXAMPLE is not valid", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := r.Redact(tt.input) + if result != tt.expected { + t.Errorf("Redact() = %q, want %q", result, tt.expected) + } + }) + } +} + +func TestRedactor_AWSSecretKeys(t *testing.T) { + r := NewRedactor() + + tests := []struct { + name string + input string + expected string + }{ + { + name: "AWS secret key with equals", + input: "aws_secret_access_key=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + expected: "aws_secret_access_key=[REDACTED]", + }, + { + name: "AWS secret key with quotes", + input: `secret_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"`, + expected: `secret_key=[REDACTED]`, + }, + { + name: "AWS secret in environment variable format", + input: "export AWS_SECRET=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + expected: "export AWS_SECRET=[REDACTED]", + }, + { + name: "Case insensitive matching", + input: "AWS_SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + expected: "AWS_SECRET_ACCESS_KEY=[REDACTED]", + }, + { + name: "Secret key too short (not 40 chars)", + input: "secret_key=shortkey123", + expected: "secret_key=shortkey123", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := r.Redact(tt.input) + if result != tt.expected { + t.Errorf("Redact() = %q, want %q", result, tt.expected) + } + }) + } +} + +func TestRedactor_BearerTokens(t *testing.T) { + r := NewRedactor() + + tests := []struct { + name string + input string + expected string + }{ + { + name: "Bearer token in Authorization header", + input: "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9", + expected: "Authorization: Bearer [REDACTED]", + }, + { + name: "Bearer token with dots (JWT format)", + input: "Bearer abc123.def456.ghi789", + expected: "Bearer [REDACTED]", + }, + { + name: "Case insensitive bearer", + input: "bearer token_value_here", + expected: "Bearer [REDACTED]", + }, + { + name: "Bearer token in log", + input: "Request sent with Bearer sk_live_1234567890abcdef", + expected: "Request sent with Bearer [REDACTED]", + }, + { + name: "Not a bearer token (word too short)", + input: "Bearer auth required", + expected: "Bearer auth required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := r.Redact(tt.input) + if result != tt.expected { + t.Errorf("Redact() = %q, want %q", result, tt.expected) + } + }) + } +} + +func TestRedactor_APIKeys(t *testing.T) { + r := NewRedactor() + + tests := []struct { + name string + input string + expected string + }{ + { + name: "API key with underscore", + input: "api_key=sk_live_abcdef1234567890", + expected: "api_key=[REDACTED]", + }, + { + name: "API key with dash", + input: "api-key: pk_test_1234567890abcdefghij", + expected: "api-key=[REDACTED]", + }, + { + name: "camelCase apiKey", + input: `apiKey: "1234567890abcdefghijklmnopqrst"`, + expected: `apiKey=[REDACTED]`, + }, + { + name: "Case insensitive", + input: "API_KEY=abcdef1234567890ghijklmnop", + expected: "API_KEY=[REDACTED]", + }, + { + name: "API key too short", + input: "api_key=short", + expected: "api_key=short", + }, + { + name: "Generic token pattern", + input: "access_token=ghp_1234567890abcdefghijklmnopqrst", + expected: "access_token=[REDACTED]", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := r.Redact(tt.input) + if result != tt.expected { + t.Errorf("Redact() = %q, want %q", result, tt.expected) + } + }) + } +} + +func TestRedactor_PasswordsInURLs(t *testing.T) { + r := NewRedactor() + + tests := []struct { + name string + input string + expected string + }{ + { + name: "HTTP URL with password", + input: "https://user:secretpass@example.com/path", + expected: "https://user:[REDACTED]@example.com/path", + }, + { + name: "Database URL with password", + input: "postgresql://dbuser:dbpass123@localhost:5432/mydb", + expected: "postgresql://dbuser:[REDACTED]@localhost:5432/mydb", + }, + { + name: "MongoDB connection string", + input: "mongodb://admin:complexPass123@mongo.example.com:27017/db", + expected: "mongodb://admin:[REDACTED]@mongo.example.com:27017/db", + }, + { + name: "Redis URL with password", + input: "redis://default:secret@redis.example.com:6379/0", + expected: "redis://default:[REDACTED]@redis.example.com:6379/0", + }, + { + name: "URL without password", + input: "https://user@example.com/path", + expected: "https://user@example.com/path", + }, + { + name: "URL without credentials", + input: "https://example.com/path", + expected: "https://example.com/path", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := r.Redact(tt.input) + if result != tt.expected { + t.Errorf("Redact() = %q, want %q", result, tt.expected) + } + }) + } +} + +func TestRedactor_DatabaseConnectionStrings(t *testing.T) { + r := NewRedactor() + + tests := []struct { + name string + input string + expected string + }{ + { + name: "SQL Server connection string", + input: "Server=myserver;Database=mydb;User Id=myuser;Password=mypassword;", + expected: "Server=myserver;Database=mydb;User Id=myuser;Password=[REDACTED];", + }, + { + name: "Connection string with pwd", + input: "Data Source=server;Initial Catalog=db;User ID=user;pwd=secret123", + expected: "Data Source=server;Initial Catalog=db;User ID=user;pwd=[REDACTED]", + }, + { + name: "Case insensitive password", + input: "PASSWORD=MySecretPass123", + expected: "PASSWORD=[REDACTED]", + }, + { + name: "Password with quotes", + input: `password='complex!pass@123'`, + expected: `password=[REDACTED]`, + }, + { + name: "Password with double quotes", + input: `Password="my secret pass"`, + expected: `Password=[REDACTED]`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := r.Redact(tt.input) + if result != tt.expected { + t.Errorf("Redact() = %q, want %q", result, tt.expected) + } + }) + } +} + +func TestRedactor_GenericSecrets(t *testing.T) { + r := NewRedactor() + + tests := []struct { + name string + input string + expected string + }{ + { + name: "Secret in configuration", + input: "client_secret=abcdef1234567890ghijklmnopqrst", + expected: "client_secret=[REDACTED]", + }, + { + name: "Private key", + input: "private_key=MIIEvQIBADANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA", + expected: "private_key=[REDACTED]", + }, + { + name: "Private key with dash", + input: "private-key: base64encodedkeydata1234567890", + expected: "private-key=[REDACTED]", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := r.Redact(tt.input) + if result != tt.expected { + t.Errorf("Redact() = %q, want %q", result, tt.expected) + } + }) + } +} + +func TestRedactor_MultiplePatterns(t *testing.T) { + r := NewRedactor() + + tests := []struct { + name string + input string + validate func(t *testing.T, result string) + }{ + { + name: "Multiple sensitive values in one string", + input: `export AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE +export AWS_SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY +export API_KEY=sk_live_1234567890abcdefghij`, + validate: func(t *testing.T, result string) { + if strings.Contains(result, "AKIAIOSFODNN7EXAMPLE") { + t.Error("AWS access key not redacted") + } + if strings.Contains(result, "wJalrXUtnFEMI") { + t.Error("AWS secret key not redacted") + } + if strings.Contains(result, "sk_live_1234567890abcdefghij") { + t.Error("API key not redacted") + } + if !strings.Contains(result, "[REDACTED]") { + t.Error("No redaction markers found") + } + }, + }, + { + name: "Log line with URL and bearer token", + input: "Connecting to https://user:pass123@api.example.com with Authorization: Bearer token_abc123def456", + validate: func(t *testing.T, result string) { + if strings.Contains(result, "pass123") { + t.Error("Password not redacted") + } + if strings.Contains(result, "token_abc123def456") { + t.Error("Bearer token not redacted") + } + }, + }, + { + name: "Configuration file content", + input: "db_url=postgresql://user:secret@localhost/db\napi_key=sk_test_abcdefghij1234567890\nbearer_token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.payload.signature", + validate: func(t *testing.T, result string) { + if strings.Contains(result, "secret@") { + t.Error("Database password not redacted") + } + if strings.Contains(result, "sk_test_abcdefghij1234567890") { + t.Error("API key not redacted") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := r.Redact(tt.input) + tt.validate(t, result) + }) + } +} + +func TestRedactor_NoFalsePositives(t *testing.T) { + r := NewRedactor() + + tests := []struct { + name string + input string + }{ + { + name: "Normal text should not be redacted", + input: "This is a normal log message without secrets", + }, + { + name: "Variable names without values", + input: "Please set API_KEY and SECRET in your environment", + }, + { + name: "Short password values should not trigger (less than 3 chars)", + input: "password=ab", + }, + { + name: "URLs without credentials", + input: "https://example.com/api/v1/users", + }, + { + name: "Documentation examples", + input: "Use format: api_key=YOUR_KEY_HERE", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := r.Redact(tt.input) + if result != tt.input { + t.Errorf("False positive redaction: input=%q, output=%q", tt.input, result) + } + }) + } +} + +func TestRedactor_ThreadSafety(t *testing.T) { + r := NewRedactor() + + // Run multiple goroutines calling Redact concurrently + done := make(chan bool) + for i := 0; i < 10; i++ { + go func() { + for j := 0; j < 100; j++ { + _ = r.Redact("api_key=sk_live_1234567890abcdefghij") + } + done <- true + }() + } + + // Wait for all goroutines to complete + for i := 0; i < 10; i++ { + <-done + } +} + +func BenchmarkRedactor_Redact(b *testing.B) { + r := NewRedactor() + input := `export AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE +export AWS_SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY +export API_KEY=sk_live_1234567890abcdefghij +Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9 +Connection: postgresql://user:password123@localhost:5432/mydb` + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = r.Redact(input) + } +} diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index 3fcb05c7..eea46e99 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -11,6 +11,7 @@ import ( "fmt" "sync" + "github.com/google/uuid" "github.com/tombee/conductor/pkg/errors" ) @@ -29,6 +30,51 @@ type Tool interface { Execute(ctx context.Context, inputs map[string]interface{}) (map[string]interface{}, error) } +// StreamingTool extends Tool with streaming execution support. +// Tools that implement this interface can emit incremental output chunks during execution, +// enabling real-time progress visibility for long-running operations. +type StreamingTool interface { + Tool + + // ExecuteStream runs the tool and streams output chunks via a channel. + // The channel is closed when execution completes. + // Callers should drain the channel to get the complete result. + // + // Returns: + // - A receive-only channel of ToolChunk values + // - An error if the tool fails to start (startup errors only) + // + // Runtime errors during execution are delivered in the final chunk's Error field. + // Exactly one chunk will have IsFinal=true, which is always the last chunk. + ExecuteStream(ctx context.Context, inputs map[string]any) (<-chan ToolChunk, error) +} + +// ToolChunk represents an incremental output from a streaming tool. +// Chunks are emitted during tool execution to provide real-time progress feedback. +type ToolChunk struct { + // Data is the chunk content (e.g., a line of output from stdout/stderr) + Data string + + // Stream identifies the output stream ("stdout", "stderr", or empty for general output) + Stream string + + // IsFinal indicates this is the last chunk with complete results. + // Exactly one chunk will have IsFinal=true, and it will be the last chunk sent. + IsFinal bool + + // Result contains the final tool output (only set when IsFinal is true). + // For non-final chunks, this field is nil. + Result map[string]any + + // Error contains any execution error (only set when IsFinal is true). + // For non-final chunks, this field is nil. + Error error + + // Metadata contains optional additional information about the chunk. + // Common uses include truncation indicators, timing data, or tool-specific context. + Metadata map[string]any +} + // Schema defines the input and output schema for a tool using JSON Schema. type Schema struct { // Inputs defines the expected input parameters @@ -73,11 +119,16 @@ type Property struct { // Registry maintains a collection of registered tools. type Registry struct { - mu sync.RWMutex - tools map[string]Tool - interceptor Interceptor + mu sync.RWMutex + tools map[string]Tool + interceptor Interceptor + eventEmitter EventEmitter } +// EventEmitter emits tool execution events. +// This allows the registry to publish streaming output events to the SDK event system. +type EventEmitter func(ctx context.Context, eventType string, workflowID string, stepID string, data any) + // Interceptor validates tool execution against security policy. // This interface is defined here to avoid circular dependencies. type Interceptor interface { @@ -103,6 +154,14 @@ func (r *Registry) SetInterceptor(interceptor Interceptor) { r.interceptor = interceptor } +// SetEventEmitter sets the event emitter for this registry. +// The emitter will be called for each streaming tool output chunk. +func (r *Registry) SetEventEmitter(emitter EventEmitter) { + r.mu.Lock() + defer r.mu.Unlock() + r.eventEmitter = emitter +} + // Register adds a tool to the registry. // Returns an error if a tool with the same name is already registered. func (r *Registry) Register(tool Tool) error { @@ -400,3 +459,153 @@ func (r *Registry) Filter(allowedNames []string) (*Registry, error) { return filtered, nil } + +// SupportsStreaming checks if a tool implements the StreamingTool interface. +// Returns false if the tool is not registered. +func (r *Registry) SupportsStreaming(name string) bool { + r.mu.RLock() + defer r.mu.RUnlock() + + tool, exists := r.tools[name] + if !exists { + return false + } + + _, ok := tool.(StreamingTool) + return ok +} + +// ExecuteStream executes a tool with streaming output support. +// If the tool implements StreamingTool, it delegates to the tool's ExecuteStream method. +// If the tool does not implement StreamingTool, it wraps the standard Execute method +// to emit a single chunk with IsFinal=true containing the complete result. +// +// The toolCallID parameter is used to correlate tool execution with LLM tool calls. +// If empty, a UUID will be generated automatically. +// +// Returns a channel that emits ToolChunk values during execution. +// The channel is closed when execution completes. +// Exactly one chunk will have IsFinal=true, which is always the last chunk. +func (r *Registry) ExecuteStream(ctx context.Context, name string, inputs map[string]interface{}, toolCallID string) (<-chan ToolChunk, error) { + // Generate UUID for toolCallID if not provided + if toolCallID == "" { + toolCallID = uuid.New().String() + } + + tool, err := r.Get(name) + if err != nil { + return nil, err + } + + // Validate inputs against schema + if err := r.validateInputs(tool, inputs); err != nil { + return nil, &errors.ValidationError{ + Field: "inputs", + Message: fmt.Sprintf("input validation failed for tool %s: %v", name, err), + Suggestion: "Check the tool schema for required inputs and correct types", + } + } + + // Call security interceptor before execution + r.mu.RLock() + interceptor := r.interceptor + r.mu.RUnlock() + + if interceptor != nil { + if err := interceptor.Intercept(ctx, tool, inputs); err != nil { + return nil, fmt.Errorf("security validation failed for tool %s: %w", name, err) + } + } + + // Check if tool supports streaming + if streamingTool, ok := tool.(StreamingTool); ok { + // Use native streaming support + chunks, err := streamingTool.ExecuteStream(ctx, inputs) + if err != nil { + return nil, fmt.Errorf("tool execution failed for %s: %w", name, err) + } + + // Wrap channel to emit events and call post-execute interceptor + wrappedChunks := make(chan ToolChunk) + go func() { + defer close(wrappedChunks) + var finalResult map[string]interface{} + var finalError error + + for chunk := range chunks { + // Emit event for this chunk + r.emitToolOutputEvent(ctx, toolCallID, name, chunk) + + wrappedChunks <- chunk + if chunk.IsFinal { + finalResult = chunk.Result + finalError = chunk.Error + } + } + + // Call post-execute interceptor after streaming completes + if interceptor != nil { + interceptor.PostExecute(ctx, tool, finalResult, finalError) + } + }() + return wrappedChunks, nil + } + + // Tool does not support streaming - wrap standard Execute method + // to emit a single chunk with IsFinal=true + chunks := make(chan ToolChunk, 1) + + go func() { + defer close(chunks) + + result, err := tool.Execute(ctx, inputs) + + // Call post-execute interceptor + if interceptor != nil { + interceptor.PostExecute(ctx, tool, result, err) + } + + // Create single final chunk + chunk := ToolChunk{ + IsFinal: true, + Result: result, + Error: err, + } + + // Emit event for this chunk + r.emitToolOutputEvent(ctx, toolCallID, name, chunk) + + // Emit single final chunk + chunks <- chunk + }() + + return chunks, nil +} + +// emitToolOutputEvent emits a tool output event for a chunk. +// This is called for every chunk produced by ExecuteStream. +func (r *Registry) emitToolOutputEvent(ctx context.Context, toolCallID string, toolName string, chunk ToolChunk) { + r.mu.RLock() + emitter := r.eventEmitter + r.mu.RUnlock() + + if emitter == nil { + return + } + + // Extract workflow and step IDs from context if available + workflowID, _ := ctx.Value("workflow_id").(string) + stepID, _ := ctx.Value("step_id").(string) + + // Create ToolOutputEvent matching sdk/events.go structure + eventData := map[string]any{ + "tool_call_id": toolCallID, + "tool_name": toolName, + "stream": chunk.Stream, + "data": chunk.Data, + "is_final": chunk.IsFinal, + "metadata": chunk.Metadata, + } + + emitter(ctx, "tool.output", workflowID, stepID, eventData) +} diff --git a/pkg/tools/registry_test.go b/pkg/tools/registry_test.go index 57b3caa6..c9a90ebc 100644 --- a/pkg/tools/registry_test.go +++ b/pkg/tools/registry_test.go @@ -540,3 +540,509 @@ func TestHasNamespacePrefix(t *testing.T) { }) } } + +// mockStreamingTool implements the StreamingTool interface for testing +type mockStreamingTool struct { + mockTool + executeStreamFn func(ctx context.Context, inputs map[string]any) (<-chan ToolChunk, error) +} + +func (m *mockStreamingTool) ExecuteStream(ctx context.Context, inputs map[string]any) (<-chan ToolChunk, error) { + if m.executeStreamFn != nil { + return m.executeStreamFn(ctx, inputs) + } + // Default implementation: emit single chunk with result + chunks := make(chan ToolChunk, 1) + go func() { + defer close(chunks) + chunks <- ToolChunk{ + Data: "test output", + IsFinal: true, + Result: map[string]any{"result": "success"}, + } + }() + return chunks, nil +} + +func TestRegistry_SupportsStreaming(t *testing.T) { + r := NewRegistry() + + // Register a non-streaming tool + nonStreamingTool := &mockTool{ + name: "non-streaming", + schema: &Schema{ + Inputs: &ParameterSchema{Type: "object"}, + }, + } + if err := r.Register(nonStreamingTool); err != nil { + t.Fatalf("Register(non-streaming) failed: %v", err) + } + + // Register a streaming tool + streamingTool := &mockStreamingTool{ + mockTool: mockTool{ + name: "streaming", + schema: &Schema{ + Inputs: &ParameterSchema{Type: "object"}, + }, + }, + } + if err := r.Register(streamingTool); err != nil { + t.Fatalf("Register(streaming) failed: %v", err) + } + + tests := []struct { + name string + toolName string + want bool + }{ + { + name: "streaming tool", + toolName: "streaming", + want: true, + }, + { + name: "non-streaming tool", + toolName: "non-streaming", + want: false, + }, + { + name: "non-existent tool", + toolName: "non-existent", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := r.SupportsStreaming(tt.toolName) + if got != tt.want { + t.Errorf("SupportsStreaming(%q) = %v, want %v", tt.toolName, got, tt.want) + } + }) + } +} + +func TestRegistry_ExecuteStream_StreamingTool(t *testing.T) { + r := NewRegistry() + + // Create a streaming tool that emits multiple chunks + streamingTool := &mockStreamingTool{ + mockTool: mockTool{ + name: "streaming-tool", + schema: &Schema{ + Inputs: &ParameterSchema{Type: "object"}, + }, + }, + executeStreamFn: func(ctx context.Context, inputs map[string]any) (<-chan ToolChunk, error) { + chunks := make(chan ToolChunk, 3) + go func() { + defer close(chunks) + chunks <- ToolChunk{Data: "chunk 1", Stream: "stdout"} + chunks <- ToolChunk{Data: "chunk 2", Stream: "stdout"} + chunks <- ToolChunk{ + Data: "chunk 3", + Stream: "stdout", + IsFinal: true, + Result: map[string]any{"exit_code": 0}, + } + }() + return chunks, nil + }, + } + + if err := r.Register(streamingTool); err != nil { + t.Fatalf("Register() failed: %v", err) + } + + ctx := context.Background() + chunks, err := r.ExecuteStream(ctx, "streaming-tool", map[string]interface{}{}, "") + if err != nil { + t.Fatalf("ExecuteStream() failed: %v", err) + } + + // Collect all chunks + var collected []ToolChunk + for chunk := range chunks { + collected = append(collected, chunk) + } + + // Verify we got all chunks + if len(collected) != 3 { + t.Errorf("Expected 3 chunks, got %d", len(collected)) + } + + // Verify chunk contents + if collected[0].Data != "chunk 1" { + t.Errorf("First chunk data = %q, want %q", collected[0].Data, "chunk 1") + } + if collected[1].Data != "chunk 2" { + t.Errorf("Second chunk data = %q, want %q", collected[1].Data, "chunk 2") + } + if !collected[2].IsFinal { + t.Error("Final chunk should have IsFinal=true") + } + if collected[2].Result == nil { + t.Error("Final chunk should have Result set") + } +} + +func TestRegistry_ExecuteStream_NonStreamingTool(t *testing.T) { + r := NewRegistry() + + // Create a non-streaming tool + nonStreamingTool := &mockTool{ + name: "non-streaming-tool", + schema: &Schema{ + Inputs: &ParameterSchema{Type: "object"}, + }, + executeFn: func(ctx context.Context, inputs map[string]interface{}) (map[string]interface{}, error) { + return map[string]interface{}{"result": "success"}, nil + }, + } + + if err := r.Register(nonStreamingTool); err != nil { + t.Fatalf("Register() failed: %v", err) + } + + ctx := context.Background() + chunks, err := r.ExecuteStream(ctx, "non-streaming-tool", map[string]interface{}{}, "") + if err != nil { + t.Fatalf("ExecuteStream() failed: %v", err) + } + + // Collect all chunks + var collected []ToolChunk + for chunk := range chunks { + collected = append(collected, chunk) + } + + // Non-streaming tools should emit exactly one chunk + if len(collected) != 1 { + t.Errorf("Expected 1 chunk from non-streaming tool, got %d", len(collected)) + } + + // Verify the single chunk is final and contains the result + chunk := collected[0] + if !chunk.IsFinal { + t.Error("Single chunk from non-streaming tool should have IsFinal=true") + } + if chunk.Result == nil { + t.Error("Single chunk should have Result set") + } + if chunk.Result["result"] != "success" { + t.Errorf("Result = %v, want %v", chunk.Result, map[string]interface{}{"result": "success"}) + } +} + +func TestRegistry_ExecuteStream_GeneratesToolCallID(t *testing.T) { + r := NewRegistry() + + tool := &mockTool{ + name: "test-tool", + schema: &Schema{ + Inputs: &ParameterSchema{Type: "object"}, + }, + } + + if err := r.Register(tool); err != nil { + t.Fatalf("Register() failed: %v", err) + } + + ctx := context.Background() + + // Call with empty toolCallID - should generate UUID + chunks, err := r.ExecuteStream(ctx, "test-tool", map[string]interface{}{}, "") + if err != nil { + t.Fatalf("ExecuteStream() failed: %v", err) + } + + // Drain the channel (required to avoid goroutine leak) + for range chunks { + } + + // The test passes if no error was returned and the channel was successfully created + // UUID generation happens internally and is verified by the lack of errors +} + +func TestRegistry_ExecuteStream_ValidationError(t *testing.T) { + r := NewRegistry() + + tool := &mockTool{ + name: "test-tool", + schema: &Schema{ + Inputs: &ParameterSchema{ + Type: "object", + Required: []string{"required-param"}, + }, + }, + } + + if err := r.Register(tool); err != nil { + t.Fatalf("Register() failed: %v", err) + } + + ctx := context.Background() + + // Call with missing required parameter + _, err := r.ExecuteStream(ctx, "test-tool", map[string]interface{}{}, "") + if err == nil { + t.Error("ExecuteStream() should fail with validation error for missing required parameter") + } +} + +func TestRegistry_ExecuteStream_NonExistentTool(t *testing.T) { + r := NewRegistry() + + ctx := context.Background() + _, err := r.ExecuteStream(ctx, "non-existent", map[string]interface{}{}, "") + if err == nil { + t.Error("ExecuteStream() should fail for non-existent tool") + } +} + +func TestRegistry_EventEmission_StreamingTool(t *testing.T) { + r := NewRegistry() + + // Track emitted events + var emittedEvents []map[string]any + r.SetEventEmitter(func(ctx context.Context, eventType string, workflowID string, stepID string, data any) { + emittedEvents = append(emittedEvents, map[string]any{ + "eventType": eventType, + "workflowID": workflowID, + "stepID": stepID, + "data": data, + }) + }) + + // Create a streaming tool that emits multiple chunks + streamingTool := &mockStreamingTool{ + mockTool: mockTool{ + name: "streaming-tool", + schema: &Schema{ + Inputs: &ParameterSchema{Type: "object"}, + }, + }, + executeStreamFn: func(ctx context.Context, inputs map[string]any) (<-chan ToolChunk, error) { + chunks := make(chan ToolChunk, 2) + go func() { + defer close(chunks) + chunks <- ToolChunk{Data: "output line 1", Stream: "stdout"} + chunks <- ToolChunk{ + Data: "output line 2", + Stream: "stdout", + IsFinal: true, + Result: map[string]any{"exit_code": 0}, + } + }() + return chunks, nil + }, + } + + if err := r.Register(streamingTool); err != nil { + t.Fatalf("Register() failed: %v", err) + } + + ctx := context.Background() + chunks, err := r.ExecuteStream(ctx, "streaming-tool", map[string]interface{}{}, "test-call-id") + if err != nil { + t.Fatalf("ExecuteStream() failed: %v", err) + } + + // Drain the chunks channel + for range chunks { + } + + // Verify events were emitted + if len(emittedEvents) != 2 { + t.Errorf("Expected 2 events, got %d", len(emittedEvents)) + } + + // Verify first event + if len(emittedEvents) > 0 { + event := emittedEvents[0] + if event["eventType"] != "tool.output" { + t.Errorf("Event type = %q, want %q", event["eventType"], "tool.output") + } + + data, ok := event["data"].(map[string]any) + if !ok { + t.Fatal("Event data is not a map") + } + + if data["tool_call_id"] != "test-call-id" { + t.Errorf("tool_call_id = %q, want %q", data["tool_call_id"], "test-call-id") + } + if data["tool_name"] != "streaming-tool" { + t.Errorf("tool_name = %q, want %q", data["tool_name"], "streaming-tool") + } + if data["data"] != "output line 1" { + t.Errorf("data = %q, want %q", data["data"], "output line 1") + } + if data["stream"] != "stdout" { + t.Errorf("stream = %q, want %q", data["stream"], "stdout") + } + if data["is_final"] != false { + t.Errorf("is_final = %v, want false", data["is_final"]) + } + } + + // Verify second (final) event + if len(emittedEvents) > 1 { + event := emittedEvents[1] + data, ok := event["data"].(map[string]any) + if !ok { + t.Fatal("Event data is not a map") + } + + if data["is_final"] != true { + t.Errorf("is_final = %v, want true", data["is_final"]) + } + } +} + +func TestRegistry_EventEmission_NonStreamingTool(t *testing.T) { + r := NewRegistry() + + // Track emitted events + var emittedEvents []map[string]any + r.SetEventEmitter(func(ctx context.Context, eventType string, workflowID string, stepID string, data any) { + emittedEvents = append(emittedEvents, map[string]any{ + "eventType": eventType, + "workflowID": workflowID, + "stepID": stepID, + "data": data, + }) + }) + + // Create a non-streaming tool + nonStreamingTool := &mockTool{ + name: "non-streaming-tool", + schema: &Schema{ + Inputs: &ParameterSchema{Type: "object"}, + }, + executeFn: func(ctx context.Context, inputs map[string]interface{}) (map[string]interface{}, error) { + return map[string]interface{}{"result": "success"}, nil + }, + } + + if err := r.Register(nonStreamingTool); err != nil { + t.Fatalf("Register() failed: %v", err) + } + + ctx := context.Background() + chunks, err := r.ExecuteStream(ctx, "non-streaming-tool", map[string]interface{}{}, "test-call-id") + if err != nil { + t.Fatalf("ExecuteStream() failed: %v", err) + } + + // Drain the chunks channel + for range chunks { + } + + // Non-streaming tools should emit exactly one event + if len(emittedEvents) != 1 { + t.Errorf("Expected 1 event from non-streaming tool, got %d", len(emittedEvents)) + } + + // Verify the event + if len(emittedEvents) > 0 { + event := emittedEvents[0] + if event["eventType"] != "tool.output" { + t.Errorf("Event type = %q, want %q", event["eventType"], "tool.output") + } + + data, ok := event["data"].(map[string]any) + if !ok { + t.Fatal("Event data is not a map") + } + + if data["tool_call_id"] != "test-call-id" { + t.Errorf("tool_call_id = %q, want %q", data["tool_call_id"], "test-call-id") + } + if data["tool_name"] != "non-streaming-tool" { + t.Errorf("tool_name = %q, want %q", data["tool_name"], "non-streaming-tool") + } + if data["is_final"] != true { + t.Errorf("is_final = %v, want true", data["is_final"]) + } + } +} + +func TestRegistry_EventEmission_NoEmitterSet(t *testing.T) { + r := NewRegistry() + + // No emitter set - should not panic + + tool := &mockTool{ + name: "test-tool", + schema: &Schema{ + Inputs: &ParameterSchema{Type: "object"}, + }, + } + + if err := r.Register(tool); err != nil { + t.Fatalf("Register() failed: %v", err) + } + + ctx := context.Background() + chunks, err := r.ExecuteStream(ctx, "test-tool", map[string]interface{}{}, "") + if err != nil { + t.Fatalf("ExecuteStream() failed: %v", err) + } + + // Drain the chunks channel - should not panic even without emitter + for range chunks { + } +} + +func TestRegistry_EventEmission_WithContextValues(t *testing.T) { + r := NewRegistry() + + // Track emitted events + var emittedEvents []map[string]any + r.SetEventEmitter(func(ctx context.Context, eventType string, workflowID string, stepID string, data any) { + emittedEvents = append(emittedEvents, map[string]any{ + "eventType": eventType, + "workflowID": workflowID, + "stepID": stepID, + "data": data, + }) + }) + + tool := &mockTool{ + name: "test-tool", + schema: &Schema{ + Inputs: &ParameterSchema{Type: "object"}, + }, + } + + if err := r.Register(tool); err != nil { + t.Fatalf("Register() failed: %v", err) + } + + // Create context with workflow and step IDs + ctx := context.Background() + ctx = context.WithValue(ctx, "workflow_id", "workflow-123") + ctx = context.WithValue(ctx, "step_id", "step-456") + + chunks, err := r.ExecuteStream(ctx, "test-tool", map[string]interface{}{}, "") + if err != nil { + t.Fatalf("ExecuteStream() failed: %v", err) + } + + // Drain the chunks channel + for range chunks { + } + + // Verify context values were passed to emitter + if len(emittedEvents) > 0 { + event := emittedEvents[0] + if event["workflowID"] != "workflow-123" { + t.Errorf("workflowID = %q, want %q", event["workflowID"], "workflow-123") + } + if event["stepID"] != "step-456" { + t.Errorf("stepID = %q, want %q", event["stepID"], "step-456") + } + } +} diff --git a/sdk/events.go b/sdk/events.go index f3edb2e2..e93e32ef 100644 --- a/sdk/events.go +++ b/sdk/events.go @@ -25,6 +25,7 @@ const ( EventAgentToolResult EventType = "agent.tool_result" // Agent tool result EventAgentComplete EventType = "agent.complete" // Agent completion EventTokenUpdate EventType = "token.update" // Token usage update + EventToolOutput EventType = "tool.output" // Streaming tool output chunk ) // Event represents a workflow event. @@ -125,3 +126,14 @@ func fromLLMTokenUsage(usage llm.TokenUsage) TokenUsage { CacheReadTokens: usage.CacheReadTokens, } } + +// ToolOutputEvent is the Data for EventToolOutput. +// It represents a streaming output chunk from a tool execution. +type ToolOutputEvent struct { + ToolCallID string // Links to the tool call + ToolName string // Name of the tool + Stream string // "stdout", "stderr", or "" + Data string // Chunk content + IsFinal bool // True for the final chunk + Metadata map[string]any // Optional metadata +} diff --git a/sdk/tool.go b/sdk/tool.go index f52ce253..9aa4e72d 100644 --- a/sdk/tool.go +++ b/sdk/tool.go @@ -22,6 +22,25 @@ type Tool interface { Execute(ctx context.Context, inputs map[string]any) (map[string]any, error) } +// StreamingTool extends Tool with streaming execution support. +// Tools that implement this interface can emit incremental output chunks during execution, +// enabling real-time progress visibility for long-running operations. +type StreamingTool interface { + Tool + + // ExecuteStream runs the tool and streams output chunks via a channel. + // The channel is closed when execution completes. + // Callers should drain the channel to get the complete result. + // + // Returns: + // - A receive-only channel of ToolChunk values + // - An error if the tool fails to start (startup errors only) + // + // Runtime errors during execution are delivered in the final chunk's Error field. + // Exactly one chunk will have IsFinal=true, which is always the last chunk. + ExecuteStream(ctx context.Context, inputs map[string]any) (<-chan tools.ToolChunk, error) +} + // FuncTool creates a tool from a function. // This is a convenience wrapper for simple tools that don't need complex state. // @@ -79,6 +98,95 @@ func (t *funcTool) Execute(ctx context.Context, inputs map[string]any) (map[stri return t.fn(ctx, inputs) } +// FuncStreamingTool creates a streaming tool from a function. +// This is a convenience wrapper for streaming tools that don't need complex state. +// +// Example: +// +// tool := sdk.FuncStreamingTool( +// "tail_log", +// "Stream log file contents in real-time", +// map[string]any{ +// "type": "object", +// "properties": map[string]any{ +// "file": map[string]any{"type": "string"}, +// }, +// "required": []string{"file"}, +// }, +// func(ctx context.Context, inputs map[string]any) (<-chan tools.ToolChunk, error) { +// file := inputs["file"].(string) +// chunks := make(chan tools.ToolChunk) +// go func() { +// defer close(chunks) +// // ... stream file contents ... +// chunks <- tools.ToolChunk{ +// Data: "log line 1\n", +// Stream: "stdout", +// } +// chunks <- tools.ToolChunk{ +// IsFinal: true, +// Result: map[string]any{"lines": 1}, +// } +// }() +// return chunks, nil +// }, +// ) +func FuncStreamingTool(name, description string, schema map[string]any, fn func(ctx context.Context, inputs map[string]any) (<-chan tools.ToolChunk, error)) StreamingTool { + return &funcStreamingTool{ + name: name, + description: description, + schema: schema, + fn: fn, + } +} + +// funcStreamingTool implements StreamingTool using a simple function. +type funcStreamingTool struct { + name string + description string + schema map[string]any + fn func(ctx context.Context, inputs map[string]any) (<-chan tools.ToolChunk, error) +} + +func (t *funcStreamingTool) Name() string { + return t.name +} + +func (t *funcStreamingTool) Description() string { + return t.description +} + +func (t *funcStreamingTool) InputSchema() map[string]any { + return t.schema +} + +func (t *funcStreamingTool) Execute(ctx context.Context, inputs map[string]any) (map[string]any, error) { + // For non-streaming execution, collect all chunks and return the final result + chunks, err := t.fn(ctx, inputs) + if err != nil { + return nil, err + } + + // Drain the channel and return the final result + var result map[string]any + var execError error + for chunk := range chunks { + if chunk.IsFinal { + result = chunk.Result + execError = chunk.Error + } + } + + if execError != nil { + return nil, execError + } + return result, nil +} + +func (t *funcStreamingTool) ExecuteStream(ctx context.Context, inputs map[string]any) (<-chan tools.ToolChunk, error) { + return t.fn(ctx, inputs) +} + // sdkToolAdapter adapts an SDK Tool to pkg/tools.Tool interface type sdkToolAdapter struct { tool Tool @@ -138,6 +246,30 @@ func (a *sdkToolAdapter) Execute(ctx context.Context, inputs map[string]interfac return a.tool.Execute(ctx, inputs) } +// ExecuteStream implements the pkg/tools.StreamingTool interface by delegating to +// the wrapped SDK tool if it implements SDK StreamingTool. +// This method is only called if the adapter is used as a StreamingTool. +func (a *sdkToolAdapter) ExecuteStream(ctx context.Context, inputs map[string]any) (<-chan tools.ToolChunk, error) { + // Check if the wrapped SDK tool implements streaming + if streamingTool, ok := a.tool.(StreamingTool); ok { + return streamingTool.ExecuteStream(ctx, inputs) + } + + // This should never happen because the registry checks SupportsStreaming first, + // but provide a fallback just in case + chunks := make(chan tools.ToolChunk, 1) + go func() { + defer close(chunks) + result, err := a.tool.Execute(ctx, inputs) + chunks <- tools.ToolChunk{ + IsFinal: true, + Result: result, + Error: err, + } + }() + return chunks, nil +} + // RegisterTool adds a custom tool to the SDK. // The tool will be available to agent loops and LLM steps with tool use. //