Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/providers/claude-adapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ export class ClaudeCliProvider implements LlmProvider {
sessionId: response.sessionId,
usage: response.usage,
durationMs: response.durationMs,
stopReason: response.stopReason,
toolCalls: response.toolCalls,
};
}

Expand Down Expand Up @@ -71,6 +73,8 @@ export class ClaudeCliProvider implements LlmProvider {
sessionId: response.sessionId,
usage: response.usage,
durationMs: response.durationMs,
stopReason: response.stopReason,
toolCalls: response.toolCalls,
};
}
}
26 changes: 24 additions & 2 deletions src/providers/claude-cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ export interface ClaudeResponse {
};
/** Execution duration in milliseconds. */
durationMs: number;
/** Stop reason indicating why generation stopped. */
stopReason?: "end_turn" | "tool_use" | "max_tokens" | "stop_sequence";
/** Tool calls requested by the LLM (present when stopReason is "tool_use"). */
toolCalls?: Array<{ id: string; name: string; input: unknown }>;
}

/**
Expand Down Expand Up @@ -424,6 +428,8 @@ export async function callClaudeStream(
let stderr = "";
let buffer = "";
let eventCount = 0; // Track event count for debugging
const toolCalls: Array<{ id: string; name: string; input: unknown }> = [];
let stopReason: "end_turn" | "tool_use" | "max_tokens" | "stop_sequence" | undefined;

// Process stdout line-by-line
child.stdout.on("data", async (chunk) => {
Expand Down Expand Up @@ -525,12 +531,19 @@ export async function callClaudeStream(
if (event.type === "tool_use") {
// Tool call event
eventHandled = true;
const streamEvent: StreamEvent = {
type: "tool_call",
const toolCall = {
id: event.id || "",
name: event.name || "",
input: event.input || {},
};
toolCalls.push(toolCall);

const streamEvent: StreamEvent = {
type: "tool_call",
id: toolCall.id,
name: toolCall.name,
input: toolCall.input,
};
await onEvent(streamEvent);
} else if (event.type === "tool_result") {
// Tool result event
Expand Down Expand Up @@ -710,12 +723,21 @@ export async function callClaudeStream(
sessionId,
});

// Determine stop reason
if (toolCalls.length > 0) {
stopReason = "tool_use";
} else {
stopReason = "end_turn";
}

return {
type: "success",
result: completeResult,
sessionId,
usage,
durationMs,
stopReason,
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
};
} catch (error) {
// Clear timeout if it was set
Expand Down
196 changes: 179 additions & 17 deletions src/providers/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
* - Default: 5 retry attempts with retry enabled
*/

import { LlmProvider, ProviderRequest, ProviderResponse, StreamEvent, Usage, Message } from "./types.js";
import { LlmProvider, ProviderRequest, ProviderResponse, StreamEvent, Usage, Message, ToolDefinitionLike, ToolCall } from "./types.js";
import { Environment } from "../env/environment.js";
import { log } from "../logger.js";
import { withRetry, DEFAULT_RETRY_CONFIG, RetryConfig, parseRetryAfter, ErrorWithRetryMetadata } from "../retry.js";
import { randomUUID } from "node:crypto";

interface GeminiConfig {
apiKey: string;
Expand All @@ -23,14 +24,29 @@ interface GeminiConfig {

interface GeminiContent {
role: "user" | "model";
parts: { text: string }[];
parts: Array<{ text?: string; functionCall?: { name: string; args: unknown } }>;
}

interface GeminiFunctionDeclaration {
name: string;
description: string;
parameters: {
type: "object";
properties: Record<string, unknown>;
required?: string[];
};
}

interface GeminiTool {
functionDeclarations: GeminiFunctionDeclaration[];
}

interface GeminiRequest {
contents: GeminiContent[];
systemInstruction?: {
parts: { text: string }[];
};
tools?: GeminiTool[];
generationConfig?: {
temperature?: number;
maxOutputTokens?: number;
Expand Down Expand Up @@ -61,6 +77,73 @@ export class GeminiProvider implements LlmProvider {
}));
}

/**
* Convert ToolDefinitionLike to Gemini FunctionDeclaration format.
*/
private convertToolDefinitions(toolDefinitions: ToolDefinitionLike[]): GeminiTool[] {
if (!toolDefinitions || toolDefinitions.length === 0) {
return [];
}

const functionDeclarations: GeminiFunctionDeclaration[] = toolDefinitions.map(tool => ({
name: tool.name,
description: tool.description,
parameters: tool.parameters,
}));

return [{ functionDeclarations }];
}

/**
* Extract tool calls from Gemini response.
*/
private extractToolCalls(response: {
candidates?: Array<{
content?: {
parts?: Array<{
text?: string;
functionCall?: { name: string; args: unknown }
}>
};
finishReason?: string;
}>;
}): { toolCalls: ToolCall[]; stopReason?: "end_turn" | "tool_use" | "max_tokens" | "stop_sequence" } {
const toolCalls: ToolCall[] = [];
let stopReason: "end_turn" | "tool_use" | "max_tokens" | "stop_sequence" | undefined;

const candidate = response.candidates?.[0];
if (!candidate) {
return { toolCalls, stopReason: "end_turn" };
}

// Check finish reason
const finishReason = candidate.finishReason;
if (finishReason === "MAX_TOKENS") {
stopReason = "max_tokens";
} else if (finishReason === "STOP") {
stopReason = "end_turn";
}

// Extract function calls from parts
const parts = candidate.content?.parts || [];
for (const part of parts) {
if (part.functionCall) {
toolCalls.push({
id: `call_${randomUUID()}`, // Generate unique ID
name: part.functionCall.name,
input: part.functionCall.args,
});
}
}

// If we have tool calls, set stop reason to tool_use
if (toolCalls.length > 0) {
stopReason = "tool_use";
}

return { toolCalls, stopReason };
}

/**
* Handle error response and check for Retry-After header on 429 status.
*/
Expand Down Expand Up @@ -111,6 +194,12 @@ export class GeminiProvider implements LlmProvider {
}
};

// Add tool definitions if provided
if (request.toolDefinitions && request.toolDefinitions.length > 0) {
geminiBody.tools = this.convertToolDefinitions(request.toolDefinitions);
log.info('[gemini]', `Added ${request.toolDefinitions.length} tool definition(s)`);
}

try {
log.info("[gemini]", `Request [model=${model}]: ${JSON.stringify(geminiBody.contents.slice(-1))}`);

Expand Down Expand Up @@ -155,7 +244,15 @@ export class GeminiProvider implements LlmProvider {
}

return await response.json() as {
candidates?: Array<{ content?: { parts?: Array<{ text?: string }> } }>;
candidates?: Array<{
content?: {
parts?: Array<{
text?: string;
functionCall?: { name: string; args: unknown };
}>
};
finishReason?: string;
}>;
usageMetadata?: { promptTokenCount?: number; candidatesTokenCount?: number };
};
},
Expand All @@ -166,8 +263,18 @@ export class GeminiProvider implements LlmProvider {

const durationMs = env.clock.now() - startTime;

// Extract text
const resultText = result.candidates?.[0]?.content?.parts?.[0]?.text || "";
// Extract tool calls and stop reason
const { toolCalls, stopReason } = this.extractToolCalls(result);

// Extract text from parts that aren't function calls
const textParts: string[] = [];
const parts = result.candidates?.[0]?.content?.parts || [];
for (const part of parts) {
if (part.text) {
textParts.push(part.text);
}
}
const resultText = textParts.join("");

// Usage stats
const usage: Usage = {
Expand All @@ -177,14 +284,23 @@ export class GeminiProvider implements LlmProvider {
costUsd: 0 // We'd need a pricing table to calculate this
};

return {
const response: ProviderResponse = {
type: "success",
result: resultText,
usage,
durationMs,
sessionId: "" // Stateless API
sessionId: "", // Stateless API
stopReason,
};

// Add tool calls if present
if (toolCalls.length > 0) {
response.toolCalls = toolCalls;
log.info('[gemini]', `Extracted ${toolCalls.length} tool call(s)`);
}

return response;

} catch (error) {
return {
type: "error",
Expand Down Expand Up @@ -214,12 +330,18 @@ export class GeminiProvider implements LlmProvider {
}
};

// Add tool definitions if provided
if (request.toolDefinitions && request.toolDefinitions.length > 0) {
geminiBody.tools = this.convertToolDefinitions(request.toolDefinitions);
log.info('[gemini]', `Added ${request.toolDefinitions.length} tool definition(s) to stream request`);
}

try {
log.info("[gemini]", `Streaming Request [model=${model}]`);

// Execute initial connection with retry, but not the streaming itself
// Once streaming starts, we don't retry mid-stream to avoid duplicate tokens
const response = await withRetry(
const streamResponse = await withRetry(
async () => {
const url = `${this.baseUrl}/${model}:streamGenerateContent?key=${this.config.apiKey}&alt=sse`;
const response = await env.http.fetch(url, {
Expand All @@ -245,16 +367,18 @@ export class GeminiProvider implements LlmProvider {
// Now stream the response without retry (already connected)
let completeResult = "";
let usage: Usage = { inputTokens: 0, outputTokens: 0, cacheReadTokens: 0, costUsd: 0 };
const toolCalls: ToolCall[] = [];
let stopReason: "end_turn" | "tool_use" | "max_tokens" | "stop_sequence" | undefined;

// Simple SSE parser
// Node's native fetch response body is a ReadableStream
// We need to read from it.

// In Node 22+ with native fetch, response.body is a Web ReadableStream.
// We can iterate it.
if (!response.body) throw new Error("No response body");
if (!streamResponse.body) throw new Error("No response body");

const reader = response.body.getReader();
const reader = streamResponse.body.getReader();
const decoder = new TextDecoder();
let buffer = "";

Expand All @@ -274,11 +398,40 @@ export class GeminiProvider implements LlmProvider {
try {
const chunk = JSON.parse(jsonStr);

// Extract content delta
const textDelta = chunk.candidates?.[0]?.content?.parts?.[0]?.text;
if (textDelta) {
completeResult += textDelta;
await onEvent({ type: "token", text: textDelta });
// Extract parts from the chunk
const parts = chunk.candidates?.[0]?.content?.parts || [];

for (const part of parts) {
// Handle text delta
if (part.text) {
completeResult += part.text;
await onEvent({ type: "token", text: part.text });
}

// Handle function calls
if (part.functionCall) {
const toolCall: ToolCall = {
id: `call_${randomUUID()}`,
name: part.functionCall.name,
input: part.functionCall.args,
};
toolCalls.push(toolCall);

await onEvent({
type: "tool_call",
id: toolCall.id,
name: toolCall.name,
input: toolCall.input,
});
}
}

// Check finish reason
const finishReason = chunk.candidates?.[0]?.finishReason;
if (finishReason === "MAX_TOKENS") {
stopReason = "max_tokens";
} else if (finishReason === "STOP") {
stopReason = toolCalls.length > 0 ? "tool_use" : "end_turn";
}

// Extract usage from the final chunk usually
Expand All @@ -302,13 +455,22 @@ export class GeminiProvider implements LlmProvider {
durationMs
});

return {
const response: ProviderResponse = {
type: "success",
result: completeResult,
usage,
durationMs
durationMs,
stopReason,
};

// Add tool calls if present
if (toolCalls.length > 0) {
response.toolCalls = toolCalls;
log.info('[gemini]', `Stream extracted ${toolCalls.length} tool call(s)`);
}

return response;

} catch (error) {
const msg = error instanceof Error ? error.message : String(error);
await onEvent({ type: "error", message: msg });
Expand Down
Loading
Loading