diff --git a/interface/src/api/client.ts b/interface/src/api/client.ts index 8a320a265..a2e020545 100644 --- a/interface/src/api/client.ts +++ b/interface/src/api/client.ts @@ -661,6 +661,7 @@ export interface CronExecutionsParams { export interface ProviderStatus { anthropic: boolean; openai: boolean; + openai_chatgpt: boolean; openrouter: boolean; zhipu: boolean; groq: boolean; @@ -669,10 +670,12 @@ export interface ProviderStatus { deepseek: boolean; xai: boolean; mistral: boolean; + gemini: boolean; ollama: boolean; opencode_zen: boolean; nvidia: boolean; minimax: boolean; + minimax_cn: boolean; moonshot: boolean; zai_coding_plan: boolean; } @@ -695,6 +698,20 @@ export interface ProviderModelTestResponse { sample: string | null; } +export interface OpenAiOAuthBrowserStartResponse { + success: boolean; + message: string; + authorization_url: string | null; + state: string | null; +} + +export interface OpenAiOAuthBrowserStatusResponse { + found: boolean; + done: boolean; + success: boolean; + message: string | null; +} + // -- Model Types -- export interface ModelInfo { @@ -1153,6 +1170,28 @@ export const api = { } return response.json() as Promise; }, + startOpenAiOAuthBrowser: async (params: {model: string}) => { + const response = await fetch(`${API_BASE}/providers/openai/oauth/browser/start`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + model: params.model, + }), + }); + if (!response.ok) { + throw new Error(`API error: ${response.status}`); + } + return response.json() as Promise; + }, + openAiOAuthBrowserStatus: async (state: string) => { + const response = await fetch( + `${API_BASE}/providers/openai/oauth/browser/status?state=${encodeURIComponent(state)}`, + ); + if (!response.ok) { + throw new Error(`API error: ${response.status}`); + } + return response.json() as Promise; + }, removeProvider: async (provider: string) => { const response = await fetch(`${API_BASE}/providers/${encodeURIComponent(provider)}`, { method: "DELETE", diff --git a/interface/src/components/ModelSelect.tsx b/interface/src/components/ModelSelect.tsx index c4841b861..8abab4391 100644 --- a/interface/src/components/ModelSelect.tsx +++ b/interface/src/components/ModelSelect.tsx @@ -16,6 +16,7 @@ const PROVIDER_LABELS: Record = { anthropic: "Anthropic", openrouter: "OpenRouter", openai: "OpenAI", + "openai-chatgpt": "ChatGPT Plus (OAuth)", deepseek: "DeepSeek", xai: "xAI", mistral: "Mistral", @@ -129,6 +130,7 @@ export function ModelSelect({ "openrouter", "anthropic", "openai", + "openai-chatgpt", "ollama", "deepseek", "xai", diff --git a/interface/src/lib/providerIcons.tsx b/interface/src/lib/providerIcons.tsx index 08abacce5..29487961a 100644 --- a/interface/src/lib/providerIcons.tsx +++ b/interface/src/lib/providerIcons.tsx @@ -99,6 +99,7 @@ export function ProviderIcon({ provider, className = "text-ink-faint", size = 24 const iconMap: Record> = { anthropic: Anthropic, openai: OpenAI, + "openai-chatgpt": OpenAI, openrouter: OpenRouter, groq: Groq, mistral: Mistral, diff --git a/interface/src/routes/Settings.tsx b/interface/src/routes/Settings.tsx index f7f13fb56..087a3d998 100644 --- a/interface/src/routes/Settings.tsx +++ b/interface/src/routes/Settings.tsx @@ -210,6 +210,8 @@ const PROVIDERS = [ }, ] as const; +const CHATGPT_OAUTH_DEFAULT_MODEL = "openai-chatgpt/gpt-5.3-codex"; + export function Settings() { const queryClient = useQueryClient(); const navigate = useNavigate(); @@ -236,6 +238,11 @@ export function Settings() { message: string; sample?: string | null; } | null>(null); + const [isPollingOpenAiBrowserOAuth, setIsPollingOpenAiBrowserOAuth] = useState(false); + const [openAiBrowserOAuthMessage, setOpenAiBrowserOAuthMessage] = useState<{ + text: string; + type: "success" | "error"; + } | null>(null); const [message, setMessage] = useState<{ text: string; type: "success" | "error"; @@ -287,6 +294,9 @@ export function Settings() { mutationFn: ({ provider, apiKey, model }: { provider: string; apiKey: string; model: string }) => api.testProviderModel(provider, apiKey, model), }); + const startOpenAiBrowserOAuthMutation = useMutation({ + mutationFn: (params: { model: string }) => api.startOpenAiOAuthBrowser(params), + }); const removeMutation = useMutation({ mutationFn: (provider: string) => api.removeProvider(provider), @@ -347,6 +357,79 @@ export function Settings() { }); }; + const monitorOpenAiBrowserOAuth = async (stateToken: string, popup: Window | null) => { + setIsPollingOpenAiBrowserOAuth(true); + setOpenAiBrowserOAuthMessage(null); + try { + for (let attempt = 0; attempt < 180; attempt += 1) { + const status = await api.openAiOAuthBrowserStatus(stateToken); + if (status.done) { + if (status.success) { + setOpenAiBrowserOAuthMessage({ + text: status.message || "ChatGPT OAuth configured.", + type: "success", + }); + queryClient.invalidateQueries({queryKey: ["providers"]}); + setTimeout(() => { + queryClient.invalidateQueries({queryKey: ["agents"]}); + queryClient.invalidateQueries({queryKey: ["overview"]}); + }, 3000); + } else { + setOpenAiBrowserOAuthMessage({ + text: status.message || "Browser sign-in failed.", + type: "error", + }); + } + return; + } + await new Promise((resolve) => setTimeout(resolve, 2000)); + } + setOpenAiBrowserOAuthMessage({ + text: "Browser sign-in timed out. Please try again.", + type: "error", + }); + } catch (error: any) { + setOpenAiBrowserOAuthMessage({ + text: `Failed to verify browser sign-in: ${error.message}`, + type: "error", + }); + } finally { + setIsPollingOpenAiBrowserOAuth(false); + if (popup && !popup.closed) { + popup.close(); + } + } + }; + + const handleStartChatGptOAuth = async () => { + setOpenAiBrowserOAuthMessage(null); + try { + const result = await startOpenAiBrowserOAuthMutation.mutateAsync({ + model: CHATGPT_OAUTH_DEFAULT_MODEL, + }); + if (!result.success || !result.authorization_url || !result.state) { + setOpenAiBrowserOAuthMessage({ + text: result.message || "Failed to start browser sign-in", + type: "error", + }); + return; + } + + const popup = window.open( + result.authorization_url, + "spacebot-openai-oauth", + "popup=true,width=560,height=780,noopener,noreferrer", + ); + setOpenAiBrowserOAuthMessage({ + text: "Complete sign-in in the browser window. Waiting for callback...", + type: "success", + }); + void monitorOpenAiBrowserOAuth(result.state, popup); + } catch (error: any) { + setOpenAiBrowserOAuthMessage({text: `Failed: ${error.message}`, type: "error"}); + } + }; + const handleClose = () => { setEditingProvider(null); setKeyInput(""); @@ -419,24 +502,36 @@ export function Settings() { ) : (
{PROVIDERS.map((provider) => ( - { - setEditingProvider(provider.id); - setKeyInput(""); - setModelInput(provider.defaultModel ?? ""); - setTestedSignature(null); - setTestResult(null); - setMessage(null); - }} - onRemove={() => removeMutation.mutate(provider.id)} - removing={removeMutation.isPending} - /> + [ + { + setEditingProvider(provider.id); + setKeyInput(""); + setModelInput(provider.defaultModel ?? ""); + setTestedSignature(null); + setTestResult(null); + setMessage(null); + }} + onRemove={() => removeMutation.mutate(provider.id)} + removing={removeMutation.isPending} + />, + provider.id === "openai" ? ( + + ) : null, + ] ))}
)} @@ -483,6 +578,8 @@ export function Settings() { {editingProvider === "ollama" ? `Enter your ${editingProviderData?.name} base URL. It will be saved to your instance config.` + : editingProvider === "openai" + ? "Enter an OpenAI API key. The model below will be applied to routing." : `Enter your ${editingProviderData?.name} API key. It will be saved to your instance config.`} @@ -1470,8 +1567,9 @@ function ProviderCard({ provider, name, description, configured, defaultModel, o
{name} {configured && ( - - ● Configured + + )}
@@ -1494,3 +1592,54 @@ function ProviderCard({ provider, name, description, configured, defaultModel, o ); } + +interface ChatGptOAuthCardProps { + configured: boolean; + defaultModel: string; + isPolling: boolean; + message: { text: string; type: "success" | "error" } | null; + onSignIn: () => void; +} + +function ChatGptOAuthCard({ configured, defaultModel, isPolling, message, onSignIn }: ChatGptOAuthCardProps) { + return ( +
+
+ +
+
+ ChatGPT Plus (OAuth) + {configured && ( + + + )} +
+

+ Sign in with your ChatGPT Plus account in the browser. +

+

+ Default model: {defaultModel} +

+ {message && ( +

+ {message.text} +

+ )} +
+
+ +
+
+
+ ); +} diff --git a/src/api/models.rs b/src/api/models.rs index df85d8185..35e108689 100644 --- a/src/api/models.rs +++ b/src/api/models.rs @@ -121,6 +121,23 @@ fn is_known_voice_transcription_model(model_id: &str) -> bool { KNOWN_VOICE_TRANSCRIPTION_MODELS.contains(&model_id) } +fn as_openai_chatgpt_model(model: &ModelInfo) -> Option { + if model.provider != "openai" { + return None; + } + + let model_name = model.id.strip_prefix("openai/")?; + Some(ModelInfo { + id: format!("openai-chatgpt/{model_name}"), + name: model.name.clone(), + provider: "openai-chatgpt".into(), + context_window: model.context_window, + tool_call: model.tool_call, + reasoning: model.reasoning, + input_audio: model.input_audio, + }) +} + /// Models from providers not in models.dev (private/custom endpoints). fn extra_models() -> Vec { vec![ @@ -355,17 +372,14 @@ async fn ensure_models_cache() -> Vec { pub(super) async fn configured_providers(config_path: &std::path::Path) -> Vec<&'static str> { let mut providers = Vec::new(); - let content = match tokio::fs::read_to_string(config_path).await { - Ok(c) => c, - Err(_) => return providers, - }; - let doc: toml_edit::DocumentMut = match content.parse() { - Ok(d) => d, - Err(_) => return providers, - }; + let document = tokio::fs::read_to_string(config_path) + .await + .ok() + .and_then(|content| content.parse::().ok()); - let has_key = |key: &str, env_var: &str| -> bool { - if let Some(llm) = doc.get("llm") + let has_key = |key: &str, env_var: &str| { + if let Some(doc) = document.as_ref() + && let Some(llm) = doc.get("llm") && let Some(val) = llm.get(key) && let Some(s) = val.as_str() { @@ -383,6 +397,12 @@ pub(super) async fn configured_providers(config_path: &std::path::Path) -> Vec<& if has_key("openai_key", "OPENAI_API_KEY") { providers.push("openai"); } + if config_path + .parent() + .is_some_and(|instance_dir| crate::openai_auth::credentials_path(instance_dir).exists()) + { + providers.push("openai-chatgpt"); + } if has_key("openrouter_key", "OPENROUTER_API_KEY") { providers.push("openrouter"); } @@ -440,6 +460,11 @@ pub(super) async fn get_models( .as_deref() .map(str::trim) .filter(|provider| !provider.is_empty()); + let requested_provider_for_catalog = if requested_provider == Some("openai-chatgpt") { + Some("openai") + } else { + requested_provider + }; let requested_capability = query .capability .as_deref() @@ -447,11 +472,24 @@ pub(super) async fn get_models( .filter(|capability| !capability.is_empty()); let catalog = ensure_models_cache().await; + let capability_matches = |model: &ModelInfo| { + if let Some(capability) = requested_capability { + match capability { + "input_audio" => model.input_audio, + "voice_transcription" => { + model.input_audio && is_known_voice_transcription_model(&model.id) + } + _ => true, + } + } else { + true + } + }; let mut models: Vec = catalog - .into_iter() + .iter() .filter(|model| { - let provider_match = if let Some(provider) = requested_provider { + let provider_match = if let Some(provider) = requested_provider_for_catalog { model.provider == provider } else { configured.contains(&model.provider.as_str()) @@ -459,21 +497,25 @@ pub(super) async fn get_models( if !provider_match { return false; } - - if let Some(capability) = requested_capability { - return match capability { - "input_audio" => model.input_audio, - "voice_transcription" => { - model.input_audio && is_known_voice_transcription_model(&model.id) - } - _ => true, - }; - } - - true + capability_matches(model) }) + .cloned() .collect(); + if requested_provider == Some("openai-chatgpt") { + models = models + .into_iter() + .filter_map(|model| as_openai_chatgpt_model(&model)) + .collect(); + } else if requested_provider.is_none() && configured.contains(&"openai-chatgpt") { + let chatgpt_models: Vec = catalog + .iter() + .filter(|model| model.provider == "openai" && capability_matches(model)) + .filter_map(as_openai_chatgpt_model) + .collect(); + models.extend(chatgpt_models); + } + for model in extra_models() { if let Some(capability) = requested_capability { if capability == "input_audio" && !model.input_audio { diff --git a/src/api/providers.rs b/src/api/providers.rs index 251ceae34..8fb5e54d6 100644 --- a/src/api/providers.rs +++ b/src/api/providers.rs @@ -1,18 +1,45 @@ use super::state::ApiState; +use anyhow::Context as _; use axum::Json; -use axum::extract::State; -use axum::http::StatusCode; +use axum::extract::{Query, State}; +use axum::http::{HeaderMap, StatusCode}; +use axum::response::Html; +use reqwest::Url; use rig::agent::AgentBuilder; use rig::completion::{CompletionModel as _, Prompt as _}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; +use tokio::sync::RwLock; + +const OPENAI_BROWSER_OAUTH_SESSION_TTL_SECS: i64 = 15 * 60; +const OPENAI_BROWSER_OAUTH_REDIRECT_PATH: &str = "/providers/openai/oauth/browser/callback"; + +static OPENAI_BROWSER_OAUTH_SESSIONS: LazyLock>> = + LazyLock::new(|| RwLock::new(HashMap::new())); + +#[derive(Clone, Debug)] +struct BrowserOAuthSession { + pkce_verifier: String, + redirect_uri: String, + model: String, + created_at: i64, + status: BrowserOAuthSessionStatus, +} + +#[derive(Clone, Debug)] +enum BrowserOAuthSessionStatus { + Pending, + Completed(String), + Failed(String), +} #[derive(Serialize)] pub(super) struct ProviderStatus { anthropic: bool, openai: bool, + openai_chatgpt: bool, openrouter: bool, zhipu: bool, groq: bool, @@ -66,6 +93,40 @@ pub(super) struct ProviderModelTestResponse { sample: Option, } +#[derive(Deserialize)] +pub(super) struct OpenAiOAuthBrowserStartRequest { + model: String, +} + +#[derive(Serialize)] +pub(super) struct OpenAiOAuthBrowserStartResponse { + success: bool, + message: String, + authorization_url: Option, + state: Option, +} + +#[derive(Deserialize)] +pub(super) struct OpenAiOAuthBrowserStatusRequest { + state: String, +} + +#[derive(Serialize)] +pub(super) struct OpenAiOAuthBrowserStatusResponse { + found: bool, + done: bool, + success: bool, + message: Option, +} + +#[derive(Deserialize)] +pub(super) struct OpenAiOAuthBrowserCallbackQuery { + code: Option, + state: Option, + error: Option, + error_description: Option, +} + fn provider_toml_key(provider: &str) -> Option<&'static str> { match provider { "anthropic" => Some("anthropic_key"), @@ -94,6 +155,20 @@ fn model_matches_provider(provider: &str, model: &str) -> bool { crate::llm::routing::provider_from_model(model) == provider } +fn normalize_openai_chatgpt_model(model: &str) -> Option { + let trimmed = model.trim(); + let (provider, model_name) = trimmed.split_once('/')?; + if model_name.is_empty() { + return None; + } + + match provider { + "openai" => Some(format!("openai-chatgpt/{model_name}")), + "openai-chatgpt" => Some(trimmed.to_string()), + _ => None, + } +} + fn build_test_llm_config(provider: &str, credential: &str) -> crate::config::LlmConfig { use crate::config::{ApiType, ProviderConfig}; @@ -232,14 +307,188 @@ fn build_test_llm_config(provider: &str, credential: &str) -> crate::config::Llm } } +fn apply_model_routing(doc: &mut toml_edit::DocumentMut, model: &str) { + if doc.get("defaults").is_none() { + doc["defaults"] = toml_edit::Item::Table(toml_edit::Table::new()); + } + if let Some(defaults) = doc.get_mut("defaults").and_then(|item| item.as_table_mut()) { + if defaults.get("routing").is_none() { + defaults["routing"] = toml_edit::Item::Table(toml_edit::Table::new()); + } + if let Some(routing_table) = defaults + .get_mut("routing") + .and_then(|item| item.as_table_mut()) + { + routing_table["channel"] = toml_edit::value(model); + routing_table["branch"] = toml_edit::value(model); + routing_table["worker"] = toml_edit::value(model); + routing_table["compactor"] = toml_edit::value(model); + routing_table["cortex"] = toml_edit::value(model); + } + } + + if let Some(agents) = doc + .get_mut("agents") + .and_then(|agents_item| agents_item.as_array_of_tables_mut()) + && let Some(default_agent) = agents.iter_mut().find(|agent| { + agent + .get("default") + .and_then(|value| value.as_bool()) + .unwrap_or(false) + }) + { + if default_agent.get("routing").is_none() { + default_agent["routing"] = toml_edit::Item::Table(toml_edit::Table::new()); + } + if let Some(routing_table) = default_agent + .get_mut("routing") + .and_then(|routing_item| routing_item.as_table_mut()) + { + routing_table["channel"] = toml_edit::value(model); + routing_table["branch"] = toml_edit::value(model); + routing_table["worker"] = toml_edit::value(model); + routing_table["compactor"] = toml_edit::value(model); + routing_table["cortex"] = toml_edit::value(model); + } + } +} + +async fn prune_expired_browser_oauth_sessions() { + let cutoff = chrono::Utc::now().timestamp() - OPENAI_BROWSER_OAUTH_SESSION_TTL_SECS; + let mut sessions = OPENAI_BROWSER_OAUTH_SESSIONS.write().await; + sessions.retain(|_, session| session.created_at >= cutoff); +} + +fn resolve_browser_oauth_redirect_uri(headers: &HeaderMap) -> Option { + if let Some(origin) = header_value(headers, axum::http::header::ORIGIN.as_str()) + && let Ok(origin_url) = Url::parse(origin) + { + let origin = origin_url.origin().ascii_serialization(); + if origin != "null" { + return Some(format!("{origin}{OPENAI_BROWSER_OAUTH_REDIRECT_PATH}")); + } + } + + if let (Some(proto), Some(host)) = ( + header_value(headers, "x-forwarded-proto"), + header_value(headers, "x-forwarded-host"), + ) { + let proto = first_header_value(proto); + let host = normalize_host(first_header_value(host)); + return Some(format!( + "{proto}://{host}{OPENAI_BROWSER_OAUTH_REDIRECT_PATH}" + )); + } + + if let Some(host) = header_value(headers, "host") { + let host = normalize_host(host); + let scheme = if is_local_host(&host) { + "http" + } else { + "https" + }; + return Some(format!( + "{scheme}://{host}{OPENAI_BROWSER_OAUTH_REDIRECT_PATH}" + )); + } + + None +} + +fn header_value(headers: &HeaderMap, name: impl AsRef) -> Option<&str> { + headers + .get(name.as_ref()) + .and_then(|value| value.to_str().ok()) +} + +fn first_header_value(value: &str) -> &str { + value.split(',').next().map(str::trim).unwrap_or(value) +} + +fn normalize_host(host: &str) -> String { + let host = host.trim(); + let colon_count = host.matches(':').count(); + if colon_count > 1 && !host.starts_with('[') { + format!("[{host}]") + } else { + host.to_string() + } +} + +fn is_local_host(host: &str) -> bool { + let host = host + .trim_start_matches('[') + .trim_end_matches(']') + .split(':') + .next() + .unwrap_or(host); + matches!(host, "localhost" | "127.0.0.1" | "::1") +} + +fn browser_oauth_success_html() -> String { + r#" + + + + Spacebot OpenAI Sign-in + + + +
+

Sign-in complete

+

You can close this window and return to Spacebot settings.

+
+ + +"# + .to_string() +} + +fn browser_oauth_error_html(message: &str) -> String { + let escaped = message + .replace('&', "&") + .replace('<', "<") + .replace('>', ">"); + format!( + r#" + + + + Spacebot OpenAI Sign-in + + + +
+

Sign-in failed

+

{}

+
+ +"#, + escaped + ) +} + pub(super) async fn get_providers( State(state): State>, ) -> Result, StatusCode> { let config_path = state.config_path.read().await.clone(); + let instance_dir = (**state.instance_dir.load()).clone(); + let openai_oauth_configured = crate::openai_auth::credentials_path(&instance_dir).exists(); let ( anthropic, openai, + openai_chatgpt, openrouter, zhipu, groq, @@ -280,6 +529,7 @@ pub(super) async fn get_providers( ( has_value("anthropic_key", "ANTHROPIC_API_KEY"), has_value("openai_key", "OPENAI_API_KEY"), + openai_oauth_configured, has_value("openrouter_key", "OPENROUTER_API_KEY"), has_value("zhipu_key", "ZHIPU_API_KEY"), has_value("groq_key", "GROQ_API_KEY"), @@ -302,6 +552,7 @@ pub(super) async fn get_providers( ( std::env::var("ANTHROPIC_API_KEY").is_ok(), std::env::var("OPENAI_API_KEY").is_ok(), + openai_oauth_configured, std::env::var("OPENROUTER_API_KEY").is_ok(), std::env::var("ZHIPU_API_KEY").is_ok(), std::env::var("GROQ_API_KEY").is_ok(), @@ -324,6 +575,7 @@ pub(super) async fn get_providers( let providers = ProviderStatus { anthropic, openai, + openai_chatgpt, openrouter, zhipu, groq, @@ -343,6 +595,7 @@ pub(super) async fn get_providers( }; let has_any = providers.anthropic || providers.openai + || providers.openai_chatgpt || providers.openrouter || providers.zhipu || providers.groq @@ -363,6 +616,256 @@ pub(super) async fn get_providers( Ok(Json(ProvidersResponse { providers, has_any })) } +pub(super) async fn start_openai_browser_oauth( + headers: HeaderMap, + Json(request): Json, +) -> Result, StatusCode> { + if request.model.trim().is_empty() { + return Ok(Json(OpenAiOAuthBrowserStartResponse { + success: false, + message: "Model cannot be empty".to_string(), + authorization_url: None, + state: None, + })); + } + let Some(chatgpt_model) = normalize_openai_chatgpt_model(&request.model) else { + return Ok(Json(OpenAiOAuthBrowserStartResponse { + success: false, + message: format!( + "Model '{}' must use provider 'openai' or 'openai-chatgpt'.", + request.model + ), + authorization_url: None, + state: None, + })); + }; + + let Some(redirect_uri) = resolve_browser_oauth_redirect_uri(&headers) else { + return Ok(Json(OpenAiOAuthBrowserStartResponse { + success: false, + message: "Unable to determine OAuth callback URL. Check your Host/Origin headers." + .to_string(), + authorization_url: None, + state: None, + })); + }; + + prune_expired_browser_oauth_sessions().await; + let browser_authorization = crate::openai_auth::start_browser_authorization(&redirect_uri); + let state_key = browser_authorization.state.clone(); + + OPENAI_BROWSER_OAUTH_SESSIONS.write().await.insert( + state_key.clone(), + BrowserOAuthSession { + pkce_verifier: browser_authorization.pkce_verifier, + redirect_uri, + model: chatgpt_model, + created_at: chrono::Utc::now().timestamp(), + status: BrowserOAuthSessionStatus::Pending, + }, + ); + + Ok(Json(OpenAiOAuthBrowserStartResponse { + success: true, + message: "OpenAI browser OAuth started".to_string(), + authorization_url: Some(browser_authorization.authorization_url), + state: Some(state_key), + })) +} + +pub(super) async fn openai_browser_oauth_status( + Query(request): Query, +) -> Result, StatusCode> { + prune_expired_browser_oauth_sessions().await; + if request.state.trim().is_empty() { + return Ok(Json(OpenAiOAuthBrowserStatusResponse { + found: false, + done: false, + success: false, + message: Some("Missing OAuth state".to_string()), + })); + } + + let sessions = OPENAI_BROWSER_OAUTH_SESSIONS.read().await; + let Some(session) = sessions.get(request.state.trim()) else { + return Ok(Json(OpenAiOAuthBrowserStatusResponse { + found: false, + done: false, + success: false, + message: None, + })); + }; + + let response = match &session.status { + BrowserOAuthSessionStatus::Pending => OpenAiOAuthBrowserStatusResponse { + found: true, + done: false, + success: false, + message: None, + }, + BrowserOAuthSessionStatus::Completed(message) => OpenAiOAuthBrowserStatusResponse { + found: true, + done: true, + success: true, + message: Some(message.clone()), + }, + BrowserOAuthSessionStatus::Failed(message) => OpenAiOAuthBrowserStatusResponse { + found: true, + done: true, + success: false, + message: Some(message.clone()), + }, + }; + Ok(Json(response)) +} + +pub(super) async fn openai_browser_oauth_callback( + State(state): State>, + Query(query): Query, +) -> Html { + prune_expired_browser_oauth_sessions().await; + + let Some(state_key) = query + .state + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(str::to_string) + else { + return Html(browser_oauth_error_html("Missing OAuth state.")); + }; + + if let Some(error_code) = query.error.as_deref() { + let mut message = format!("OpenAI returned OAuth error: {}", error_code); + if let Some(description) = query.error_description.as_deref() { + message.push_str(&format!(" ({})", description)); + } + if let Some(session) = OPENAI_BROWSER_OAUTH_SESSIONS + .write() + .await + .get_mut(&state_key) + { + session.status = BrowserOAuthSessionStatus::Failed(message.clone()); + } + return Html(browser_oauth_error_html(&message)); + } + + let Some(code) = query + .code + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + else { + let message = "OpenAI callback did not include an authorization code."; + if let Some(session) = OPENAI_BROWSER_OAUTH_SESSIONS + .write() + .await + .get_mut(&state_key) + { + session.status = BrowserOAuthSessionStatus::Failed(message.to_string()); + } + return Html(browser_oauth_error_html(message)); + }; + + let (pkce_verifier, redirect_uri, model) = { + let sessions = OPENAI_BROWSER_OAUTH_SESSIONS.read().await; + let Some(session) = sessions.get(&state_key) else { + return Html(browser_oauth_error_html( + "OAuth session expired or was not found. Start sign-in again.", + )); + }; + ( + session.pkce_verifier.clone(), + session.redirect_uri.clone(), + session.model.clone(), + ) + }; + + let credentials = match crate::openai_auth::exchange_browser_code( + code, + &redirect_uri, + &pkce_verifier, + ) + .await + { + Ok(credentials) => credentials, + Err(error) => { + let message = format!("Failed to exchange OpenAI authorization code: {error}"); + if let Some(session) = OPENAI_BROWSER_OAUTH_SESSIONS + .write() + .await + .get_mut(&state_key) + { + session.status = BrowserOAuthSessionStatus::Failed(message.clone()); + } + return Html(browser_oauth_error_html(&message)); + } + }; + + let persist_result = async { + let instance_dir = (**state.instance_dir.load()).clone(); + crate::openai_auth::save_credentials(&instance_dir, &credentials) + .context("failed to save OpenAI OAuth credentials")?; + + if let Some(llm_manager) = state.llm_manager.read().await.as_ref() { + llm_manager + .set_openai_oauth_credentials(credentials.clone()) + .await; + } + + let config_path = state.config_path.read().await.clone(); + let content = if config_path.exists() { + tokio::fs::read_to_string(&config_path) + .await + .context("failed to read config.toml")? + } else { + String::new() + }; + + let mut doc: toml_edit::DocumentMut = + content.parse().context("failed to parse config.toml")?; + apply_model_routing(&mut doc, &model); + tokio::fs::write(&config_path, doc.to_string()) + .await + .context("failed to write config.toml")?; + + state + .provider_setup_tx + .try_send(crate::ProviderSetupEvent::ProvidersConfigured) + .ok(); + + anyhow::Ok(()) + } + .await; + + match persist_result { + Ok(()) => { + if let Some(session) = OPENAI_BROWSER_OAUTH_SESSIONS + .write() + .await + .get_mut(&state_key) + { + session.status = BrowserOAuthSessionStatus::Completed(format!( + "OpenAI configured via browser OAuth. Model '{}' applied to defaults and default agent routing.", + model + )); + } + Html(browser_oauth_success_html()) + } + Err(error) => { + let message = format!("OAuth sign-in completed but finalization failed: {error}"); + if let Some(session) = OPENAI_BROWSER_OAUTH_SESSIONS + .write() + .await + .get_mut(&state_key) + { + session.status = BrowserOAuthSessionStatus::Failed(message.clone()); + } + Html(browser_oauth_error_html(&message)) + } + } +} + pub(super) async fn update_provider( State(state): State>, Json(request): Json, @@ -417,47 +920,7 @@ pub(super) async fn update_provider( } doc["llm"][key_name] = toml_edit::value(request.api_key); - - if doc.get("defaults").is_none() { - doc["defaults"] = toml_edit::Item::Table(toml_edit::Table::new()); - } - if let Some(defaults) = doc.get_mut("defaults").and_then(|d| d.as_table_mut()) { - if defaults.get("routing").is_none() { - defaults["routing"] = toml_edit::Item::Table(toml_edit::Table::new()); - } - if let Some(routing_table) = defaults.get_mut("routing").and_then(|r| r.as_table_mut()) { - routing_table["channel"] = toml_edit::value(request.model.as_str()); - routing_table["branch"] = toml_edit::value(request.model.as_str()); - routing_table["worker"] = toml_edit::value(request.model.as_str()); - routing_table["compactor"] = toml_edit::value(request.model.as_str()); - routing_table["cortex"] = toml_edit::value(request.model.as_str()); - } - } - - if let Some(agents) = doc - .get_mut("agents") - .and_then(|agents_item| agents_item.as_array_of_tables_mut()) - && let Some(default_agent) = agents.iter_mut().find(|agent| { - agent - .get("default") - .and_then(|value| value.as_bool()) - .unwrap_or(false) - }) - { - if default_agent.get("routing").is_none() { - default_agent["routing"] = toml_edit::Item::Table(toml_edit::Table::new()); - } - if let Some(routing_table) = default_agent - .get_mut("routing") - .and_then(|routing_item| routing_item.as_table_mut()) - { - routing_table["channel"] = toml_edit::value(request.model.as_str()); - routing_table["branch"] = toml_edit::value(request.model.as_str()); - routing_table["worker"] = toml_edit::value(request.model.as_str()); - routing_table["compactor"] = toml_edit::value(request.model.as_str()); - routing_table["cortex"] = toml_edit::value(request.model.as_str()); - } - } + apply_model_routing(&mut doc, request.model.as_str()); tokio::fs::write(&config_path, doc.to_string()) .await diff --git a/src/api/server.rs b/src/api/server.rs index a41c7f8d1..31d2116dd 100644 --- a/src/api/server.rs +++ b/src/api/server.rs @@ -123,6 +123,18 @@ pub async fn start_http_server( "/providers", get(providers::get_providers).put(providers::update_provider), ) + .route( + "/providers/openai/oauth/browser/start", + post(providers::start_openai_browser_oauth), + ) + .route( + "/providers/openai/oauth/browser/status", + get(providers::openai_browser_oauth_status), + ) + .route( + "/providers/openai/oauth/browser/callback", + get(providers::openai_browser_oauth_callback), + ) .route("/providers/test", post(providers::test_provider_model)) .route("/providers/{provider}", delete(providers::delete_provider)) .route("/models", get(models::get_models)) diff --git a/src/config.rs b/src/config.rs index 20e3394d7..f2b82b0ce 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1944,7 +1944,9 @@ impl Config { } // OAuth credentials count as configured - if crate::auth::credentials_path(&instance_dir).exists() { + if crate::auth::credentials_path(&instance_dir).exists() + || crate::openai_auth::credentials_path(&instance_dir).exists() + { return false; } @@ -4272,6 +4274,26 @@ name = "Custom OpenAI" assert!(!Config::needs_onboarding()); } + #[test] + fn test_needs_onboarding_false_with_openai_oauth_credentials() { + let _lock = env_test_lock() + .lock() + .expect("failed to lock env test mutex"); + let _env = EnvGuard::new(); + + let instance_dir = Config::default_instance_dir(); + let creds = crate::openai_auth::OAuthCredentials { + access_token: "openai-access-token-test".to_string(), + refresh_token: "openai-refresh-token-test".to_string(), + expires_at: chrono::Utc::now().timestamp_millis() + 3_600_000, + account_id: Some("acct_test_123".to_string()), + }; + crate::openai_auth::save_credentials(&instance_dir, &creds) + .expect("failed to save OpenAI OAuth credentials"); + + assert!(!Config::needs_onboarding()); + } + #[test] fn test_load_from_env_populates_legacy_key_and_provider() { let _lock = env_test_lock() diff --git a/src/lib.rs b/src/lib.rs index 98b4eac3f..b06e676ce 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,6 +15,7 @@ pub mod llm; pub mod mcp; pub mod memory; pub mod messaging; +pub mod openai_auth; pub mod opencode; pub mod prompts; pub mod secrets; diff --git a/src/llm/manager.rs b/src/llm/manager.rs index 69fd113ac..32fb97412 100644 --- a/src/llm/manager.rs +++ b/src/llm/manager.rs @@ -8,9 +8,10 @@ //! `reload_config()` when config.toml changes, and all subsequent //! `get_api_key()` calls read the new values lock-free. -use crate::auth::OAuthCredentials; +use crate::auth::OAuthCredentials as AnthropicOAuthCredentials; use crate::config::{ApiType, LlmConfig, ProviderConfig}; use crate::error::{LlmError, Result}; +use crate::openai_auth::OAuthCredentials as OpenAiOAuthCredentials; use anyhow::Context as _; use arc_swap::ArcSwap; @@ -28,8 +29,10 @@ pub struct LlmManager { rate_limited: Arc>>, /// Instance directory for reading/writing OAuth credentials. instance_dir: Option, - /// Cached OAuth credentials (refreshed lazily). - oauth_credentials: RwLock>, + /// Cached Anthropic OAuth credentials (refreshed lazily). + anthropic_oauth_credentials: RwLock>, + /// Cached OpenAI OAuth credentials (refreshed lazily). + openai_oauth_credentials: RwLock>, } impl LlmManager { @@ -45,15 +48,20 @@ impl LlmManager { http_client, rate_limited: Arc::new(RwLock::new(HashMap::new())), instance_dir: None, - oauth_credentials: RwLock::new(None), + anthropic_oauth_credentials: RwLock::new(None), + openai_oauth_credentials: RwLock::new(None), }) } /// Set the instance directory and load any existing OAuth credentials. pub async fn set_instance_dir(&self, instance_dir: PathBuf) { if let Ok(Some(creds)) = crate::auth::load_credentials(&instance_dir) { - tracing::info!("loaded OAuth credentials from auth.json"); - *self.oauth_credentials.write().await = Some(creds); + tracing::info!("loaded Anthropic OAuth credentials from auth.json"); + *self.anthropic_oauth_credentials.write().await = Some(creds); + } + if let Ok(Some(creds)) = crate::openai_auth::load_credentials(&instance_dir) { + tracing::info!("loaded OpenAI OAuth credentials from openai_chatgpt_oauth.json"); + *self.openai_oauth_credentials.write().await = Some(creds); } // Store instance_dir — we can't set it on &self since it's not behind RwLock, // but we only need it for save_credentials which we handle inline. @@ -66,14 +74,26 @@ impl LlmManager { .build() .with_context(|| "failed to build HTTP client")?; - let oauth_credentials = match crate::auth::load_credentials(&instance_dir) { + let anthropic_oauth_credentials = match crate::auth::load_credentials(&instance_dir) { Ok(Some(creds)) => { - tracing::info!("loaded OAuth credentials from auth.json"); + tracing::info!("loaded Anthropic OAuth credentials from auth.json"); Some(creds) } Ok(None) => None, Err(error) => { - tracing::warn!(%error, "failed to load OAuth credentials"); + tracing::warn!(%error, "failed to load Anthropic OAuth credentials"); + None + } + }; + + let openai_oauth_credentials = match crate::openai_auth::load_credentials(&instance_dir) { + Ok(Some(creds)) => { + tracing::info!("loaded OpenAI OAuth credentials from openai_chatgpt_oauth.json"); + Some(creds) + } + Ok(None) => None, + Err(error) => { + tracing::warn!(%error, "failed to load OpenAI OAuth credentials"); None } }; @@ -83,7 +103,8 @@ impl LlmManager { http_client, rate_limited: Arc::new(RwLock::new(HashMap::new())), instance_dir: Some(instance_dir), - oauth_credentials: RwLock::new(oauth_credentials), + anthropic_oauth_credentials: RwLock::new(anthropic_oauth_credentials), + openai_oauth_credentials: RwLock::new(openai_oauth_credentials), }) } @@ -110,7 +131,7 @@ impl LlmManager { /// returns the OAuth access token (refreshing if needed). Otherwise /// falls back to the static API key from config. pub async fn get_anthropic_token(&self) -> Result> { - let mut creds_guard = self.oauth_credentials.write().await; + let mut creds_guard = self.anthropic_oauth_credentials.write().await; let Some(creds) = creds_guard.as_ref() else { return Ok(None); }; @@ -120,22 +141,22 @@ impl LlmManager { } // Need to refresh - tracing::info!("OAuth access token expired, refreshing..."); + tracing::info!("Anthropic OAuth access token expired, refreshing..."); match creds.refresh().await { Ok(new_creds) => { // Save to disk if let Some(ref instance_dir) = self.instance_dir && let Err(error) = crate::auth::save_credentials(instance_dir, &new_creds) { - tracing::warn!(%error, "failed to persist refreshed OAuth credentials"); + tracing::warn!(%error, "failed to persist refreshed Anthropic OAuth credentials"); } let token = new_creds.access_token.clone(); *creds_guard = Some(new_creds); - tracing::info!("OAuth token refreshed successfully"); + tracing::info!("Anthropic OAuth token refreshed successfully"); Ok(Some(token)) } Err(error) => { - tracing::error!(%error, "OAuth token refresh failed"); + tracing::error!(%error, "Anthropic OAuth token refresh failed"); // Return the expired token anyway — the API will reject it // and the error message will be clearer than "no key" Ok(Some(creds.access_token.clone())) @@ -169,6 +190,78 @@ impl LlmManager { } } + /// Set OpenAI OAuth credentials in memory after successful auth. + pub async fn set_openai_oauth_credentials(&self, creds: OpenAiOAuthCredentials) { + *self.openai_oauth_credentials.write().await = Some(creds); + } + + /// Get OpenAI OAuth access token if available, refreshing when needed. + pub async fn get_openai_token(&self) -> Result> { + let mut creds_guard = self.openai_oauth_credentials.write().await; + let Some(creds) = creds_guard.as_ref() else { + return Ok(None); + }; + + if !creds.is_expired() { + return Ok(Some(creds.access_token.clone())); + } + + tracing::info!("OpenAI OAuth access token expired, refreshing..."); + match creds.refresh().await { + Ok(new_creds) => { + if let Some(ref instance_dir) = self.instance_dir + && let Err(error) = + crate::openai_auth::save_credentials(instance_dir, &new_creds) + { + tracing::warn!(%error, "failed to persist refreshed OpenAI OAuth credentials"); + } + let token = new_creds.access_token.clone(); + *creds_guard = Some(new_creds); + tracing::info!("OpenAI OAuth token refreshed successfully"); + Ok(Some(token)) + } + Err(error) => { + tracing::error!(%error, "OpenAI OAuth token refresh failed"); + Ok(Some(creds.access_token.clone())) + } + } + } + + /// Resolve the OpenAI provider config from static API-key configuration. + /// + /// OpenAI ChatGPT OAuth is intentionally handled via a separate internal + /// provider (`openai-chatgpt`) so a saved OAuth token cannot shadow a + /// configured `openai` API key. + pub async fn get_openai_provider(&self) -> Result { + self.get_provider("openai") + } + + /// Resolve the OpenAI ChatGPT OAuth provider config. + /// + /// This internal provider uses OAuth access tokens from ChatGPT Plus/Pro. + pub async fn get_openai_chatgpt_provider(&self) -> Result { + let token = self.get_openai_token().await?; + + match token { + Some(token) => Ok(ProviderConfig { + api_type: ApiType::OpenAiResponses, + base_url: "https://chatgpt.com/backend-api/codex".to_string(), + api_key: token, + name: None, + }), + None => Err(LlmError::UnknownProvider("openai-chatgpt".to_string()).into()), + } + } + + /// Get OpenAI OAuth account id (for ChatGPT Plus/Pro account scoping headers). + pub async fn get_openai_account_id(&self) -> Option { + self.openai_oauth_credentials + .read() + .await + .as_ref() + .and_then(|credentials| credentials.account_id.clone()) + } + /// Get the appropriate API key for a provider. pub fn get_api_key(&self, provider_id: &str) -> Result { let provider = self.get_provider(provider_id)?; diff --git a/src/llm/model.rs b/src/llm/model.rs index e1bf5203c..c2a166dd0 100644 --- a/src/llm/model.rs +++ b/src/llm/model.rs @@ -89,15 +89,26 @@ impl SpacebotModel { .map(|(provider, _)| provider) .unwrap_or("anthropic"); - let provider_config = if provider_id == "anthropic" { - self.llm_manager + let provider_config = match provider_id { + "anthropic" => self + .llm_manager .get_anthropic_provider() .await - .map_err(|e| CompletionError::ProviderError(e.to_string()))? - } else { - self.llm_manager + .map_err(|e| CompletionError::ProviderError(e.to_string()))?, + "openai" => self + .llm_manager + .get_openai_provider() + .await + .map_err(|e| CompletionError::ProviderError(e.to_string()))?, + "openai-chatgpt" => self + .llm_manager + .get_openai_chatgpt_provider() + .await + .map_err(|e| CompletionError::ProviderError(e.to_string()))?, + _ => self + .llm_manager .get_provider(provider_id) - .map_err(|e| CompletionError::ProviderError(e.to_string()))? + .map_err(|e| CompletionError::ProviderError(e.to_string()))?, }; if provider_id == "zai-coding-plan" || provider_id == "zhipu" { @@ -536,6 +547,11 @@ impl SpacebotModel { "{}/v1/chat/completions", provider_config.base_url.trim_end_matches('/') ); + let openai_account_id = if self.provider == "openai-chatgpt" { + self.llm_manager.get_openai_account_id().await + } else { + None + }; let mut request_builder = self .llm_manager @@ -543,6 +559,9 @@ impl SpacebotModel { .post(&chat_completions_url) .header("authorization", format!("Bearer {api_key}")) .header("content-type", "application/json"); + if let Some(account_id) = openai_account_id { + request_builder = request_builder.header("chatgpt-account-id", account_id); + } // Kimi endpoints require a specific user-agent header. if chat_completions_url.contains("kimi.com") || chat_completions_url.contains("moonshot.ai") @@ -587,7 +606,12 @@ impl SpacebotModel { provider_config: &ProviderConfig, ) -> Result, CompletionError> { let base_url = provider_config.base_url.trim_end_matches('/'); - let responses_url = format!("{base_url}/v1/responses"); + let is_chatgpt_codex = self.provider == "openai-chatgpt"; + let responses_url = if is_chatgpt_codex { + format!("{base_url}/responses") + } else { + format!("{base_url}/v1/responses") + }; let api_key = provider_config.api_key.as_str(); let input = convert_messages_to_openai_responses(&request.chat_history); @@ -599,16 +623,25 @@ impl SpacebotModel { if let Some(preamble) = &request.preamble { body["instructions"] = serde_json::json!(preamble); + } else if is_chatgpt_codex { + body["instructions"] = serde_json::json!( + "You are Spacebot. Follow instructions exactly and respond concisely." + ); } - if let Some(max_tokens) = request.max_tokens { + if !is_chatgpt_codex && let Some(max_tokens) = request.max_tokens { body["max_output_tokens"] = serde_json::json!(max_tokens); } - if let Some(temperature) = request.temperature { + if !is_chatgpt_codex && let Some(temperature) = request.temperature { body["temperature"] = serde_json::json!(temperature); } + if is_chatgpt_codex { + body["store"] = serde_json::json!(false); + body["stream"] = serde_json::json!(true); + } + if !request.tools.is_empty() { let tools: Vec = request .tools @@ -625,12 +658,35 @@ impl SpacebotModel { body["tools"] = serde_json::json!(tools); } - let response = self + let openai_account_id = if self.provider == "openai-chatgpt" { + self.llm_manager.get_openai_account_id().await + } else { + None + }; + + let mut request_builder = self .llm_manager .http_client() .post(&responses_url) .header("authorization", format!("Bearer {api_key}")) - .header("content-type", "application/json") + .header("content-type", "application/json"); + if let Some(account_id) = openai_account_id { + request_builder = request_builder.header("ChatGPT-Account-Id", account_id); + } + if is_chatgpt_codex { + request_builder = request_builder + .header("originator", "opencode") + .header( + "session_id", + format!("spacebot-{}", chrono::Utc::now().timestamp()), + ) + .header( + "user-agent", + format!("spacebot/{}", env!("CARGO_PKG_VERSION")), + ); + } + + let response = request_builder .json(&body) .send() .await @@ -641,23 +697,25 @@ impl SpacebotModel { CompletionError::ProviderError(format!("failed to read response body: {e}")) })?; - let response_body: serde_json::Value = - serde_json::from_str(&response_text).map_err(|e| { - CompletionError::ProviderError(format!( - "OpenAI Responses API response ({status}) is not valid JSON: {e}\nBody: {}", - truncate_body(&response_text) - )) - })?; - if !status.is_success() { - let message = response_body["error"]["message"] - .as_str() - .unwrap_or("unknown error"); + let message = parse_openai_error_message(&response_text) + .unwrap_or_else(|| "unknown error".to_string()); return Err(CompletionError::ProviderError(format!( "OpenAI Responses API error ({status}): {message}" ))); } + let response_body: serde_json::Value = if is_chatgpt_codex { + parse_openai_responses_sse_response(&response_text)? + } else { + serde_json::from_str(&response_text).map_err(|e| { + CompletionError::ProviderError(format!( + "OpenAI Responses API response ({status}) is not valid JSON: {e}\nBody: {}", + truncate_body(&response_text) + )) + })? + }; + parse_openai_responses_response(response_body) } @@ -1412,6 +1470,44 @@ fn parse_openai_responses_response( }) } +fn parse_openai_responses_sse_response( + response_text: &str, +) -> Result { + for line in response_text.lines() { + let Some(data) = line.strip_prefix("data: ") else { + continue; + }; + + if data.trim().is_empty() || data.trim() == "[DONE]" { + continue; + } + + let Ok(event_body) = serde_json::from_str::(data) else { + continue; + }; + + if event_body["type"].as_str() == Some("response.completed") + && let Some(response) = event_body.get("response") + { + return Ok(response.clone()); + } + } + + Err(CompletionError::ProviderError(format!( + "OpenAI Responses SSE stream missing response.completed event.\nBody: {}", + truncate_body(response_text) + ))) +} + +fn parse_openai_error_message(response_text: &str) -> Option { + let parsed = serde_json::from_str::(response_text).ok()?; + parsed["error"]["message"] + .as_str() + .or(parsed["detail"].as_str()) + .or(parsed["message"].as_str()) + .map(ToOwned::to_owned) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/llm/routing.rs b/src/llm/routing.rs index 2cdbd929a..cae1bb157 100644 --- a/src/llm/routing.rs +++ b/src/llm/routing.rs @@ -190,6 +190,22 @@ pub fn defaults_for_provider(provider: &str) -> RoutingConfig { ..RoutingConfig::default() } } + "openai-chatgpt" => { + let channel: String = "openai-chatgpt/gpt-4.1".into(); + let worker: String = "openai-chatgpt/gpt-4.1-mini".into(); + RoutingConfig { + channel: channel.clone(), + branch: channel.clone(), + worker: worker.clone(), + compactor: worker.clone(), + cortex: worker.clone(), + voice: String::new(), + task_overrides: HashMap::from([("coding".into(), channel.clone())]), + fallbacks: HashMap::from([(channel, vec![worker])]), + rate_limit_cooldown_secs: 60, + ..RoutingConfig::default() + } + } "zhipu" => { let channel: String = "zhipu/glm-4-plus".into(); let worker: String = "zhipu/glm-4-flash".into(); @@ -352,6 +368,7 @@ pub fn provider_to_prefix(provider: &str) -> &str { match provider { "openrouter" => "openrouter/", "openai" => "openai/", + "openai-chatgpt" => "openai-chatgpt/", "anthropic" => "anthropic/", "zhipu" => "zhipu/", "groq" => "groq/", diff --git a/src/main.rs b/src/main.rs index bd0090bb2..cc2c004af 100644 --- a/src/main.rs +++ b/src/main.rs @@ -587,6 +587,15 @@ fn load_config( } } +fn has_provider_credentials( + llm_config: &spacebot::config::LlmConfig, + instance_dir: &std::path::Path, +) -> bool { + llm_config.has_any_key() + || spacebot::auth::credentials_path(instance_dir).exists() + || spacebot::openai_auth::credentials_path(instance_dir).exists() +} + async fn run( config: spacebot::config::Config, foreground: bool, @@ -654,8 +663,7 @@ async fn run( }; // Check if we have provider configuration (API keys or OAuth credentials) - let has_providers = - config.llm.has_any_key() || spacebot::auth::credentials_path(&config.instance_dir).exists(); + let has_providers = has_provider_credentials(&config.llm, &config.instance_dir); if !has_providers { tracing::info!("No LLM providers configured. Starting in setup mode."); @@ -715,6 +723,7 @@ async fn run( // Set the config path on the API state for config.toml writes let config_path = config.instance_dir.join("config.toml"); api_state.set_config_path(config_path.clone()).await; + api_state.set_instance_dir(config.instance_dir.clone()); api_state.set_llm_manager(llm_manager.clone()).await; api_state.set_embedding_model(embedding_model.clone()).await; api_state.set_prompt_engine(prompt_engine.clone()).await; @@ -1050,9 +1059,16 @@ async fn run( }; match new_config { - Ok(new_config) if new_config.llm.has_any_key() => { + Ok(new_config) + if has_provider_credentials(&new_config.llm, &new_config.instance_dir) => + { // Rebuild LlmManager with the new keys - match spacebot::llm::LlmManager::new(new_config.llm.clone()).await { + match spacebot::llm::LlmManager::with_instance_dir( + new_config.llm.clone(), + new_config.instance_dir.clone(), + ) + .await + { Ok(new_llm) => { let new_llm_manager = Arc::new(new_llm); let mut new_watcher_agents = Vec::new(); diff --git a/src/openai_auth.rs b/src/openai_auth.rs new file mode 100644 index 000000000..bdf615f01 --- /dev/null +++ b/src/openai_auth.rs @@ -0,0 +1,265 @@ +//! OpenAI ChatGPT Plus OAuth browser flow, token exchange, refresh, and storage. + +use anyhow::{Context as _, Result}; +use base64::Engine as _; +use base64::engine::general_purpose::URL_SAFE_NO_PAD; +use rand::RngCore as _; +use serde::{Deserialize, Serialize}; +use sha2::{Digest as _, Sha256}; +use std::path::{Path, PathBuf}; + +const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann"; +const AUTHORIZE_URL: &str = "https://auth.openai.com/oauth/authorize"; +const OAUTH_TOKEN_URL: &str = "https://auth.openai.com/oauth/token"; +const BROWSER_SCOPES: &str = "openid profile email offline_access"; + +/// Stored OpenAI OAuth credentials. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OAuthCredentials { + pub access_token: String, + pub refresh_token: String, + /// Expiry as Unix timestamp in milliseconds. + pub expires_at: i64, + pub account_id: Option, +} + +impl OAuthCredentials { + /// Check if the access token is expired or about to expire (within 5 minutes). + pub fn is_expired(&self) -> bool { + let now = chrono::Utc::now().timestamp_millis(); + let buffer = 5 * 60 * 1000; + now >= self.expires_at - buffer + } + + /// Refresh the access token and return updated credentials. + pub async fn refresh(&self) -> Result { + let client = reqwest::Client::new(); + let response = client + .post(OAUTH_TOKEN_URL) + .header("Content-Type", "application/x-www-form-urlencoded") + .form(&[ + ("grant_type", "refresh_token"), + ("refresh_token", self.refresh_token.as_str()), + ("client_id", CLIENT_ID), + ]) + .send() + .await + .context("failed to send OpenAI OAuth refresh request")?; + + let status = response.status(); + let body = response + .text() + .await + .context("failed to read OpenAI OAuth refresh response")?; + + if !status.is_success() { + anyhow::bail!("OpenAI OAuth refresh failed ({}): {}", status, body); + } + + let token_response: TokenResponse = + serde_json::from_str(&body).context("failed to parse OpenAI OAuth refresh response")?; + + let account_id = extract_account_id(&token_response).or_else(|| self.account_id.clone()); + let refresh_token = token_response + .refresh_token + .unwrap_or_else(|| self.refresh_token.clone()); + + Ok(Self { + access_token: token_response.access_token, + refresh_token, + expires_at: chrono::Utc::now().timestamp_millis() + + token_response.expires_in.unwrap_or(3600) * 1000, + account_id, + }) + } +} + +#[derive(Debug, Deserialize)] +struct TokenResponse { + access_token: String, + refresh_token: Option, + expires_in: Option, + id_token: Option, +} + +#[derive(Debug, Deserialize)] +struct TokenClaims { + chatgpt_account_id: Option, + organizations: Option>, + #[serde(rename = "https://api.openai.com/auth")] + openai_auth: Option, +} + +#[derive(Debug, Deserialize)] +struct TokenOrganization { + id: String, +} + +#[derive(Debug, Deserialize)] +struct TokenOpenAiAuthClaims { + chatgpt_account_id: Option, +} + +/// Data needed to complete OpenAI browser OAuth. +#[derive(Debug, Clone, Serialize)] +pub struct BrowserAuthorization { + pub authorization_url: String, + pub state: String, + pub pkce_verifier: String, +} + +fn generate_random_urlsafe_string(bytes_len: usize) -> String { + let mut bytes = vec![0u8; bytes_len]; + rand::rng().fill_bytes(&mut bytes); + URL_SAFE_NO_PAD.encode(bytes) +} + +fn generate_pkce() -> (String, String) { + let verifier = generate_random_urlsafe_string(64); + let challenge = URL_SAFE_NO_PAD.encode(Sha256::digest(verifier.as_bytes())); + (verifier, challenge) +} + +/// Build a browser-based OAuth authorization URL using PKCE. +pub fn start_browser_authorization(redirect_uri: &str) -> BrowserAuthorization { + let (pkce_verifier, pkce_challenge) = generate_pkce(); + let state = generate_random_urlsafe_string(32); + + let authorization_url = format!( + "{authorize}?response_type=code&client_id={client_id}&redirect_uri={redirect_uri}&scope={scope}&code_challenge={challenge}&code_challenge_method=S256&id_token_add_organizations=true&codex_cli_simplified_flow=true&originator=opencode&state={state}", + authorize = AUTHORIZE_URL, + client_id = urlencoding::encode(CLIENT_ID), + redirect_uri = urlencoding::encode(redirect_uri), + scope = urlencoding::encode(BROWSER_SCOPES), + challenge = urlencoding::encode(&pkce_challenge), + state = urlencoding::encode(&state), + ); + + BrowserAuthorization { + authorization_url, + state, + pkce_verifier, + } +} + +fn parse_jwt_claims(token: &str) -> Option { + let mut parts = token.split('.'); + let _header = parts.next()?; + let payload = parts.next()?; + let _signature = parts.next()?; + if parts.next().is_some() { + return None; + } + + let decoded = URL_SAFE_NO_PAD.decode(payload).ok()?; + serde_json::from_slice::(&decoded).ok() +} + +fn extract_account_id(token_response: &TokenResponse) -> Option { + let from_claims = |claims: TokenClaims| { + claims + .chatgpt_account_id + .or_else(|| claims.openai_auth.and_then(|auth| auth.chatgpt_account_id)) + .or_else(|| { + claims + .organizations + .and_then(|organizations| organizations.into_iter().next()) + .map(|organization| organization.id) + }) + }; + + token_response + .id_token + .as_deref() + .and_then(parse_jwt_claims) + .and_then(from_claims) + .or_else(|| parse_jwt_claims(&token_response.access_token).and_then(from_claims)) +} + +/// Exchange an OAuth authorization code from browser flow for tokens. +pub async fn exchange_browser_code( + code: &str, + redirect_uri: &str, + pkce_verifier: &str, +) -> Result { + let client = reqwest::Client::new(); + let response = client + .post(OAUTH_TOKEN_URL) + .header("Content-Type", "application/x-www-form-urlencoded") + .form(&[ + ("grant_type", "authorization_code"), + ("code", code), + ("redirect_uri", redirect_uri), + ("client_id", CLIENT_ID), + ("code_verifier", pkce_verifier), + ]) + .send() + .await + .context("failed to exchange OpenAI browser authorization code for tokens")?; + + let status = response.status(); + let body = response + .text() + .await + .context("failed to read OpenAI browser token exchange response")?; + + if !status.is_success() { + anyhow::bail!( + "OpenAI browser token exchange failed ({}): {}", + status, + body + ); + } + + let token_response: TokenResponse = serde_json::from_str(&body) + .context("failed to parse OpenAI browser token exchange response")?; + let account_id = extract_account_id(&token_response); + let refresh_token = token_response + .refresh_token + .context("OpenAI browser token response did not include refresh_token")?; + + Ok(OAuthCredentials { + access_token: token_response.access_token, + refresh_token, + expires_at: chrono::Utc::now().timestamp_millis() + + token_response.expires_in.unwrap_or(3600) * 1000, + account_id, + }) +} + +/// Path to OpenAI OAuth credentials within the instance directory. +pub fn credentials_path(instance_dir: &Path) -> PathBuf { + instance_dir.join("openai_chatgpt_oauth.json") +} + +/// Load OpenAI OAuth credentials from disk. +pub fn load_credentials(instance_dir: &Path) -> Result> { + let path = credentials_path(instance_dir); + if !path.exists() { + return Ok(None); + } + + let data = std::fs::read_to_string(&path) + .with_context(|| format!("failed to read {}", path.display()))?; + let creds: OAuthCredentials = + serde_json::from_str(&data).context("failed to parse OpenAI OAuth credentials")?; + Ok(Some(creds)) +} + +/// Save OpenAI OAuth credentials to disk with restricted permissions (0600). +pub fn save_credentials(instance_dir: &Path, creds: &OAuthCredentials) -> Result<()> { + let path = credentials_path(instance_dir); + let data = serde_json::to_string_pretty(creds) + .context("failed to serialize OpenAI OAuth credentials")?; + + std::fs::write(&path, &data).with_context(|| format!("failed to write {}", path.display()))?; + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o600)) + .with_context(|| format!("failed to set permissions on {}", path.display()))?; + } + + Ok(()) +}