Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 92 additions & 8 deletions pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,37 @@ type ToolExecution struct {

// DurationMs is the duration in milliseconds (for spec compliance)
DurationMs int

// OutputChunks contains streaming output chunks from the tool execution
OutputChunks []ToolOutputChunk
}

// ToolOutputChunk represents a streaming output chunk from a tool execution.
type ToolOutputChunk struct {
// ToolCallID links to the tool call
ToolCallID string

// ToolName is the name of the tool
ToolName string

// Stream identifies the output stream ("stdout", "stderr", or "")
Stream string

// Data is the chunk content
Data string

// IsFinal indicates this is the last chunk
IsFinal bool

// Metadata contains optional metadata
Metadata map[string]interface{}
}

// StepContext contains contextual information available during agent execution steps.
// This context accumulates data across iterations and is available for reasoning.
type StepContext struct {
// ToolOutputChunks contains all streaming output chunks from tool executions
ToolOutputChunks []ToolOutputChunk
}

// NewAgent creates a new agent.
Expand Down Expand Up @@ -216,6 +247,11 @@ func (a *Agent) Run(ctx context.Context, systemPrompt string, userPrompt string)
ToolExecutions: []ToolExecution{},
}

// Initialize step context
stepContext := &StepContext{
ToolOutputChunks: []ToolOutputChunk{},
}

// Initialize conversation with system and user messages
messages := []Message{
{Role: "system", Content: systemPrompt},
Expand Down Expand Up @@ -274,7 +310,7 @@ func (a *Agent) Run(ctx context.Context, systemPrompt string, userPrompt string)
// Execute tool calls if any
if len(response.ToolCalls) > 0 {
for _, toolCall := range response.ToolCalls {
execution := a.executeTool(ctx, toolCall)
execution := a.executeTool(ctx, toolCall, stepContext)
result.ToolExecutions = append(result.ToolExecutions, execution)

// Add tool result to conversation
Expand Down Expand Up @@ -313,11 +349,12 @@ func (a *Agent) Run(ctx context.Context, systemPrompt string, userPrompt string)
return result, fmt.Errorf("max iterations reached")
}

// executeTool executes a single tool call.
func (a *Agent) executeTool(ctx context.Context, toolCall ToolCall) ToolExecution {
// executeTool executes a single tool call using streaming execution.
func (a *Agent) executeTool(ctx context.Context, toolCall ToolCall, stepContext *StepContext) ToolExecution {
startTime := time.Now()
execution := ToolExecution{
ToolName: toolCall.Name,
ToolName: toolCall.Name,
OutputChunks: []ToolOutputChunk{},
}

// Parse arguments
Expand All @@ -340,15 +377,62 @@ func (a *Agent) executeTool(ctx context.Context, toolCall ToolCall) ToolExecutio

execution.Inputs = inputs

// Execute tool
outputs, err := a.registry.Execute(ctx, toolCall.Name, inputs)
// Execute tool with streaming support
chunks, err := a.registry.ExecuteStream(ctx, toolCall.Name, inputs, toolCall.ID)
if err != nil {
execution.Success = false
execution.Status = "error"
execution.Error = err.Error()
execution.Duration = time.Since(startTime)
execution.DurationMs = int(execution.Duration.Milliseconds())
return execution
}

// Process streaming chunks
var outputs map[string]interface{}
var execError error

for chunk := range chunks {
// Create output chunk for this execution
outputChunk := ToolOutputChunk{
ToolCallID: toolCall.ID,
ToolName: toolCall.Name,
Stream: chunk.Stream,
Data: chunk.Data,
IsFinal: chunk.IsFinal,
Metadata: chunk.Metadata,
}

// Store chunk in execution and step context
execution.OutputChunks = append(execution.OutputChunks, outputChunk)
stepContext.ToolOutputChunks = append(stepContext.ToolOutputChunks, outputChunk)

// Emit event via callback if configured
if a.eventCallback != nil {
a.eventCallback("tool.output", map[string]interface{}{
"tool_call_id": toolCall.ID,
"tool_name": toolCall.Name,
"stream": chunk.Stream,
"data": chunk.Data,
"is_final": chunk.IsFinal,
"metadata": chunk.Metadata,
})
}

// Extract final result
if chunk.IsFinal {
outputs = chunk.Result
execError = chunk.Error
}
}

execution.Duration = time.Since(startTime)
execution.DurationMs = int(execution.Duration.Milliseconds())

if err != nil {
if execError != nil {
execution.Success = false
execution.Status = "error"
execution.Error = err.Error()
execution.Error = execError.Error()
return execution
}

Expand Down
190 changes: 190 additions & 0 deletions pkg/agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -560,3 +560,193 @@ func (m *mockStreamingLLMProvider) Stream(ctx context.Context, messages []Messag
close(ch)
return ch, nil
}

// mockStreamingTool implements StreamingTool for testing
type mockStreamingTool struct {
name string
chunks []tools.ToolChunk
}

func (m *mockStreamingTool) Name() string {
return m.name
}

func (m *mockStreamingTool) Description() string {
return "A mock streaming tool"
}

func (m *mockStreamingTool) Schema() *tools.Schema {
return &tools.Schema{
Inputs: &tools.ParameterSchema{
Type: "object",
},
Outputs: &tools.ParameterSchema{
Type: "object",
},
}
}

func (m *mockStreamingTool) Execute(ctx context.Context, inputs map[string]interface{}) (map[string]interface{}, error) {
// For non-streaming execution, collect all chunks and return the final result
ch, err := m.ExecuteStream(ctx, inputs)
if err != nil {
return nil, err
}

var result map[string]interface{}
var execError error
for chunk := range ch {
if chunk.IsFinal {
result = chunk.Result
execError = chunk.Error
}
}

if execError != nil {
return nil, execError
}
return result, nil
}

func (m *mockStreamingTool) ExecuteStream(ctx context.Context, inputs map[string]interface{}) (<-chan tools.ToolChunk, error) {
ch := make(chan tools.ToolChunk, len(m.chunks))
go func() {
defer close(ch)
for _, chunk := range m.chunks {
ch <- chunk
}
}()
return ch, nil
}

func TestAgent_ToolStreamingExecution(t *testing.T) {
// Create a streaming tool that emits chunks
streamingTool := &mockStreamingTool{
name: "streaming-tool",
chunks: []tools.ToolChunk{
{
Data: "Line 1\n",
Stream: "stdout",
},
{
Data: "Line 2\n",
Stream: "stdout",
},
{
Data: "Error message\n",
Stream: "stderr",
},
{
IsFinal: true,
Result: map[string]interface{}{
"exit_code": 0,
"duration": 100,
},
},
},
}

registry := tools.NewRegistry()
if err := registry.Register(streamingTool); err != nil {
t.Fatalf("Failed to register tool: %v", err)
}

llm := &mockLLMProvider{
responses: []Response{
{
Content: "Using streaming tool",
FinishReason: "tool_calls",
ToolCalls: []ToolCall{
{
ID: "call-1",
Name: "streaming-tool",
Arguments: map[string]interface{}{},
},
},
Usage: TokenUsage{TotalTokens: 10},
},
{
Content: "Completed with streaming output",
FinishReason: "stop",
Usage: TokenUsage{TotalTokens: 10},
},
},
}

// Track events emitted via callback
var capturedEvents []map[string]interface{}
agent := NewAgent(llm, registry).WithEventCallback(func(eventType string, data interface{}) {
if eventType == "tool.output" {
if eventData, ok := data.(map[string]interface{}); ok {
capturedEvents = append(capturedEvents, eventData)
}
}
})

ctx := context.Background()
result, err := agent.Run(ctx, "System", "Task")
if err != nil {
t.Fatalf("Run() error = %v", err)
}

// Verify tool execution
if len(result.ToolExecutions) != 1 {
t.Fatalf("ToolExecutions count = %d, want 1", len(result.ToolExecutions))
}

execution := result.ToolExecutions[0]

// Verify output chunks are captured in execution
if len(execution.OutputChunks) != 4 {
t.Errorf("OutputChunks count = %d, want 4", len(execution.OutputChunks))
}

// Verify chunk content
if execution.OutputChunks[0].Data != "Line 1\n" {
t.Errorf("Chunk 0 data = %q, want %q", execution.OutputChunks[0].Data, "Line 1\n")
}
if execution.OutputChunks[0].Stream != "stdout" {
t.Errorf("Chunk 0 stream = %q, want %q", execution.OutputChunks[0].Stream, "stdout")
}

if execution.OutputChunks[2].Stream != "stderr" {
t.Errorf("Chunk 2 stream = %q, want %q", execution.OutputChunks[2].Stream, "stderr")
}

// Verify final chunk
if !execution.OutputChunks[3].IsFinal {
t.Error("Last chunk should have IsFinal=true")
}

// Verify events were emitted
if len(capturedEvents) != 4 {
t.Errorf("Captured %d events, want 4", len(capturedEvents))
}

// Verify event structure
if len(capturedEvents) > 0 {
firstEvent := capturedEvents[0]
if firstEvent["tool_name"] != "streaming-tool" {
t.Errorf("Event tool_name = %q, want %q", firstEvent["tool_name"], "streaming-tool")
}
if firstEvent["tool_call_id"] != "call-1" {
t.Errorf("Event tool_call_id = %q, want %q", firstEvent["tool_call_id"], "call-1")
}
if firstEvent["data"] != "Line 1\n" {
t.Errorf("Event data = %q, want %q", firstEvent["data"], "Line 1\n")
}
}

// Verify execution succeeded
if !execution.Success {
t.Error("Tool execution should have succeeded")
}

// Verify final result
if execution.Outputs == nil {
t.Error("Execution outputs should not be nil")
}
if exitCode, ok := execution.Outputs["exit_code"].(int); !ok || exitCode != 0 {
t.Errorf("Exit code = %v, want 0", execution.Outputs["exit_code"])
}
}
Loading
Loading