Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 62 additions & 5 deletions crates/openfang-runtime/src/drivers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@ pub mod copilot;
pub mod fallback;
pub mod gemini;
pub mod openai;
pub mod vertex;

use crate::llm_driver::{DriverConfig, LlmDriver, LlmError};
use openfang_types::model_catalog::{
AI21_BASE_URL, ANTHROPIC_BASE_URL, CEREBRAS_BASE_URL, COHERE_BASE_URL, DEEPSEEK_BASE_URL,
FIREWORKS_BASE_URL, GEMINI_BASE_URL, GROQ_BASE_URL, HUGGINGFACE_BASE_URL, LMSTUDIO_BASE_URL,
MINIMAX_BASE_URL, MISTRAL_BASE_URL, MOONSHOT_BASE_URL, OLLAMA_BASE_URL, OPENAI_BASE_URL,
OPENROUTER_BASE_URL, PERPLEXITY_BASE_URL, QIANFAN_BASE_URL, QWEN_BASE_URL,
REPLICATE_BASE_URL, SAMBANOVA_BASE_URL, TOGETHER_BASE_URL, VLLM_BASE_URL, XAI_BASE_URL,
ZHIPU_BASE_URL, ZHIPU_CODING_BASE_URL,
OPENROUTER_BASE_URL, PERPLEXITY_BASE_URL, QIANFAN_BASE_URL, QWEN_BASE_URL, REPLICATE_BASE_URL,
SAMBANOVA_BASE_URL, TOGETHER_BASE_URL, VLLM_BASE_URL, XAI_BASE_URL, ZHIPU_BASE_URL,
ZHIPU_CODING_BASE_URL,
};
use std::sync::Arc;

Expand Down Expand Up @@ -162,6 +163,12 @@ fn provider_defaults(provider: &str) -> Option<ProviderDefaults> {
api_key_env: "QIANFAN_API_KEY",
key_required: true,
}),
"vertex-ai" | "vertex" | "google-vertex" => Some(ProviderDefaults {
// Vertex AI uses OAuth, not API keys - base_url is per-project
base_url: "https://us-central1-aiplatform.googleapis.com",
api_key_env: "GOOGLE_APPLICATION_CREDENTIALS",
key_required: false, // Uses OAuth service account, not API key
}),
_ => None,
}
}
Expand Down Expand Up @@ -250,6 +257,39 @@ pub fn create_driver(config: &DriverConfig) -> Result<Arc<dyn LlmDriver>, LlmErr
)));
}

// Vertex AI — uses Google Cloud OAuth with service account credentials.
// Requires GOOGLE_APPLICATION_CREDENTIALS env var pointing to service account JSON,
// and the service account must be activated via gcloud CLI.
if provider == "vertex-ai" || provider == "vertex" || provider == "google-vertex" {
// Get project_id from environment or service account JSON
let project_id = std::env::var("GOOGLE_CLOUD_PROJECT")
.or_else(|_| std::env::var("GCLOUD_PROJECT"))
.or_else(|_| std::env::var("GCP_PROJECT"))
.or_else(|_| {
// Try to read from service account JSON
if let Ok(creds_path) = std::env::var("GOOGLE_APPLICATION_CREDENTIALS") {
if let Ok(contents) = std::fs::read_to_string(&creds_path) {
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&contents) {
if let Some(proj) = json.get("project_id").and_then(|v| v.as_str()) {
return Ok(proj.to_string());
}
}
}
}
Err(std::env::VarError::NotPresent)
})
.map_err(|_| {
LlmError::MissingApiKey(
"Set GOOGLE_APPLICATION_CREDENTIALS or GOOGLE_CLOUD_PROJECT for Vertex AI"
.to_string(),
)
})?;
let region = std::env::var("GOOGLE_CLOUD_REGION")
.or_else(|_| std::env::var("VERTEX_AI_REGION"))
.unwrap_or_else(|_| "us-central1".to_string());
return Ok(Arc::new(vertex::VertexAIDriver::new(project_id, region)));
}

// All other providers use OpenAI-compatible format
if let Some(defaults) = provider_defaults(provider) {
let api_key = config
Expand Down Expand Up @@ -287,8 +327,8 @@ pub fn create_driver(config: &DriverConfig) -> Result<Arc<dyn LlmDriver>, LlmErr
message: format!(
"Unknown provider '{}'. Supported: anthropic, gemini, openai, groq, openrouter, \
deepseek, together, mistral, fireworks, ollama, vllm, lmstudio, perplexity, \
cohere, ai21, cerebras, sambanova, huggingface, xai, replicate, github-copilot. \
Or set base_url for a custom OpenAI-compatible endpoint.",
cohere, ai21, cerebras, sambanova, huggingface, xai, replicate, github-copilot, \
vertex-ai. Or set base_url for a custom OpenAI-compatible endpoint.",
provider
),
})
Expand Down Expand Up @@ -318,6 +358,7 @@ pub fn known_providers() -> &'static [&'static str] {
"xai",
"replicate",
"github-copilot",
"vertex-ai",
"moonshot",
"qwen",
"minimax",
Expand Down Expand Up @@ -411,6 +452,7 @@ mod tests {
assert!(providers.contains(&"xai"));
assert!(providers.contains(&"replicate"));
assert!(providers.contains(&"github-copilot"));
assert!(providers.contains(&"vertex-ai"));
assert!(providers.contains(&"moonshot"));
assert!(providers.contains(&"qwen"));
assert!(providers.contains(&"minimax"));
Expand Down Expand Up @@ -457,4 +499,19 @@ mod tests {
assert_eq!(d.api_key_env, "HF_API_KEY");
assert!(d.key_required);
}

#[test]
fn test_provider_defaults_vertex_ai() {
let d = provider_defaults("vertex-ai").unwrap();
assert_eq!(d.base_url, "https://us-central1-aiplatform.googleapis.com");
assert_eq!(d.api_key_env, "GOOGLE_APPLICATION_CREDENTIALS");
assert!(!d.key_required); // Uses OAuth, not API key
}

#[test]
fn test_provider_defaults_vertex_alias() {
let d = provider_defaults("vertex").unwrap();
assert_eq!(d.api_key_env, "GOOGLE_APPLICATION_CREDENTIALS");
assert!(!d.key_required);
}
}
Loading