From 6ab1c041e8e12da6d3f3605559f58c21d67239ea Mon Sep 17 00:00:00 2001 From: Victor Sumner Date: Sun, 22 Feb 2026 08:38:18 -0500 Subject: [PATCH 1/9] Fix failing tests and harden runtime error handling --- src/agent/channel.rs | 13 +- src/api/bindings.rs | 302 +++++++++++++++++++++++++------------- src/api/config.rs | 20 +-- src/api/server.rs | 14 +- src/config.rs | 36 ++++- src/main.rs | 105 ++++++------- src/memory/search.rs | 87 ++++++----- src/messaging/slack.rs | 39 ++++- src/prompts/engine.rs | 6 + src/prompts/text.rs | 10 +- src/tools/browser.rs | 10 +- src/tools/spawn_worker.rs | 6 +- tests/bulletin.rs | 20 ++- tests/context_dump.rs | 28 +++- 14 files changed, 452 insertions(+), 244 deletions(-) diff --git a/src/agent/channel.rs b/src/agent/channel.rs index 158564718..aa5d53b40 100644 --- a/src/agent/channel.rs +++ b/src/agent/channel.rs @@ -391,8 +391,11 @@ impl Channel { if messages.len() == 1 { // Single message - process normally - let message = messages.into_iter().next().unwrap(); - self.handle_message(message).await + if let Some(message) = messages.into_iter().next() { + self.handle_message(message).await + } else { + Ok(()) + } } else { // Multiple messages - batch them self.handle_message_batch(messages).await @@ -1147,8 +1150,10 @@ impl Channel { for (key, value) in retrigger_metadata { self.pending_retrigger_metadata.insert(key, value); } - self.retrigger_deadline = - Some(tokio::time::Instant::now() + std::time::Duration::from_millis(RETRIGGER_DEBOUNCE_MS)); + self.retrigger_deadline = Some( + tokio::time::Instant::now() + + std::time::Duration::from_millis(RETRIGGER_DEBOUNCE_MS), + ); } } diff --git a/src/api/bindings.rs b/src/api/bindings.rs index 6c9fc0e0d..fa7fc27f7 100644 --- a/src/api/bindings.rs +++ b/src/api/bindings.rs @@ -125,6 +125,31 @@ pub(super) struct UpdateBindingResponse { message: String, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum HotReloadDisposition { + NoCredentials, + Start, + MissingConfig, +} + +fn hot_reload_disposition( + platform: &'static str, + has_credentials: bool, + has_config: bool, +) -> HotReloadDisposition { + match (has_credentials, has_config) { + (false, _) => HotReloadDisposition::NoCredentials, + (true, true) => HotReloadDisposition::Start, + (true, false) => { + tracing::warn!( + platform, + "credentials provided but messaging config is missing after reload" + ); + HotReloadDisposition::MissingConfig + } + } +} + /// List all bindings, optionally filtered by agent_id. pub(super) async fn list_bindings( State(state): State>, @@ -352,120 +377,142 @@ pub(super) async fn create_binding( let manager_guard = state.messaging_manager.read().await; if let Some(manager) = manager_guard.as_ref() { - if let Some(token) = new_discord_token { - let discord_perms = { - let perms_guard = state.discord_permissions.read().await; - match perms_guard.as_ref() { - Some(existing) => existing.clone(), - None => { - drop(perms_guard); - let perms = crate::config::DiscordPermissions::from_config( - new_config - .messaging - .discord - .as_ref() - .expect("discord config exists when token is provided"), - &new_config.bindings, - ); - let arc_swap = - std::sync::Arc::new(arc_swap::ArcSwap::from_pointee(perms)); - state.set_discord_permissions(arc_swap.clone()).await; - arc_swap + let discord_config = new_config.messaging.discord.as_ref(); + if matches!( + hot_reload_disposition( + "discord", + new_discord_token.is_some(), + discord_config.is_some(), + ), + HotReloadDisposition::Start + ) { + if let (Some(token), Some(discord_config)) = + (new_discord_token.as_ref(), discord_config) + { + let discord_perms = { + let perms_guard = state.discord_permissions.read().await; + match perms_guard.as_ref() { + Some(existing) => existing.clone(), + None => { + drop(perms_guard); + let perms = crate::config::DiscordPermissions::from_config( + discord_config, + &new_config.bindings, + ); + let arc_swap = + std::sync::Arc::new(arc_swap::ArcSwap::from_pointee(perms)); + state.set_discord_permissions(arc_swap.clone()).await; + arc_swap + } } + }; + let adapter = + crate::messaging::discord::DiscordAdapter::new(token, discord_perms); + if let Err(error) = manager.register_and_start(adapter).await { + tracing::error!(%error, "failed to hot-start discord adapter"); } - }; - let adapter = crate::messaging::discord::DiscordAdapter::new(&token, discord_perms); - if let Err(error) = manager.register_and_start(adapter).await { - tracing::error!(%error, "failed to hot-start discord adapter"); } } - if let Some((bot_token, app_token)) = new_slack_tokens { - let slack_perms = { - let perms_guard = state.slack_permissions.read().await; - match perms_guard.as_ref() { - Some(existing) => existing.clone(), - None => { - drop(perms_guard); - let perms = crate::config::SlackPermissions::from_config( - new_config - .messaging - .slack - .as_ref() - .expect("slack config exists when tokens are provided"), - &new_config.bindings, - ); - let arc_swap = - std::sync::Arc::new(arc_swap::ArcSwap::from_pointee(perms)); - state.set_slack_permissions(arc_swap.clone()).await; - arc_swap + let slack_config = new_config.messaging.slack.as_ref(); + if matches!( + hot_reload_disposition("slack", new_slack_tokens.is_some(), slack_config.is_some(),), + HotReloadDisposition::Start + ) { + if let (Some((bot_token, app_token)), Some(slack_config)) = + (new_slack_tokens.as_ref(), slack_config) + { + let slack_perms = { + let perms_guard = state.slack_permissions.read().await; + match perms_guard.as_ref() { + Some(existing) => existing.clone(), + None => { + drop(perms_guard); + let perms = crate::config::SlackPermissions::from_config( + slack_config, + &new_config.bindings, + ); + let arc_swap = + std::sync::Arc::new(arc_swap::ArcSwap::from_pointee(perms)); + state.set_slack_permissions(arc_swap.clone()).await; + arc_swap + } } - } - }; - let slack_commands = new_config - .messaging - .slack - .as_ref() - .map(|s| s.commands.clone()) - .unwrap_or_default(); - match crate::messaging::slack::SlackAdapter::new( - &bot_token, - &app_token, - slack_perms, - slack_commands, - ) { - Ok(adapter) => { - if let Err(error) = manager.register_and_start(adapter).await { - tracing::error!(%error, "failed to hot-start slack adapter"); + }; + match crate::messaging::slack::SlackAdapter::new( + bot_token, + app_token, + slack_perms, + slack_config.commands.clone(), + ) { + Ok(adapter) => { + if let Err(error) = manager.register_and_start(adapter).await { + tracing::error!(%error, "failed to hot-start slack adapter"); + } + } + Err(error) => { + tracing::error!(%error, "failed to build slack adapter"); } - } - Err(error) => { - tracing::error!(%error, "failed to build slack adapter"); } } } - if let Some(token) = new_telegram_token { - let telegram_perms = { - let perms = crate::config::TelegramPermissions::from_config( - new_config - .messaging - .telegram - .as_ref() - .expect("telegram config exists when token is provided"), - &new_config.bindings, - ); - std::sync::Arc::new(arc_swap::ArcSwap::from_pointee(perms)) - }; - let adapter = - crate::messaging::telegram::TelegramAdapter::new(&token, telegram_perms); - if let Err(error) = manager.register_and_start(adapter).await { - tracing::error!(%error, "failed to hot-start telegram adapter"); + let telegram_config = new_config.messaging.telegram.as_ref(); + if matches!( + hot_reload_disposition( + "telegram", + new_telegram_token.is_some(), + telegram_config.is_some(), + ), + HotReloadDisposition::Start + ) { + if let (Some(token), Some(telegram_config)) = + (new_telegram_token.as_ref(), telegram_config) + { + let telegram_perms = { + let perms = crate::config::TelegramPermissions::from_config( + telegram_config, + &new_config.bindings, + ); + std::sync::Arc::new(arc_swap::ArcSwap::from_pointee(perms)) + }; + let adapter = + crate::messaging::telegram::TelegramAdapter::new(token, telegram_perms); + if let Err(error) = manager.register_and_start(adapter).await { + tracing::error!(%error, "failed to hot-start telegram adapter"); + } } } - if let Some((username, oauth_token)) = new_twitch_creds { - let twitch_config = new_config - .messaging - .twitch - .as_ref() - .expect("twitch config exists when credentials are provided"); - let twitch_perms = { - let perms = crate::config::TwitchPermissions::from_config( - twitch_config, - &new_config.bindings, + let twitch_config = new_config.messaging.twitch.as_ref(); + if matches!( + hot_reload_disposition( + "twitch", + new_twitch_creds.is_some(), + twitch_config.is_some(), + ), + HotReloadDisposition::Start + ) { + if let (Some((username, oauth_token)), Some(twitch_config)) = + (new_twitch_creds.as_ref(), twitch_config) + { + let twitch_perms = { + let perms = crate::config::TwitchPermissions::from_config( + twitch_config, + &new_config.bindings, + ); + std::sync::Arc::new(arc_swap::ArcSwap::from_pointee(perms)) + }; + let adapter = crate::messaging::twitch::TwitchAdapter::new( + username, + oauth_token, + twitch_config.channels.clone(), + twitch_config.trigger_prefix.clone(), + twitch_perms, ); - std::sync::Arc::new(arc_swap::ArcSwap::from_pointee(perms)) - }; - let adapter = crate::messaging::twitch::TwitchAdapter::new( - &username, - &oauth_token, - twitch_config.channels.clone(), - twitch_config.trigger_prefix.clone(), - twitch_perms, - ); - if let Err(error) = manager.register_and_start(adapter).await { - tracing::error!(%error, "failed to hot-start twitch adapter"); + if let Err(error) = manager.register_and_start(adapter).await { + tracing::error!(%error, "failed to hot-start twitch adapter"); + } } } } @@ -767,3 +814,60 @@ pub(super) async fn delete_binding( message: "Binding deleted.".to_string(), })) } + +#[cfg(test)] +mod tests { + use super::{HotReloadDisposition, hot_reload_disposition}; + + #[test] + fn hot_reload_disposition_starts_when_credentials_and_config_exist() { + assert_eq!( + hot_reload_disposition("discord", true, true), + HotReloadDisposition::Start + ); + assert_eq!( + hot_reload_disposition("slack", true, true), + HotReloadDisposition::Start + ); + assert_eq!( + hot_reload_disposition("telegram", true, true), + HotReloadDisposition::Start + ); + assert_eq!( + hot_reload_disposition("twitch", true, true), + HotReloadDisposition::Start + ); + } + + #[test] + fn hot_reload_disposition_skips_when_credentials_missing() { + assert_eq!( + hot_reload_disposition("discord", false, true), + HotReloadDisposition::NoCredentials + ); + assert_eq!( + hot_reload_disposition("slack", false, false), + HotReloadDisposition::NoCredentials + ); + } + + #[test] + fn hot_reload_disposition_marks_missing_config_when_token_present() { + assert_eq!( + hot_reload_disposition("discord", true, false), + HotReloadDisposition::MissingConfig + ); + assert_eq!( + hot_reload_disposition("slack", true, false), + HotReloadDisposition::MissingConfig + ); + assert_eq!( + hot_reload_disposition("telegram", true, false), + HotReloadDisposition::MissingConfig + ); + assert_eq!( + hot_reload_disposition("twitch", true, false), + HotReloadDisposition::MissingConfig + ); + } +} diff --git a/src/api/config.rs b/src/api/config.rs index 70bd6d1af..72cdad427 100644 --- a/src/api/config.rs +++ b/src/api/config.rs @@ -408,11 +408,13 @@ fn get_agent_table_mut( fn get_or_create_subtable<'a>( agent: &'a mut toml_edit::Table, key: &str, -) -> &'a mut toml_edit::Table { - if !agent.contains_key(key) { +) -> Result<&'a mut toml_edit::Table, StatusCode> { + if !agent.contains_key(key) || !agent[key].is_table() { agent[key] = toml_edit::Item::Table(toml_edit::Table::new()); } - agent[key].as_table_mut().expect("just created as table") + agent[key] + .as_table_mut() + .ok_or(StatusCode::INTERNAL_SERVER_ERROR) } fn update_routing_table( @@ -421,7 +423,7 @@ fn update_routing_table( routing: &RoutingUpdate, ) -> Result<(), StatusCode> { let agent = get_agent_table_mut(doc, agent_idx)?; - let table = get_or_create_subtable(agent, "routing"); + let table = get_or_create_subtable(agent, "routing")?; if let Some(ref v) = routing.channel { table["channel"] = toml_edit::value(v.as_str()); } @@ -479,7 +481,7 @@ fn update_compaction_table( compaction: &CompactionUpdate, ) -> Result<(), StatusCode> { let agent = get_agent_table_mut(doc, agent_idx)?; - let table = get_or_create_subtable(agent, "compaction"); + let table = get_or_create_subtable(agent, "compaction")?; if let Some(v) = compaction.background_threshold { table["background_threshold"] = toml_edit::value(v as f64); } @@ -498,7 +500,7 @@ fn update_cortex_table( cortex: &CortexUpdate, ) -> Result<(), StatusCode> { let agent = get_agent_table_mut(doc, agent_idx)?; - let table = get_or_create_subtable(agent, "cortex"); + let table = get_or_create_subtable(agent, "cortex")?; if let Some(v) = cortex.tick_interval_secs { table["tick_interval_secs"] = toml_edit::value(v as i64); } @@ -529,7 +531,7 @@ fn update_coalesce_table( coalesce: &CoalesceUpdate, ) -> Result<(), StatusCode> { let agent = get_agent_table_mut(doc, agent_idx)?; - let table = get_or_create_subtable(agent, "coalesce"); + let table = get_or_create_subtable(agent, "coalesce")?; if let Some(v) = coalesce.enabled { table["enabled"] = toml_edit::value(v); } @@ -554,7 +556,7 @@ fn update_memory_persistence_table( memory_persistence: &MemoryPersistenceUpdate, ) -> Result<(), StatusCode> { let agent = get_agent_table_mut(doc, agent_idx)?; - let table = get_or_create_subtable(agent, "memory_persistence"); + let table = get_or_create_subtable(agent, "memory_persistence")?; if let Some(v) = memory_persistence.enabled { table["enabled"] = toml_edit::value(v); } @@ -570,7 +572,7 @@ fn update_browser_table( browser: &BrowserUpdate, ) -> Result<(), StatusCode> { let agent = get_agent_table_mut(doc, agent_idx)?; - let table = get_or_create_subtable(agent, "browser"); + let table = get_or_create_subtable(agent, "browser")?; if let Some(v) = browser.enabled { table["enabled"] = toml_edit::value(v); } diff --git a/src/api/server.rs b/src/api/server.rs index 39dcc9d64..ca4c59972 100644 --- a/src/api/server.rs +++ b/src/api/server.rs @@ -7,8 +7,8 @@ use super::{ }; use axum::Json; -use axum::extract::Request; use axum::Router; +use axum::extract::{Request, State}; use axum::http::{StatusCode, Uri, header}; use axum::middleware::{self, Next}; use axum::response::{Html, IntoResponse, Response}; @@ -177,7 +177,11 @@ pub async fn start_http_server( Ok(handle) } -async fn api_auth_middleware(state: Arc, request: Request, next: Next) -> Response { +async fn api_auth_middleware( + State(state): State>, + request: Request, + next: Next, +) -> Response { let Some(expected_token) = state.auth_token.as_deref() else { return next.run(request).await; }; @@ -197,7 +201,11 @@ async fn api_auth_middleware(state: Arc, request: Request, next: Next) if is_authorized { next.run(request).await } else { - (StatusCode::UNAUTHORIZED, Json(json!({"error": "unauthorized"}))).into_response() + ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "unauthorized"})), + ) + .into_response() } } diff --git a/src/config.rs b/src/config.rs index 6e1c1586a..935ccd63a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -2147,18 +2147,22 @@ impl Config { .providers .into_iter() .map(|(provider_id, config)| { - ( + let api_key = resolve_env_value(&config.api_key).ok_or_else(|| { + ConfigError::Invalid(format!( + "Failed to resolve API key for provider '{provider_id}': env var missing" + )) + })?; + Ok(( provider_id.to_lowercase(), ProviderConfig { api_type: config.api_type, base_url: config.base_url, - api_key: resolve_env_value(&config.api_key) - .expect("Failed to resolve API key for provider"), + api_key, name: config.name, }, - ) + )) }) - .collect(), + .collect::, ConfigError>>()?, }; if let Some(anthropic_key) = llm.anthropic_key.clone() { @@ -3812,6 +3816,28 @@ api_key = "static-provider-key" assert_eq!(second_provider.api_key, "static-provider-key"); } + #[test] + fn test_provider_env_key_missing_returns_error() { + let _lock = env_test_lock().lock().unwrap_or_else(|e| e.into_inner()); + let _env = EnvGuard::new(); + + let toml = r#" +[llm.provider.MissingKeyProv] +api_type = "openai_responses" +base_url = "https://api.example.com/v1" +api_key = "env:OPENAI_API_KEY" +"#; + + let parsed: TomlConfig = toml::from_str(toml).expect("failed to parse test TOML"); + let result = Config::from_toml(parsed, PathBuf::from(".")); + let error = result.expect_err("missing env var should return an error"); + assert!( + error + .to_string() + .contains("Failed to resolve API key for provider 'MissingKeyProv'") + ); + } + #[test] fn test_legacy_llm_keys_auto_migrate_to_providers() { let toml = r#" diff --git a/src/main.rs b/src/main.rs index 0b7fc6b4e..6ffefd292 100644 --- a/src/main.rs +++ b/src/main.rs @@ -129,7 +129,7 @@ struct ActiveChannel { fn main() -> anyhow::Result<()> { rustls::crypto::ring::default_provider() .install_default() - .expect("failed to install rustls crypto provider"); + .map_err(|error| anyhow::anyhow!("failed to install rustls crypto provider: {error:?}"))?; let cli = Cli::parse(); let command = cli.command.unwrap_or(Command::Start { foreground: false }); @@ -559,18 +559,20 @@ fn get_agent_config<'a>( config: &'a spacebot::config::Config, agent_id: Option<&str>, ) -> anyhow::Result<&'a spacebot::config::AgentConfig> { - let agent_id = agent_id.unwrap_or_else(|| { - if config.agents.is_empty() { - panic!("no agents configured"); - } - &config.agents[0].id - }); + let resolved_agent_id = match agent_id { + Some(id) => id, + None => config + .agents + .first() + .map(|agent| agent.id.as_str()) + .ok_or_else(|| anyhow::anyhow!("no agents configured"))?, + }; config .agents .iter() - .find(|a| a.id == agent_id) - .with_context(|| format!("agent not found: {agent_id}")) + .find(|agent| agent.id == resolved_agent_id) + .with_context(|| format!("agent not found: {resolved_agent_id}")) } fn load_config( @@ -607,11 +609,8 @@ async fn run( let (agent_remove_tx, mut agent_remove_rx) = mpsc::channel::(8); // Start HTTP API server if enabled - let mut api_state = spacebot::api::ApiState::new_with_provider_sender( - provider_tx, - agent_tx, - agent_remove_tx, - ); + let mut api_state = + spacebot::api::ApiState::new_with_provider_sender(provider_tx, agent_tx, agent_remove_tx); api_state.auth_token = config.api.auth_token.clone(); let api_state = Arc::new(api_state); @@ -1371,13 +1370,13 @@ async fn initialize_agents( if let Some(discord_config) = &config.messaging.discord && discord_config.enabled { - let adapter = spacebot::messaging::discord::DiscordAdapter::new( - &discord_config.token, - discord_permissions - .clone() - .expect("discord permissions initialized when discord is enabled"), - ); - new_messaging_manager.register(adapter).await; + if let Some(perms) = discord_permissions.clone() { + let adapter = + spacebot::messaging::discord::DiscordAdapter::new(&discord_config.token, perms); + new_messaging_manager.register(adapter).await; + } else { + tracing::warn!("discord enabled but permissions were not initialized"); + } } // Shared Slack permissions (hot-reloadable via file watcher) @@ -1392,20 +1391,22 @@ async fn initialize_agents( if let Some(slack_config) = &config.messaging.slack && slack_config.enabled { - match spacebot::messaging::slack::SlackAdapter::new( - &slack_config.bot_token, - &slack_config.app_token, - slack_permissions - .clone() - .expect("slack permissions initialized when slack is enabled"), - slack_config.commands.clone(), - ) { - Ok(adapter) => { - new_messaging_manager.register(adapter).await; - } - Err(error) => { - tracing::error!(%error, "failed to build slack adapter"); + if let Some(perms) = slack_permissions.clone() { + match spacebot::messaging::slack::SlackAdapter::new( + &slack_config.bot_token, + &slack_config.app_token, + perms, + slack_config.commands.clone(), + ) { + Ok(adapter) => { + new_messaging_manager.register(adapter).await; + } + Err(error) => { + tracing::error!(%error, "failed to build slack adapter"); + } } + } else { + tracing::warn!("slack enabled but permissions were not initialized"); } } @@ -1419,13 +1420,13 @@ async fn initialize_agents( if let Some(telegram_config) = &config.messaging.telegram && telegram_config.enabled { - let adapter = spacebot::messaging::telegram::TelegramAdapter::new( - &telegram_config.token, - telegram_permissions - .clone() - .expect("telegram permissions initialized when telegram is enabled"), - ); - new_messaging_manager.register(adapter).await; + if let Some(perms) = telegram_permissions.clone() { + let adapter = + spacebot::messaging::telegram::TelegramAdapter::new(&telegram_config.token, perms); + new_messaging_manager.register(adapter).await; + } else { + tracing::warn!("telegram enabled but permissions were not initialized"); + } } if let Some(webhook_config) = &config.messaging.webhook @@ -1449,16 +1450,18 @@ async fn initialize_agents( if let Some(twitch_config) = &config.messaging.twitch && twitch_config.enabled { - let adapter = spacebot::messaging::twitch::TwitchAdapter::new( - &twitch_config.username, - &twitch_config.oauth_token, - twitch_config.channels.clone(), - twitch_config.trigger_prefix.clone(), - twitch_permissions - .clone() - .expect("twitch permissions initialized when twitch is enabled"), - ); - new_messaging_manager.register(adapter).await; + if let Some(perms) = twitch_permissions.clone() { + let adapter = spacebot::messaging::twitch::TwitchAdapter::new( + &twitch_config.username, + &twitch_config.oauth_token, + twitch_config.channels.clone(), + twitch_config.trigger_prefix.clone(), + perms, + ); + new_messaging_manager.register(adapter).await; + } else { + tracing::warn!("twitch enabled but permissions were not initialized"); + } } let webchat_adapter = Arc::new(spacebot::messaging::webchat::WebChatAdapter::new()); diff --git a/src/memory/search.rs b/src/memory/search.rs index ae6e815f5..761ceab43 100644 --- a/src/memory/search.rs +++ b/src/memory/search.rs @@ -36,8 +36,8 @@ pub enum SearchSort { /// Bundles all memory search dependencies. pub struct MemorySearch { store: Arc, - embedding_table: EmbeddingTable, - embedding_model: Arc, + embedding_table: Option, + embedding_model: Option>, } impl Clone for MemorySearch { @@ -45,7 +45,7 @@ impl Clone for MemorySearch { Self { store: Arc::clone(&self.store), embedding_table: self.embedding_table.clone(), - embedding_model: Arc::clone(&self.embedding_model), + embedding_model: self.embedding_model.clone(), } } } @@ -67,8 +67,17 @@ impl MemorySearch { ) -> Self { Self { store, - embedding_table, - embedding_model, + embedding_table: Some(embedding_table), + embedding_model: Some(embedding_model), + } + } + + #[cfg(test)] + fn new_metadata_only(store: Arc) -> Self { + Self { + store, + embedding_table: None, + embedding_model: None, } } @@ -79,17 +88,23 @@ impl MemorySearch { /// Get a reference to the embedding table. pub fn embedding_table(&self) -> &EmbeddingTable { - &self.embedding_table + self.embedding_table + .as_ref() + .expect("embedding table unavailable for metadata-only MemorySearch") } /// Get a reference to the embedding model. pub fn embedding_model(&self) -> &EmbeddingModel { - &self.embedding_model + self.embedding_model + .as_deref() + .expect("embedding model unavailable for metadata-only MemorySearch") } /// Get a shared handle to the embedding model (for async embed_one). pub fn embedding_model_arc(&self) -> &Arc { - &self.embedding_model + self.embedding_model + .as_ref() + .expect("embedding model unavailable for metadata-only MemorySearch") } /// Unified search entry point. Dispatches to the appropriate strategy @@ -148,6 +163,17 @@ impl MemorySearch { query: &str, config: &SearchConfig, ) -> Result> { + let embedding_table = self.embedding_table.as_ref().ok_or_else(|| { + crate::error::MemoryError::SearchFailed( + "hybrid search requires an embedding table".to_string(), + ) + })?; + let embedding_model = self.embedding_model.as_ref().ok_or_else(|| { + crate::error::MemoryError::SearchFailed( + "hybrid search requires an embedding model".to_string(), + ) + })?; + // Collect results from different sources let mut vector_results = Vec::new(); let mut fts_results = Vec::new(); @@ -156,8 +182,7 @@ impl MemorySearch { // 1. Full-text search via LanceDB // FTS requires an inverted index. If the index doesn't exist yet (empty // table, first run) this will fail — fall back to vector + graph search. - match self - .embedding_table + match embedding_table .text_search(query, config.max_results_per_source) .await { @@ -179,9 +204,8 @@ impl MemorySearch { } // 2. Vector similarity search via LanceDB - let query_embedding = self.embedding_model.embed_one(query).await?; - match self - .embedding_table + let query_embedding = embedding_model.embed_one(query).await?; + match embedding_table .vector_search(&query_embedding, config.max_results_per_source) .await { @@ -557,15 +581,7 @@ mod tests { async fn test_metadata_search_recent() { let (store, _memories) = setup_search_with_memories().await; - // Construct MemorySearch with dummy lance/embedding (we won't use them) - let lance_dir = tempfile::tempdir().unwrap(); - let lance_conn = lancedb::connect(lance_dir.path().to_str().unwrap()) - .execute() - .await - .unwrap(); - let embedding_table = EmbeddingTable::open_or_create(&lance_conn).await.unwrap(); - let embedding_model = Arc::new(EmbeddingModel::new(lance_dir.path()).unwrap()); - let search = MemorySearch::new(store, embedding_table, embedding_model); + let search = MemorySearch::new_metadata_only(store); let config = SearchConfig { mode: SearchMode::Recent, @@ -586,14 +602,7 @@ mod tests { async fn test_metadata_search_important() { let (store, _memories) = setup_search_with_memories().await; - let lance_dir = tempfile::tempdir().unwrap(); - let lance_conn = lancedb::connect(lance_dir.path().to_str().unwrap()) - .execute() - .await - .unwrap(); - let embedding_table = EmbeddingTable::open_or_create(&lance_conn).await.unwrap(); - let embedding_model = Arc::new(EmbeddingModel::new(lance_dir.path()).unwrap()); - let search = MemorySearch::new(store, embedding_table, embedding_model); + let search = MemorySearch::new_metadata_only(store); let config = SearchConfig { mode: SearchMode::Important, @@ -611,14 +620,7 @@ mod tests { async fn test_metadata_search_typed() { let (store, _memories) = setup_search_with_memories().await; - let lance_dir = tempfile::tempdir().unwrap(); - let lance_conn = lancedb::connect(lance_dir.path().to_str().unwrap()) - .execute() - .await - .unwrap(); - let embedding_table = EmbeddingTable::open_or_create(&lance_conn).await.unwrap(); - let embedding_model = Arc::new(EmbeddingModel::new(lance_dir.path()).unwrap()); - let search = MemorySearch::new(store, embedding_table, embedding_model); + let search = MemorySearch::new_metadata_only(store); let config = SearchConfig { mode: SearchMode::Typed, @@ -636,14 +638,7 @@ mod tests { async fn test_metadata_search_typed_empty() { let (store, _memories) = setup_search_with_memories().await; - let lance_dir = tempfile::tempdir().unwrap(); - let lance_conn = lancedb::connect(lance_dir.path().to_str().unwrap()) - .execute() - .await - .unwrap(); - let embedding_table = EmbeddingTable::open_or_create(&lance_conn).await.unwrap(); - let embedding_model = Arc::new(EmbeddingModel::new(lance_dir.path()).unwrap()); - let search = MemorySearch::new(store, embedding_table, embedding_model); + let search = MemorySearch::new_metadata_only(store); let config = SearchConfig { mode: SearchMode::Typed, diff --git a/src/messaging/slack.rs b/src/messaging/slack.rs index 19e896665..309e4af88 100644 --- a/src/messaging/slack.rs +++ b/src/messaging/slack.rs @@ -139,9 +139,14 @@ async fn handle_message_event( } let state_guard = states.read().await; - let adapter_state = state_guard + let Some(adapter_state) = state_guard .get_user_state::>() - .expect("SlackAdapterState must be in user_state"); + .cloned() + else { + tracing::error!("missing SlackAdapterState in socket mode user_state"); + return Ok(()); + }; + drop(state_guard); let user_id = msg_event.sender.user.as_ref().map(|u| u.0.clone()); @@ -238,9 +243,14 @@ async fn handle_app_mention_event( states: SlackClientEventsUserState, ) -> UserCallbackResult<()> { let state_guard = states.read().await; - let adapter_state = state_guard + let Some(adapter_state) = state_guard .get_user_state::>() - .expect("SlackAdapterState must be in user_state"); + .cloned() + else { + tracing::error!("missing SlackAdapterState in socket mode user_state"); + return Ok(()); + }; + drop(state_guard); let user_id = mention.user.0.clone(); @@ -337,9 +347,17 @@ async fn handle_command_event( states: SlackClientEventsUserState, ) -> UserCallbackResult { let state_guard = states.read().await; - let adapter_state = state_guard + let Some(adapter_state) = state_guard .get_user_state::>() - .expect("SlackAdapterState must be in user_state"); + .cloned() + else { + tracing::error!("missing SlackAdapterState in socket mode user_state"); + return Ok(SlackCommandEventResponse { + content: SlackMessageContent::new(), + response_type: Some(SlackMessageResponseType::Ephemeral), + }); + }; + drop(state_guard); let command_str = event.command.0.clone(); let team_id = event.team_id.0.clone(); @@ -476,9 +494,14 @@ async fn handle_interaction_event( }; let state_guard = states.read().await; - let adapter_state = state_guard + let Some(adapter_state) = state_guard .get_user_state::>() - .expect("SlackAdapterState must be in user_state"); + .cloned() + else { + tracing::error!("missing SlackAdapterState in socket mode user_state"); + return Ok(()); + }; + drop(state_guard); let user_id = block_actions .user diff --git a/src/prompts/engine.rs b/src/prompts/engine.rs index 718e0ab57..df2c5a1fd 100644 --- a/src/prompts/engine.rs +++ b/src/prompts/engine.rs @@ -137,11 +137,17 @@ impl PromptEngine { /// /// # Example /// ```rust + /// use minijinja::context; + /// use spacebot::prompts::PromptEngine; + /// + /// let engine = PromptEngine::new("en")?; /// let ctx = context! { /// identity_context => "Some identity text", /// browser_enabled => true, /// }; /// let rendered = engine.render("channel", ctx)?; + /// assert!(!rendered.is_empty()); + /// # Ok::<(), anyhow::Error>(()) /// ``` pub fn render(&self, template_name: &str, context: Value) -> Result { let template = self diff --git a/src/prompts/text.rs b/src/prompts/text.rs index d9614adae..90c8805f3 100644 --- a/src/prompts/text.rs +++ b/src/prompts/text.rs @@ -6,12 +6,12 @@ //! # Usage //! //! ```rust -//! // At startup (main.rs): -//! prompts::text::init("en").expect("invalid language"); +//! use spacebot::prompts::text; //! -//! // Anywhere: -//! let desc = prompts::text::get("tools/file"); -//! let prompt = prompts::text::get("channel"); +//! let desc = text::get("tools/file"); +//! let prompt = text::get("channel"); +//! assert!(!desc.is_empty()); +//! assert!(!prompt.is_empty()); //! ``` use std::sync::OnceLock; diff --git a/src/tools/browser.rs b/src/tools/browser.rs index e650a5ac3..ee6d8e138 100644 --- a/src/tools/browser.rs +++ b/src/tools/browser.rs @@ -4,8 +4,8 @@ //! via headless Chrome using chromiumoxide. Uses an accessibility-tree based //! ref system for LLM-friendly element addressing. -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use crate::config::BrowserConfig; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use chromiumoxide::browser::{Browser, BrowserConfig as ChromeConfig}; use chromiumoxide::page::ScreenshotParams; @@ -17,6 +17,7 @@ use chromiumoxide_cdp::cdp::browser_protocol::input::{ }; use chromiumoxide_cdp::cdp::browser_protocol::page::CaptureScreenshotFormat; use futures::StreamExt as _; +use reqwest::Url; use rig::completion::ToolDefinition; use rig::tool::Tool; use schemars::JsonSchema; @@ -31,9 +32,8 @@ use tokio::task::JoinHandle; /// Blocks private/loopback IPs, link-local addresses, and cloud metadata endpoints /// to prevent server-side request forgery. fn validate_url(url: &str) -> Result<(), BrowserError> { - let parsed = url::Url::parse(url).map_err(|error| { - BrowserError::new(format!("invalid URL '{url}': {error}")) - })?; + let parsed = Url::parse(url) + .map_err(|error| BrowserError::new(format!("invalid URL '{url}': {error}")))?; match parsed.scheme() { "http" | "https" => {} @@ -91,7 +91,7 @@ fn is_blocked_ip(ip: IpAddr) -> bool { || v4.is_link_local() // 169.254.0.0/16 || v4.is_broadcast() // 255.255.255.255 || v4.is_unspecified() // 0.0.0.0 - || is_v4_cgnat(v4) // 100.64.0.0/10 + || is_v4_cgnat(v4) // 100.64.0.0/10 } IpAddr::V6(v6) => { v6.is_loopback() // ::1 diff --git a/src/tools/spawn_worker.rs b/src/tools/spawn_worker.rs index 35157df78..dd189911c 100644 --- a/src/tools/spawn_worker.rs +++ b/src/tools/spawn_worker.rs @@ -112,8 +112,8 @@ impl Tool for SpawnWorkerTool { } }); - if opencode_enabled { - properties.as_object_mut().unwrap().insert( + if opencode_enabled && let Some(properties_object) = properties.as_object_mut() { + properties_object.insert( "worker_type".to_string(), serde_json::json!({ "type": "string", @@ -122,7 +122,7 @@ impl Tool for SpawnWorkerTool { "description": "\"builtin\" (default) runs a Rig agent loop. \"opencode\" spawns a full OpenCode coding agent — use for complex multi-file coding tasks." }), ); - properties.as_object_mut().unwrap().insert( + properties_object.insert( "directory".to_string(), serde_json::json!({ "type": "string", diff --git a/tests/bulletin.rs b/tests/bulletin.rs index a41cb84b9..1d428b0b5 100644 --- a/tests/bulletin.rs +++ b/tests/bulletin.rs @@ -83,6 +83,18 @@ async fn bootstrap_deps() -> anyhow::Result { }) } +async fn bootstrap_deps_or_skip() -> Option { + match bootstrap_deps().await { + Ok(deps) => Some(deps), + Err(error) => { + eprintln!( + "skipping bulletin integration test (requires local ~/.spacebot with credentials and embedding model cache): {error:#}" + ); + None + } + } +} + /// The cortex user prompt references memory types inline. If a new variant is /// added to MemoryType::ALL, this test fails until the type list is updated. #[test] @@ -120,7 +132,9 @@ fn test_bulletin_prompts_cover_all_memory_types() { #[tokio::test] async fn test_memory_recall_returns_results() { - let deps = bootstrap_deps().await.expect("failed to bootstrap"); + let Some(deps) = bootstrap_deps_or_skip().await else { + return; + }; let config = spacebot::memory::search::SearchConfig::default(); let results = deps @@ -147,7 +161,9 @@ async fn test_memory_recall_returns_results() { #[tokio::test] async fn test_bulletin_generation() { - let deps = bootstrap_deps().await.expect("failed to bootstrap"); + let Some(deps) = bootstrap_deps_or_skip().await else { + return; + }; // Verify the bulletin starts empty let before = deps.runtime_config.memory_bulletin.load(); diff --git a/tests/context_dump.rs b/tests/context_dump.rs index 2f94d5ddf..b8e2f1605 100644 --- a/tests/context_dump.rs +++ b/tests/context_dump.rs @@ -85,6 +85,18 @@ async fn bootstrap_deps() -> anyhow::Result<(spacebot::AgentDeps, spacebot::conf Ok((deps, config)) } +async fn bootstrap_deps_or_skip() -> Option<(spacebot::AgentDeps, spacebot::config::Config)> { + match bootstrap_deps().await { + Ok(deps_and_config) => Some(deps_and_config), + Err(error) => { + eprintln!( + "skipping context_dump integration test (requires local ~/.spacebot with credentials and embedding model cache): {error:#}" + ); + None + } + } +} + /// Print a labeled section with a separator. fn print_section(label: &str, content: &str) { let separator = "=".repeat(80); @@ -158,7 +170,9 @@ fn build_channel_system_prompt(rc: &spacebot::config::RuntimeConfig) -> String { #[tokio::test] async fn dump_channel_context() { - let (deps, _config) = bootstrap_deps().await.expect("failed to bootstrap"); + let Some((deps, _config)) = bootstrap_deps_or_skip().await else { + return; + }; let rc = &deps.runtime_config; let prompt = build_channel_system_prompt(rc); @@ -238,7 +252,9 @@ async fn dump_channel_context() { #[tokio::test] async fn dump_branch_context() { - let (deps, _config) = bootstrap_deps().await.expect("failed to bootstrap"); + let Some((deps, _config)) = bootstrap_deps_or_skip().await else { + return; + }; let rc = &deps.runtime_config; let prompt_engine = rc.prompts.load(); @@ -291,7 +307,9 @@ async fn dump_branch_context() { #[tokio::test] async fn dump_worker_context() { - let (deps, _config) = bootstrap_deps().await.expect("failed to bootstrap"); + let Some((deps, _config)) = bootstrap_deps_or_skip().await else { + return; + }; let rc = &deps.runtime_config; let prompt_engine = rc.prompts.load(); @@ -356,7 +374,9 @@ async fn dump_worker_context() { #[tokio::test] async fn dump_all_contexts() { - let (deps, _config) = bootstrap_deps().await.expect("failed to bootstrap"); + let Some((deps, _config)) = bootstrap_deps_or_skip().await else { + return; + }; let rc = &deps.runtime_config; let prompt_engine = rc.prompts.load(); let instance_dir = rc.instance_dir.to_string_lossy(); From c987ef0c1079668b199fc046108f4ef03b215060 Mon Sep 17 00:00:00 2001 From: Victor Sumner Date: Sun, 22 Feb 2026 09:03:25 -0500 Subject: [PATCH 2/9] Harden channel prompt handling and streaming lock safety --- src/agent/channel.rs | 55 ++++++++++++++++++++++---------------- src/agent/compactor.rs | 20 +++++++++----- src/agent/status.rs | 11 ++++++++ src/agent/worker.rs | 17 +++++++++--- src/api/config.rs | 56 ++++++++++++++++++++------------------- src/messaging/slack.rs | 9 ++++--- src/messaging/telegram.rs | 28 ++++++++++++++------ 7 files changed, 125 insertions(+), 71 deletions(-) diff --git a/src/agent/channel.rs b/src/agent/channel.rs index aa5d53b40..70a35ab93 100644 --- a/src/agent/channel.rs +++ b/src/agent/channel.rs @@ -63,19 +63,20 @@ impl ChannelState { /// Cancel a running worker by aborting its tokio task and cleaning up state. /// Returns an error message if the worker is not found. pub async fn cancel_worker(&self, worker_id: WorkerId) -> std::result::Result<(), String> { - let handle = self.worker_handles.write().await.remove(&worker_id); let removed = self .active_workers .write() .await .remove(&worker_id) .is_some(); + let handle = self.worker_handles.write().await.remove(&worker_id); self.worker_inputs.write().await.remove(&worker_id); + let status_removed = self.status_block.write().await.remove_worker(worker_id); if let Some(handle) = handle { handle.abort(); Ok(()) - } else if removed { + } else if removed || status_removed { // Worker was in active_workers but had no handle (shouldn't happen, but handle gracefully) Ok(()) } else { @@ -465,11 +466,15 @@ impl Channel { .get("telegram_chat_type") .and_then(|v| v.as_str()) }); - self.conversation_context = Some( - prompt_engine - .render_conversation_context(&first.source, server_name, channel_name) - .expect("failed to render conversation context"), - ); + match prompt_engine.render_conversation_context(&first.source, server_name, channel_name) + { + Ok(context) => { + self.conversation_context = Some(context); + } + Err(error) => { + tracing::warn!(%error, "failed to render conversation context"); + } + } } // Persist each message to conversation log (individual audit trail) @@ -560,7 +565,7 @@ impl Channel { // Build system prompt with coalesce hint let system_prompt = self .build_system_prompt_with_coalesce(message_count, elapsed_secs, unique_sender_count) - .await; + .await?; { let mut reply_target = self.state.reply_target_message_id.write().await; @@ -597,7 +602,7 @@ impl Channel { message_count: usize, elapsed_secs: f64, unique_senders: usize, - ) -> String { + ) -> Result { let rc = &self.deps.runtime_config; let prompt_engine = rc.prompts.load(); @@ -611,7 +616,7 @@ impl Channel { let opencode_enabled = rc.opencode.load().enabled; let worker_capabilities = prompt_engine .render_worker_capabilities(browser_enabled, web_search_enabled, opencode_enabled) - .expect("failed to render worker capabilities"); + .map_err(|error| AgentError::Other(error.into()))?; let status_text = { let status = self.state.status_block.read().await; @@ -628,7 +633,7 @@ impl Channel { let empty_to_none = |s: String| if s.is_empty() { None } else { Some(s) }; - prompt_engine + Ok(prompt_engine .render_channel_prompt( empty_to_none(identity_context), empty_to_none(memory_bulletin.to_string()), @@ -639,7 +644,7 @@ impl Channel { coalesce_hint, available_channels, ) - .expect("failed to render channel prompt") + .map_err(|error| AgentError::Other(error.into()))?) } /// Handle an incoming message by running the channel's LLM agent loop. @@ -719,14 +724,18 @@ impl Channel { .get("telegram_chat_type") .and_then(|v| v.as_str()) }); - self.conversation_context = Some( - prompt_engine - .render_conversation_context(&message.source, server_name, channel_name) - .expect("failed to render conversation context"), - ); + match prompt_engine.render_conversation_context(&message.source, server_name, channel_name) + { + Ok(context) => { + self.conversation_context = Some(context); + } + Err(error) => { + tracing::warn!(%error, "failed to render conversation context"); + } + } } - let system_prompt = self.build_system_prompt().await; + let system_prompt = self.build_system_prompt().await?; { let mut reply_target = self.state.reply_target_message_id.write().await; @@ -798,7 +807,7 @@ impl Channel { } /// Assemble the full system prompt using the PromptEngine. - async fn build_system_prompt(&self) -> String { + async fn build_system_prompt(&self) -> Result { let rc = &self.deps.runtime_config; let prompt_engine = rc.prompts.load(); @@ -812,7 +821,7 @@ impl Channel { let opencode_enabled = rc.opencode.load().enabled; let worker_capabilities = prompt_engine .render_worker_capabilities(browser_enabled, web_search_enabled, opencode_enabled) - .expect("failed to render worker capabilities"); + .map_err(|error| AgentError::Other(error.into()))?; let status_text = { let status = self.state.status_block.read().await; @@ -823,7 +832,7 @@ impl Channel { let empty_to_none = |s: String| if s.is_empty() { None } else { Some(s) }; - prompt_engine + Ok(prompt_engine .render_channel_prompt( empty_to_none(identity_context), empty_to_none(memory_bulletin.to_string()), @@ -834,7 +843,7 @@ impl Channel { None, // coalesce_hint - only set for batched messages available_channels, ) - .expect("failed to render channel prompt") + .map_err(|error| AgentError::Other(error.into()))?) } /// Register per-turn tools, run the LLM agentic loop, and clean up. @@ -1418,7 +1427,7 @@ pub async fn spawn_worker_from_state( &rc.instance_dir.display().to_string(), &rc.workspace_dir.display().to_string(), ) - .expect("failed to render worker prompt"); + .map_err(|error| AgentError::Other(error.into()))?; let skills = rc.skills.load(); let browser_config = (**rc.browser_config.load()).clone(); let brave_search_key = (**rc.brave_search_key.load()).clone(); diff --git a/src/agent/compactor.rs b/src/agent/compactor.rs index ab17b0680..618309482 100644 --- a/src/agent/compactor.rs +++ b/src/agent/compactor.rs @@ -108,9 +108,19 @@ impl Compactor { let channel_id = self.channel_id.clone(); let deps = self.deps.clone(); let prompt_engine = deps.runtime_config.prompts.load(); - let compactor_prompt = prompt_engine - .render_static("compactor") - .expect("failed to render compactor prompt"); + let compactor_prompt = match prompt_engine.render_static("compactor") { + Ok(prompt) => prompt, + Err(error) => { + tracing::error!( + channel_id = %self.channel_id, + %error, + "failed to render compactor prompt" + ); + let mut flag = self.is_compacting.write().await; + *flag = false; + return; + } + }; tokio::spawn(async move { let result = run_compaction(&deps, &compactor_prompt, &history, fraction).await; @@ -155,9 +165,7 @@ impl Compactor { // Insert a marker at the beginning let prompt_engine = self.deps.runtime_config.prompts.load(); - let marker = prompt_engine - .render_system_truncation(remove_count) - .expect("failed to render truncation message"); + let marker = prompt_engine.render_system_truncation(remove_count)?; history.insert(0, Message::from(marker)); tracing::warn!( diff --git a/src/agent/status.rs b/src/agent/status.rs index 1b8e35003..963f3ad4d 100644 --- a/src/agent/status.rs +++ b/src/agent/status.rs @@ -143,6 +143,17 @@ impl StatusBlock { }); } + /// Remove an active worker by ID. + pub fn remove_worker(&mut self, worker_id: WorkerId) -> bool { + if let Some(position) = self.active_workers.iter().position(|worker| worker.id == worker_id) + { + self.active_workers.remove(position); + true + } else { + false + } + } + /// Render the status block as a string for context injection. pub fn render(&self) -> String { let mut output = String::new(); diff --git a/src/agent/worker.rs b/src/agent/worker.rs index 931d2c8dd..f2f0bece4 100644 --- a/src/agent/worker.rs +++ b/src/agent/worker.rs @@ -357,10 +357,19 @@ impl Worker { self.hook.send_status("compacting (overflow recovery)"); self.force_compact_history(&mut history).await; let prompt_engine = self.deps.runtime_config.prompts.load(); - let overflow_msg = prompt_engine - .render_system_worker_overflow() - .expect("failed to render worker overflow message"); - follow_up_prompt = format!("{follow_up}\n\n{overflow_msg}"); + match prompt_engine.render_system_worker_overflow() { + Ok(overflow_msg) => { + follow_up_prompt = format!("{follow_up}\n\n{overflow_msg}"); + } + Err(render_error) => { + tracing::warn!( + worker_id = %self.id, + %render_error, + "failed to render worker overflow message" + ); + follow_up_prompt = follow_up.clone(); + } + } } Err(error) => { self.write_failure_log(&history, &format!("follow-up failed: {error}")); diff --git a/src/api/config.rs b/src/api/config.rs index 72cdad427..49722f6f9 100644 --- a/src/api/config.rs +++ b/src/api/config.rs @@ -317,7 +317,13 @@ pub(super) async fn update_agent_config( update_discord_table(&mut doc, discord)?; } - tokio::fs::write(&config_path, doc.to_string()) + let updated_content = doc.to_string(); + if let Err(error) = crate::config::Config::validate_toml(&updated_content) { + tracing::warn!(%error, agent_id = %request.agent_id, "config update validation failed"); + return Err(StatusCode::BAD_REQUEST); + } + + tokio::fs::write(&config_path, updated_content) .await .map_err(|error| { tracing::warn!(%error, "failed to write config.toml"); @@ -326,32 +332,28 @@ pub(super) async fn update_agent_config( tracing::info!(agent_id = %request.agent_id, "config.toml updated via API"); - match crate::config::Config::load_from_path(&config_path) { - Ok(new_config) => { - let runtime_configs = state.runtime_configs.load(); - let mcp_managers = state.mcp_managers.load(); - if let (Some(rc), Some(mcp_manager)) = ( - runtime_configs.get(&request.agent_id).cloned(), - mcp_managers.get(&request.agent_id).cloned(), - ) { - rc.reload_config(&new_config, &request.agent_id, &mcp_manager) - .await; - } - if request.discord.is_some() - && let Some(discord_config) = &new_config.messaging.discord - { - let new_perms = crate::config::DiscordPermissions::from_config( - discord_config, - &new_config.bindings, - ); - let perms = state.discord_permissions.read().await; - if let Some(arc_swap) = perms.as_ref() { - arc_swap.store(std::sync::Arc::new(new_perms)); - } - } - } - Err(error) => { - tracing::warn!(%error, "config.toml written but failed to reload immediately"); + let new_config = crate::config::Config::load_from_path(&config_path).map_err(|error| { + tracing::warn!(%error, "config.toml written but failed to reload immediately"); + StatusCode::BAD_REQUEST + })?; + + let runtime_configs = state.runtime_configs.load(); + let mcp_managers = state.mcp_managers.load(); + if let (Some(rc), Some(mcp_manager)) = ( + runtime_configs.get(&request.agent_id).cloned(), + mcp_managers.get(&request.agent_id).cloned(), + ) { + rc.reload_config(&new_config, &request.agent_id, &mcp_manager) + .await; + } + if request.discord.is_some() + && let Some(discord_config) = &new_config.messaging.discord + { + let new_perms = + crate::config::DiscordPermissions::from_config(discord_config, &new_config.bindings); + let perms = state.discord_permissions.read().await; + if let Some(arc_swap) = perms.as_ref() { + arc_swap.store(std::sync::Arc::new(new_perms)); } } diff --git a/src/messaging/slack.rs b/src/messaging/slack.rs index 309e4af88..a4c5b613d 100644 --- a/src/messaging/slack.rs +++ b/src/messaging/slack.rs @@ -995,8 +995,11 @@ impl Messaging for SlackAdapter { } OutboundResponse::StreamChunk(text) => { - let active = self.active_messages.read().await; - if let Some(ts) = active.get(&message.id) { + let stream_ts = { + let active = self.active_messages.read().await; + active.get(&message.id).cloned() + }; + if let Some(ts) = stream_ts { let display_text = if text.len() > 12_000 { let end = text.floor_char_boundary(11_997); format!("{}...", &text[..end]) @@ -1006,7 +1009,7 @@ impl Messaging for SlackAdapter { let req = SlackApiChatUpdateRequest::new( channel_id.clone(), markdown_content(display_text), - SlackTs(ts.clone()), + SlackTs(ts), ); if let Err(error) = session.chat_update(&req).await { tracing::warn!(%error, "failed to edit streaming message"); diff --git a/src/messaging/telegram.rs b/src/messaging/telegram.rs index cd4bdb1f2..386cb6870 100644 --- a/src/messaging/telegram.rs +++ b/src/messaging/telegram.rs @@ -389,13 +389,19 @@ impl Messaging for TelegramAdapter { ); } OutboundResponse::StreamChunk(text) => { - let mut active = self.active_messages.write().await; - if let Some(stream) = active.get_mut(&message.conversation_id) { - // Rate-limit edits to avoid Telegram API throttling - if stream.last_edit.elapsed() < STREAM_EDIT_INTERVAL { - return Ok(()); - } + let stream_target = { + let active = self.active_messages.read().await; + active.get(&message.conversation_id).and_then(|stream| { + // Rate-limit edits to avoid Telegram API throttling. + if stream.last_edit.elapsed() < STREAM_EDIT_INTERVAL { + None + } else { + Some((stream.chat_id, stream.message_id)) + } + }) + }; + if let Some((chat_id, message_id)) = stream_target { let display_text = if text.len() > MAX_MESSAGE_LENGTH { let end = text.floor_char_boundary(MAX_MESSAGE_LENGTH - 3); format!("{}...", &text[..end]) @@ -405,13 +411,19 @@ impl Messaging for TelegramAdapter { if let Err(error) = self .bot - .edit_message_text(stream.chat_id, stream.message_id, display_text) + .edit_message_text(chat_id, message_id, display_text) .send() .await { tracing::debug!(%error, "failed to edit streaming message"); } - stream.last_edit = Instant::now(); + + let mut active = self.active_messages.write().await; + if let Some(stream) = active.get_mut(&message.conversation_id) + && stream.message_id == message_id + { + stream.last_edit = Instant::now(); + } } } OutboundResponse::StreamEnd => { From 1f5d4af2f53fa062864c5e18f86269fafc3dda13 Mon Sep 17 00:00:00 2001 From: Victor Sumner Date: Sun, 22 Feb 2026 09:05:51 -0500 Subject: [PATCH 3/9] Remove remaining channel and worker prompt panics --- src/agent/channel.rs | 11 +++++++---- src/agent/worker.rs | 9 ++++++++- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/agent/channel.rs b/src/agent/channel.rs index 70a35ab93..c43fef1f4 100644 --- a/src/agent/channel.rs +++ b/src/agent/channel.rs @@ -1197,7 +1197,10 @@ impl Channel { .prompts .load() .render_system_retrigger() - .expect("failed to render retrigger message"); + .unwrap_or_else(|error| { + tracing::warn!(%error, "failed to render retrigger message"); + "Background work completed; continue processing.".to_string() + }); let synthetic = InboundMessage { id: uuid::Uuid::new_v4().to_string(), @@ -1269,7 +1272,7 @@ pub async fn spawn_branch_from_state( &rc.instance_dir.display().to_string(), &rc.workspace_dir.display().to_string(), ) - .expect("failed to render branch prompt"); + .map_err(|error| AgentError::Other(error.into()))?; spawn_branch( state, @@ -1293,10 +1296,10 @@ async fn spawn_memory_persistence_branch( let prompt_engine = deps.runtime_config.prompts.load(); let system_prompt = prompt_engine .render_static("memory_persistence") - .expect("failed to render memory_persistence prompt"); + .map_err(|error| AgentError::Other(error.into()))?; let prompt = prompt_engine .render_system_memory_persistence() - .expect("failed to render memory persistence prompt"); + .map_err(|error| AgentError::Other(error.into()))?; spawn_branch( state, diff --git a/src/agent/worker.rs b/src/agent/worker.rs index f2f0bece4..b314e5f03 100644 --- a/src/agent/worker.rs +++ b/src/agent/worker.rs @@ -460,7 +460,14 @@ impl Worker { let prompt_engine = self.deps.runtime_config.prompts.load(); let marker = prompt_engine .render_system_worker_compact(remove_count, &recap) - .expect("failed to render worker compact message"); + .unwrap_or_else(|error| { + tracing::warn!( + worker_id = %self.id, + %error, + "failed to render worker compact message" + ); + format!("[Compacted worker history: removed {remove_count} messages]") + }); history.insert(0, rig::message::Message::from(marker)); tracing::info!( From 35d04da7bdcd605f3da4cf4d6a56e48f8fa24e2d Mon Sep 17 00:00:00 2001 From: Victor Sumner Date: Sun, 22 Feb 2026 09:10:19 -0500 Subject: [PATCH 4/9] Run rustfmt to satisfy CI format check --- src/agent/channel.rs | 14 ++++++++++---- src/agent/status.rs | 5 ++++- src/agent/worker.rs | 5 ++++- src/secrets/store.rs | 2 +- src/tools/cron.rs | 3 ++- 5 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/agent/channel.rs b/src/agent/channel.rs index c43fef1f4..9b51f792c 100644 --- a/src/agent/channel.rs +++ b/src/agent/channel.rs @@ -466,8 +466,11 @@ impl Channel { .get("telegram_chat_type") .and_then(|v| v.as_str()) }); - match prompt_engine.render_conversation_context(&first.source, server_name, channel_name) - { + match prompt_engine.render_conversation_context( + &first.source, + server_name, + channel_name, + ) { Ok(context) => { self.conversation_context = Some(context); } @@ -724,8 +727,11 @@ impl Channel { .get("telegram_chat_type") .and_then(|v| v.as_str()) }); - match prompt_engine.render_conversation_context(&message.source, server_name, channel_name) - { + match prompt_engine.render_conversation_context( + &message.source, + server_name, + channel_name, + ) { Ok(context) => { self.conversation_context = Some(context); } diff --git a/src/agent/status.rs b/src/agent/status.rs index 963f3ad4d..b494a37d0 100644 --- a/src/agent/status.rs +++ b/src/agent/status.rs @@ -145,7 +145,10 @@ impl StatusBlock { /// Remove an active worker by ID. pub fn remove_worker(&mut self, worker_id: WorkerId) -> bool { - if let Some(position) = self.active_workers.iter().position(|worker| worker.id == worker_id) + if let Some(position) = self + .active_workers + .iter() + .position(|worker| worker.id == worker_id) { self.active_workers.remove(position); true diff --git a/src/agent/worker.rs b/src/agent/worker.rs index b314e5f03..970958c9d 100644 --- a/src/agent/worker.rs +++ b/src/agent/worker.rs @@ -263,7 +263,10 @@ impl Worker { None } }) - .unwrap_or_else(|| "Worker reached maximum segments without a final response.".to_string()); + .unwrap_or_else(|| { + "Worker reached maximum segments without a final response." + .to_string() + }); } self.maybe_compact_history(&mut history).await; diff --git a/src/secrets/store.rs b/src/secrets/store.rs index 5b11681cc..d598be95d 100644 --- a/src/secrets/store.rs +++ b/src/secrets/store.rs @@ -1,7 +1,7 @@ //! Encrypted credentials storage (AES-256-GCM, redb). use crate::error::SecretsError; -use aes_gcm::{aead::Aead, Aes256Gcm, KeyInit, Nonce}; +use aes_gcm::{Aes256Gcm, KeyInit, Nonce, aead::Aead}; use rand::RngCore; use redb::{Database, ReadableTable, TableDefinition}; use sha2::{Digest, Sha256}; diff --git a/src/tools/cron.rs b/src/tools/cron.rs index 5e4729758..83589bb62 100644 --- a/src/tools/cron.rs +++ b/src/tools/cron.rs @@ -182,7 +182,8 @@ impl CronTool { .all(|c| c.is_alphanumeric() || c == '-' || c == '_') { return Err(CronError( - "'id' must be 1-50 characters, alphanumeric with hyphens and underscores only".into(), + "'id' must be 1-50 characters, alphanumeric with hyphens and underscores only" + .into(), )); } From a5d74f3397ddd3b253641ea761ed8c0027da7810 Mon Sep 17 00:00:00 2001 From: Victor Sumner Date: Sun, 22 Feb 2026 09:18:09 -0500 Subject: [PATCH 5/9] Standardize config write validation and runtime reload paths --- src/api/bindings.rs | 356 ++++++++++++++++--------------------------- src/api/config.rs | 112 +++++++++++--- src/api/messaging.rs | 32 ++-- src/api/providers.rs | 11 +- 4 files changed, 243 insertions(+), 268 deletions(-) diff --git a/src/api/bindings.rs b/src/api/bindings.rs index fa7fc27f7..50e39bfca 100644 --- a/src/api/bindings.rs +++ b/src/api/bindings.rs @@ -1,3 +1,6 @@ +use super::config::{ + reload_all_runtime_configs, sync_bindings_and_permissions, write_validated_config, +}; use super::state::ApiState; use axum::Json; @@ -335,12 +338,7 @@ pub(super) async fn create_binding( } bindings_array.push(binding_table); - tokio::fs::write(&config_path, doc.to_string()) - .await - .map_err(|error| { - tracing::warn!(%error, "failed to write config.toml"); - StatusCode::INTERNAL_SERVER_ERROR - })?; + let new_config = write_validated_config(&config_path, doc.to_string()).await?; tracing::info!( agent_id = %request.agent_id, @@ -348,171 +346,145 @@ pub(super) async fn create_binding( "binding created via API" ); - if let Ok(new_config) = crate::config::Config::load_from_path(&config_path) { - let bindings_guard = state.bindings.read().await; - if let Some(bindings_swap) = bindings_guard.as_ref() { - bindings_swap.store(std::sync::Arc::new(new_config.bindings.clone())); - } - drop(bindings_guard); - - if let Some(discord_config) = &new_config.messaging.discord { - let new_perms = crate::config::DiscordPermissions::from_config( - discord_config, - &new_config.bindings, - ); - let perms = state.discord_permissions.read().await; - if let Some(arc_swap) = perms.as_ref() { - arc_swap.store(std::sync::Arc::new(new_perms)); - } - } - - if let Some(slack_config) = &new_config.messaging.slack { - let new_perms = - crate::config::SlackPermissions::from_config(slack_config, &new_config.bindings); - let perms = state.slack_permissions.read().await; - if let Some(arc_swap) = perms.as_ref() { - arc_swap.store(std::sync::Arc::new(new_perms)); - } - } - - let manager_guard = state.messaging_manager.read().await; - if let Some(manager) = manager_guard.as_ref() { - let discord_config = new_config.messaging.discord.as_ref(); - if matches!( - hot_reload_disposition( - "discord", - new_discord_token.is_some(), - discord_config.is_some(), - ), - HotReloadDisposition::Start - ) { - if let (Some(token), Some(discord_config)) = - (new_discord_token.as_ref(), discord_config) - { - let discord_perms = { - let perms_guard = state.discord_permissions.read().await; - match perms_guard.as_ref() { - Some(existing) => existing.clone(), - None => { - drop(perms_guard); - let perms = crate::config::DiscordPermissions::from_config( - discord_config, - &new_config.bindings, - ); - let arc_swap = - std::sync::Arc::new(arc_swap::ArcSwap::from_pointee(perms)); - state.set_discord_permissions(arc_swap.clone()).await; - arc_swap - } + sync_bindings_and_permissions(&state, &new_config).await; + reload_all_runtime_configs(&state, &new_config).await; + + let manager_guard = state.messaging_manager.read().await; + if let Some(manager) = manager_guard.as_ref() { + let discord_config = new_config.messaging.discord.as_ref(); + if matches!( + hot_reload_disposition( + "discord", + new_discord_token.is_some(), + discord_config.is_some(), + ), + HotReloadDisposition::Start + ) { + if let (Some(token), Some(discord_config)) = + (new_discord_token.as_ref(), discord_config) + { + let discord_perms = { + let perms_guard = state.discord_permissions.read().await; + match perms_guard.as_ref() { + Some(existing) => existing.clone(), + None => { + drop(perms_guard); + let perms = crate::config::DiscordPermissions::from_config( + discord_config, + &new_config.bindings, + ); + let arc_swap = + std::sync::Arc::new(arc_swap::ArcSwap::from_pointee(perms)); + state.set_discord_permissions(arc_swap.clone()).await; + arc_swap } - }; - let adapter = - crate::messaging::discord::DiscordAdapter::new(token, discord_perms); - if let Err(error) = manager.register_and_start(adapter).await { - tracing::error!(%error, "failed to hot-start discord adapter"); } + }; + let adapter = crate::messaging::discord::DiscordAdapter::new(token, discord_perms); + if let Err(error) = manager.register_and_start(adapter).await { + tracing::error!(%error, "failed to hot-start discord adapter"); } } + } - let slack_config = new_config.messaging.slack.as_ref(); - if matches!( - hot_reload_disposition("slack", new_slack_tokens.is_some(), slack_config.is_some(),), - HotReloadDisposition::Start - ) { - if let (Some((bot_token, app_token)), Some(slack_config)) = - (new_slack_tokens.as_ref(), slack_config) - { - let slack_perms = { - let perms_guard = state.slack_permissions.read().await; - match perms_guard.as_ref() { - Some(existing) => existing.clone(), - None => { - drop(perms_guard); - let perms = crate::config::SlackPermissions::from_config( - slack_config, - &new_config.bindings, - ); - let arc_swap = - std::sync::Arc::new(arc_swap::ArcSwap::from_pointee(perms)); - state.set_slack_permissions(arc_swap.clone()).await; - arc_swap - } - } - }; - match crate::messaging::slack::SlackAdapter::new( - bot_token, - app_token, - slack_perms, - slack_config.commands.clone(), - ) { - Ok(adapter) => { - if let Err(error) = manager.register_and_start(adapter).await { - tracing::error!(%error, "failed to hot-start slack adapter"); - } + let slack_config = new_config.messaging.slack.as_ref(); + if matches!( + hot_reload_disposition("slack", new_slack_tokens.is_some(), slack_config.is_some(),), + HotReloadDisposition::Start + ) { + if let (Some((bot_token, app_token)), Some(slack_config)) = + (new_slack_tokens.as_ref(), slack_config) + { + let slack_perms = { + let perms_guard = state.slack_permissions.read().await; + match perms_guard.as_ref() { + Some(existing) => existing.clone(), + None => { + drop(perms_guard); + let perms = crate::config::SlackPermissions::from_config( + slack_config, + &new_config.bindings, + ); + let arc_swap = + std::sync::Arc::new(arc_swap::ArcSwap::from_pointee(perms)); + state.set_slack_permissions(arc_swap.clone()).await; + arc_swap } - Err(error) => { - tracing::error!(%error, "failed to build slack adapter"); + } + }; + match crate::messaging::slack::SlackAdapter::new( + bot_token, + app_token, + slack_perms, + slack_config.commands.clone(), + ) { + Ok(adapter) => { + if let Err(error) = manager.register_and_start(adapter).await { + tracing::error!(%error, "failed to hot-start slack adapter"); } } + Err(error) => { + tracing::error!(%error, "failed to build slack adapter"); + } } } + } - let telegram_config = new_config.messaging.telegram.as_ref(); - if matches!( - hot_reload_disposition( - "telegram", - new_telegram_token.is_some(), - telegram_config.is_some(), - ), - HotReloadDisposition::Start - ) { - if let (Some(token), Some(telegram_config)) = - (new_telegram_token.as_ref(), telegram_config) - { - let telegram_perms = { - let perms = crate::config::TelegramPermissions::from_config( - telegram_config, - &new_config.bindings, - ); - std::sync::Arc::new(arc_swap::ArcSwap::from_pointee(perms)) - }; - let adapter = - crate::messaging::telegram::TelegramAdapter::new(token, telegram_perms); - if let Err(error) = manager.register_and_start(adapter).await { - tracing::error!(%error, "failed to hot-start telegram adapter"); - } + let telegram_config = new_config.messaging.telegram.as_ref(); + if matches!( + hot_reload_disposition( + "telegram", + new_telegram_token.is_some(), + telegram_config.is_some(), + ), + HotReloadDisposition::Start + ) { + if let (Some(token), Some(telegram_config)) = + (new_telegram_token.as_ref(), telegram_config) + { + let telegram_perms = { + let perms = crate::config::TelegramPermissions::from_config( + telegram_config, + &new_config.bindings, + ); + std::sync::Arc::new(arc_swap::ArcSwap::from_pointee(perms)) + }; + let adapter = + crate::messaging::telegram::TelegramAdapter::new(token, telegram_perms); + if let Err(error) = manager.register_and_start(adapter).await { + tracing::error!(%error, "failed to hot-start telegram adapter"); } } + } - let twitch_config = new_config.messaging.twitch.as_ref(); - if matches!( - hot_reload_disposition( - "twitch", - new_twitch_creds.is_some(), - twitch_config.is_some(), - ), - HotReloadDisposition::Start - ) { - if let (Some((username, oauth_token)), Some(twitch_config)) = - (new_twitch_creds.as_ref(), twitch_config) - { - let twitch_perms = { - let perms = crate::config::TwitchPermissions::from_config( - twitch_config, - &new_config.bindings, - ); - std::sync::Arc::new(arc_swap::ArcSwap::from_pointee(perms)) - }; - let adapter = crate::messaging::twitch::TwitchAdapter::new( - username, - oauth_token, - twitch_config.channels.clone(), - twitch_config.trigger_prefix.clone(), - twitch_perms, + let twitch_config = new_config.messaging.twitch.as_ref(); + if matches!( + hot_reload_disposition( + "twitch", + new_twitch_creds.is_some(), + twitch_config.is_some(), + ), + HotReloadDisposition::Start + ) { + if let (Some((username, oauth_token)), Some(twitch_config)) = + (new_twitch_creds.as_ref(), twitch_config) + { + let twitch_perms = { + let perms = crate::config::TwitchPermissions::from_config( + twitch_config, + &new_config.bindings, ); - if let Err(error) = manager.register_and_start(adapter).await { - tracing::error!(%error, "failed to hot-start twitch adapter"); - } + std::sync::Arc::new(arc_swap::ArcSwap::from_pointee(perms)) + }; + let adapter = crate::messaging::twitch::TwitchAdapter::new( + username, + oauth_token, + twitch_config.channels.clone(), + twitch_config.trigger_prefix.clone(), + twitch_perms, + ); + if let Err(error) = manager.register_and_start(adapter).await { + tracing::error!(%error, "failed to hot-start twitch adapter"); } } } @@ -648,12 +620,7 @@ pub(super) async fn update_binding( binding.remove("dm_allowed_users"); } - tokio::fs::write(&config_path, doc.to_string()) - .await - .map_err(|error| { - tracing::warn!(%error, "failed to write config.toml"); - StatusCode::INTERNAL_SERVER_ERROR - })?; + let new_config = write_validated_config(&config_path, doc.to_string()).await?; tracing::info!( agent_id = %request.agent_id, @@ -661,33 +628,8 @@ pub(super) async fn update_binding( "binding updated via API" ); - if let Ok(new_config) = crate::config::Config::load_from_path(&config_path) { - let bindings_guard = state.bindings.read().await; - if let Some(bindings_swap) = bindings_guard.as_ref() { - bindings_swap.store(std::sync::Arc::new(new_config.bindings.clone())); - } - drop(bindings_guard); - - if let Some(discord_config) = &new_config.messaging.discord { - let new_perms = crate::config::DiscordPermissions::from_config( - discord_config, - &new_config.bindings, - ); - let perms = state.discord_permissions.read().await; - if let Some(arc_swap) = perms.as_ref() { - arc_swap.store(std::sync::Arc::new(new_perms)); - } - } - - if let Some(slack_config) = &new_config.messaging.slack { - let new_perms = - crate::config::SlackPermissions::from_config(slack_config, &new_config.bindings); - let perms = state.slack_permissions.read().await; - if let Some(arc_swap) = perms.as_ref() { - arc_swap.store(std::sync::Arc::new(new_perms)); - } - } - } + sync_bindings_and_permissions(&state, &new_config).await; + reload_all_runtime_configs(&state, &new_config).await; Ok(Json(UpdateBindingResponse { success: true, @@ -768,12 +710,7 @@ pub(super) async fn delete_binding( bindings_array.remove(idx); - tokio::fs::write(&config_path, doc.to_string()) - .await - .map_err(|error| { - tracing::warn!(%error, "failed to write config.toml"); - StatusCode::INTERNAL_SERVER_ERROR - })?; + let new_config = write_validated_config(&config_path, doc.to_string()).await?; tracing::info!( agent_id = %request.agent_id, @@ -781,33 +718,8 @@ pub(super) async fn delete_binding( "binding deleted via API" ); - if let Ok(new_config) = crate::config::Config::load_from_path(&config_path) { - let bindings_guard = state.bindings.read().await; - if let Some(bindings_swap) = bindings_guard.as_ref() { - bindings_swap.store(std::sync::Arc::new(new_config.bindings.clone())); - } - drop(bindings_guard); - - if let Some(discord_config) = &new_config.messaging.discord { - let new_perms = crate::config::DiscordPermissions::from_config( - discord_config, - &new_config.bindings, - ); - let perms = state.discord_permissions.read().await; - if let Some(arc_swap) = perms.as_ref() { - arc_swap.store(std::sync::Arc::new(new_perms)); - } - } - - if let Some(slack_config) = &new_config.messaging.slack { - let new_perms = - crate::config::SlackPermissions::from_config(slack_config, &new_config.bindings); - let perms = state.slack_permissions.read().await; - if let Some(arc_swap) = perms.as_ref() { - arc_swap.store(std::sync::Arc::new(new_perms)); - } - } - } + sync_bindings_and_permissions(&state, &new_config).await; + reload_all_runtime_configs(&state, &new_config).await; Ok(Json(DeleteBindingResponse { success: true, diff --git a/src/api/config.rs b/src/api/config.rs index 49722f6f9..3d075825f 100644 --- a/src/api/config.rs +++ b/src/api/config.rs @@ -6,6 +6,99 @@ use axum::http::StatusCode; use serde::{Deserialize, Serialize}; use std::sync::Arc; +/// Validate, persist, and reload a config.toml update. +/// +/// Returns the fully parsed `Config` snapshot that was persisted. +pub(super) async fn write_validated_config( + config_path: &std::path::Path, + updated_content: String, +) -> Result { + if let Err(error) = crate::config::Config::validate_toml(&updated_content) { + tracing::warn!(%error, "config update validation failed"); + return Err(StatusCode::BAD_REQUEST); + } + + tokio::fs::write(config_path, updated_content) + .await + .map_err(|error| { + tracing::warn!(%error, "failed to write config.toml"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + crate::config::Config::load_from_path(config_path).map_err(|error| { + tracing::warn!(%error, "config.toml written but failed to reload immediately"); + StatusCode::BAD_REQUEST + }) +} + +/// Reload all live RuntimeConfig instances from a parsed config snapshot. +pub(super) async fn reload_all_runtime_configs( + state: &Arc, + new_config: &crate::config::Config, +) { + let runtime_configs = state.runtime_configs.load(); + let mcp_managers = state.mcp_managers.load(); + let reload_targets = runtime_configs + .iter() + .filter_map(|(agent_id, runtime_config)| { + mcp_managers.get(agent_id).map(|mcp_manager| { + ( + agent_id.clone(), + runtime_config.clone(), + mcp_manager.clone(), + ) + }) + }) + .collect::>(); + drop(runtime_configs); + drop(mcp_managers); + + for (agent_id, runtime_config, mcp_manager) in reload_targets { + runtime_config + .reload_config(new_config, &agent_id, &mcp_manager) + .await; + } +} + +/// Sync bindings and messaging permission snapshots from a parsed config snapshot. +pub(super) async fn sync_bindings_and_permissions( + state: &Arc, + new_config: &crate::config::Config, +) { + let bindings_guard = state.bindings.read().await; + if let Some(bindings_swap) = bindings_guard.as_ref() { + bindings_swap.store(std::sync::Arc::new(new_config.bindings.clone())); + } + drop(bindings_guard); + + let discord_permissions = new_config + .messaging + .discord + .as_ref() + .map(|discord_config| { + crate::config::DiscordPermissions::from_config(discord_config, &new_config.bindings) + }) + .unwrap_or_default(); + let discord_guard = state.discord_permissions.read().await; + if let Some(arc_swap) = discord_guard.as_ref() { + arc_swap.store(std::sync::Arc::new(discord_permissions)); + } + drop(discord_guard); + + let slack_permissions = new_config + .messaging + .slack + .as_ref() + .map(|slack_config| { + crate::config::SlackPermissions::from_config(slack_config, &new_config.bindings) + }) + .unwrap_or_default(); + let slack_guard = state.slack_permissions.read().await; + if let Some(arc_swap) = slack_guard.as_ref() { + arc_swap.store(std::sync::Arc::new(slack_permissions)); + } +} + #[derive(Serialize, Debug)] pub(super) struct RoutingSection { channel: String, @@ -317,26 +410,9 @@ pub(super) async fn update_agent_config( update_discord_table(&mut doc, discord)?; } - let updated_content = doc.to_string(); - if let Err(error) = crate::config::Config::validate_toml(&updated_content) { - tracing::warn!(%error, agent_id = %request.agent_id, "config update validation failed"); - return Err(StatusCode::BAD_REQUEST); - } - - tokio::fs::write(&config_path, updated_content) - .await - .map_err(|error| { - tracing::warn!(%error, "failed to write config.toml"); - StatusCode::INTERNAL_SERVER_ERROR - })?; - + let new_config = write_validated_config(&config_path, doc.to_string()).await?; tracing::info!(agent_id = %request.agent_id, "config.toml updated via API"); - let new_config = crate::config::Config::load_from_path(&config_path).map_err(|error| { - tracing::warn!(%error, "config.toml written but failed to reload immediately"); - StatusCode::BAD_REQUEST - })?; - let runtime_configs = state.runtime_configs.load(); let mcp_managers = state.mcp_managers.load(); if let (Some(rc), Some(mcp_manager)) = ( diff --git a/src/api/messaging.rs b/src/api/messaging.rs index 59811566c..429e738cf 100644 --- a/src/api/messaging.rs +++ b/src/api/messaging.rs @@ -1,3 +1,6 @@ +use super::config::{ + reload_all_runtime_configs, sync_bindings_and_permissions, write_validated_config, +}; use super::state::ApiState; use axum::Json; @@ -222,19 +225,9 @@ pub(super) async fn disconnect_platform( } } - tokio::fs::write(&config_path, doc.to_string()) - .await - .map_err(|error| { - tracing::warn!(%error, "failed to write config.toml"); - StatusCode::INTERNAL_SERVER_ERROR - })?; - - if let Ok(new_config) = crate::config::Config::load_from_path(&config_path) { - let bindings_guard = state.bindings.read().await; - if let Some(bindings_swap) = bindings_guard.as_ref() { - bindings_swap.store(std::sync::Arc::new(new_config.bindings.clone())); - } - } + let new_config = write_validated_config(&config_path, doc.to_string()).await?; + sync_bindings_and_permissions(&state, &new_config).await; + reload_all_runtime_configs(&state, &new_config).await; let manager_guard = state.messaging_manager.read().await; if let Some(manager) = manager_guard.as_ref() @@ -286,20 +279,15 @@ pub(super) async fn toggle_platform( table["enabled"] = toml_edit::value(request.enabled); - tokio::fs::write(&config_path, doc.to_string()) - .await - .map_err(|error| { - tracing::warn!(%error, "failed to write config.toml"); - StatusCode::INTERNAL_SERVER_ERROR - })?; + let new_config = write_validated_config(&config_path, doc.to_string()).await?; + sync_bindings_and_permissions(&state, &new_config).await; + reload_all_runtime_configs(&state, &new_config).await; let manager_guard = state.messaging_manager.read().await; let manager = manager_guard.as_ref(); if request.enabled { - if let Ok(new_config) = crate::config::Config::load_from_path(&config_path) - && let Some(manager) = manager - { + if let Some(manager) = manager { match platform.as_str() { "discord" => { if let Some(discord_config) = &new_config.messaging.discord { diff --git a/src/api/providers.rs b/src/api/providers.rs index f70c8eac4..d732c6c1c 100644 --- a/src/api/providers.rs +++ b/src/api/providers.rs @@ -1,3 +1,4 @@ +use super::config::{reload_all_runtime_configs, write_validated_config}; use super::state::ApiState; use axum::Json; @@ -445,9 +446,8 @@ pub(super) async fn update_provider( } } - tokio::fs::write(&config_path, doc.to_string()) - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + let new_config = write_validated_config(&config_path, doc.to_string()).await?; + reload_all_runtime_configs(&state, &new_config).await; state .provider_setup_tx @@ -579,9 +579,8 @@ pub(super) async fn delete_provider( table.remove(key_name); } - tokio::fs::write(&config_path, doc.to_string()) - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + let new_config = write_validated_config(&config_path, doc.to_string()).await?; + reload_all_runtime_configs(&state, &new_config).await; Ok(Json(ProviderUpdateResponse { success: true, From f1c90ccb9097d5226306b1c636d61930aad7608c Mon Sep 17 00:00:00 2001 From: Victor Sumner Date: Sun, 22 Feb 2026 09:27:24 -0500 Subject: [PATCH 6/9] Add API tests for config write and reload semantics --- src/api/bindings.rs | 284 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 283 insertions(+), 1 deletion(-) diff --git a/src/api/bindings.rs b/src/api/bindings.rs index 50e39bfca..e6adba51e 100644 --- a/src/api/bindings.rs +++ b/src/api/bindings.rs @@ -729,7 +729,90 @@ pub(super) async fn delete_binding( #[cfg(test)] mod tests { - use super::{HotReloadDisposition, hot_reload_disposition}; + use super::{CreateBindingRequest, HotReloadDisposition, hot_reload_disposition}; + use crate::api::providers; + use crate::api::state::ApiState; + use crate::config::{Config, DiscordPermissions, RuntimeConfig, SlackPermissions}; + use arc_swap::ArcSwap; + use axum::Json; + use axum::extract::State; + use axum::http::StatusCode; + use std::collections::HashMap; + use std::sync::Arc; + + const VALID_BASE_CONFIG: &str = r#" +[llm] +anthropic_key = "test-anthropic-key" + +[defaults.routing] +channel = "anthropic/claude-sonnet-4" +branch = "anthropic/claude-sonnet-4" +worker = "anthropic/claude-sonnet-4" +compactor = "anthropic/claude-sonnet-4" +cortex = "anthropic/claude-sonnet-4" + +[[agents]] +id = "main" +default = true +"#; + + const VALID_DISCORD_CONFIG: &str = r#" +[llm] +anthropic_key = "test-anthropic-key" + +[[agents]] +id = "main" +default = true + +[messaging.discord] +enabled = true +token = "discord-test-token" +dm_allowed_users = ["111"] +"#; + + fn new_test_state() -> Arc { + let (provider_setup_tx, _) = tokio::sync::mpsc::channel(8); + let (agent_tx, _) = tokio::sync::mpsc::channel(8); + let (agent_remove_tx, _) = tokio::sync::mpsc::channel(8); + Arc::new(ApiState::new_with_provider_sender( + provider_setup_tx, + agent_tx, + agent_remove_tx, + )) + } + + async fn configure_runtime_for_main( + state: &Arc, + config_path: &std::path::Path, + ) -> Arc { + let config = Config::load_from_path(config_path).expect("config should load for test"); + let resolved_agent = config + .resolve_agents() + .into_iter() + .find(|agent| agent.id == "main") + .expect("main agent should exist in resolved config"); + let runtime_config = Arc::new(RuntimeConfig::new( + &config.instance_dir, + &resolved_agent, + &config.defaults, + crate::prompts::PromptEngine::new("en").expect("prompt engine should build"), + crate::identity::Identity::default(), + crate::skills::SkillSet::default(), + )); + + let mut runtime_configs = HashMap::new(); + runtime_configs.insert("main".to_string(), runtime_config.clone()); + state.set_runtime_configs(runtime_configs); + + let mut mcp_managers = HashMap::new(); + mcp_managers.insert( + "main".to_string(), + Arc::new(crate::mcp::McpManager::new(resolved_agent.mcp.clone())), + ); + state.set_mcp_managers(mcp_managers); + + runtime_config + } #[test] fn hot_reload_disposition_starts_when_credentials_and_config_exist() { @@ -782,4 +865,203 @@ mod tests { HotReloadDisposition::MissingConfig ); } + + #[tokio::test] + async fn create_binding_returns_bad_request_when_config_validation_fails() { + const MISSING_ENV_VAR: &str = "SPACEBOT_TEST_MISSING_PROVIDER_KEY_4E845CE0234A45D6AAB6"; + let config_toml = format!( + r#" +[llm.provider.invalid] +api_type = "openai_completions" +base_url = "https://api.example.com/v1" +api_key = "env:{MISSING_ENV_VAR}" +"# + ); + + let temp_dir = tempfile::tempdir().expect("temp dir should be created"); + let config_path = temp_dir.path().join("config.toml"); + tokio::fs::write(&config_path, config_toml) + .await + .expect("config.toml should be written"); + + let state = new_test_state(); + state.set_config_path(config_path).await; + + let result = super::create_binding( + State(state), + Json(CreateBindingRequest { + agent_id: "main".to_string(), + channel: "discord".to_string(), + guild_id: Some("123".to_string()), + workspace_id: None, + chat_id: None, + channel_ids: vec!["456".to_string()], + require_mention: false, + dm_allowed_users: Vec::new(), + platform_credentials: None, + }), + ) + .await; + + assert!( + matches!(result, Err(StatusCode::BAD_REQUEST)), + "invalid config update should return BAD_REQUEST" + ); + } + + #[tokio::test] + async fn update_provider_writes_config_and_reloads_runtime_routing() { + let temp_dir = tempfile::tempdir().expect("temp dir should be created"); + let config_path = temp_dir.path().join("config.toml"); + tokio::fs::write(&config_path, VALID_BASE_CONFIG) + .await + .expect("config.toml should be written"); + + let state = new_test_state(); + state.set_config_path(config_path.clone()).await; + let runtime_config = configure_runtime_for_main(&state, &config_path).await; + + let initial_routing = runtime_config.routing.load(); + assert_eq!(initial_routing.channel, "anthropic/claude-sonnet-4"); + + let request_body = serde_json::json!({ + "provider": "openai", + "api_key": "test-openai-key", + "model": "openai/gpt-4.1-mini" + }); + let result = providers::update_provider( + State(state), + Json( + serde_json::from_value(request_body).expect("provider request should deserialize"), + ), + ) + .await + .expect("provider update should succeed"); + + let response_json = + serde_json::to_value(result.0).expect("provider response should serialize"); + assert_eq!(response_json["success"], true); + + let config_after_update = tokio::fs::read_to_string(&config_path) + .await + .expect("config should be readable"); + assert!( + config_after_update.contains("openai_key = \"test-openai-key\""), + "updated config should persist provider key" + ); + + let updated_routing = runtime_config.routing.load(); + assert_eq!(updated_routing.channel, "openai/gpt-4.1-mini"); + assert_eq!(updated_routing.branch, "openai/gpt-4.1-mini"); + assert_eq!(updated_routing.worker, "openai/gpt-4.1-mini"); + assert_eq!(updated_routing.compactor, "openai/gpt-4.1-mini"); + assert_eq!(updated_routing.cortex, "openai/gpt-4.1-mini"); + } + + #[tokio::test] + async fn create_binding_writes_config_and_syncs_bindings_and_permissions() { + let temp_dir = tempfile::tempdir().expect("temp dir should be created"); + let config_path = temp_dir.path().join("config.toml"); + tokio::fs::write(&config_path, VALID_DISCORD_CONFIG) + .await + .expect("config.toml should be written"); + + let state = new_test_state(); + state.set_config_path(config_path.clone()).await; + let _runtime_config = configure_runtime_for_main(&state, &config_path).await; + state + .set_bindings(Arc::new(ArcSwap::from_pointee(Vec::new()))) + .await; + state + .set_discord_permissions(Arc::new(ArcSwap::from_pointee( + DiscordPermissions::default(), + ))) + .await; + state + .set_slack_permissions(Arc::new(ArcSwap::from_pointee(SlackPermissions { + workspace_filter: Some(vec!["stale-workspace".to_string()]), + channel_filter: HashMap::from([( + "stale-workspace".to_string(), + vec!["stale-channel".to_string()], + )]), + dm_allowed_users: vec!["stale-user".to_string()], + }))) + .await; + + let result = super::create_binding( + State(state.clone()), + Json(CreateBindingRequest { + agent_id: "main".to_string(), + channel: "discord".to_string(), + guild_id: Some("123".to_string()), + workspace_id: None, + chat_id: None, + channel_ids: vec!["456".to_string()], + require_mention: true, + dm_allowed_users: vec!["789".to_string()], + platform_credentials: None, + }), + ) + .await + .expect("binding create should succeed"); + + assert!(result.0.success); + + let written_config = + Config::load_from_path(&config_path).expect("written config should parse"); + assert_eq!(written_config.bindings.len(), 1); + assert_eq!(written_config.bindings[0].channel, "discord"); + assert_eq!( + written_config.bindings[0].guild_id.as_deref(), + Some("123"), + "binding should be persisted to config.toml" + ); + + let bindings_guard = state.bindings.read().await; + let bindings_swap = bindings_guard + .as_ref() + .expect("bindings ArcSwap should be registered") + .clone(); + drop(bindings_guard); + let live_bindings = bindings_swap.load(); + assert_eq!(live_bindings.len(), 1); + assert_eq!(live_bindings[0].channel, "discord"); + assert_eq!(live_bindings[0].guild_id.as_deref(), Some("123")); + + let discord_guard = state.discord_permissions.read().await; + let discord_swap = discord_guard + .as_ref() + .expect("discord permissions ArcSwap should be registered") + .clone(); + drop(discord_guard); + let discord_permissions = discord_swap.load(); + assert_eq!(discord_permissions.guild_filter, Some(vec![123])); + let guild_channels = discord_permissions + .channel_filter + .get(&123) + .expect("channel filter should include the guild"); + assert_eq!(guild_channels, &vec![456]); + assert!( + discord_permissions.dm_allowed_users.contains(&111), + "dm user from messaging.discord config should be retained" + ); + assert!( + discord_permissions.dm_allowed_users.contains(&789), + "dm user from binding should be merged into permissions" + ); + + let slack_guard = state.slack_permissions.read().await; + let slack_swap = slack_guard + .as_ref() + .expect("slack permissions ArcSwap should be registered") + .clone(); + drop(slack_guard); + let slack_permissions = slack_swap.load(); + assert_eq!( + slack_permissions.workspace_filter, None, + "slack permissions should reset when slack config is absent" + ); + assert!(slack_permissions.channel_filter.is_empty()); + assert!(slack_permissions.dm_allowed_users.is_empty()); + } } From 535703bae9ea9a0da8988930331fd78c143582b7 Mon Sep 17 00:00:00 2001 From: Victor Sumner Date: Sun, 22 Feb 2026 10:20:15 -0500 Subject: [PATCH 7/9] Standardize remaining API config mutation reload paths --- src/api/agents.rs | 70 +++++------------ src/api/config.rs | 8 ++ src/api/mcp.rs | 16 ++-- src/api/settings.rs | 184 +++++++++++++++++++++++++++++++++----------- 4 files changed, 175 insertions(+), 103 deletions(-) diff --git a/src/api/agents.rs b/src/api/agents.rs index d655e4142..a8739ebe2 100644 --- a/src/api/agents.rs +++ b/src/api/agents.rs @@ -1,3 +1,4 @@ +use super::config::{reload_all_runtime_configs, write_validated_config}; use super::state::{AgentInfo, ApiState}; use crate::agent::cortex::CortexLogger; @@ -265,42 +266,21 @@ pub(super) async fn create_agent( new_table["id"] = toml_edit::value(&agent_id); agents_array.push(new_table); - tokio::fs::write(&config_path, doc.to_string()) - .await - .map_err(|error| { - tracing::warn!(%error, "failed to write config.toml"); + let new_config = write_validated_config(&config_path, doc.to_string()).await?; + reload_all_runtime_configs(&state, &new_config).await; + + let agent_config = new_config + .resolve_agents() + .into_iter() + .find(|resolved| resolved.id == agent_id) + .ok_or_else(|| { + tracing::error!( + agent_id = %agent_id, + "newly created agent missing from resolved config" + ); StatusCode::INTERNAL_SERVER_ERROR })?; - let defaults = state.defaults_config.read().await; - let defaults = defaults.as_ref().ok_or_else(|| { - tracing::error!("defaults config not available"); - StatusCode::INTERNAL_SERVER_ERROR - })?; - - let raw_config = crate::config::AgentConfig { - id: agent_id.clone(), - default: false, - workspace: None, - routing: None, - max_concurrent_branches: None, - max_concurrent_workers: None, - max_turns: None, - branch_max_turns: None, - context_window: None, - compaction: None, - memory_persistence: None, - coalesce: None, - ingestion: None, - cortex: None, - browser: None, - mcp: None, - brave_search_key: None, - cron: Vec::new(), - }; - let agent_config = raw_config.resolve(&instance_dir, defaults); - let _ = defaults; - for dir in [ &agent_config.workspace, &agent_config.data_dir, @@ -384,16 +364,7 @@ pub(super) async fn create_agent( .clone() }; - let defaults_for_runtime = { - let guard = state.defaults_config.read().await; - guard - .as_ref() - .ok_or_else(|| { - tracing::error!("defaults config not available"); - StatusCode::INTERNAL_SERVER_ERROR - })? - .clone() - }; + let defaults_for_runtime = new_config.defaults.clone(); let runtime_config = std::sync::Arc::new(crate::config::RuntimeConfig::new( &instance_dir, @@ -588,6 +559,7 @@ pub(super) async fn delete_agent( // Remove the [[agents]] entry from config.toml let config_path = state.config_path.read().await.clone(); + let mut new_config_snapshot: Option = None; if config_path.exists() { let content = tokio::fs::read_to_string(&config_path) .await @@ -619,12 +591,8 @@ pub(super) async fn delete_agent( } } - tokio::fs::write(&config_path, doc.to_string()) - .await - .map_err(|error| { - tracing::warn!(%error, "failed to write config.toml"); - StatusCode::INTERNAL_SERVER_ERROR - })?; + let new_config = write_validated_config(&config_path, doc.to_string()).await?; + new_config_snapshot = Some(new_config); } // Close the SQLite pool before removing state @@ -682,6 +650,10 @@ pub(super) async fn delete_agent( .store(std::sync::Arc::new(sessions)); } + if let Some(new_config) = new_config_snapshot.as_ref() { + reload_all_runtime_configs(&state, new_config).await; + } + // Signal the main event loop to remove the agent if let Err(error) = state.agent_remove_tx.send(agent_id.clone()).await { tracing::error!(%error, "failed to send agent removal to main loop"); diff --git a/src/api/config.rs b/src/api/config.rs index 3d075825f..a2e73d3a2 100644 --- a/src/api/config.rs +++ b/src/api/config.rs @@ -36,6 +36,14 @@ pub(super) async fn reload_all_runtime_configs( state: &Arc, new_config: &crate::config::Config, ) { + let llm_manager_guard = state.llm_manager.read().await; + if let Some(llm_manager) = llm_manager_guard.as_ref() { + llm_manager.reload_config(new_config.llm.clone()); + } + drop(llm_manager_guard); + + state.set_defaults_config(new_config.defaults.clone()).await; + let runtime_configs = state.runtime_configs.load(); let mcp_managers = state.mcp_managers.load(); let reload_targets = runtime_configs diff --git a/src/api/mcp.rs b/src/api/mcp.rs index 1166182ae..a5cc9999a 100644 --- a/src/api/mcp.rs +++ b/src/api/mcp.rs @@ -3,6 +3,7 @@ //! CRUD endpoints for `[[mcp_servers]]` in config.toml, plus per-agent //! connection status. +use super::config::{reload_all_runtime_configs, write_validated_config}; use super::state::ApiState; use axum::Json; @@ -199,9 +200,8 @@ pub(super) async fn create_mcp_server( arr.push(new_table); } - tokio::fs::write(&config_path, doc.to_string()) - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + let new_config = write_validated_config(&config_path, doc.to_string()).await?; + reload_all_runtime_configs(&state, &new_config).await; Ok(Json(MutationResponse { success: true, @@ -279,9 +279,8 @@ pub(super) async fn update_mcp_server( })); } - tokio::fs::write(&config_path, doc.to_string()) - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + let new_config = write_validated_config(&config_path, doc.to_string()).await?; + reload_all_runtime_configs(&state, &new_config).await; Ok(Json(MutationResponse { success: true, @@ -341,9 +340,8 @@ pub(super) async fn delete_mcp_server( })); } - tokio::fs::write(&config_path, doc.to_string()) - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + let new_config = write_validated_config(&config_path, doc.to_string()).await?; + reload_all_runtime_configs(&state, &new_config).await; Ok(Json(MutationResponse { success: true, diff --git a/src/api/settings.rs b/src/api/settings.rs index 754c12f25..ec5cfbf8c 100644 --- a/src/api/settings.rs +++ b/src/api/settings.rs @@ -1,3 +1,6 @@ +use super::config::{ + reload_all_runtime_configs, sync_bindings_and_permissions, write_validated_config, +}; use super::state::ApiState; use axum::Json; @@ -329,9 +332,8 @@ pub(super) async fn update_global_settings( } } - tokio::fs::write(&config_path, doc.to_string()) - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + let new_config = write_validated_config(&config_path, doc.to_string()).await?; + reload_all_runtime_configs(&state, &new_config).await; let message = if requires_restart { "Settings updated. API server changes require a restart to take effect.".to_string() @@ -412,54 +414,146 @@ pub(super) async fn update_raw_config( return Err(StatusCode::INTERNAL_SERVER_ERROR); } - if let Err(error) = crate::config::Config::validate_toml(&request.content) { - return Ok(Json(RawConfigUpdateResponse { - success: false, - message: format!("Validation error: {error}"), - })); - } - - tokio::fs::write(&config_path, &request.content) - .await - .map_err(|error| { - tracing::warn!(%error, "failed to write config.toml"); - StatusCode::INTERNAL_SERVER_ERROR - })?; + let new_config = match write_validated_config(&config_path, request.content).await { + Ok(config) => config, + Err(StatusCode::BAD_REQUEST) => { + return Ok(Json(RawConfigUpdateResponse { + success: false, + message: "Validation error: config could not be applied.".to_string(), + })); + } + Err(status) => return Err(status), + }; tracing::info!("config.toml updated via raw editor"); - match crate::config::Config::load_from_path(&config_path) { - Ok(new_config) => { - let runtime_configs = state.runtime_configs.load(); - let mcp_managers = state.mcp_managers.load(); - let reload_targets = runtime_configs - .iter() - .filter_map(|(agent_id, runtime_config)| { - mcp_managers.get(agent_id).map(|mcp_manager| { - ( - agent_id.clone(), - runtime_config.clone(), - mcp_manager.clone(), - ) - }) - }) - .collect::>(); - drop(runtime_configs); - drop(mcp_managers); - - for (agent_id, runtime_config, mcp_manager) in reload_targets { - runtime_config - .reload_config(&new_config, &agent_id, &mcp_manager) - .await; - } - } - Err(error) => { - tracing::warn!(%error, "config.toml written but failed to reload immediately"); - } - } + sync_bindings_and_permissions(&state, &new_config).await; + reload_all_runtime_configs(&state, &new_config).await; Ok(Json(RawConfigUpdateResponse { success: true, message: "Config saved and reloaded.".to_string(), })) } + +#[cfg(test)] +mod tests { + use super::{RawConfigUpdateRequest, update_raw_config}; + use crate::api::state::ApiState; + use axum::Json; + use axum::extract::State; + use std::sync::Arc; + + const VALID_BASE_CONFIG: &str = r#" +[llm] +anthropic_key = "test-anthropic-key" + +[defaults.routing] +channel = "anthropic/claude-sonnet-4" +branch = "anthropic/claude-sonnet-4" +worker = "anthropic/claude-sonnet-4" +compactor = "anthropic/claude-sonnet-4" +cortex = "anthropic/claude-sonnet-4" + +[[agents]] +id = "main" +default = true +"#; + + const UPDATED_VALID_CONFIG: &str = r#" +[llm] +anthropic_key = "test-anthropic-key" + +[defaults.routing] +channel = "openai/gpt-4.1-mini" +branch = "openai/gpt-4.1-mini" +worker = "openai/gpt-4.1-mini" +compactor = "openai/gpt-4.1-mini" +cortex = "openai/gpt-4.1-mini" + +[[agents]] +id = "main" +default = true +"#; + + fn new_test_state() -> Arc { + let (provider_setup_tx, _) = tokio::sync::mpsc::channel(8); + let (agent_tx, _) = tokio::sync::mpsc::channel(8); + let (agent_remove_tx, _) = tokio::sync::mpsc::channel(8); + Arc::new(ApiState::new_with_provider_sender( + provider_setup_tx, + agent_tx, + agent_remove_tx, + )) + } + + #[tokio::test] + async fn update_raw_config_returns_failure_without_persisting_on_invalid_content() { + const MISSING_ENV_VAR: &str = "SPACEBOT_TEST_MISSING_PROVIDER_KEY_9C4187220F46465A91C3"; + let invalid_content = format!( + r#" +[llm.provider.invalid] +api_type = "openai_completions" +base_url = "https://api.example.com/v1" +api_key = "env:{MISSING_ENV_VAR}" +"# + ); + + let temp_dir = tempfile::tempdir().expect("temp dir should be created"); + let config_path = temp_dir.path().join("config.toml"); + tokio::fs::write(&config_path, VALID_BASE_CONFIG) + .await + .expect("config.toml should be written"); + + let state = new_test_state(); + state.set_config_path(config_path.clone()).await; + + let response = update_raw_config( + State(state), + Json(RawConfigUpdateRequest { + content: invalid_content, + }), + ) + .await + .expect("invalid raw config should return API response") + .0; + + assert!(!response.success); + assert!(response.message.contains("Validation error")); + + let persisted_content = tokio::fs::read_to_string(&config_path) + .await + .expect("config should be readable after failed update"); + assert_eq!(persisted_content, VALID_BASE_CONFIG); + } + + #[tokio::test] + async fn update_raw_config_persists_valid_content() { + let temp_dir = tempfile::tempdir().expect("temp dir should be created"); + let config_path = temp_dir.path().join("config.toml"); + tokio::fs::write(&config_path, VALID_BASE_CONFIG) + .await + .expect("config.toml should be written"); + + let state = new_test_state(); + state.set_config_path(config_path.clone()).await; + + let response = update_raw_config( + State(state), + Json(RawConfigUpdateRequest { + content: UPDATED_VALID_CONFIG.to_string(), + }), + ) + .await + .expect("valid raw config should succeed") + .0; + + assert!(response.success); + assert!(response.message.contains("saved and reloaded")); + + let persisted_content = tokio::fs::read_to_string(&config_path) + .await + .expect("config should be readable after update"); + assert_eq!(persisted_content, UPDATED_VALID_CONFIG); + } +} From 96848cfb9e21699a3b0e74c49406bd72c3531e8a Mon Sep 17 00:00:00 2001 From: Victor Sumner Date: Sun, 22 Feb 2026 12:07:58 -0500 Subject: [PATCH 8/9] Add locked atomic config writes and API mutation tests --- src/api/agents.rs | 110 +++++++++++++++++++++++++++- src/api/bindings.rs | 6 +- src/api/config.rs | 46 +++++++++++- src/api/mcp.rs | 166 +++++++++++++++++++++++++++++++++++++++---- src/api/messaging.rs | 4 +- src/api/providers.rs | 4 +- src/api/settings.rs | 4 +- src/api/state.rs | 5 +- 8 files changed, 315 insertions(+), 30 deletions(-) diff --git a/src/api/agents.rs b/src/api/agents.rs index a8739ebe2..ffaff630e 100644 --- a/src/api/agents.rs +++ b/src/api/agents.rs @@ -266,7 +266,7 @@ pub(super) async fn create_agent( new_table["id"] = toml_edit::value(&agent_id); agents_array.push(new_table); - let new_config = write_validated_config(&config_path, doc.to_string()).await?; + let new_config = write_validated_config(&state, &config_path, doc.to_string()).await?; reload_all_runtime_configs(&state, &new_config).await; let agent_config = new_config @@ -591,7 +591,7 @@ pub(super) async fn delete_agent( } } - let new_config = write_validated_config(&config_path, doc.to_string()).await?; + let new_config = write_validated_config(&state, &config_path, doc.to_string()).await?; new_config_snapshot = Some(new_config); } @@ -1010,3 +1010,109 @@ pub(super) async fn update_identity( user: updated.user, })) } + +#[cfg(test)] +mod tests { + use super::{DeleteAgentQuery, delete_agent}; + use crate::api::state::{AgentInfo, ApiState}; + use crate::config::Config; + use axum::extract::{Query, State}; + use std::path::PathBuf; + use std::sync::Arc; + + const CONFIG_WITH_TWO_AGENTS: &str = r#" +[llm] +anthropic_key = "test-anthropic-key" + +[defaults.routing] +channel = "anthropic/claude-sonnet-4" +branch = "anthropic/claude-sonnet-4" +worker = "anthropic/claude-sonnet-4" +compactor = "anthropic/claude-sonnet-4" +cortex = "anthropic/claude-sonnet-4" + +[[agents]] +id = "main" +default = true + +[[agents]] +id = "secondary" +default = false +"#; + + fn new_test_state() -> Arc { + let (provider_setup_tx, _) = tokio::sync::mpsc::channel(8); + let (agent_tx, _) = tokio::sync::mpsc::channel(8); + let (agent_remove_tx, _) = tokio::sync::mpsc::channel(8); + Arc::new(ApiState::new_with_provider_sender( + provider_setup_tx, + agent_tx, + agent_remove_tx, + )) + } + + #[tokio::test] + async fn delete_agent_removes_agent_from_config_and_api_state() { + let temp_dir = tempfile::tempdir().expect("temp dir should be created"); + let config_path = temp_dir.path().join("config.toml"); + tokio::fs::write(&config_path, CONFIG_WITH_TWO_AGENTS) + .await + .expect("config.toml should be written"); + + let state = new_test_state(); + state.set_config_path(config_path.clone()).await; + state.set_agent_configs(vec![ + AgentInfo { + id: "main".to_string(), + workspace: PathBuf::from("/tmp/main"), + context_window: 128_000, + max_turns: 5, + max_concurrent_branches: 5, + max_concurrent_workers: 5, + }, + AgentInfo { + id: "secondary".to_string(), + workspace: PathBuf::from("/tmp/secondary"), + context_window: 128_000, + max_turns: 5, + max_concurrent_branches: 5, + max_concurrent_workers: 5, + }, + ]); + + let response = delete_agent( + State(state.clone()), + Query(DeleteAgentQuery { + agent_id: "secondary".to_string(), + }), + ) + .await + .expect("delete_agent should succeed") + .0; + + assert_eq!(response["success"], true); + assert!( + response["message"] + .as_str() + .is_some_and(|message| message.contains("secondary")) + ); + + let persisted_config = + Config::load_from_path(&config_path).expect("config after deletion should parse"); + assert!( + persisted_config + .agents + .iter() + .any(|agent| agent.id == "main") + ); + assert!( + !persisted_config + .agents + .iter() + .any(|agent| agent.id == "secondary") + ); + + let live_agents = state.agent_configs.load(); + assert!(live_agents.iter().all(|agent| agent.id != "secondary")); + } +} diff --git a/src/api/bindings.rs b/src/api/bindings.rs index e6adba51e..d903dd4f4 100644 --- a/src/api/bindings.rs +++ b/src/api/bindings.rs @@ -338,7 +338,7 @@ pub(super) async fn create_binding( } bindings_array.push(binding_table); - let new_config = write_validated_config(&config_path, doc.to_string()).await?; + let new_config = write_validated_config(&state, &config_path, doc.to_string()).await?; tracing::info!( agent_id = %request.agent_id, @@ -620,7 +620,7 @@ pub(super) async fn update_binding( binding.remove("dm_allowed_users"); } - let new_config = write_validated_config(&config_path, doc.to_string()).await?; + let new_config = write_validated_config(&state, &config_path, doc.to_string()).await?; tracing::info!( agent_id = %request.agent_id, @@ -710,7 +710,7 @@ pub(super) async fn delete_binding( bindings_array.remove(idx); - let new_config = write_validated_config(&config_path, doc.to_string()).await?; + let new_config = write_validated_config(&state, &config_path, doc.to_string()).await?; tracing::info!( agent_id = %request.agent_id, diff --git a/src/api/config.rs b/src/api/config.rs index a2e73d3a2..869762ce3 100644 --- a/src/api/config.rs +++ b/src/api/config.rs @@ -10,18 +10,58 @@ use std::sync::Arc; /// /// Returns the fully parsed `Config` snapshot that was persisted. pub(super) async fn write_validated_config( + state: &Arc, config_path: &std::path::Path, updated_content: String, ) -> Result { + let _config_write_guard = state.config_write_lock.lock().await; + if let Err(error) = crate::config::Config::validate_toml(&updated_content) { tracing::warn!(%error, "config update validation failed"); return Err(StatusCode::BAD_REQUEST); } - tokio::fs::write(config_path, updated_content) + let parent_dir = config_path + .parent() + .unwrap_or_else(|| std::path::Path::new(".")); + let file_name = config_path + .file_name() + .and_then(|name| name.to_str()) + .unwrap_or("config.toml"); + let temp_suffix = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|duration| duration.as_nanos()) + .unwrap_or(0); + let temp_path = parent_dir.join(format!( + ".{file_name}.tmp.{}.{}", + std::process::id(), + temp_suffix + )); + + tokio::fs::write(&temp_path, updated_content) + .await + .map_err(|error| { + tracing::warn!(%error, path = %temp_path.display(), "failed to write temp config file"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + tokio::fs::rename(&temp_path, config_path) .await .map_err(|error| { - tracing::warn!(%error, "failed to write config.toml"); + let cleanup_result = std::fs::remove_file(&temp_path); + if let Err(cleanup_error) = cleanup_result { + tracing::debug!( + %cleanup_error, + path = %temp_path.display(), + "failed to clean up temp config file after rename error" + ); + } + tracing::warn!( + %error, + from = %temp_path.display(), + to = %config_path.display(), + "failed to atomically replace config.toml" + ); StatusCode::INTERNAL_SERVER_ERROR })?; @@ -418,7 +458,7 @@ pub(super) async fn update_agent_config( update_discord_table(&mut doc, discord)?; } - let new_config = write_validated_config(&config_path, doc.to_string()).await?; + let new_config = write_validated_config(&state, &config_path, doc.to_string()).await?; tracing::info!(agent_id = %request.agent_id, "config.toml updated via API"); let runtime_configs = state.runtime_configs.load(); diff --git a/src/api/mcp.rs b/src/api/mcp.rs index a5cc9999a..a4bf8cd61 100644 --- a/src/api/mcp.rs +++ b/src/api/mcp.rs @@ -1,6 +1,6 @@ //! API handlers for MCP server management. //! -//! CRUD endpoints for `[[mcp_servers]]` in config.toml, plus per-agent +//! CRUD endpoints for `[[defaults.mcp]]` in config.toml, plus per-agent //! connection status. use super::config::{reload_all_runtime_configs, write_validated_config}; @@ -76,7 +76,11 @@ pub(super) async fn list_mcp_servers( .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; let mut defs = Vec::new(); - if let Some(arr) = doc.get("mcp_servers").and_then(|v| v.as_array_of_tables()) { + if let Some(arr) = doc + .get("defaults") + .and_then(|defaults| defaults.get("mcp")) + .and_then(|mcp| mcp.as_array_of_tables()) + { for table in arr.iter() { let name = table .get("name") @@ -138,7 +142,11 @@ pub(super) async fn create_mcp_server( .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; // Check for duplicates - if let Some(arr) = doc.get("mcp_servers").and_then(|v| v.as_array_of_tables()) { + if let Some(arr) = doc + .get("defaults") + .and_then(|defaults| defaults.get("mcp")) + .and_then(|mcp| mcp.as_array_of_tables()) + { for table in arr.iter() { if table.get("name").and_then(|v| v.as_str()) == Some(&request.name) { return Ok(Json(MutationResponse { @@ -186,21 +194,22 @@ pub(super) async fn create_mcp_server( new_table["headers"] = toml_edit::value(headers_table); } - // Append to [[mcp_servers]] array - if doc.get("mcp_servers").is_none() { - doc.insert( - "mcp_servers", - toml_edit::Item::ArrayOfTables(toml_edit::ArrayOfTables::new()), - ); + // Append to [[defaults.mcp]] array + if doc.get("defaults").is_none() { + doc["defaults"] = toml_edit::Item::Table(toml_edit::Table::new()); + } + if doc["defaults"].get("mcp").is_none() { + doc["defaults"]["mcp"] = toml_edit::Item::ArrayOfTables(toml_edit::ArrayOfTables::new()); } if let Some(arr) = doc - .get_mut("mcp_servers") + .get_mut("defaults") + .and_then(|defaults| defaults.get_mut("mcp")) .and_then(|v| v.as_array_of_tables_mut()) { arr.push(new_table); } - let new_config = write_validated_config(&config_path, doc.to_string()).await?; + let new_config = write_validated_config(&state, &config_path, doc.to_string()).await?; reload_all_runtime_configs(&state, &new_config).await; Ok(Json(MutationResponse { @@ -230,7 +239,8 @@ pub(super) async fn update_mcp_server( .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; let Some(arr) = doc - .get_mut("mcp_servers") + .get_mut("defaults") + .and_then(|defaults| defaults.get_mut("mcp")) .and_then(|v| v.as_array_of_tables_mut()) else { return Ok(Json(MutationResponse { @@ -279,7 +289,7 @@ pub(super) async fn update_mcp_server( })); } - let new_config = write_validated_config(&config_path, doc.to_string()).await?; + let new_config = write_validated_config(&state, &config_path, doc.to_string()).await?; reload_all_runtime_configs(&state, &new_config).await; Ok(Json(MutationResponse { @@ -309,7 +319,8 @@ pub(super) async fn delete_mcp_server( .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; let Some(arr) = doc - .get_mut("mcp_servers") + .get_mut("defaults") + .and_then(|defaults| defaults.get_mut("mcp")) .and_then(|v| v.as_array_of_tables_mut()) else { return Ok(Json(MutationResponse { @@ -340,7 +351,7 @@ pub(super) async fn delete_mcp_server( })); } - let new_config = write_validated_config(&config_path, doc.to_string()).await?; + let new_config = write_validated_config(&state, &config_path, doc.to_string()).await?; reload_all_runtime_configs(&state, &new_config).await; Ok(Json(MutationResponse { @@ -424,3 +435,128 @@ async fn get_server_state(state: &ApiState, server_name: &str) -> String { } "not_connected".into() } + +#[cfg(test)] +mod tests { + use super::{CreateMcpServerRequest, create_mcp_server, delete_mcp_server}; + use crate::api::state::ApiState; + use crate::config::{Config, RuntimeConfig}; + use axum::Json; + use axum::extract::{Path, State}; + use std::collections::HashMap; + use std::sync::Arc; + + const VALID_BASE_CONFIG: &str = r#" +[llm] +anthropic_key = "test-anthropic-key" + +[defaults.routing] +channel = "anthropic/claude-sonnet-4" +branch = "anthropic/claude-sonnet-4" +worker = "anthropic/claude-sonnet-4" +compactor = "anthropic/claude-sonnet-4" +cortex = "anthropic/claude-sonnet-4" + +[[agents]] +id = "main" +default = true +"#; + + fn new_test_state() -> Arc { + let (provider_setup_tx, _) = tokio::sync::mpsc::channel(8); + let (agent_tx, _) = tokio::sync::mpsc::channel(8); + let (agent_remove_tx, _) = tokio::sync::mpsc::channel(8); + Arc::new(ApiState::new_with_provider_sender( + provider_setup_tx, + agent_tx, + agent_remove_tx, + )) + } + + async fn configure_runtime_for_main( + state: &Arc, + config_path: &std::path::Path, + ) -> Arc { + let config = Config::load_from_path(config_path).expect("config should load for test"); + let resolved_agent = config + .resolve_agents() + .into_iter() + .find(|agent| agent.id == "main") + .expect("main agent should exist in resolved config"); + let runtime_config = Arc::new(RuntimeConfig::new( + &config.instance_dir, + &resolved_agent, + &config.defaults, + crate::prompts::PromptEngine::new("en").expect("prompt engine should build"), + crate::identity::Identity::default(), + crate::skills::SkillSet::default(), + )); + + let mut runtime_configs = HashMap::new(); + runtime_configs.insert("main".to_string(), runtime_config.clone()); + state.set_runtime_configs(runtime_configs); + + let mut mcp_managers = HashMap::new(); + mcp_managers.insert( + "main".to_string(), + Arc::new(crate::mcp::McpManager::new(resolved_agent.mcp.clone())), + ); + state.set_mcp_managers(mcp_managers); + + runtime_config + } + + #[tokio::test] + async fn create_and_delete_mcp_server_persist_and_reload_runtime_config() { + let temp_dir = tempfile::tempdir().expect("temp dir should be created"); + let config_path = temp_dir.path().join("config.toml"); + tokio::fs::write(&config_path, VALID_BASE_CONFIG) + .await + .expect("config.toml should be written"); + + let state = new_test_state(); + state.set_config_path(config_path.clone()).await; + let runtime_config = configure_runtime_for_main(&state, &config_path).await; + + let create_response = create_mcp_server( + State(state.clone()), + Json(CreateMcpServerRequest { + name: "test-server".to_string(), + transport: "stdio".to_string(), + enabled: true, + command: Some("echo".to_string()), + args: vec!["hello".to_string()], + env: HashMap::new(), + url: None, + headers: HashMap::new(), + }), + ) + .await + .expect("create_mcp_server should succeed") + .0; + + assert!(create_response.success); + + let persisted_after_create = + Config::load_from_path(&config_path).expect("created config should parse"); + assert_eq!(persisted_after_create.defaults.mcp.len(), 1); + assert_eq!(persisted_after_create.defaults.mcp[0].name, "test-server"); + + let runtime_mcp_after_create = runtime_config.mcp.load(); + assert_eq!(runtime_mcp_after_create.len(), 1); + assert_eq!(runtime_mcp_after_create[0].name, "test-server"); + + let delete_response = delete_mcp_server(State(state), Path("test-server".to_string())) + .await + .expect("delete_mcp_server should succeed") + .0; + assert!(delete_response.success); + + let persisted_after_delete = + Config::load_from_path(&config_path).expect("deleted config should parse"); + assert!(persisted_after_delete.defaults.mcp.is_empty()); + + let runtime_mcp_after_delete = runtime_config.mcp.load(); + assert!(runtime_mcp_after_delete.is_empty()); + } +} diff --git a/src/api/messaging.rs b/src/api/messaging.rs index 429e738cf..cbf5f1b9f 100644 --- a/src/api/messaging.rs +++ b/src/api/messaging.rs @@ -225,7 +225,7 @@ pub(super) async fn disconnect_platform( } } - let new_config = write_validated_config(&config_path, doc.to_string()).await?; + let new_config = write_validated_config(&state, &config_path, doc.to_string()).await?; sync_bindings_and_permissions(&state, &new_config).await; reload_all_runtime_configs(&state, &new_config).await; @@ -279,7 +279,7 @@ pub(super) async fn toggle_platform( table["enabled"] = toml_edit::value(request.enabled); - let new_config = write_validated_config(&config_path, doc.to_string()).await?; + let new_config = write_validated_config(&state, &config_path, doc.to_string()).await?; sync_bindings_and_permissions(&state, &new_config).await; reload_all_runtime_configs(&state, &new_config).await; diff --git a/src/api/providers.rs b/src/api/providers.rs index d732c6c1c..cc44e3daf 100644 --- a/src/api/providers.rs +++ b/src/api/providers.rs @@ -446,7 +446,7 @@ pub(super) async fn update_provider( } } - let new_config = write_validated_config(&config_path, doc.to_string()).await?; + let new_config = write_validated_config(&state, &config_path, doc.to_string()).await?; reload_all_runtime_configs(&state, &new_config).await; state @@ -579,7 +579,7 @@ pub(super) async fn delete_provider( table.remove(key_name); } - let new_config = write_validated_config(&config_path, doc.to_string()).await?; + let new_config = write_validated_config(&state, &config_path, doc.to_string()).await?; reload_all_runtime_configs(&state, &new_config).await; Ok(Json(ProviderUpdateResponse { diff --git a/src/api/settings.rs b/src/api/settings.rs index ec5cfbf8c..b24570ed2 100644 --- a/src/api/settings.rs +++ b/src/api/settings.rs @@ -332,7 +332,7 @@ pub(super) async fn update_global_settings( } } - let new_config = write_validated_config(&config_path, doc.to_string()).await?; + let new_config = write_validated_config(&state, &config_path, doc.to_string()).await?; reload_all_runtime_configs(&state, &new_config).await; let message = if requires_restart { @@ -414,7 +414,7 @@ pub(super) async fn update_raw_config( return Err(StatusCode::INTERNAL_SERVER_ERROR); } - let new_config = match write_validated_config(&config_path, request.content).await { + let new_config = match write_validated_config(&state, &config_path, request.content).await { Ok(config) => config, Err(StatusCode::BAD_REQUEST) => { return Ok(Json(RawConfigUpdateResponse { diff --git a/src/api/state.rs b/src/api/state.rs index d3e3d6288..066d34ac4 100644 --- a/src/api/state.rs +++ b/src/api/state.rs @@ -21,7 +21,7 @@ use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; use std::time::Instant; -use tokio::sync::{RwLock, broadcast, mpsc}; +use tokio::sync::{Mutex, RwLock, broadcast, mpsc}; /// Summary of an agent's configuration, exposed via the API. #[derive(Debug, Clone, Serialize)] @@ -57,6 +57,8 @@ pub struct ApiState { pub agent_workspaces: arc_swap::ArcSwap>, /// Path to the instance config.toml file. pub config_path: RwLock, + /// Serialize config.toml mutations across API endpoints. + pub config_write_lock: Mutex<()>, /// Per-agent cron stores for cron job CRUD operations. pub cron_stores: arc_swap::ArcSwap>>, /// Per-agent cron schedulers for job timer management. @@ -193,6 +195,7 @@ impl ApiState { cortex_chat_sessions: arc_swap::ArcSwap::from_pointee(HashMap::new()), agent_workspaces: arc_swap::ArcSwap::from_pointee(HashMap::new()), config_path: RwLock::new(PathBuf::new()), + config_write_lock: Mutex::new(()), cron_stores: arc_swap::ArcSwap::from_pointee(HashMap::new()), cron_schedulers: arc_swap::ArcSwap::from_pointee(HashMap::new()), runtime_configs: ArcSwap::from_pointee(HashMap::new()), From 15062815f9de30b3ee7f467a0411e302bc0ae101 Mon Sep 17 00:00:00 2001 From: Victor Sumner Date: Sun, 22 Feb 2026 13:07:06 -0500 Subject: [PATCH 9/9] Add MCP compatibility parsing and reload events --- README.md | 6 ++- docs/design-docs/mcp.md | 4 +- src/api/config.rs | 6 ++- src/config.rs | 89 ++++++++++++++++++++++++++++++++++++++++- 4 files changed, 100 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 8a7bafb1d..dbcf30495 100644 --- a/README.md +++ b/README.md @@ -223,19 +223,21 @@ Connect workers to external [MCP](https://modelcontextprotocol.io/) (Model Conte - **API management** — full CRUD API under `/api/mcp/` for managing server definitions and monitoring connection status programmatically ```toml -[[mcp_servers]] +[[defaults.mcp]] name = "filesystem" transport = "stdio" command = "npx" args = ["-y", "@modelcontextprotocol/server-filesystem", "/workspace"] -[[mcp_servers]] +[[defaults.mcp]] name = "sentry" transport = "http" url = "https://mcp.sentry.io" headers = { Authorization = "Bearer ${SENTRY_TOKEN}" } ``` +Legacy `[[mcp_servers]]` entries are still loaded for compatibility, but `[[defaults.mcp]]` is the canonical location. + --- ## How It Works diff --git a/docs/design-docs/mcp.md b/docs/design-docs/mcp.md index 8680bc0b9..43b87d6ad 100644 --- a/docs/design-docs/mcp.md +++ b/docs/design-docs/mcp.md @@ -246,7 +246,7 @@ Per-agent visibility endpoints: - `GET /api/agents/mcp` — list configured MCP servers and their connection status - `POST /api/agents/mcp/reconnect` — force reconnect a specific server by name -CRUD endpoints for managing `[[mcp_servers]]` in config.toml: +CRUD endpoints for managing `[[defaults.mcp]]` in config.toml: - `GET /api/mcp/servers` — list all configured servers with live connection state - `POST /api/mcp/servers` — add a new server definition to config.toml @@ -255,6 +255,8 @@ CRUD endpoints for managing `[[mcp_servers]]` in config.toml: - `POST /api/mcp/servers/{name}/reconnect` — force-reconnect a specific server across all agents - `GET /api/mcp/status` — per-agent connection status +Legacy `[[mcp_servers]]` entries are accepted for backward compatibility during config loading, but new writes and API edits target `[[defaults.mcp]]`. + ### Shutdown `McpManager::disconnect_all()` called during agent shutdown, before database cleanup. Kills child processes, closes HTTP sessions. diff --git a/src/api/config.rs b/src/api/config.rs index 869762ce3..8b3c40311 100644 --- a/src/api/config.rs +++ b/src/api/config.rs @@ -1,4 +1,4 @@ -use super::state::ApiState; +use super::state::{ApiEvent, ApiState}; use axum::Json; use axum::extract::{Query, State}; @@ -106,6 +106,8 @@ pub(super) async fn reload_all_runtime_configs( .reload_config(new_config, &agent_id, &mcp_manager) .await; } + + state.send_event(ApiEvent::ConfigReloaded); } /// Sync bindings and messaging permission snapshots from a parsed config snapshot. @@ -481,6 +483,8 @@ pub(super) async fn update_agent_config( } } + state.send_event(ApiEvent::ConfigReloaded); + get_agent_config( State(state), Query(AgentConfigQuery { diff --git a/src/config.rs b/src/config.rs index 935ccd63a..49b43034e 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1086,6 +1086,9 @@ struct TomlConfig { llm: TomlLlmConfig, #[serde(default)] defaults: TomlDefaultsConfig, + /// Legacy top-level MCP server table. Canonical location is `defaults.mcp`. + #[serde(default)] + mcp_servers: Vec, #[serde(default)] agents: Vec, #[serde(default)] @@ -1642,6 +1645,23 @@ fn parse_mcp_server_config(raw: TomlMcpServerConfig) -> Result }) } +fn merge_defaults_mcp( + canonical: Vec, + legacy: Vec, +) -> Vec { + let mut merged = canonical; + for legacy_config in legacy { + if merged + .iter() + .any(|existing| existing.name == legacy_config.name) + { + continue; + } + merged.push(legacy_config); + } + merged +} + /// Resolve a TomlRoutingConfig against a base RoutingConfig. fn resolve_routing(toml: Option, base: &RoutingConfig) -> RoutingConfig { let Some(t) = toml else { return base.clone() }; @@ -2278,12 +2298,24 @@ impl Config { // Note: We allow boot without provider keys now. System starts in setup mode. // Agents are initialized later when keys are added via API. - let default_mcp = toml + let canonical_default_mcp = toml .defaults .mcp .into_iter() .map(parse_mcp_server_config) .collect::>>()?; + let legacy_default_mcp = toml + .mcp_servers + .into_iter() + .map(parse_mcp_server_config) + .collect::>>()?; + if !legacy_default_mcp.is_empty() { + tracing::warn!( + count = legacy_default_mcp.len(), + "config uses deprecated [[mcp_servers]] table; migrate entries to [[defaults.mcp]]" + ); + } + let default_mcp = merge_defaults_mcp(canonical_default_mcp, legacy_default_mcp); let base_defaults = DefaultsConfig::default(); let defaults = DefaultsConfig { @@ -3906,6 +3938,61 @@ name = "Custom OpenAI" assert_eq!(config.llm.openai_key.as_deref(), Some("legacy-openai-key")); } + #[test] + fn test_legacy_mcp_servers_table_migrates_into_defaults_mcp() { + let toml = r#" +[[defaults.mcp]] +name = "canonical" +transport = "stdio" +command = "canonical-cmd" +args = ["--canonical"] + +[[mcp_servers]] +name = "legacy" +transport = "stdio" +command = "legacy-cmd" +args = ["--legacy"] + +[[mcp_servers]] +name = "canonical" +transport = "stdio" +command = "legacy-override" +args = ["--legacy-override"] +"#; + + let parsed: TomlConfig = toml::from_str(toml).expect("failed to parse test TOML"); + let config = Config::from_toml(parsed, PathBuf::from(".")).expect("failed to build Config"); + + assert_eq!(config.defaults.mcp.len(), 2); + let canonical = config + .defaults + .mcp + .iter() + .find(|server| server.name == "canonical") + .expect("canonical mcp server should exist"); + match &canonical.transport { + McpTransport::Stdio { command, args, .. } => { + assert_eq!(command, "canonical-cmd"); + assert_eq!(args, &vec!["--canonical".to_string()]); + } + McpTransport::Http { .. } => panic!("expected stdio transport"), + } + + let legacy = config + .defaults + .mcp + .iter() + .find(|server| server.name == "legacy") + .expect("legacy mcp server should be merged into defaults.mcp"); + match &legacy.transport { + McpTransport::Stdio { command, args, .. } => { + assert_eq!(command, "legacy-cmd"); + assert_eq!(args, &vec!["--legacy".to_string()]); + } + McpTransport::Http { .. } => panic!("expected stdio transport"), + } + } + #[test] fn test_needs_onboarding_without_config_or_env() { let _lock = env_test_lock()