From 34b0e7ad0936e0fb3c4a6fc2dbee7a606cb2c4df Mon Sep 17 00:00:00 2001 From: at384 Date: Fri, 27 Feb 2026 01:11:56 +0100 Subject: [PATCH] feat(drivers): add Vertex AI driver with OAuth authentication Adds a new vertex-ai provider that authenticates using GCP service account credentials via OAuth 2.0 bearer tokens. This enables enterprise deployments with existing GCP subscriptions instead of requiring separate API keys. Features: - VertexAIDriver with full streaming support - Token caching (50 min) with auto-refresh via gcloud CLI - Auto-detection of project_id from service account JSON - Security: tokens stored with Zeroizing - Provider aliases: vertex-ai, vertex, google-vertex - 8 unit tests Usage: export GOOGLE_APPLICATION_CREDENTIALS=/path/to/sa.json # Set provider=vertex-ai, model=gemini-2.0-flash in config.toml Closes #11 --- crates/openfang-runtime/src/drivers/mod.rs | 67 +- crates/openfang-runtime/src/drivers/vertex.rs | 789 ++++++++++++++++++ 2 files changed, 851 insertions(+), 5 deletions(-) create mode 100644 crates/openfang-runtime/src/drivers/vertex.rs diff --git a/crates/openfang-runtime/src/drivers/mod.rs b/crates/openfang-runtime/src/drivers/mod.rs index dfbebf4..40a24b0 100644 --- a/crates/openfang-runtime/src/drivers/mod.rs +++ b/crates/openfang-runtime/src/drivers/mod.rs @@ -9,15 +9,16 @@ pub mod copilot; pub mod fallback; pub mod gemini; pub mod openai; +pub mod vertex; use crate::llm_driver::{DriverConfig, LlmDriver, LlmError}; use openfang_types::model_catalog::{ AI21_BASE_URL, ANTHROPIC_BASE_URL, CEREBRAS_BASE_URL, COHERE_BASE_URL, DEEPSEEK_BASE_URL, FIREWORKS_BASE_URL, GEMINI_BASE_URL, GROQ_BASE_URL, HUGGINGFACE_BASE_URL, LMSTUDIO_BASE_URL, MINIMAX_BASE_URL, MISTRAL_BASE_URL, MOONSHOT_BASE_URL, OLLAMA_BASE_URL, OPENAI_BASE_URL, - OPENROUTER_BASE_URL, PERPLEXITY_BASE_URL, QIANFAN_BASE_URL, QWEN_BASE_URL, - REPLICATE_BASE_URL, SAMBANOVA_BASE_URL, TOGETHER_BASE_URL, VLLM_BASE_URL, XAI_BASE_URL, - ZHIPU_BASE_URL, ZHIPU_CODING_BASE_URL, + OPENROUTER_BASE_URL, PERPLEXITY_BASE_URL, QIANFAN_BASE_URL, QWEN_BASE_URL, REPLICATE_BASE_URL, + SAMBANOVA_BASE_URL, TOGETHER_BASE_URL, VLLM_BASE_URL, XAI_BASE_URL, ZHIPU_BASE_URL, + ZHIPU_CODING_BASE_URL, }; use std::sync::Arc; @@ -162,6 +163,12 @@ fn provider_defaults(provider: &str) -> Option { api_key_env: "QIANFAN_API_KEY", key_required: true, }), + "vertex-ai" | "vertex" | "google-vertex" => Some(ProviderDefaults { + // Vertex AI uses OAuth, not API keys - base_url is per-project + base_url: "https://us-central1-aiplatform.googleapis.com", + api_key_env: "GOOGLE_APPLICATION_CREDENTIALS", + key_required: false, // Uses OAuth service account, not API key + }), _ => None, } } @@ -250,6 +257,39 @@ pub fn create_driver(config: &DriverConfig) -> Result, LlmErr ))); } + // Vertex AI — uses Google Cloud OAuth with service account credentials. + // Requires GOOGLE_APPLICATION_CREDENTIALS env var pointing to service account JSON, + // and the service account must be activated via gcloud CLI. + if provider == "vertex-ai" || provider == "vertex" || provider == "google-vertex" { + // Get project_id from environment or service account JSON + let project_id = std::env::var("GOOGLE_CLOUD_PROJECT") + .or_else(|_| std::env::var("GCLOUD_PROJECT")) + .or_else(|_| std::env::var("GCP_PROJECT")) + .or_else(|_| { + // Try to read from service account JSON + if let Ok(creds_path) = std::env::var("GOOGLE_APPLICATION_CREDENTIALS") { + if let Ok(contents) = std::fs::read_to_string(&creds_path) { + if let Ok(json) = serde_json::from_str::(&contents) { + if let Some(proj) = json.get("project_id").and_then(|v| v.as_str()) { + return Ok(proj.to_string()); + } + } + } + } + Err(std::env::VarError::NotPresent) + }) + .map_err(|_| { + LlmError::MissingApiKey( + "Set GOOGLE_APPLICATION_CREDENTIALS or GOOGLE_CLOUD_PROJECT for Vertex AI" + .to_string(), + ) + })?; + let region = std::env::var("GOOGLE_CLOUD_REGION") + .or_else(|_| std::env::var("VERTEX_AI_REGION")) + .unwrap_or_else(|_| "us-central1".to_string()); + return Ok(Arc::new(vertex::VertexAIDriver::new(project_id, region))); + } + // All other providers use OpenAI-compatible format if let Some(defaults) = provider_defaults(provider) { let api_key = config @@ -287,8 +327,8 @@ pub fn create_driver(config: &DriverConfig) -> Result, LlmErr message: format!( "Unknown provider '{}'. Supported: anthropic, gemini, openai, groq, openrouter, \ deepseek, together, mistral, fireworks, ollama, vllm, lmstudio, perplexity, \ - cohere, ai21, cerebras, sambanova, huggingface, xai, replicate, github-copilot. \ - Or set base_url for a custom OpenAI-compatible endpoint.", + cohere, ai21, cerebras, sambanova, huggingface, xai, replicate, github-copilot, \ + vertex-ai. Or set base_url for a custom OpenAI-compatible endpoint.", provider ), }) @@ -318,6 +358,7 @@ pub fn known_providers() -> &'static [&'static str] { "xai", "replicate", "github-copilot", + "vertex-ai", "moonshot", "qwen", "minimax", @@ -411,6 +452,7 @@ mod tests { assert!(providers.contains(&"xai")); assert!(providers.contains(&"replicate")); assert!(providers.contains(&"github-copilot")); + assert!(providers.contains(&"vertex-ai")); assert!(providers.contains(&"moonshot")); assert!(providers.contains(&"qwen")); assert!(providers.contains(&"minimax")); @@ -457,4 +499,19 @@ mod tests { assert_eq!(d.api_key_env, "HF_API_KEY"); assert!(d.key_required); } + + #[test] + fn test_provider_defaults_vertex_ai() { + let d = provider_defaults("vertex-ai").unwrap(); + assert_eq!(d.base_url, "https://us-central1-aiplatform.googleapis.com"); + assert_eq!(d.api_key_env, "GOOGLE_APPLICATION_CREDENTIALS"); + assert!(!d.key_required); // Uses OAuth, not API key + } + + #[test] + fn test_provider_defaults_vertex_alias() { + let d = provider_defaults("vertex").unwrap(); + assert_eq!(d.api_key_env, "GOOGLE_APPLICATION_CREDENTIALS"); + assert!(!d.key_required); + } } diff --git a/crates/openfang-runtime/src/drivers/vertex.rs b/crates/openfang-runtime/src/drivers/vertex.rs new file mode 100644 index 0000000..4815694 --- /dev/null +++ b/crates/openfang-runtime/src/drivers/vertex.rs @@ -0,0 +1,789 @@ +//! Google Vertex AI driver with OAuth authentication. +//! +//! Uses service account credentials (`GOOGLE_APPLICATION_CREDENTIALS`) to +//! authenticate with Vertex AI's Gemini models via OAuth 2.0 bearer tokens. +//! This enables enterprise deployments without requiring consumer API keys. +//! +//! # Endpoint Format +//! +//! ```text +//! https://{region}-aiplatform.googleapis.com/v1/projects/{project}/locations/{region}/publishers/google/models/{model}:generateContent +//! ``` +//! +//! # Authentication +//! +//! Uses OAuth 2.0 bearer tokens obtained via `gcloud auth print-access-token`. +//! Tokens are cached for 50 minutes and automatically refreshed. +//! +//! # Environment Variables +//! +//! - `GOOGLE_APPLICATION_CREDENTIALS` — Path to service account JSON +//! - `GOOGLE_CLOUD_PROJECT` / `GCLOUD_PROJECT` / `GCP_PROJECT` — Project ID (optional if in credentials) +//! - `GOOGLE_CLOUD_REGION` / `VERTEX_AI_REGION` — Region (default: `us-central1`) +//! - `VERTEX_AI_ACCESS_TOKEN` — Pre-generated token (optional, for testing) + +use crate::llm_driver::{CompletionRequest, CompletionResponse, LlmDriver, LlmError, StreamEvent}; +use async_trait::async_trait; +use futures::StreamExt; +use openfang_types::message::{ + ContentBlock, Message, MessageContent, Role, StopReason, TokenUsage, +}; +use openfang_types::tool::ToolCall; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::RwLock; +use tracing::{debug, info, warn}; +use zeroize::Zeroizing; + +/// Vertex AI driver with OAuth authentication. +/// +/// Authenticates using GCP service account credentials and OAuth 2.0 bearer tokens. +/// Tokens are cached with automatic refresh before expiry. +pub struct VertexAIDriver { + project_id: String, + region: String, + /// Cached OAuth access token (zeroized on drop for security). + token_cache: Arc>, + client: reqwest::Client, +} + +/// Cached OAuth token with expiry tracking. +/// +/// SECURITY: Token is wrapped in `Zeroizing` to clear memory on drop. +struct TokenCache { + token: Option>, + expires_at: Option, +} + +impl TokenCache { + fn new() -> Self { + Self { + token: None, + expires_at: None, + } + } + + fn is_valid(&self) -> bool { + match (&self.token, &self.expires_at) { + (Some(_), Some(expires)) => Instant::now() < *expires, + _ => false, + } + } + + fn get(&self) -> Option { + if self.is_valid() { + self.token.as_ref().map(|t| t.as_str().to_string()) + } else { + None + } + } +} + +impl VertexAIDriver { + /// Create a new Vertex AI driver. + /// + /// # Arguments + /// * `project_id` - GCP project ID + /// * `region` - GCP region (e.g., `us-central1`) + pub fn new(project_id: String, region: String) -> Self { + Self { + project_id, + region, + token_cache: Arc::new(RwLock::new(TokenCache::new())), + client: reqwest::Client::new(), + } + } + + /// Get a valid OAuth access token, refreshing if needed. + async fn get_access_token(&self) -> Result { + // Check cache first + { + let cache = self.token_cache.read().await; + if let Some(token) = cache.get() { + debug!("Using cached Vertex AI access token"); + return Ok(token); + } + } + + // Need to refresh token + info!("Refreshing Vertex AI OAuth access token"); + let token = self.fetch_access_token().await?; + + // Cache the token (expires in ~1 hour, we refresh at 50 min) + { + let mut cache = self.token_cache.write().await; + cache.token = Some(Zeroizing::new(token.clone())); + cache.expires_at = Some(Instant::now() + Duration::from_secs(50 * 60)); + } + + Ok(token) + } + + /// Fetch a new access token using gcloud CLI. + /// + /// This uses the service account specified in GOOGLE_APPLICATION_CREDENTIALS + /// via the gcloud CLI. For production, this should use the google-auth library. + async fn fetch_access_token(&self) -> Result { + // First, check if a pre-generated token is available in env + if let Ok(token) = std::env::var("VERTEX_AI_ACCESS_TOKEN") { + if !token.is_empty() { + debug!("Using pre-set VERTEX_AI_ACCESS_TOKEN"); + return Ok(token); + } + } + + // Try application-default credentials first (uses GOOGLE_APPLICATION_CREDENTIALS) + let output = tokio::process::Command::new("gcloud") + .args(["auth", "application-default", "print-access-token"]) + .output() + .await; + + if let Ok(output) = output { + if output.status.success() { + let token = String::from_utf8_lossy(&output.stdout).trim().to_string(); + if !token.is_empty() { + debug!("Successfully obtained Vertex AI access token via application-default"); + return Ok(token); + } + } + } + + // Fall back to regular gcloud auth (requires activated service account) + let output = tokio::process::Command::new("gcloud") + .args(["auth", "print-access-token"]) + .output() + .await + .map_err(|e| LlmError::MissingApiKey(format!("Failed to run gcloud: {}", e)))?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(LlmError::MissingApiKey(format!( + "gcloud auth failed: {}. Ensure GOOGLE_APPLICATION_CREDENTIALS is set and \ + run: gcloud auth activate-service-account --key-file=$GOOGLE_APPLICATION_CREDENTIALS", + stderr.trim() + ))); + } + + let token = String::from_utf8_lossy(&output.stdout).trim().to_string(); + if token.is_empty() { + return Err(LlmError::MissingApiKey( + "Empty access token from gcloud".to_string(), + )); + } + + debug!("Successfully obtained Vertex AI access token"); + Ok(token) + } + + /// Build the Vertex AI endpoint URL for a model. + fn build_endpoint(&self, model: &str, streaming: bool) -> String { + // Strip any "gemini-" prefix duplications + let model_name = model.strip_prefix("models/").unwrap_or(model); + + let method = if streaming { + "streamGenerateContent" + } else { + "generateContent" + }; + + format!( + "https://{region}-aiplatform.googleapis.com/v1/projects/{project}/locations/{region}/publishers/google/models/{model}:{method}", + region = self.region, + project = self.project_id, + model = model_name, + method = method + ) + } +} + +// ── Request types (reusing Gemini format) ────────────────────────────── + +/// Top-level Gemini/Vertex API request body. +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +struct VertexRequest { + contents: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + system_instruction: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + tools: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + generation_config: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +struct VertexContent { + #[serde(skip_serializing_if = "Option::is_none")] + role: Option, + parts: Vec, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(untagged)] +enum VertexPart { + Text { + text: String, + }, + InlineData { + #[serde(rename = "inlineData")] + inline_data: VertexInlineData, + }, + FunctionCall { + #[serde(rename = "functionCall")] + function_call: VertexFunctionCallData, + }, + FunctionResponse { + #[serde(rename = "functionResponse")] + function_response: VertexFunctionResponseData, + }, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +struct VertexInlineData { + #[serde(rename = "mimeType")] + mime_type: String, + data: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +struct VertexFunctionCallData { + name: String, + args: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +struct VertexFunctionResponseData { + name: String, + response: serde_json::Value, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +struct VertexToolConfig { + function_declarations: Vec, +} + +#[derive(Debug, Serialize)] +struct VertexFunctionDeclaration { + name: String, + description: String, + parameters: serde_json::Value, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +struct GenerationConfig { + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + max_output_tokens: Option, +} + +// ── Response types ───────────────────────────────────────────────────── + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct VertexResponse { + #[serde(default)] + candidates: Vec, + #[serde(default)] + usage_metadata: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct VertexCandidate { + content: Option, + #[serde(default)] + finish_reason: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct VertexUsageMetadata { + #[serde(default)] + prompt_token_count: u64, + #[serde(default)] + candidates_token_count: u64, +} + +#[derive(Debug, Deserialize)] +struct VertexErrorResponse { + error: VertexErrorDetail, +} + +#[derive(Debug, Deserialize)] +struct VertexErrorDetail { + message: String, +} + +// ── Message conversion ───────────────────────────────────────────────── + +fn convert_messages( + messages: &[Message], + system: &Option, +) -> (Vec, Option) { + let mut contents = Vec::new(); + + let system_instruction = extract_system(messages, system); + + for msg in messages { + if msg.role == Role::System { + continue; + } + + let role = match msg.role { + Role::User => "user", + Role::Assistant => "model", + Role::System => continue, + }; + + let parts = match &msg.content { + MessageContent::Text(text) => vec![VertexPart::Text { text: text.clone() }], + MessageContent::Blocks(blocks) => { + let mut parts = Vec::new(); + for block in blocks { + match block { + ContentBlock::Text { text } => { + parts.push(VertexPart::Text { text: text.clone() }); + } + ContentBlock::ToolUse { name, input, .. } => { + parts.push(VertexPart::FunctionCall { + function_call: VertexFunctionCallData { + name: name.clone(), + args: input.clone(), + }, + }); + } + ContentBlock::Image { media_type, data } => { + parts.push(VertexPart::InlineData { + inline_data: VertexInlineData { + mime_type: media_type.clone(), + data: data.clone(), + }, + }); + } + ContentBlock::ToolResult { content, .. } => { + parts.push(VertexPart::FunctionResponse { + function_response: VertexFunctionResponseData { + name: String::new(), + response: serde_json::json!({ "result": content }), + }, + }); + } + ContentBlock::Thinking { .. } => {} + _ => {} + } + } + parts + } + }; + + if !parts.is_empty() { + contents.push(VertexContent { + role: Some(role.to_string()), + parts, + }); + } + } + + (contents, system_instruction) +} + +fn extract_system(messages: &[Message], system: &Option) -> Option { + let text = system.clone().or_else(|| { + messages.iter().find_map(|m| { + if m.role == Role::System { + match &m.content { + MessageContent::Text(t) => Some(t.clone()), + _ => None, + } + } else { + None + } + }) + })?; + + Some(VertexContent { + role: None, + parts: vec![VertexPart::Text { text }], + }) +} + +fn convert_tools(request: &CompletionRequest) -> Vec { + if request.tools.is_empty() { + return Vec::new(); + } + + let declarations: Vec = request + .tools + .iter() + .map(|t| { + let normalized = + openfang_types::tool::normalize_schema_for_provider(&t.input_schema, "gemini"); + VertexFunctionDeclaration { + name: t.name.clone(), + description: t.description.clone(), + parameters: normalized, + } + }) + .collect(); + + vec![VertexToolConfig { + function_declarations: declarations, + }] +} + +fn convert_response(resp: VertexResponse) -> Result { + let candidate = resp + .candidates + .into_iter() + .next() + .ok_or_else(|| LlmError::Parse("No candidates in Vertex AI response".to_string()))?; + + let mut content = Vec::new(); + let mut tool_calls = Vec::new(); + + if let Some(vertex_content) = candidate.content { + for part in vertex_content.parts { + match part { + VertexPart::Text { text } => { + content.push(ContentBlock::Text { text }); + } + VertexPart::FunctionCall { function_call } => { + tool_calls.push(ToolCall { + id: format!("call_{}", &uuid::Uuid::new_v4().to_string()[..8]), + name: function_call.name, + input: function_call.args, + }); + } + _ => {} + } + } + } + + let stop_reason = match candidate.finish_reason.as_deref() { + Some("STOP") => StopReason::EndTurn, + Some("MAX_TOKENS") => StopReason::MaxTokens, + Some("SAFETY") | Some("RECITATION") | Some("BLOCKLIST") => StopReason::EndTurn, + _ if !tool_calls.is_empty() => StopReason::ToolUse, + _ => StopReason::EndTurn, + }; + + let usage = resp + .usage_metadata + .map(|u| TokenUsage { + input_tokens: u.prompt_token_count, + output_tokens: u.candidates_token_count, + }) + .unwrap_or_default(); + + Ok(CompletionResponse { + content, + stop_reason, + tool_calls, + usage, + }) +} + +// ── LlmDriver implementation ────────────────────────────────────────── + +#[async_trait] +impl LlmDriver for VertexAIDriver { + async fn complete(&self, request: CompletionRequest) -> Result { + let (contents, system_instruction) = convert_messages(&request.messages, &request.system); + let tools = convert_tools(&request); + + let vertex_request = VertexRequest { + contents, + system_instruction, + tools, + generation_config: Some(GenerationConfig { + temperature: Some(request.temperature), + max_output_tokens: Some(request.max_tokens), + }), + }; + + let access_token = self.get_access_token().await?; + + let max_retries = 3; + for attempt in 0..=max_retries { + let url = self.build_endpoint(&request.model, false); + debug!(url = %url, attempt, "Sending Vertex AI request"); + + let resp = self + .client + .post(&url) + .header("Authorization", format!("Bearer {}", access_token)) + .header("Content-Type", "application/json") + .json(&vertex_request) + .send() + .await + .map_err(|e| LlmError::Http(e.to_string()))?; + + let status = resp.status().as_u16(); + + if status == 429 || status == 503 { + if attempt < max_retries { + let retry_ms = (attempt + 1) as u64 * 2000; + warn!(status, retry_ms, "Rate limited/overloaded, retrying"); + tokio::time::sleep(std::time::Duration::from_millis(retry_ms)).await; + continue; + } + return Err(if status == 429 { + LlmError::RateLimited { + retry_after_ms: 5000, + } + } else { + LlmError::Overloaded { + retry_after_ms: 5000, + } + }); + } + + if !resp.status().is_success() { + let body = resp.text().await.unwrap_or_default(); + let message = serde_json::from_str::(&body) + .map(|e| e.error.message) + .unwrap_or(body); + return Err(LlmError::Api { status, message }); + } + + let body = resp + .text() + .await + .map_err(|e| LlmError::Http(e.to_string()))?; + let vertex_response: VertexResponse = + serde_json::from_str(&body).map_err(|e| LlmError::Parse(e.to_string()))?; + + return convert_response(vertex_response); + } + + Err(LlmError::Api { + status: 0, + message: "Max retries exceeded".to_string(), + }) + } + + async fn stream( + &self, + request: CompletionRequest, + tx: tokio::sync::mpsc::Sender, + ) -> Result { + let (contents, system_instruction) = convert_messages(&request.messages, &request.system); + let tools = convert_tools(&request); + + let vertex_request = VertexRequest { + contents, + system_instruction, + tools, + generation_config: Some(GenerationConfig { + temperature: Some(request.temperature), + max_output_tokens: Some(request.max_tokens), + }), + }; + + let access_token = self.get_access_token().await?; + + let max_retries = 3; + for attempt in 0..=max_retries { + let url = format!("{}?alt=sse", self.build_endpoint(&request.model, true)); + debug!(url = %url, attempt, "Sending Vertex AI streaming request"); + + let resp = self + .client + .post(&url) + .header("Authorization", format!("Bearer {}", access_token)) + .header("Content-Type", "application/json") + .json(&vertex_request) + .send() + .await + .map_err(|e| LlmError::Http(e.to_string()))?; + + let status = resp.status().as_u16(); + + if status == 429 || status == 503 { + if attempt < max_retries { + let retry_ms = (attempt + 1) as u64 * 2000; + warn!( + status, + retry_ms, "Rate limited/overloaded (stream), retrying" + ); + tokio::time::sleep(std::time::Duration::from_millis(retry_ms)).await; + continue; + } + return Err(if status == 429 { + LlmError::RateLimited { + retry_after_ms: 5000, + } + } else { + LlmError::Overloaded { + retry_after_ms: 5000, + } + }); + } + + if !resp.status().is_success() { + let body = resp.text().await.unwrap_or_default(); + let message = serde_json::from_str::(&body) + .map(|e| e.error.message) + .unwrap_or(body); + return Err(LlmError::Api { status, message }); + } + + // Process SSE stream + let mut byte_stream = resp.bytes_stream(); + let mut buffer = String::new(); + let mut accumulated_text = String::new(); + let mut final_tool_calls = Vec::new(); + let mut final_usage = None; + + while let Some(chunk_result) = byte_stream.next().await { + let chunk = chunk_result.map_err(|e| LlmError::Http(e.to_string()))?; + buffer.push_str(&String::from_utf8_lossy(&chunk)); + + // Process complete lines + while let Some(line_end) = buffer.find('\n') { + let line = buffer[..line_end].trim().to_string(); + buffer = buffer[line_end + 1..].to_string(); + + if line.is_empty() || !line.starts_with("data: ") { + continue; + } + + let json_str = &line[6..]; + if json_str == "[DONE]" { + break; + } + + if let Ok(resp) = serde_json::from_str::(json_str) { + if let Some(candidate) = resp.candidates.into_iter().next() { + if let Some(content) = candidate.content { + for part in content.parts { + match part { + VertexPart::Text { text } => { + accumulated_text.push_str(&text); + let _ = tx.send(StreamEvent::TextDelta { text }).await; + } + VertexPart::FunctionCall { function_call } => { + final_tool_calls.push(ToolCall { + id: format!( + "call_{}", + &uuid::Uuid::new_v4().to_string()[..8] + ), + name: function_call.name, + input: function_call.args, + }); + } + _ => {} + } + } + } + } + if let Some(usage) = resp.usage_metadata { + final_usage = Some(TokenUsage { + input_tokens: usage.prompt_token_count, + output_tokens: usage.candidates_token_count, + }); + } + } + } + } + + let stop_reason = if !final_tool_calls.is_empty() { + StopReason::ToolUse + } else { + StopReason::EndTurn + }; + + let usage = final_usage.unwrap_or_default(); + + let _ = tx + .send(StreamEvent::ContentComplete { stop_reason, usage }) + .await; + + let content = if accumulated_text.is_empty() { + Vec::new() + } else { + vec![ContentBlock::Text { + text: accumulated_text, + }] + }; + + return Ok(CompletionResponse { + content, + stop_reason, + tool_calls: final_tool_calls, + usage, + }); + } + + Err(LlmError::Api { + status: 0, + message: "Max retries exceeded".to_string(), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_vertex_driver_creation() { + let driver = VertexAIDriver::new("test-project".to_string(), "us-central1".to_string()); + assert_eq!(driver.project_id, "test-project"); + assert_eq!(driver.region, "us-central1"); + } + + #[test] + fn test_build_endpoint_non_streaming() { + let driver = VertexAIDriver::new("my-project".to_string(), "us-central1".to_string()); + let endpoint = driver.build_endpoint("gemini-2.0-flash", false); + assert_eq!( + endpoint, + "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:generateContent" + ); + } + + #[test] + fn test_build_endpoint_streaming() { + let driver = VertexAIDriver::new("my-project".to_string(), "europe-west4".to_string()); + let endpoint = driver.build_endpoint("gemini-1.5-pro", true); + assert_eq!( + endpoint, + "https://europe-west4-aiplatform.googleapis.com/v1/projects/my-project/locations/europe-west4/publishers/google/models/gemini-1.5-pro:streamGenerateContent" + ); + } + + #[test] + fn test_build_endpoint_strips_model_prefix() { + let driver = VertexAIDriver::new("my-project".to_string(), "us-central1".to_string()); + let endpoint = driver.build_endpoint("models/gemini-2.0-flash", false); + assert_eq!( + endpoint, + "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:generateContent" + ); + } + + #[test] + fn test_token_cache_initially_invalid() { + let cache = TokenCache::new(); + assert!(!cache.is_valid()); + assert!(cache.token.is_none()); + } + + #[test] + fn test_vertex_content_serialization() { + let content = VertexContent { + role: Some("user".to_string()), + parts: vec![VertexPart::Text { + text: "Hello".to_string(), + }], + }; + let json = serde_json::to_string(&content).unwrap(); + assert!(json.contains("\"role\":\"user\"")); + assert!(json.contains("\"text\":\"Hello\"")); + } +}