From a697f1c5b0b25f1fd52eff7804e76334ab7d332e Mon Sep 17 00:00:00 2001 From: nmb Date: Tue, 17 Feb 2026 16:30:29 +0100 Subject: [PATCH 01/11] =?UTF-8?q?=E2=9C=A8=20feat(llm):=20add=20Ollama=20C?= =?UTF-8?q?loud=20provider=20support?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 +- interface/src/lib/providerIcons.tsx | 1 + src/config.rs | 43 +++++++++++++++++++---------- src/llm/providers.rs | 14 ++++++++-- src/llm/routing.rs | 15 ++++++++++ 5 files changed, 57 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index fa7b8920c..b791364e4 100644 --- a/README.md +++ b/README.md @@ -311,7 +311,7 @@ Read the full vision in [docs/spacedrive.md](docs/spacedrive.md). ### Prerequisites - **Rust** 1.85+ ([rustup](https://rustup.rs/)) -- An LLM API key from any supported provider (Anthropic, OpenAI, OpenRouter, Z.ai, Groq, Together, Fireworks, DeepSeek, xAI, Mistral, or OpenCode Zen) +- An LLM API key from any supported provider (Anthropic, OpenAI, OpenRouter, Ollama Cloud, Z.ai, Groq, Together, Fireworks, DeepSeek, xAI, Mistral, or OpenCode Zen) ### Build and Run diff --git a/interface/src/lib/providerIcons.tsx b/interface/src/lib/providerIcons.tsx index 8d6666bd6..76ba6fdd1 100644 --- a/interface/src/lib/providerIcons.tsx +++ b/interface/src/lib/providerIcons.tsx @@ -60,6 +60,7 @@ export function ProviderIcon({ provider, className = "text-ink-faint", size = 24 anthropic: Anthropic, openai: OpenAI, openrouter: OpenRouter, + ollama: OpenRouter, groq: Groq, mistral: Mistral, deepseek: DeepSeek, diff --git a/src/config.rs b/src/config.rs index 94c289a86..d930df304 100644 --- a/src/config.rs +++ b/src/config.rs @@ -55,6 +55,7 @@ pub struct LlmConfig { pub anthropic_key: Option, pub openai_key: Option, pub openrouter_key: Option, + pub ollama_key: Option, pub zhipu_key: Option, pub groq_key: Option, pub together_key: Option, @@ -68,9 +69,10 @@ pub struct LlmConfig { impl LlmConfig { /// Check if any provider key is configured. pub fn has_any_key(&self) -> bool { - self.anthropic_key.is_some() - || self.openai_key.is_some() - || self.openrouter_key.is_some() + self.anthropic_key.is_some() + || self.openai_key.is_some() + || self.openrouter_key.is_some() + || self.ollama_key.is_some() || self.zhipu_key.is_some() || self.groq_key.is_some() || self.together_key.is_some() @@ -869,6 +871,7 @@ struct TomlLlmConfig { anthropic_key: Option, openai_key: Option, openrouter_key: Option, + ollama_key: Option, zhipu_key: Option, groq_key: Option, together_key: Option, @@ -1146,6 +1149,7 @@ impl Config { std::env::var("ANTHROPIC_API_KEY").is_err() && std::env::var("OPENAI_API_KEY").is_err() && std::env::var("OPENROUTER_API_KEY").is_err() + && std::env::var("OLLAMA_API_KEY").is_err() && std::env::var("OPENCODE_ZEN_API_KEY").is_err() } @@ -1183,6 +1187,7 @@ impl Config { anthropic_key: std::env::var("ANTHROPIC_API_KEY").ok(), openai_key: std::env::var("OPENAI_API_KEY").ok(), openrouter_key: std::env::var("OPENROUTER_API_KEY").ok(), + ollama_key: std::env::var("OLLAMA_API_KEY").ok(), zhipu_key: std::env::var("ZHIPU_API_KEY").ok(), groq_key: std::env::var("GROQ_API_KEY").ok(), together_key: std::env::var("TOGETHER_API_KEY").ok(), @@ -1239,8 +1244,8 @@ impl Config { /// Validate a raw TOML string as a valid Spacebot config. /// Returns Ok(()) if the config is structurally valid, or an error describing what's wrong. pub fn validate_toml(content: &str) -> Result<()> { - let toml_config: TomlConfig = toml::from_str(content) - .context("failed to parse config TOML")?; + let toml_config: TomlConfig = + toml::from_str(content).context("failed to parse config TOML")?; // Run full conversion to catch semantic errors (env resolution, defaults, etc.) let instance_dir = Self::default_instance_dir(); Self::from_toml(toml_config, instance_dir)?; @@ -1267,6 +1272,12 @@ impl Config { .as_deref() .and_then(resolve_env_value) .or_else(|| std::env::var("OPENROUTER_API_KEY").ok()), + ollama_key: toml + .llm + .ollama_key + .as_deref() + .and_then(resolve_env_value) + .or_else(|| std::env::var("OLLAMA_API_KEY").ok()), zhipu_key: toml .llm .zhipu_key @@ -1939,7 +1950,9 @@ pub fn spawn_file_watcher( // Only forward data modification events, not metadata/access changes use notify::EventKind; match &event.kind { - EventKind::Create(_) | EventKind::Modify(notify::event::ModifyKind::Data(_)) | EventKind::Remove(_) => { + EventKind::Create(_) + | EventKind::Modify(notify::event::ModifyKind::Data(_)) + | EventKind::Remove(_) => { let _ = tx.send(event); } // Also forward Any/Other modify events (some backends don't distinguish) @@ -2248,6 +2261,7 @@ pub fn run_onboarding() -> anyhow::Result> { "Anthropic", "OpenRouter", "OpenAI", + "Ollama Cloud", "Z.ai (GLM)", "Groq", "Together AI", @@ -2267,14 +2281,15 @@ pub fn run_onboarding() -> anyhow::Result> { 0 => ("Anthropic API key", "anthropic_key", "anthropic"), 1 => ("OpenRouter API key", "openrouter_key", "openrouter"), 2 => ("OpenAI API key", "openai_key", "openai"), - 3 => ("Z.ai (GLM) API key", "zhipu_key", "zhipu"), - 4 => ("Groq API key", "groq_key", "groq"), - 5 => ("Together AI API key", "together_key", "together"), - 6 => ("Fireworks AI API key", "fireworks_key", "fireworks"), - 7 => ("DeepSeek API key", "deepseek_key", "deepseek"), - 8 => ("xAI API key", "xai_key", "xai"), - 9 => ("Mistral AI API key", "mistral_key", "mistral"), - 10 => ("OpenCode Zen API key", "opencode_zen_key", "opencode-zen"), + 3 => ("Ollama Cloud API key", "ollama_key", "ollama"), + 4 => ("Z.ai (GLM) API key", "zhipu_key", "zhipu"), + 5 => ("Groq API key", "groq_key", "groq"), + 6 => ("Together AI API key", "together_key", "together"), + 7 => ("Fireworks AI API key", "fireworks_key", "fireworks"), + 8 => ("DeepSeek API key", "deepseek_key", "deepseek"), + 9 => ("xAI API key", "xai_key", "xai"), + 10 => ("Mistral AI API key", "mistral_key", "mistral"), + 11 => ("OpenCode Zen API key", "opencode_zen_key", "opencode-zen"), _ => unreachable!(), }; diff --git a/src/llm/providers.rs b/src/llm/providers.rs index e3ce47ccc..8e0672149 100644 --- a/src/llm/providers.rs +++ b/src/llm/providers.rs @@ -8,18 +8,26 @@ pub async fn init_providers(config: &LlmConfig) -> Result<()> { // Provider clients are initialized lazily through LlmManager // This module exists for any provider-specific setup that needs to happen // during system startup - + if config.anthropic_key.is_some() { tracing::info!("Anthropic provider configured"); } - + if config.openai_key.is_some() { tracing::info!("OpenAI provider configured"); } + if config.openrouter_key.is_some() { + tracing::info!("OpenRouter provider configured"); + } + + if config.ollama_key.is_some() { + tracing::info!("Ollama provider configured"); + } + if config.opencode_zen_key.is_some() { tracing::info!("OpenCode Zen provider configured"); } - + Ok(()) } diff --git a/src/llm/routing.rs b/src/llm/routing.rs index 5671516cb..d6d61dd8a 100644 --- a/src/llm/routing.rs +++ b/src/llm/routing.rs @@ -153,6 +153,20 @@ pub fn defaults_for_provider(provider: &str) -> RoutingConfig { rate_limit_cooldown_secs: 60, } } + "ollama" => { + let channel: String = "ollama/gpt-oss:120b".into(); + let worker: String = "ollama/gpt-oss:20b".into(); + RoutingConfig { + channel: channel.clone(), + branch: channel.clone(), + worker: worker.clone(), + compactor: worker.clone(), + cortex: worker.clone(), + task_overrides: HashMap::from([("coding".into(), channel.clone())]), + fallbacks: HashMap::from([(channel, vec![worker])]), + rate_limit_cooldown_secs: 60, + } + } "zhipu" => { let channel: String = "zhipu/glm-4-plus".into(); let worker: String = "zhipu/glm-4-flash".into(); @@ -277,6 +291,7 @@ pub fn provider_to_prefix(provider: &str) -> &str { match provider { "openrouter" => "openrouter/", "openai" => "openai/", + "ollama" => "ollama/", "anthropic" => "anthropic/", "zhipu" => "zhipu/", "groq" => "groq/", From 02d2a07eab94e61234bf182e580d322508d0dd2f Mon Sep 17 00:00:00 2001 From: nmb Date: Tue, 17 Feb 2026 16:30:57 +0100 Subject: [PATCH 02/11] =?UTF-8?q?=E2=9C=A8=20feat(ui):=20update=20settings?= =?UTF-8?q?=20for=20new=20provider=20and=20options?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- interface/src/routes/Settings.tsx | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/interface/src/routes/Settings.tsx b/interface/src/routes/Settings.tsx index 716e1cd4d..5fd3129ad 100644 --- a/interface/src/routes/Settings.tsx +++ b/interface/src/routes/Settings.tsx @@ -95,6 +95,13 @@ const PROVIDERS = [ placeholder: "sk-...", envVar: "OPENAI_API_KEY", }, + { + id: "ollama", + name: "Ollama Cloud", + description: "Hosted Ollama models via OpenAI-compatible API", + placeholder: "ollama_...", + envVar: "OLLAMA_API_KEY", + }, { id: "zhipu", name: "Z.ai (GLM)", From 96226ca02f9d2e0e6abc84860991e405595a5fdf Mon Sep 17 00:00:00 2001 From: nmb Date: Tue, 17 Feb 2026 16:30:54 +0100 Subject: [PATCH 03/11] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor(history):?= =?UTF-8?q?=20improve=20conversation=20pruning=20and=20management?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/conversation/history.rs | 101 ++++++++++++++++++++++++------------ 1 file changed, 69 insertions(+), 32 deletions(-) diff --git a/src/conversation/history.rs b/src/conversation/history.rs index aeb9b54ea..1916e0ef6 100644 --- a/src/conversation/history.rs +++ b/src/conversation/history.rs @@ -79,7 +79,7 @@ impl ConversationLogger { tokio::spawn(async move { if let Err(error) = sqlx::query( "INSERT INTO conversation_messages (id, channel_id, role, content) \ - VALUES (?, ?, 'assistant', ?)" + VALUES (?, ?, 'assistant', ?)", ) .bind(&id) .bind(&channel_id) @@ -103,7 +103,7 @@ impl ConversationLogger { FROM conversation_messages \ WHERE channel_id = ? \ ORDER BY created_at DESC \ - LIMIT ?" + LIMIT ?", ) .bind(channel_id.as_ref()) .bind(limit) @@ -121,7 +121,9 @@ impl ConversationLogger { sender_id: row.try_get("sender_id").ok(), content: row.try_get("content").unwrap_or_default(), metadata: row.try_get("metadata").ok(), - created_at: row.try_get("created_at").unwrap_or_else(|_| chrono::Utc::now()), + created_at: row + .try_get("created_at") + .unwrap_or_else(|_| chrono::Utc::now()), }) .collect(); @@ -142,7 +144,7 @@ impl ConversationLogger { FROM conversation_messages \ WHERE channel_id = ? \ ORDER BY created_at DESC \ - LIMIT ?" + LIMIT ?", ) .bind(channel_id) .bind(limit) @@ -160,14 +162,15 @@ impl ConversationLogger { sender_id: row.try_get("sender_id").ok(), content: row.try_get("content").unwrap_or_default(), metadata: row.try_get("metadata").ok(), - created_at: row.try_get("created_at").unwrap_or_else(|_| chrono::Utc::now()), + created_at: row + .try_get("created_at") + .unwrap_or_else(|_| chrono::Utc::now()), }) .collect(); messages.reverse(); Ok(messages) } - } /// A unified timeline item combining messages, branch runs, and worker runs. @@ -213,7 +216,12 @@ impl ProcessRunLogger { } /// Record a branch starting. Fire-and-forget. - pub fn log_branch_started(&self, channel_id: &ChannelId, branch_id: BranchId, description: &str) { + pub fn log_branch_started( + &self, + channel_id: &ChannelId, + branch_id: BranchId, + description: &str, + ) { let pool = self.pool.clone(); let id = branch_id.to_string(); let channel_id = channel_id.to_string(); @@ -221,7 +229,7 @@ impl ProcessRunLogger { tokio::spawn(async move { if let Err(error) = sqlx::query( - "INSERT INTO branch_runs (id, channel_id, description) VALUES (?, ?, ?)" + "INSERT INTO branch_runs (id, channel_id, description) VALUES (?, ?, ?)", ) .bind(&id) .bind(&channel_id) @@ -254,22 +262,46 @@ impl ProcessRunLogger { }); } - /// Record a worker starting. Fire-and-forget. - pub fn log_worker_started(&self, channel_id: Option<&ChannelId>, worker_id: WorkerId, task: &str) { + /// Record a branch failing and mark its run as completed. Fire-and-forget. + pub fn log_branch_failed(&self, branch_id: BranchId, error: &str) { let pool = self.pool.clone(); - let id = worker_id.to_string(); - let channel_id = channel_id.map(|c| c.to_string()); - let task = task.to_string(); + let id = branch_id.to_string(); + let summary = format!("Branch failed: {error}"); tokio::spawn(async move { if let Err(error) = sqlx::query( - "INSERT INTO worker_runs (id, channel_id, task) VALUES (?, ?, ?)" + "UPDATE branch_runs SET conclusion = ?, completed_at = CURRENT_TIMESTAMP WHERE id = ?" ) + .bind(&summary) .bind(&id) - .bind(&channel_id) - .bind(&task) .execute(&pool) .await + { + tracing::warn!(%error, branch_id = %id, "failed to persist branch failure"); + } + }); + } + + /// Record a worker starting. Fire-and-forget. + pub fn log_worker_started( + &self, + channel_id: Option<&ChannelId>, + worker_id: WorkerId, + task: &str, + ) { + let pool = self.pool.clone(); + let id = worker_id.to_string(); + let channel_id = channel_id.map(|c| c.to_string()); + let task = task.to_string(); + + tokio::spawn(async move { + if let Err(error) = + sqlx::query("INSERT INTO worker_runs (id, channel_id, task) VALUES (?, ?, ?)") + .bind(&id) + .bind(&channel_id) + .bind(&task) + .execute(&pool) + .await { tracing::warn!(%error, worker_id = %id, "failed to persist worker start"); } @@ -283,13 +315,11 @@ impl ProcessRunLogger { let status = status.to_string(); tokio::spawn(async move { - if let Err(error) = sqlx::query( - "UPDATE worker_runs SET status = ? WHERE id = ?" - ) - .bind(&status) - .bind(&id) - .execute(&pool) - .await + if let Err(error) = sqlx::query("UPDATE worker_runs SET status = ? WHERE id = ?") + .bind(&status) + .bind(&id) + .execute(&pool) + .await { tracing::warn!(%error, worker_id = %id, "failed to persist worker status"); } @@ -327,7 +357,11 @@ impl ProcessRunLogger { limit: i64, before: Option<&str>, ) -> crate::error::Result> { - let before_clause = if before.is_some() { "AND timestamp < ?3" } else { "" }; + let before_clause = if before.is_some() { + "AND timestamp < ?3" + } else { + "" + }; let query_str = format!( "SELECT * FROM ( \ @@ -348,9 +382,7 @@ impl ProcessRunLogger { ) WHERE 1=1 {before_clause} ORDER BY timestamp DESC LIMIT ?2" ); - let mut query = sqlx::query(&query_str) - .bind(channel_id) - .bind(limit); + let mut query = sqlx::query(&query_str).bind(channel_id).bind(limit); if let Some(before_ts) = before { query = query.bind(before_ts); @@ -372,7 +404,8 @@ impl ProcessRunLogger { sender_name: row.try_get("sender_name").ok(), sender_id: row.try_get("sender_id").ok(), content: row.try_get("content").unwrap_or_default(), - created_at: row.try_get::, _>("timestamp") + created_at: row + .try_get::, _>("timestamp") .map(|t| t.to_rfc3339()) .unwrap_or_default(), }), @@ -380,10 +413,12 @@ impl ProcessRunLogger { id: row.try_get("id").unwrap_or_default(), description: row.try_get("description").unwrap_or_default(), conclusion: row.try_get("conclusion").ok(), - started_at: row.try_get::, _>("timestamp") + started_at: row + .try_get::, _>("timestamp") .map(|t| t.to_rfc3339()) .unwrap_or_default(), - completed_at: row.try_get::, _>("completed_at") + completed_at: row + .try_get::, _>("completed_at") .ok() .map(|t| t.to_rfc3339()), }), @@ -392,10 +427,12 @@ impl ProcessRunLogger { task: row.try_get("task").unwrap_or_default(), result: row.try_get("result").ok(), status: row.try_get("status").unwrap_or_default(), - started_at: row.try_get::, _>("timestamp") + started_at: row + .try_get::, _>("timestamp") .map(|t| t.to_rfc3339()) .unwrap_or_default(), - completed_at: row.try_get::, _>("completed_at") + completed_at: row + .try_get::, _>("completed_at") .ok() .map(|t| t.to_rfc3339()), }), From 3c06b7382e54d9888f70b256ac8b4558327b69ab Mon Sep 17 00:00:00 2001 From: nmb Date: Tue, 17 Feb 2026 16:31:00 +0100 Subject: [PATCH 04/11] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor(agent):=20i?= =?UTF-8?q?mprove=20cortex,=20branching,=20and=20worker=20coordination?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/agent.rs | 4 +-- src/agent/branch.rs | 42 ++++++++++++++++---------- src/agent/compactor.rs | 35 +++++++++------------- src/agent/cortex.rs | 64 +++++++++++++++++++++++++--------------- src/agent/cortex_chat.rs | 8 +++-- src/agent/ingestion.rs | 60 ++++++++++++++++++++++--------------- src/agent/status.rs | 20 +++++++++++++ 7 files changed, 146 insertions(+), 87 deletions(-) diff --git a/src/agent.rs b/src/agent.rs index 1733f4ef7..dc277fad6 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -1,10 +1,10 @@ //! Agent processes: channels, branches, workers, compactor, cortex. -pub mod channel; pub mod branch; -pub mod worker; +pub mod channel; pub mod compactor; pub mod cortex; pub mod cortex_chat; pub mod ingestion; pub mod status; +pub mod worker; diff --git a/src/agent/branch.rs b/src/agent/branch.rs index 11d0ff986..7b8c4e2d2 100644 --- a/src/agent/branch.rs +++ b/src/agent/branch.rs @@ -2,10 +2,10 @@ use crate::agent::compactor::estimate_history_tokens; use crate::error::Result; -use crate::llm::routing::is_context_overflow_error; -use crate::llm::SpacebotModel; -use crate::{BranchId, ChannelId, ProcessId, ProcessType, AgentDeps, ProcessEvent}; use crate::hooks::SpacebotHook; +use crate::llm::SpacebotModel; +use crate::llm::routing::is_context_overflow_error; +use crate::{AgentDeps, BranchId, ChannelId, ProcessEvent, ProcessId, ProcessType}; use rig::agent::AgentBuilder; use rig::completion::{CompletionModel, Prompt}; use rig::tool::server::ToolServerHandle; @@ -44,8 +44,14 @@ impl Branch { ) -> Self { let id = Uuid::new_v4(); let process_id = ProcessId::Branch(id); - let hook = SpacebotHook::new(deps.agent_id.clone(), process_id, ProcessType::Branch, Some(channel_id.clone()), deps.event_tx.clone()); - + let hook = SpacebotHook::new( + deps.agent_id.clone(), + process_id, + ProcessType::Branch, + Some(channel_id.clone()), + deps.event_tx.clone(), + ); + Self { id, channel_id, @@ -58,7 +64,7 @@ impl Branch { max_turns, } } - + /// Run the branch's LLM agent loop and return a conclusion. /// /// Each branch has its own isolated ToolServer with `memory_save` and @@ -70,7 +76,7 @@ impl Branch { /// be large, making them susceptible to overflow on the first LLM call. pub async fn run(mut self, prompt: impl Into) -> Result { let prompt = prompt.into(); - + tracing::info!( branch_id = %self.id, channel_id = %self.channel_id, @@ -97,15 +103,17 @@ impl Branch { let mut overflow_retries = 0; let conclusion = loop { - match agent.prompt(¤t_prompt) + match agent + .prompt(¤t_prompt) .with_history(&mut self.history) .with_hook(self.hook.clone()) .await { Ok(response) => break response, Err(rig::completion::PromptError::MaxTurnsError { .. }) => { - let partial = extract_last_assistant_text(&self.history) - .unwrap_or_else(|| "Branch exhausted its turns without a final conclusion.".into()); + let partial = extract_last_assistant_text(&self.history).unwrap_or_else(|| { + "Branch exhausted its turns without a final conclusion.".into() + }); tracing::warn!(branch_id = %self.id, "branch hit max turns, returning partial result"); break partial; } @@ -133,7 +141,8 @@ impl Branch { "branch context overflow, compacting and retrying" ); self.force_compact_history(); - current_prompt = "Continue where you left off. Older context has been compacted.".into(); + current_prompt = + "Continue where you left off. Older context has been compacted.".into(); } Err(error) => { tracing::error!(branch_id = %self.id, %error, "branch LLM call failed"); @@ -149,9 +158,9 @@ impl Branch { channel_id: self.channel_id.clone(), conclusion: conclusion.clone(), }); - + tracing::info!(branch_id = %self.id, "branch completed"); - + Ok(conclusion) } @@ -192,7 +201,9 @@ impl Branch { return; } - let remove_count = ((total as f32 * fraction) as usize).max(1).min(total.saturating_sub(2)); + let remove_count = ((total as f32 * fraction) as usize) + .max(1) + .min(total.saturating_sub(2)); self.history.drain(..remove_count); let marker = format!( @@ -207,7 +218,8 @@ impl Branch { fn extract_last_assistant_text(history: &[rig::message::Message]) -> Option { for message in history.iter().rev() { if let rig::message::Message::Assistant { content, .. } = message { - let texts: Vec = content.iter() + let texts: Vec = content + .iter() .filter_map(|c| { if let rig::message::AssistantContent::Text(t) = c { Some(t.text.clone()) diff --git a/src/agent/compactor.rs b/src/agent/compactor.rs index 76b414187..e68eeda56 100644 --- a/src/agent/compactor.rs +++ b/src/agent/compactor.rs @@ -25,11 +25,7 @@ pub struct Compactor { impl Compactor { /// Create a new compactor for a channel. - pub fn new( - channel_id: ChannelId, - deps: AgentDeps, - history: Arc>>, - ) -> Self { + pub fn new(channel_id: ChannelId, deps: AgentDeps, history: Arc>>) -> Self { Self { channel_id, deps, @@ -117,12 +113,7 @@ impl Compactor { .expect("failed to render compactor prompt"); tokio::spawn(async move { - let result = run_compaction( - &deps, - &compactor_prompt, - &history, - fraction, - ).await; + let result = run_compaction(&deps, &compactor_prompt, &history, fraction).await; match result { Ok(turns_compacted) => { @@ -191,7 +182,9 @@ async fn run_compaction( let (removed_messages, remove_count) = { let mut hist = history.write().await; let total = hist.len(); - let remove_count = ((total as f32 * fraction) as usize).max(1).min(total.saturating_sub(2)); + let remove_count = ((total as f32 * fraction) as usize) + .max(1) + .min(total.saturating_sub(2)); if remove_count == 0 { return Ok(0); } @@ -205,12 +198,14 @@ async fn run_compaction( // 3. Run the compaction LLM to produce summary + extracted memories let routing = deps.runtime_config.routing.load(); let model_name = routing.resolve(ProcessType::Worker, None).to_string(); - let model = SpacebotModel::make(&deps.llm_manager, &model_name) - .with_routing((**routing).clone()); + let model = + SpacebotModel::make(&deps.llm_manager, &model_name).with_routing((**routing).clone()); // Give the compaction worker memory_save so it can directly persist memories let tool_server: ToolServerHandle = ToolServer::new() - .tool(crate::tools::MemorySaveTool::new(deps.memory_search.clone())) + .tool(crate::tools::MemorySaveTool::new( + deps.memory_search.clone(), + )) .run(); let agent = AgentBuilder::new(model) @@ -220,7 +215,8 @@ async fn run_compaction( .build(); let mut compaction_history = Vec::new(); - let response = agent.prompt(&transcript) + let response = agent + .prompt(&transcript) .with_history(&mut compaction_history) .await; @@ -294,9 +290,7 @@ fn estimate_assistant_content_chars(content: &AssistantContent) -> usize { AssistantContent::ToolCall(tc) => { tc.function.name.len() + tc.function.arguments.to_string().len() } - AssistantContent::Reasoning(r) => { - r.reasoning.iter().map(|s| s.len()).sum() - } + AssistantContent::Reasoning(r) => r.reasoning.iter().map(|s| s.len()).sum(), AssistantContent::Image(_) => 500, } } @@ -339,8 +333,7 @@ fn render_messages_as_transcript(messages: &[Message]) -> String { AssistantContent::ToolCall(tc) => { output.push_str(&format!( "[Tool Call: {}({})]\n", - tc.function.name, - tc.function.arguments + tc.function.name, tc.function.arguments )); } _ => {} diff --git a/src/agent/cortex.rs b/src/agent/cortex.rs index 0b0a20e68..c6220f9c0 100644 --- a/src/agent/cortex.rs +++ b/src/agent/cortex.rs @@ -10,11 +10,11 @@ //! health monitoring and memory consolidation. use crate::error::Result; +use crate::hooks::CortexHook; use crate::llm::SpacebotModel; use crate::memory::search::{SearchConfig, SearchMode, SearchSort}; use crate::memory::types::{Association, MemoryType, RelationType}; use crate::{AgentDeps, ProcessEvent, ProcessType}; -use crate::hooks::CortexHook; use rig::agent::AgentBuilder; use rig::completion::{CompletionModel, Prompt}; @@ -180,9 +180,7 @@ impl CortexEventRow { id: self.id, event_type: self.event_type, summary: self.summary, - details: self - .details - .and_then(|d| serde_json::from_str(&d).ok()), + details: self.details.and_then(|d| serde_json::from_str(&d).ok()), created_at: self.created_at.and_utc().to_rfc3339(), } } @@ -275,7 +273,11 @@ async fn run_bulletin_loop(deps: &AgentDeps, logger: &CortexLogger) -> anyhow::R ); logger.log( "bulletin_failed", - &format!("Bulletin generation failed, retrying (attempt {}/{})", attempt + 1, MAX_RETRIES), + &format!( + "Bulletin generation failed, retrying (attempt {}/{})", + attempt + 1, + MAX_RETRIES + ), Some(serde_json::json!({ "attempt": attempt + 1, "max_retries": MAX_RETRIES })), ); tokio::time::sleep(Duration::from_secs(RETRY_DELAY_SECS)).await; @@ -401,7 +403,12 @@ async fn gather_bulletin_sections(deps: &AgentDeps) -> String { "- [{}] (importance: {:.1}) {}\n", result.memory.memory_type, result.memory.importance, - result.memory.content.lines().next().unwrap_or(&result.memory.content), + result + .memory + .content + .lines() + .next() + .unwrap_or(&result.memory.content), )); } output.push('\n'); @@ -458,9 +465,7 @@ pub async fn generate_bulletin(deps: &AgentDeps, logger: &CortexLogger) -> bool SpacebotModel::make(&deps.llm_manager, &model_name).with_routing((**routing).clone()); // No tools needed — the LLM just synthesizes the pre-gathered data - let agent = AgentBuilder::new(model) - .preamble(&bulletin_prompt) - .build(); + let agent = AgentBuilder::new(model).preamble(&bulletin_prompt).build(); let synthesis_prompt = prompt_engine .render_system_cortex_synthesis(cortex_config.bulletin_max_words, &raw_sections) @@ -587,17 +592,24 @@ async fn generate_profile(deps: &AgentDeps, logger: &CortexLogger) { // Gather context: identity + current bulletin let identity_context = { let rendered = deps.runtime_config.identity.load().render(); - if rendered.is_empty() { None } else { Some(rendered) } + if rendered.is_empty() { + None + } else { + Some(rendered) + } }; let memory_bulletin = { let bulletin = deps.runtime_config.memory_bulletin.load(); - if bulletin.is_empty() { None } else { Some(bulletin.as_ref().clone()) } + if bulletin.is_empty() { + None + } else { + Some(bulletin.as_ref().clone()) + } }; - let synthesis_prompt = match prompt_engine.render_system_profile_synthesis( - identity_context.as_deref(), - memory_bulletin.as_deref(), - ) { + let synthesis_prompt = match prompt_engine + .render_system_profile_synthesis(identity_context.as_deref(), memory_bulletin.as_deref()) + { Ok(p) => p, Err(error) => { tracing::warn!(%error, "failed to render profile synthesis prompt"); @@ -610,9 +622,7 @@ async fn generate_profile(deps: &AgentDeps, logger: &CortexLogger) { let model = SpacebotModel::make(&deps.llm_manager, &model_name).with_routing((**routing).clone()); - let agent = AgentBuilder::new(model) - .preamble(&profile_prompt) - .build(); + let agent = AgentBuilder::new(model).preamble(&profile_prompt).build(); match agent.prompt(&synthesis_prompt).await { Ok(response) => { @@ -679,7 +689,9 @@ async fn generate_profile(deps: &AgentDeps, logger: &CortexLogger) { tracing::warn!(%error, raw = %cleaned, "failed to parse profile LLM response as JSON"); logger.log( "profile_failed", - &format!("Profile generation failed: could not parse LLM response — {error}"), + &format!( + "Profile generation failed: could not parse LLM response — {error}" + ), Some(serde_json::json!({ "error": error.to_string(), "raw_response": cleaned, @@ -711,7 +723,10 @@ async fn generate_profile(deps: &AgentDeps, logger: &CortexLogger) { /// Scans memories for embedding similarity and creates association edges /// between related memories. On first run, backfills all existing memories. /// Subsequent runs only process memories created since the last pass. -pub fn spawn_association_loop(deps: AgentDeps, logger: CortexLogger) -> tokio::task::JoinHandle<()> { +pub fn spawn_association_loop( + deps: AgentDeps, + logger: CortexLogger, +) -> tokio::task::JoinHandle<()> { tokio::spawn(async move { if let Err(error) = run_association_loop(&deps, &logger).await { tracing::error!(%error, "cortex association loop exited with error"); @@ -727,7 +742,10 @@ async fn run_association_loop(deps: &AgentDeps, logger: &CortexLogger) -> anyhow // Backfill: process all existing memories on first run let backfill_count = run_association_pass(deps, logger, None).await; - tracing::info!(associations_created = backfill_count, "association backfill complete"); + tracing::info!( + associations_created = backfill_count, + "association backfill complete" + ); let mut last_pass_at = chrono::Utc::now(); @@ -812,8 +830,8 @@ async fn run_association_pass( }; // Weight: map similarity range to 0.5-1.0 - let weight = 0.5 + (similarity - similarity_threshold) - / (1.0 - similarity_threshold) * 0.5; + let weight = + 0.5 + (similarity - similarity_threshold) / (1.0 - similarity_threshold) * 0.5; let association = Association::new(memory_id, &target_id, relation_type) .with_weight(weight.clamp(0.0, 1.0)); diff --git a/src/agent/cortex_chat.rs b/src/agent/cortex_chat.rs index a9e6edb62..6a95f1b4d 100644 --- a/src/agent/cortex_chat.rs +++ b/src/agent/cortex_chat.rs @@ -39,7 +39,10 @@ pub enum CortexChatEvent { /// A tool call started. ToolStarted { tool: String }, /// A tool call completed. - ToolCompleted { tool: String, result_preview: String }, + ToolCompleted { + tool: String, + result_preview: String, + }, /// The full response is ready. Done { full_text: String }, /// An error occurred. @@ -378,8 +381,7 @@ impl CortexChatSession { .. } => { if let Some(result) = result { - transcript - .push_str(&format!("*[Worker: {task}]*: {result}\n\n")); + transcript.push_str(&format!("*[Worker: {task}]*: {result}\n\n")); } } } diff --git a/src/agent/ingestion.rs b/src/agent/ingestion.rs index f59263cf4..97a887fbf 100644 --- a/src/agent/ingestion.rs +++ b/src/agent/ingestion.rs @@ -8,10 +8,10 @@ //! content. If the server restarts mid-file, already-completed chunks are //! skipped on the next run. -use crate::config::IngestionConfig; -use crate::llm::SpacebotModel; use crate::AgentDeps; use crate::ProcessType; +use crate::config::IngestionConfig; +use crate::llm::SpacebotModel; use anyhow::Context as _; use rig::agent::AgentBuilder; @@ -29,10 +29,7 @@ use std::time::Duration; /// Runs until the returned JoinHandle is dropped or aborted. Scans the ingest /// directory on a timer, processes any text files found, and deletes them after /// successful ingestion. -pub fn spawn_ingestion_loop( - ingest_dir: PathBuf, - deps: AgentDeps, -) -> tokio::task::JoinHandle<()> { +pub fn spawn_ingestion_loop(ingest_dir: PathBuf, deps: AgentDeps) -> tokio::task::JoinHandle<()> { tokio::spawn(async move { if let Err(error) = run_ingestion_loop(&ingest_dir, &deps).await { tracing::error!(%error, "ingestion loop exited with error"); @@ -129,13 +126,22 @@ fn is_text_file(path: &Path) -> bool { matches!( ext.to_lowercase().as_str(), - "txt" | "md" | "markdown" - | "json" | "jsonl" - | "csv" | "tsv" + "txt" + | "md" + | "markdown" + | "json" + | "jsonl" + | "csv" + | "tsv" | "log" - | "xml" | "yaml" | "yml" | "toml" - | "rst" | "org" - | "html" | "htm" + | "xml" + | "yaml" + | "yml" + | "toml" + | "rst" + | "org" + | "html" + | "htm" ) } @@ -182,7 +188,14 @@ async fn process_file( let remaining = total_chunks - completed.len(); // Record file-level tracking (idempotent — skips if already exists from a previous run) - upsert_ingestion_file(&deps.sqlite_pool, &hash, filename, file_size, total_chunks as i64).await?; + upsert_ingestion_file( + &deps.sqlite_pool, + &hash, + filename, + file_size, + total_chunks as i64, + ) + .await?; if !completed.is_empty() { tracing::info!( @@ -264,12 +277,9 @@ async fn process_file( // -- Progress tracking queries -------------------------------------------------- /// Load the set of chunk indices already completed for a given content hash. -async fn load_completed_chunks( - pool: &SqlitePool, - hash: &str, -) -> anyhow::Result> { +async fn load_completed_chunks(pool: &SqlitePool, hash: &str) -> anyhow::Result> { let rows = sqlx::query_scalar::<_, i64>( - "SELECT chunk_index FROM ingestion_progress WHERE content_hash = ?" + "SELECT chunk_index FROM ingestion_progress WHERE content_hash = ?", ) .bind(hash) .fetch_all(pool) @@ -417,13 +427,17 @@ async fn process_chunk( let routing = deps.runtime_config.routing.load(); let model_name = routing.resolve(ProcessType::Branch, None).to_string(); - let model = SpacebotModel::make(&deps.llm_manager, &model_name) - .with_routing((**routing).clone()); + let model = + SpacebotModel::make(&deps.llm_manager, &model_name).with_routing((**routing).clone()); - let conversation_logger = crate::conversation::history::ConversationLogger::new(deps.sqlite_pool.clone()); + let conversation_logger = + crate::conversation::history::ConversationLogger::new(deps.sqlite_pool.clone()); let channel_store = crate::conversation::ChannelStore::new(deps.sqlite_pool.clone()); - let tool_server: ToolServerHandle = - crate::tools::create_branch_tool_server(deps.memory_search.clone(), conversation_logger, channel_store); + let tool_server: ToolServerHandle = crate::tools::create_branch_tool_server( + deps.memory_search.clone(), + conversation_logger, + channel_store, + ); let agent = AgentBuilder::new(model) .preamble(&ingestion_prompt) diff --git a/src/agent/status.rs b/src/agent/status.rs index 1b8e35003..d316e48f9 100644 --- a/src/agent/status.rs +++ b/src/agent/status.rs @@ -118,6 +118,26 @@ impl StatusBlock { self.completed_items.remove(0); } } + ProcessEvent::BranchFailed { + branch_id, error, .. + } => { + // Remove from active branches, add to completed with error summary. + if let Some(pos) = self.active_branches.iter().position(|b| b.id == *branch_id) { + let branch = self.active_branches.remove(pos); + self.completed_items.push(CompletedItem { + id: branch_id.to_string(), + item_type: CompletedItemType::Branch, + description: branch.description, + completed_at: Utc::now(), + result_summary: format!("Branch failed: {error}"), + }); + } + + // Keep only last 10 completed items + if self.completed_items.len() > 10 { + self.completed_items.remove(0); + } + } _ => {} } } From 5a54ba900a9fa9f28671e90ac24db0ca7cb8bf73 Mon Sep 17 00:00:00 2001 From: nmb Date: Tue, 17 Feb 2026 16:31:06 +0100 Subject: [PATCH 05/11] =?UTF-8?q?=E2=9C=A8=20feat(tools):=20enhance=20erro?= =?UTF-8?q?r=20handling=20and=20add=20worker=20spawning?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/tools.rs | 101 +++++++++++++++++++++--------------- src/tools/branch_tool.rs | 7 ++- src/tools/browser.rs | 62 +++++++--------------- src/tools/cancel.rs | 21 ++++++-- src/tools/channel_recall.rs | 42 ++++++++++----- src/tools/cron.rs | 8 ++- src/tools/exec.rs | 36 ++++++++----- src/tools/file.rs | 12 +++-- src/tools/memory_recall.rs | 26 ++++++++-- src/tools/memory_save.rs | 19 ++++--- src/tools/reply.rs | 3 +- src/tools/route.rs | 13 +++-- src/tools/send_file.rs | 12 ++--- src/tools/shell.rs | 51 +++++++++++------- src/tools/skip.rs | 7 +-- src/tools/spawn_worker.rs | 20 ++++--- 16 files changed, 262 insertions(+), 178 deletions(-) diff --git a/src/tools.rs b/src/tools.rs index c0ec320fc..c08af5eaf 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -21,45 +21,56 @@ //! **Cortex ToolServer** (one per agent): //! - `memory_save` — registered at startup -pub mod reply; pub mod branch_tool; -pub mod spawn_worker; -pub mod route; +pub mod browser; pub mod cancel; -pub mod skip; -pub mod react; -pub mod memory_save; -pub mod memory_recall; +pub mod channel_recall; +pub mod cron; +pub mod exec; +pub mod file; pub mod memory_delete; +pub mod memory_recall; +pub mod memory_save; +pub mod react; +pub mod reply; +pub mod route; +pub mod send_file; pub mod set_status; pub mod shell; -pub mod file; -pub mod exec; -pub mod browser; +pub mod skip; +pub mod spawn_worker; pub mod web_search; -pub mod channel_recall; -pub mod cron; -pub mod send_file; -pub use reply::{ReplyTool, ReplyArgs, ReplyOutput, ReplyError}; -pub use branch_tool::{BranchTool, BranchArgs, BranchOutput, BranchError}; -pub use spawn_worker::{SpawnWorkerTool, SpawnWorkerArgs, SpawnWorkerOutput, SpawnWorkerError}; -pub use route::{RouteTool, RouteArgs, RouteOutput, RouteError}; -pub use cancel::{CancelTool, CancelArgs, CancelOutput, CancelError}; -pub use skip::{SkipTool, SkipArgs, SkipOutput, SkipError, SkipFlag, new_skip_flag}; -pub use react::{ReactTool, ReactArgs, ReactOutput, ReactError}; -pub use memory_save::{MemorySaveTool, MemorySaveArgs, MemorySaveOutput, MemorySaveError, AssociationInput}; -pub use memory_recall::{MemoryRecallTool, MemoryRecallArgs, MemoryRecallOutput, MemoryRecallError, MemoryOutput}; -pub use memory_delete::{MemoryDeleteTool, MemoryDeleteArgs, MemoryDeleteOutput, MemoryDeleteError}; -pub use set_status::{SetStatusTool, SetStatusArgs, SetStatusOutput, SetStatusError}; -pub use shell::{ShellTool, ShellArgs, ShellOutput, ShellError, ShellResult}; -pub use file::{FileTool, FileArgs, FileOutput, FileError, FileEntryOutput, FileEntry, FileType}; -pub use exec::{ExecTool, ExecArgs, ExecOutput, ExecError, ExecResult, EnvVar}; -pub use browser::{BrowserTool, BrowserArgs, BrowserOutput, BrowserError, BrowserAction, ActKind, ElementSummary, TabInfo}; -pub use web_search::{WebSearchTool, WebSearchArgs, WebSearchOutput, WebSearchError, SearchResult}; -pub use channel_recall::{ChannelRecallTool, ChannelRecallArgs, ChannelRecallOutput, ChannelRecallError}; -pub use cron::{CronTool, CronArgs, CronOutput, CronError}; -pub use send_file::{SendFileTool, SendFileArgs, SendFileOutput, SendFileError}; +pub use branch_tool::{BranchArgs, BranchError, BranchOutput, BranchTool}; +pub use browser::{ + ActKind, BrowserAction, BrowserArgs, BrowserError, BrowserOutput, BrowserTool, ElementSummary, + TabInfo, +}; +pub use cancel::{CancelArgs, CancelError, CancelOutput, CancelTool}; +pub use channel_recall::{ + ChannelRecallArgs, ChannelRecallError, ChannelRecallOutput, ChannelRecallTool, +}; +pub use cron::{CronArgs, CronError, CronOutput, CronTool}; +pub use exec::{EnvVar, ExecArgs, ExecError, ExecOutput, ExecResult, ExecTool}; +pub use file::{FileArgs, FileEntry, FileEntryOutput, FileError, FileOutput, FileTool, FileType}; +pub use memory_delete::{ + MemoryDeleteArgs, MemoryDeleteError, MemoryDeleteOutput, MemoryDeleteTool, +}; +pub use memory_recall::{ + MemoryOutput, MemoryRecallArgs, MemoryRecallError, MemoryRecallOutput, MemoryRecallTool, +}; +pub use memory_save::{ + AssociationInput, MemorySaveArgs, MemorySaveError, MemorySaveOutput, MemorySaveTool, +}; +pub use react::{ReactArgs, ReactError, ReactOutput, ReactTool}; +pub use reply::{ReplyArgs, ReplyError, ReplyOutput, ReplyTool}; +pub use route::{RouteArgs, RouteError, RouteOutput, RouteTool}; +pub use send_file::{SendFileArgs, SendFileError, SendFileOutput, SendFileTool}; +pub use set_status::{SetStatusArgs, SetStatusError, SetStatusOutput, SetStatusTool}; +pub use shell::{ShellArgs, ShellError, ShellOutput, ShellResult, ShellTool}; +pub use skip::{SkipArgs, SkipError, SkipFlag, SkipOutput, SkipTool, new_skip_flag}; +pub use spawn_worker::{SpawnWorkerArgs, SpawnWorkerError, SpawnWorkerOutput, SpawnWorkerTool}; +pub use web_search::{SearchResult, WebSearchArgs, WebSearchError, WebSearchOutput, WebSearchTool}; use crate::agent::channel::ChannelState; use crate::config::BrowserConfig; @@ -116,18 +127,24 @@ pub async fn add_channel_tools( skip_flag: SkipFlag, cron_tool: Option, ) -> Result<(), rig::tool::server::ToolServerError> { - handle.add_tool(ReplyTool::new( - response_tx.clone(), - conversation_id, - state.conversation_logger.clone(), - state.channel_id.clone(), - )).await?; + handle + .add_tool(ReplyTool::new( + response_tx.clone(), + conversation_id, + state.conversation_logger.clone(), + state.channel_id.clone(), + )) + .await?; handle.add_tool(BranchTool::new(state.clone())).await?; handle.add_tool(SpawnWorkerTool::new(state.clone())).await?; handle.add_tool(RouteTool::new(state.clone())).await?; handle.add_tool(CancelTool::new(state)).await?; - handle.add_tool(SkipTool::new(skip_flag, response_tx.clone())).await?; - handle.add_tool(SendFileTool::new(response_tx.clone())).await?; + handle + .add_tool(SkipTool::new(skip_flag, response_tx.clone())) + .await?; + handle + .add_tool(SendFileTool::new(response_tx.clone())) + .await?; handle.add_tool(ReactTool::new(response_tx)).await?; if let Some(cron) = cron_tool { handle.add_tool(cron).await?; @@ -196,7 +213,9 @@ pub fn create_worker_tool_server( .tool(ShellTool::new(instance_dir.clone(), workspace.clone())) .tool(FileTool::new(workspace.clone())) .tool(ExecTool::new(instance_dir, workspace)) - .tool(SetStatusTool::new(agent_id, worker_id, channel_id, event_tx)); + .tool(SetStatusTool::new( + agent_id, worker_id, channel_id, event_tx, + )); if browser_config.enabled { server = server.tool(BrowserTool::new(browser_config, screenshot_dir)); diff --git a/src/tools/branch_tool.rs b/src/tools/branch_tool.rs index db543aa2d..6f19aeed0 100644 --- a/src/tools/branch_tool.rs +++ b/src/tools/branch_tool.rs @@ -1,7 +1,7 @@ //! Branch tool for forking context and thinking (channel only). -use crate::agent::channel::{ChannelState, spawn_branch_from_state}; use crate::BranchId; +use crate::agent::channel::{ChannelState, spawn_branch_from_state}; use rig::completion::ToolDefinition; use rig::tool::Tool; use schemars::JsonSchema; @@ -75,7 +75,10 @@ impl Tool for BranchTool { Ok(BranchOutput { branch_id, spawned: true, - message: format!("Branch {branch_id} spawned. It will investigate: {}", args.description), + message: format!( + "Branch {branch_id} spawned. It will investigate: {}", + args.description + ), }) } } diff --git a/src/tools/browser.rs b/src/tools/browser.rs index 932fd1c0f..150bd87a6 100644 --- a/src/tools/browser.rs +++ b/src/tools/browser.rs @@ -384,9 +384,9 @@ impl BrowserTool { builder = builder.chrome_executable(path); } - let chrome_config = builder - .build() - .map_err(|error| BrowserError::new(format!("failed to build browser config: {error}")))?; + let chrome_config = builder.build().map_err(|error| { + BrowserError::new(format!("failed to build browser config: {error}")) + })?; tracing::info!( headless = self.config.headless, @@ -398,9 +398,7 @@ impl BrowserTool { .await .map_err(|error| BrowserError::new(format!("failed to launch browser: {error}")))?; - let handler_task = tokio::spawn(async move { - while handler.next().await.is_some() {} - }); + let handler_task = tokio::spawn(async move { while handler.next().await.is_some() {} }); state.browser = Some(browser); state._handler_task = Some(handler_task); @@ -428,7 +426,10 @@ impl BrowserTool { state.element_refs.clear(); state.next_ref = 0; - Ok(BrowserOutput::success(format!("Navigated to {url}")).with_page_info(title, current_url)) + Ok( + BrowserOutput::success(format!("Navigated to {url}")) + .with_page_info(title, current_url), + ) } async fn handle_open(&self, url: Option) -> Result { @@ -500,10 +501,7 @@ impl BrowserTool { }) } - async fn handle_focus( - &self, - target_id: Option, - ) -> Result { + async fn handle_focus(&self, target_id: Option) -> Result { let Some(target_id) = target_id else { return Err(BrowserError::new("target_id is required for focus action")); }; @@ -520,9 +518,7 @@ impl BrowserTool { state.element_refs.clear(); state.next_ref = 0; - Ok(BrowserOutput::success(format!( - "Focused tab {target_id}" - ))) + Ok(BrowserOutput::success(format!("Focused tab {target_id}"))) } async fn handle_close_tab( @@ -652,9 +648,7 @@ impl BrowserTool { key: Option, ) -> Result { let Some(act_kind) = act_kind else { - return Err(BrowserError::new( - "act_kind is required for act action", - )); + return Err(BrowserError::new("act_kind is required for act action")); }; let state = self.state.lock().await; @@ -692,8 +686,7 @@ impl BrowserTool { return Err(BrowserError::new("key is required for act:press_key")); }; if element_ref.is_some() { - let element = - self.resolve_element_ref(&state, page, element_ref).await?; + let element = self.resolve_element_ref(&state, page, element_ref).await?; element .press_key(&key) .await @@ -713,12 +706,9 @@ impl BrowserTool { } ActKind::ScrollIntoView => { let element = self.resolve_element_ref(&state, page, element_ref).await?; - element - .scroll_into_view() - .await - .map_err(|error| { - BrowserError::new(format!("scroll_into_view failed: {error}")) - })?; + element.scroll_into_view().await.map_err(|error| { + BrowserError::new(format!("scroll_into_view failed: {error}")) + })?; Ok(BrowserOutput::success("Scrolled element into view")) } ActKind::Focus => { @@ -745,9 +735,7 @@ impl BrowserTool { element .screenshot(CaptureScreenshotFormat::Png) .await - .map_err(|error| { - BrowserError::new(format!("element screenshot failed: {error}")) - })? + .map_err(|error| BrowserError::new(format!("element screenshot failed: {error}")))? } else { let params = ScreenshotParams::builder() .format(CaptureScreenshotFormat::Png) @@ -793,10 +781,7 @@ impl BrowserTool { }) } - async fn handle_evaluate( - &self, - script: Option, - ) -> Result { + async fn handle_evaluate(&self, script: Option) -> Result { if !self.config.evaluate_enabled { return Err(BrowserError::new( "JavaScript evaluation is disabled in browser config (set evaluate_enabled = true)", @@ -804,9 +789,7 @@ impl BrowserTool { } let Some(script) = script else { - return Err(BrowserError::new( - "script is required for evaluate action", - )); + return Err(BrowserError::new("script is required for evaluate action")); }; let state = self.state.lock().await; @@ -942,9 +925,7 @@ impl BrowserTool { element_ref: Option, ) -> Result { let Some(ref_id) = element_ref else { - return Err(BrowserError::new( - "element_ref is required for this action", - )); + return Err(BrowserError::new("element_ref is required for this action")); }; let elem_ref = state.element_refs.get(&ref_id).ok_or_else(|| { @@ -966,10 +947,7 @@ impl BrowserTool { } /// Dispatch a key press event to the page via CDP Input domain. -async fn dispatch_key_press( - page: &chromiumoxide::Page, - key: &str, -) -> Result<(), BrowserError> { +async fn dispatch_key_press(page: &chromiumoxide::Page, key: &str) -> Result<(), BrowserError> { let key_down = DispatchKeyEventParams::builder() .r#type(DispatchKeyEventType::KeyDown) .key(key) diff --git a/src/tools/cancel.rs b/src/tools/cancel.rs index c80177a27..ea1fd950c 100644 --- a/src/tools/cancel.rs +++ b/src/tools/cancel.rs @@ -85,22 +85,33 @@ impl Tool for CancelTool { async fn call(&self, args: Self::Args) -> Result { match args.process_type.as_str() { "branch" => { - let branch_id = args.process_id.parse::() + let branch_id = args + .process_id + .parse::() .map_err(|e| CancelError(format!("Invalid branch ID: {e}")))?; - self.state.cancel_branch(branch_id).await + self.state + .cancel_branch(branch_id) + .await .map_err(CancelError)?; } "worker" => { - let worker_id = args.process_id.parse::() + let worker_id = args + .process_id + .parse::() .map_err(|e| CancelError(format!("Invalid worker ID: {e}")))?; - self.state.cancel_worker(worker_id).await + self.state + .cancel_worker(worker_id) + .await .map_err(CancelError)?; } other => return Err(CancelError(format!("Unknown process type: {other}"))), } let message = if let Some(reason) = &args.reason { - format!("{} {} cancelled: {reason}", args.process_type, args.process_id) + format!( + "{} {} cancelled: {reason}", + args.process_type, args.process_id + ) } else { format!("{} {} cancelled.", args.process_type, args.process_id) }; diff --git a/src/tools/channel_recall.rs b/src/tools/channel_recall.rs index 7d3751467..3759082bd 100644 --- a/src/tools/channel_recall.rs +++ b/src/tools/channel_recall.rs @@ -20,7 +20,10 @@ pub struct ChannelRecallTool { impl ChannelRecallTool { pub fn new(conversation_logger: ConversationLogger, channel_store: ChannelStore) -> Self { - Self { conversation_logger, channel_store } + Self { + conversation_logger, + channel_store, + } } } @@ -118,7 +121,8 @@ impl Tool for ChannelRecallTool { let limit = args.limit.min(MAX_TRANSCRIPT_MESSAGES).max(1); // Resolve channel name to ID - let found = self.channel_store + let found = self + .channel_store .find_by_name(&channel_query) .await .map_err(|e| ChannelRecallError(format!("Failed to search channels: {e}")))?; @@ -133,19 +137,21 @@ impl Tool for ChannelRecallTool { }; // Load transcript - let messages = self.conversation_logger + let messages = self + .conversation_logger .load_channel_transcript(&channel.id, limit) .await .map_err(|e| ChannelRecallError(format!("Failed to load transcript: {e}")))?; - let transcript: Vec = messages.iter().map(|message| { - TranscriptMessage { + let transcript: Vec = messages + .iter() + .map(|message| TranscriptMessage { role: message.role.clone(), sender: message.sender_name.clone(), content: message.content.clone(), timestamp: message.created_at.to_rfc3339(), - } - }).collect(); + }) + .collect(); let summary = format_transcript(&channel.display_name, &channel.id, &transcript); @@ -162,18 +168,20 @@ impl Tool for ChannelRecallTool { impl ChannelRecallTool { async fn list_channels(&self) -> std::result::Result { - let channels = self.channel_store + let channels = self + .channel_store .list_active() .await .map_err(|e| ChannelRecallError(format!("Failed to list channels: {e}")))?; - let entries: Vec = channels.iter().map(|channel| { - ChannelListEntry { + let entries: Vec = channels + .iter() + .map(|channel| ChannelListEntry { channel_id: channel.id.clone(), channel_name: channel.display_name.clone(), last_activity: channel.last_activity_at.to_rfc3339(), - } - }).collect(); + }) + .collect(); let summary = format_channel_list(&entries); @@ -201,14 +209,20 @@ fn format_transcript( } let label = channel_name.as_deref().unwrap_or(channel_id); - let mut output = format!("## Transcript from #{label} ({} messages)\n\n", messages.len()); + let mut output = format!( + "## Transcript from #{label} ({} messages)\n\n", + messages.len() + ); for message in messages { let sender = match &message.sender { Some(name) => name.as_str(), None => "assistant", }; - output.push_str(&format!("**{}** ({}): {}\n\n", sender, message.role, message.content)); + output.push_str(&format!( + "**{}** ({}): {}\n\n", + sender, message.role, message.content + )); } output diff --git a/src/tools/cron.rs b/src/tools/cron.rs index c8d56e5bc..20e62d673 100644 --- a/src/tools/cron.rs +++ b/src/tools/cron.rs @@ -139,7 +139,9 @@ impl Tool for CronTool { impl CronTool { async fn create(&self, args: CronArgs) -> Result { - let id = args.id.ok_or_else(|| CronError("'id' is required for create".into()))?; + let id = args + .id + .ok_or_else(|| CronError("'id' is required for create".into()))?; let prompt = args .prompt .ok_or_else(|| CronError("'prompt' is required for create".into()))?; @@ -205,7 +207,9 @@ impl CronTool { prompt: config.prompt, interval_secs: config.interval_secs, delivery_target: config.delivery_target, - active_hours: config.active_hours.map(|(s, e)| format!("{s:02}:00-{e:02}:00")), + active_hours: config + .active_hours + .map(|(s, e)| format!("{s:02}:00-{e:02}:00")), }) .collect(); diff --git a/src/tools/exec.rs b/src/tools/exec.rs index 0a66cae55..f74b7d052 100644 --- a/src/tools/exec.rs +++ b/src/tools/exec.rs @@ -19,7 +19,10 @@ pub struct ExecTool { impl ExecTool { /// Create a new exec tool with the given instance directory for path blocking. pub fn new(instance_dir: PathBuf, workspace: PathBuf) -> Self { - Self { instance_dir, workspace } + Self { + instance_dir, + workspace, + } } /// Check if program arguments reference sensitive instance paths. @@ -42,9 +45,13 @@ impl ExecTool { // Block references to sensitive files by name for file in super::shell::SENSITIVE_FILES { if all_args.contains(file) { - if all_args.contains(instance_str.as_ref()) || !all_args.contains(workspace_str.as_ref()) { + if all_args.contains(instance_str.as_ref()) + || !all_args.contains(workspace_str.as_ref()) + { return Err(ExecError { - message: format!("Cannot access {file} — instance configuration is protected."), + message: format!( + "Cannot access {file} — instance configuration is protected." + ), exit_code: -1, }); } @@ -178,7 +185,10 @@ impl Tool for ExecTool { if let Some(ref dir) = args.working_dir { let path = std::path::Path::new(dir); let canonical = path.canonicalize().unwrap_or_else(|_| path.to_path_buf()); - let workspace_canonical = self.workspace.canonicalize().unwrap_or_else(|_| self.workspace.clone()); + let workspace_canonical = self + .workspace + .canonicalize() + .unwrap_or_else(|_| self.workspace.clone()); if !canonical.starts_with(&workspace_canonical) { return Err(ExecError { message: format!( @@ -216,8 +226,7 @@ impl Tool for ExecTool { cmd.env(env_var.key, env_var.value); } - cmd.stdout(Stdio::piped()) - .stderr(Stdio::piped()); + cmd.stdout(Stdio::piped()).stderr(Stdio::piped()); let timeout = tokio::time::Duration::from_secs(args.timeout_seconds); @@ -301,13 +310,14 @@ pub async fn exec( cmd.stdout(Stdio::piped()).stderr(Stdio::piped()); - let output = tokio::time::timeout( - tokio::time::Duration::from_secs(60), - cmd.output(), - ) - .await - .map_err(|_| crate::error::AgentError::Other(anyhow::anyhow!("Execution timed out").into()))? - .map_err(|e| crate::error::AgentError::Other(anyhow::anyhow!("Failed to execute: {e}").into()))?; + let output = tokio::time::timeout(tokio::time::Duration::from_secs(60), cmd.output()) + .await + .map_err(|_| { + crate::error::AgentError::Other(anyhow::anyhow!("Execution timed out").into()) + })? + .map_err(|e| { + crate::error::AgentError::Other(anyhow::anyhow!("Failed to execute: {e}").into()) + })?; Ok(ExecResult { success: output.status.success(), diff --git a/src/tools/file.rs b/src/tools/file.rs index 82897136a..b260ca95b 100644 --- a/src/tools/file.rs +++ b/src/tools/file.rs @@ -35,7 +35,10 @@ impl FileTool { // existing ancestor and append the remaining components. let canonical = best_effort_canonicalize(&resolved); - let workspace_canonical = self.workspace.canonicalize().unwrap_or_else(|_| self.workspace.clone()); + let workspace_canonical = self + .workspace + .canonicalize() + .unwrap_or_else(|_| self.workspace.clone()); if !canonical.starts_with(&workspace_canonical) { return Err(FileError(format!( @@ -285,7 +288,10 @@ async fn do_file_list(path: &Path) -> Result { if total_count > max_entries { entries.push(FileEntryOutput { - name: format!("... and {} more entries (listing capped at {max_entries})", total_count - max_entries), + name: format!( + "... and {} more entries (listing capped at {max_entries})", + total_count - max_entries + ), entry_type: "notice".to_string(), size: 0, }); @@ -301,8 +307,6 @@ async fn do_file_list(path: &Path) -> Result { }) } - - /// File entry metadata (legacy). #[derive(Debug, Clone)] pub struct FileEntry { diff --git a/src/tools/memory_recall.rs b/src/tools/memory_recall.rs index a2aa795a7..9739a7cfd 100644 --- a/src/tools/memory_recall.rs +++ b/src/tools/memory_recall.rs @@ -95,7 +95,11 @@ fn parse_memory_type(s: &str) -> std::result::Result Ok(MemoryType::Todo), other => Err(MemoryRecallError(format!( "unknown memory_type \"{other}\". Valid types: {}", - crate::memory::MemoryType::ALL.iter().map(|t| t.to_string()).collect::>().join(", ") + crate::memory::MemoryType::ALL + .iter() + .map(|t| t.to_string()) + .collect::>() + .join(", ") ))), } } @@ -296,7 +300,10 @@ pub async fn memory_recall( sort_by: None, }; - let output = tool.call(args).await.map_err(|e| crate::error::AgentError::Other(anyhow::anyhow!(e)))?; + let output = tool + .call(args) + .await + .map_err(|e| crate::error::AgentError::Other(anyhow::anyhow!(e)))?; // Convert back to Memory type for backward compatibility let store = memory_search.store(); @@ -319,7 +326,10 @@ mod tests { fn test_parse_search_mode_valid() { assert_eq!(parse_search_mode("hybrid").unwrap(), SearchMode::Hybrid); assert_eq!(parse_search_mode("recent").unwrap(), SearchMode::Recent); - assert_eq!(parse_search_mode("important").unwrap(), SearchMode::Important); + assert_eq!( + parse_search_mode("important").unwrap(), + SearchMode::Important + ); assert_eq!(parse_search_mode("typed").unwrap(), SearchMode::Typed); } @@ -332,8 +342,14 @@ mod tests { #[test] fn test_parse_search_sort_valid() { assert_eq!(parse_search_sort("recent").unwrap(), SearchSort::Recent); - assert_eq!(parse_search_sort("importance").unwrap(), SearchSort::Importance); - assert_eq!(parse_search_sort("most_accessed").unwrap(), SearchSort::MostAccessed); + assert_eq!( + parse_search_sort("importance").unwrap(), + SearchSort::Importance + ); + assert_eq!( + parse_search_sort("most_accessed").unwrap(), + SearchSort::MostAccessed + ); } #[test] diff --git a/src/tools/memory_save.rs b/src/tools/memory_save.rs index d965f1675..db2546da3 100644 --- a/src/tools/memory_save.rs +++ b/src/tools/memory_save.rs @@ -1,15 +1,14 @@ //! Memory save tool for channels and branches. use crate::error::Result; -use crate::memory::{Memory, MemorySearch, MemoryType}; use crate::memory::types::{Association, CreateAssociationInput}; +use crate::memory::{Memory, MemorySearch, MemoryType}; use rig::completion::ToolDefinition; use rig::tool::Tool; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::sync::Arc; - /// Tool for saving memories to the store. #[derive(Debug, Clone)] pub struct MemorySaveTool { @@ -207,8 +206,8 @@ impl Tool for MemorySaveTool { _ => crate::memory::types::RelationType::RelatedTo, }; - let association = - Association::new(&memory.id, &assoc.target_id, relation_type).with_weight(assoc.weight); + let association = Association::new(&memory.id, &assoc.target_id, relation_type) + .with_weight(assoc.weight); if let Err(error) = store.create_association(&association).await { tracing::warn!( @@ -236,7 +235,12 @@ impl Tool for MemorySaveTool { // Ensure the FTS index exists so full_text_search queries work. // Safe to call repeatedly — no-ops if the index already exists. - if let Err(error) = self.memory_search.embedding_table().ensure_fts_index().await { + if let Err(error) = self + .memory_search + .embedding_table() + .ensure_fts_index() + .await + { tracing::warn!(%error, "failed to ensure FTS index after memory save"); } @@ -264,6 +268,9 @@ pub async fn save_fact( associations: vec![], }; - let output = tool.call(args).await.map_err(|e| crate::error::AgentError::Other(anyhow::anyhow!(e)))?; + let output = tool + .call(args) + .await + .map_err(|e| crate::error::AgentError::Other(anyhow::anyhow!(e)))?; Ok(output.memory_id) } diff --git a/src/tools/reply.rs b/src/tools/reply.rs index 67b2fed2e..978b31c54 100644 --- a/src/tools/reply.rs +++ b/src/tools/reply.rs @@ -100,7 +100,8 @@ impl Tool for ReplyTool { "reply tool called" ); - self.conversation_logger.log_bot_message(&self.channel_id, &args.content); + self.conversation_logger + .log_bot_message(&self.channel_id, &args.content); let response = match args.thread_name { Some(ref name) => { diff --git a/src/tools/route.rs b/src/tools/route.rs index 8cbecd3b2..aa631d866 100644 --- a/src/tools/route.rs +++ b/src/tools/route.rs @@ -1,7 +1,7 @@ //! Route tool for sending follow-ups to active workers. -use crate::agent::channel::ChannelState; use crate::WorkerId; +use crate::agent::channel::ChannelState; use rig::completion::ToolDefinition; use rig::tool::Tool; use schemars::JsonSchema; @@ -74,7 +74,9 @@ impl Tool for RouteTool { } async fn call(&self, args: Self::Args) -> std::result::Result { - let worker_id = args.worker_id.parse::() + let worker_id = args + .worker_id + .parse::() .map_err(|e| RouteError(format!("Invalid worker ID: {e}")))?; // Look up the input sender for this worker @@ -87,10 +89,11 @@ impl Tool for RouteTool { drop(inputs); // Deliver the message - input_tx.send(args.message).await - .map_err(|_| RouteError(format!( + input_tx.send(args.message).await.map_err(|_| { + RouteError(format!( "Worker {worker_id} has stopped accepting input (channel closed)" - )))?; + )) + })?; tracing::info!( worker_id = %worker_id, diff --git a/src/tools/send_file.rs b/src/tools/send_file.rs index 7a2ea8756..cfe556f4b 100644 --- a/src/tools/send_file.rs +++ b/src/tools/send_file.rs @@ -85,9 +85,9 @@ impl Tool for SendFileTool { return Err(SendFileError("file_path must be an absolute path".into())); } - let metadata = tokio::fs::metadata(&path) - .await - .map_err(|error| SendFileError(format!("can't read file '{}': {error}", path.display())))?; + let metadata = tokio::fs::metadata(&path).await.map_err(|error| { + SendFileError(format!("can't read file '{}': {error}", path.display())) + })?; if !metadata.is_file() { return Err(SendFileError(format!("'{}' is not a file", path.display()))); @@ -101,9 +101,9 @@ impl Tool for SendFileTool { ))); } - let data = tokio::fs::read(&path) - .await - .map_err(|error| SendFileError(format!("failed to read '{}': {error}", path.display())))?; + let data = tokio::fs::read(&path).await.map_err(|error| { + SendFileError(format!("failed to read '{}': {error}", path.display())) + })?; let filename = path .file_name() diff --git a/src/tools/shell.rs b/src/tools/shell.rs index 91b714c05..5ae769c8f 100644 --- a/src/tools/shell.rs +++ b/src/tools/shell.rs @@ -22,6 +22,7 @@ pub const SECRET_ENV_VARS: &[&str] = &[ "ANTHROPIC_API_KEY", "OPENAI_API_KEY", "OPENROUTER_API_KEY", + "OLLAMA_API_KEY", "DISCORD_BOT_TOKEN", "SLACK_BOT_TOKEN", "SLACK_APP_TOKEN", @@ -40,7 +41,10 @@ pub struct ShellTool { impl ShellTool { /// Create a new shell tool with the given instance directory for path blocking. pub fn new(instance_dir: PathBuf, workspace: PathBuf) -> Self { - Self { instance_dir, workspace } + Self { + instance_dir, + workspace, + } } /// Check if a command references sensitive instance paths or secret env vars. @@ -58,7 +62,8 @@ impl ShellTool { return Err(ShellError { message: "ACCESS DENIED: Cannot access the instance directory — it contains \ protected configuration and data. Do not attempt to reproduce or \ - guess its contents. Inform the user that this path is restricted.".to_string(), + guess its contents. Inform the user that this path is restricted." + .to_string(), exit_code: -1, }); } @@ -101,9 +106,14 @@ impl ShellTool { // Block broad env dumps that would expose secrets if command.contains("printenv") { let trimmed = command.trim(); - if trimmed == "printenv" || trimmed.ends_with("| printenv") || trimmed.contains("printenv |") || trimmed.contains("printenv >") { + if trimmed == "printenv" + || trimmed.ends_with("| printenv") + || trimmed.contains("printenv |") + || trimmed.contains("printenv >") + { return Err(ShellError { - message: "Cannot dump all environment variables — they may contain secrets.".to_string(), + message: "Cannot dump all environment variables — they may contain secrets." + .to_string(), exit_code: -1, }); } @@ -112,7 +122,8 @@ impl ShellTool { let trimmed = command.trim(); if trimmed == "env" || trimmed.starts_with("env |") || trimmed.starts_with("env >") { return Err(ShellError { - message: "Cannot dump all environment variables — they may contain secrets.".to_string(), + message: "Cannot dump all environment variables — they may contain secrets." + .to_string(), exit_code: -1, }); } @@ -212,7 +223,10 @@ impl Tool for ShellTool { if let Some(ref dir) = args.working_dir { let path = std::path::Path::new(dir); let canonical = path.canonicalize().unwrap_or_else(|_| path.to_path_buf()); - let workspace_canonical = self.workspace.canonicalize().unwrap_or_else(|_| self.workspace.clone()); + let workspace_canonical = self + .workspace + .canonicalize() + .unwrap_or_else(|_| self.workspace.clone()); if !canonical.starts_with(&workspace_canonical) { return Err(ShellError { message: format!( @@ -241,8 +255,7 @@ impl Tool for ShellTool { cmd.current_dir(&self.workspace); } - cmd.stdout(Stdio::piped()) - .stderr(Stdio::piped()); + cmd.stdout(Stdio::piped()).stderr(Stdio::piped()); // Set timeout let timeout = tokio::time::Duration::from_secs(args.timeout_seconds); @@ -306,7 +319,10 @@ fn format_shell_output(exit_code: i32, stdout: &str, stderr: &str) -> String { /// System-internal shell execution that bypasses path restrictions. /// Used by the system itself, not LLM-facing. -pub async fn shell(command: &str, working_dir: Option<&std::path::Path>) -> crate::error::Result { +pub async fn shell( + command: &str, + working_dir: Option<&std::path::Path>, +) -> crate::error::Result { let mut cmd = if cfg!(target_os = "windows") { let mut c = Command::new("cmd"); c.arg("/C").arg(command); @@ -323,13 +339,14 @@ pub async fn shell(command: &str, working_dir: Option<&std::path::Path>) -> crat cmd.stdout(Stdio::piped()).stderr(Stdio::piped()); - let output = tokio::time::timeout( - tokio::time::Duration::from_secs(60), - cmd.output(), - ) - .await - .map_err(|_| crate::error::AgentError::Other(anyhow::anyhow!("Command timed out").into()))? - .map_err(|e| crate::error::AgentError::Other(anyhow::anyhow!("Failed to execute command: {e}").into()))?; + let output = tokio::time::timeout(tokio::time::Duration::from_secs(60), cmd.output()) + .await + .map_err(|_| crate::error::AgentError::Other(anyhow::anyhow!("Command timed out").into()))? + .map_err(|e| { + crate::error::AgentError::Other( + anyhow::anyhow!("Failed to execute command: {e}").into(), + ) + })?; Ok(ShellResult { success: output.status.success(), @@ -354,5 +371,3 @@ impl ShellResult { format_shell_output(self.exit_code, &self.stdout, &self.stderr) } } - - diff --git a/src/tools/skip.rs b/src/tools/skip.rs index 57c5a9258..481b5fb8b 100644 --- a/src/tools/skip.rs +++ b/src/tools/skip.rs @@ -84,9 +84,10 @@ impl Tool for SkipTool { self.flag.store(true, Ordering::Relaxed); // Cancel the typing indicator so it doesn't linger - let _ = self.response_tx.send( - OutboundResponse::Status(crate::StatusUpdate::StopTyping) - ).await; + let _ = self + .response_tx + .send(OutboundResponse::Status(crate::StatusUpdate::StopTyping)) + .await; let reason = args.reason.as_deref().unwrap_or("no reason given"); tracing::info!(reason, "skip tool called, suppressing response"); diff --git a/src/tools/spawn_worker.rs b/src/tools/spawn_worker.rs index 69047e8fc..35157df78 100644 --- a/src/tools/spawn_worker.rs +++ b/src/tools/spawn_worker.rs @@ -1,7 +1,9 @@ //! Spawn worker tool for creating new workers. -use crate::agent::channel::{ChannelState, spawn_worker_from_state, spawn_opencode_worker_from_state}; use crate::WorkerId; +use crate::agent::channel::{ + ChannelState, spawn_opencode_worker_from_state, spawn_worker_from_state, +}; use rig::completion::ToolDefinition; use rig::tool::Tool; use schemars::JsonSchema; @@ -144,17 +146,13 @@ impl Tool for SpawnWorkerTool { let is_opencode = args.worker_type.as_deref() == Some("opencode"); let worker_id = if is_opencode { - let directory = args.directory.as_deref() - .ok_or_else(|| SpawnWorkerError("directory is required for opencode workers".into()))?; + let directory = args.directory.as_deref().ok_or_else(|| { + SpawnWorkerError("directory is required for opencode workers".into()) + })?; - spawn_opencode_worker_from_state( - &self.state, - &args.task, - directory, - args.interactive, - ) - .await - .map_err(|e| SpawnWorkerError(format!("{e}")))? + spawn_opencode_worker_from_state(&self.state, &args.task, directory, args.interactive) + .await + .map_err(|e| SpawnWorkerError(format!("{e}")))? } else { spawn_worker_from_state( &self.state, From 40edd14ea31ecc987b38427cdc0fda9618a192ef Mon Sep 17 00:00:00 2001 From: nmb Date: Tue, 17 Feb 2026 16:31:09 +0100 Subject: [PATCH 06/11] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor(memory):=20?= =?UTF-8?q?improve=20store=20operations=20and=20search?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/memory.rs | 16 ++-- src/memory/embedding.rs | 14 ++-- src/memory/lance.rs | 117 +++++++++++++------------- src/memory/maintenance.rs | 37 ++++----- src/memory/search.rs | 170 +++++++++++++++++++++++++------------- src/memory/store.rs | 132 ++++++++++++++++++----------- 6 files changed, 291 insertions(+), 195 deletions(-) diff --git a/src/memory.rs b/src/memory.rs index 937951ca1..5609fbc12 100644 --- a/src/memory.rs +++ b/src/memory.rs @@ -1,14 +1,14 @@ //! Memory storage and retrieval system. -pub mod store; -pub mod types; -pub mod search; -pub mod lance; pub mod embedding; +pub mod lance; pub mod maintenance; +pub mod search; +pub mod store; +pub mod types; -pub use store::MemoryStore; -pub use types::{Memory, MemoryType, Association, RelationType}; -pub use search::{MemorySearch, SearchConfig, SearchMode, SearchSort, curate_results}; -pub use lance::EmbeddingTable; pub use embedding::EmbeddingModel; +pub use lance::EmbeddingTable; +pub use search::{MemorySearch, SearchConfig, SearchMode, SearchSort, curate_results}; +pub use store::MemoryStore; +pub use types::{Association, Memory, MemoryType, RelationType}; diff --git a/src/memory/embedding.rs b/src/memory/embedding.rs index 4714248e5..bedb3c371 100644 --- a/src/memory/embedding.rs +++ b/src/memory/embedding.rs @@ -22,12 +22,15 @@ impl EmbeddingModel { let model = fastembed::TextEmbedding::try_new(options) .map_err(|e| LlmError::EmbeddingFailed(e.to_string()))?; - Ok(Self { model: Arc::new(model) }) + Ok(Self { + model: Arc::new(model), + }) } /// Generate embeddings for multiple texts (blocking). pub fn embed(&self, texts: Vec) -> Result>> { - self.model.embed(texts, None) + self.model + .embed(texts, None) .map_err(|e| LlmError::EmbeddingFailed(e.to_string()).into()) } @@ -42,8 +45,9 @@ impl EmbeddingModel { let text = text.to_string(); let model = self.model.clone(); let result = tokio::task::spawn_blocking(move || { - model.embed(vec![text], None) - .map_err(|e| crate::Error::Llm(crate::error::LlmError::EmbeddingFailed(e.to_string()))) + model.embed(vec![text], None).map_err(|e| { + crate::Error::Llm(crate::error::LlmError::EmbeddingFailed(e.to_string())) + }) }) .await .map_err(|e| crate::Error::Other(anyhow::anyhow!("embedding task failed: {}", e)))??; @@ -52,8 +56,6 @@ impl EmbeddingModel { } } - - /// Async function to embed text using a shared model. pub async fn embed_text(model: &Arc, text: &str) -> Result> { model.embed_one(text).await diff --git a/src/memory/lance.rs b/src/memory/lance.rs index 5338111ce..fb0fcc443 100644 --- a/src/memory/lance.rs +++ b/src/memory/lance.rs @@ -1,9 +1,9 @@ //! LanceDB table management and embedding storage with HNSW vector index and FTS. use crate::error::{DbError, Result}; -use arrow_array::{Array, RecordBatchIterator}; use arrow_array::cast::AsArray; use arrow_array::types::Float32Type; +use arrow_array::{Array, RecordBatchIterator}; use futures::TryStreamExt; use std::sync::Arc; @@ -33,54 +33,49 @@ impl EmbeddingTable { Err(_) => { // Create new table with empty batch let schema = Self::schema(); - + // Create empty RecordBatchIterator - let batches = RecordBatchIterator::new( - vec![].into_iter().map(Ok), - Arc::new(schema), - ); - + let batches = + RecordBatchIterator::new(vec![].into_iter().map(Ok), Arc::new(schema)); + let table = connection .create_table(TABLE_NAME, Box::new(batches)) .execute() .await .map_err(|e| DbError::LanceDb(e.to_string()))?; - + Ok(Self { table }) } } } - + /// Store an embedding with content for a memory. /// The content is stored for FTS search capability. - pub async fn store( - &self, - memory_id: &str, - content: &str, - embedding: &[f32], - ) -> Result<()> { + pub async fn store(&self, memory_id: &str, content: &str, embedding: &[f32]) -> Result<()> { if embedding.len() != EMBEDDING_DIM as usize { return Err(DbError::LanceDb(format!( "Embedding dimension mismatch: expected {}, got {}", EMBEDDING_DIM, embedding.len() - )).into()); + )) + .into()); } - + use arrow_array::{FixedSizeListArray, RecordBatch, StringArray}; - + let schema = Self::schema(); - + // Build arrays for the record batch let id_array = StringArray::from(vec![memory_id]); let content_array = StringArray::from(vec![content]); - + // Convert embedding to FixedSizeListArray - let embedding_array = arrow_array::FixedSizeListArray::from_iter_primitive::( - vec![Some(embedding.iter().map(|v| Some(*v)).collect::>())], - EMBEDDING_DIM, - ); - + let embedding_array = + arrow_array::FixedSizeListArray::from_iter_primitive::( + vec![Some(embedding.iter().map(|v| Some(*v)).collect::>())], + EMBEDDING_DIM, + ); + let batch = RecordBatch::try_new( Arc::new(schema), vec![ @@ -90,22 +85,19 @@ impl EmbeddingTable { ], ) .map_err(|e| DbError::LanceDb(e.to_string()))?; - + // Create iterator for IntoArrow trait - let batches = RecordBatchIterator::new( - vec![Ok(batch)], - Arc::new(Self::schema()), - ); - + let batches = RecordBatchIterator::new(vec![Ok(batch)], Arc::new(Self::schema())); + self.table .add(Box::new(batches)) .execute() .await .map_err(|e| DbError::LanceDb(e.to_string()))?; - + Ok(()) } - + /// Delete an embedding by memory ID. pub async fn delete(&self, memory_id: &str) -> Result<()> { let predicate = format!("id = '{}'", memory_id); @@ -113,10 +105,10 @@ impl EmbeddingTable { .delete(&predicate) .await .map_err(|e| DbError::LanceDb(e.to_string()))?; - + Ok(()) } - + /// Vector similarity search using cosine distance. /// Returns (memory_id, distance) pairs sorted by distance (ascending). pub async fn vector_search( @@ -129,11 +121,12 @@ impl EmbeddingTable { "Query embedding dimension mismatch: expected {}, got {}", EMBEDDING_DIM, query_embedding.len() - )).into()); + )) + .into()); } - + use lancedb::query::{ExecutableQuery, QueryBase}; - + // Use query() API with nearest_to for vector search let results: Vec = self .table @@ -147,13 +140,16 @@ impl EmbeddingTable { .try_collect() .await .map_err(|e| DbError::LanceDb(e.to_string()))?; - + let mut matches = Vec::new(); for batch in results { - if let (Some(id_col), Some(dist_col)) = (batch.column_by_name("id"), batch.column_by_name("_distance")) { + if let (Some(id_col), Some(dist_col)) = ( + batch.column_by_name("id"), + batch.column_by_name("_distance"), + ) { let ids: &arrow_array::StringArray = id_col.as_string::(); let dists: &arrow_array::PrimitiveArray = dist_col.as_primitive(); - + for i in 0..ids.len() { if ids.is_valid(i) && dists.is_valid(i) { let id = ids.value(i).to_string(); @@ -163,10 +159,10 @@ impl EmbeddingTable { } } } - + Ok(matches) } - + /// Find memories similar to a given memory by its embedding. /// Returns (memory_id, similarity) pairs where similarity = 1.0 - cosine_distance. /// Results exclude the source memory itself. @@ -235,12 +231,14 @@ impl EmbeddingTable { /// Returns (memory_id, score) pairs sorted by score (descending). pub async fn text_search(&self, query: &str, limit: usize) -> Result> { use lancedb::query::{ExecutableQuery, QueryBase}; - + // Use full_text_search on the content column let results: Vec = self .table .query() - .full_text_search(lance_index::scalar::FullTextSearchQuery::new(query.to_string())) + .full_text_search(lance_index::scalar::FullTextSearchQuery::new( + query.to_string(), + )) .select(lancedb::query::Select::columns(&["id", "_score"])) .limit(limit) .execute() @@ -249,13 +247,15 @@ impl EmbeddingTable { .try_collect() .await .map_err(|e| DbError::LanceDb(e.to_string()))?; - + let mut matches = Vec::new(); for batch in results { - if let (Some(id_col), Some(score_col)) = (batch.column_by_name("id"), batch.column_by_name("_score")) { + if let (Some(id_col), Some(score_col)) = + (batch.column_by_name("id"), batch.column_by_name("_score")) + { let ids: &arrow_array::StringArray = id_col.as_string::(); let scores: &arrow_array::PrimitiveArray = score_col.as_primitive(); - + for i in 0..ids.len() { if ids.is_valid(i) && scores.is_valid(i) { let id = ids.value(i).to_string(); @@ -265,10 +265,10 @@ impl EmbeddingTable { } } } - + Ok(matches) } - + /// Create HNSW vector index and FTS index for better performance. /// Should be called after enough data accumulates. pub async fn create_indexes(&self) -> Result<()> { @@ -278,19 +278,20 @@ impl EmbeddingTable { .execute() .await .map_err(|e| DbError::LanceDb(format!("Failed to create vector index: {}", e)))?; - + self.ensure_fts_index().await?; - + Ok(()) } - + /// Ensure the FTS index exists on the content column. /// /// LanceDB requires an inverted index for `full_text_search()` queries. /// This is safe to call multiple times — if the index already exists, the /// error is silently ignored. pub async fn ensure_fts_index(&self) -> Result<()> { - match self.table + match self + .table .create_index(&["content"], lancedb::index::Index::FTS(Default::default())) .execute() .await @@ -311,7 +312,7 @@ impl EmbeddingTable { } } } - + /// Get the Arrow schema for the embeddings table. fn schema() -> arrow_schema::Schema { arrow_schema::Schema::new(vec![ @@ -320,7 +321,11 @@ impl EmbeddingTable { arrow_schema::Field::new( "embedding", arrow_schema::DataType::FixedSizeList( - Arc::new(arrow_schema::Field::new("item", arrow_schema::DataType::Float32, true)), + Arc::new(arrow_schema::Field::new( + "item", + arrow_schema::DataType::Float32, + true, + )), EMBEDDING_DIM, ), false, diff --git a/src/memory/maintenance.rs b/src/memory/maintenance.rs index 6a8bbdf3b..2dc378ec9 100644 --- a/src/memory/maintenance.rs +++ b/src/memory/maintenance.rs @@ -35,16 +35,16 @@ pub async fn run_maintenance( config: &MaintenanceConfig, ) -> Result { let mut report = MaintenanceReport::default(); - + // Apply decay to all non-identity memories report.decayed = apply_decay(memory_store, config.decay_rate).await?; - + // Prune old, low-importance memories report.pruned = prune_memories(memory_store, config).await?; - + // Merge near-duplicate memories report.merged = merge_similar_memories(memory_store, config.merge_similarity_threshold).await?; - + Ok(report) } @@ -56,17 +56,17 @@ async fn apply_decay(memory_store: &MemoryStore, decay_rate: f32) -> Result Result 0.01 { memory.importance = new_importance.clamp(0.0, 1.0); memory.updated_at = now; @@ -87,19 +87,16 @@ async fn apply_decay(memory_store: &MemoryStore, decay_rate: f32) -> Result Result { +async fn prune_memories(memory_store: &MemoryStore, config: &MaintenanceConfig) -> Result { let now = chrono::Utc::now(); let min_age = chrono::Duration::days(config.min_age_days); let cutoff_date = now - min_age; - + // Get all memories below threshold that are old enough let candidates = sqlx::query( r#" @@ -107,21 +104,21 @@ async fn prune_memories( WHERE importance < ? AND memory_type != 'identity' AND created_at < ? - "# + "#, ) .bind(config.prune_threshold) .bind(cutoff_date) .fetch_all(memory_store.pool()) .await?; - + let mut pruned_count = 0; - + for row in candidates { let id: String = sqlx::Row::try_get(&row, "id")?; memory_store.delete(&id).await?; pruned_count += 1; } - + Ok(pruned_count) } diff --git a/src/memory/search.rs b/src/memory/search.rs index 571f7d7b9..b4477d34a 100644 --- a/src/memory/search.rs +++ b/src/memory/search.rs @@ -71,17 +71,17 @@ impl MemorySearch { embedding_model, } } - + /// Get a reference to the memory store. pub fn store(&self) -> &MemoryStore { &self.store } - + /// Get a reference to the embedding table. pub fn embedding_table(&self) -> &EmbeddingTable { &self.embedding_table } - + /// Get a reference to the embedding model. pub fn embedding_model(&self) -> &EmbeddingModel { &self.embedding_model @@ -91,7 +91,7 @@ impl MemorySearch { pub fn embedding_model_arc(&self) -> &Arc { &self.embedding_model } - + /// Unified search entry point. Dispatches to the appropriate strategy /// based on `config.mode`. pub async fn search( @@ -152,16 +152,23 @@ impl MemorySearch { let mut vector_results = Vec::new(); let mut fts_results = Vec::new(); let mut graph_results = Vec::new(); - + // 1. Full-text search via LanceDB // FTS requires an inverted index. If the index doesn't exist yet (empty // table, first run) this will fail — fall back to vector + graph search. - match self.embedding_table.text_search(query, config.max_results_per_source).await { + match self + .embedding_table + .text_search(query, config.max_results_per_source) + .await + { Ok(fts_matches) => { for (memory_id, score) in fts_matches { if let Some(memory) = self.store.load(&memory_id).await? { if !memory.forgotten { - fts_results.push(ScoredMemory { memory, score: score as f64 }); + fts_results.push(ScoredMemory { + memory, + score: score as f64, + }); } } } @@ -170,16 +177,23 @@ impl MemorySearch { tracing::debug!(%error, "FTS search unavailable, falling back to vector + graph"); } } - + // 2. Vector similarity search via LanceDB let query_embedding = self.embedding_model.embed_one(query).await?; - match self.embedding_table.vector_search(&query_embedding, config.max_results_per_source).await { + match self + .embedding_table + .vector_search(&query_embedding, config.max_results_per_source) + .await + { Ok(vector_matches) => { for (memory_id, distance) in vector_matches { let similarity = 1.0 - distance; if let Some(memory) = self.store.load(&memory_id).await? { if !memory.forgotten { - vector_results.push(ScoredMemory { memory, score: similarity as f64 }); + vector_results.push(ScoredMemory { + memory, + score: similarity as f64, + }); } } } @@ -188,39 +202,40 @@ impl MemorySearch { tracing::debug!(%error, "vector search unavailable, falling back to graph only"); } } - + // 3. Graph traversal from high-importance memories // Get identity and high-importance memories as starting points let seed_memories = self.store.get_high_importance(0.8, 20).await?; - + for seed in seed_memories { // Check if seed is semantically related to query via simple keyword matching - if query.to_lowercase().split_whitespace().any(|term| { - seed.content.to_lowercase().contains(term) - }) { - graph_results.push(ScoredMemory { - memory: seed.clone(), - score: seed.importance as f64 + if query + .to_lowercase() + .split_whitespace() + .any(|term| seed.content.to_lowercase().contains(term)) + { + graph_results.push(ScoredMemory { + memory: seed.clone(), + score: seed.importance as f64, }); - + // Traverse graph to find related memories - self.traverse_graph(&seed.id, config.max_graph_depth, &mut graph_results).await?; + self.traverse_graph(&seed.id, config.max_graph_depth, &mut graph_results) + .await?; } } - + // 4. Merge results using Reciprocal Rank Fusion (RRF) - let fused_results = reciprocal_rank_fusion( - &vector_results, - &fts_results, - &graph_results, - config.rrf_k, - ); - + let fused_results = + reciprocal_rank_fusion(&vector_results, &fts_results, &graph_results, config.rrf_k); + // Convert to MemorySearchResult with ranks, applying optional type filter let results: Vec = fused_results .into_iter() .filter(|scored| { - config.memory_type.is_none_or(|t| scored.memory.memory_type == t) + config + .memory_type + .is_none_or(|t| scored.memory.memory_type == t) }) .enumerate() .map(|(rank, scored)| MemorySearchResult { @@ -231,10 +246,10 @@ impl MemorySearch { .filter(|r| r.score >= config.min_score) .take(config.max_results_per_source) .collect(); - + Ok(results) } - + /// Traverse the memory graph to find related memories (iterative to avoid async recursion). async fn traverse_graph( &self, @@ -243,21 +258,21 @@ impl MemorySearch { results: &mut Vec, ) -> Result<()> { use std::collections::VecDeque; - + // Queue of (memory_id, current_depth) let mut queue: VecDeque<(String, usize)> = VecDeque::new(); let mut visited: std::collections::HashSet = std::collections::HashSet::new(); - + queue.push_back((start_id.to_string(), 0)); visited.insert(start_id.to_string()); - + while let Some((current_id, depth)) = queue.pop_front() { if depth > max_depth { continue; } - + let associations = self.store.get_associations(¤t_id).await?; - + for assoc in associations { // Get the related memory let related_id = if assoc.source_id == current_id { @@ -265,12 +280,12 @@ impl MemorySearch { } else { &assoc.source_id }; - + if visited.contains(related_id) { continue; } visited.insert(related_id.clone()); - + if let Some(memory) = self.store.load(related_id).await? { if memory.forgotten { continue; @@ -283,19 +298,25 @@ impl MemorySearch { RelationType::Contradicts => 0.5, RelationType::PartOf => 0.8, }; - + let score = memory.importance as f64 * assoc.weight as f64 * type_multiplier; - - results.push(ScoredMemory { memory: memory.clone(), score }); - + + results.push(ScoredMemory { + memory: memory.clone(), + score, + }); + // Add to queue for RelatedTo and PartOf relations - if matches!(assoc.relation_type, RelationType::RelatedTo | RelationType::PartOf) { + if matches!( + assoc.relation_type, + RelationType::RelatedTo | RelationType::PartOf + ) { queue.push_back((related_id.clone(), depth + 1)); } } } } - + Ok(()) } } @@ -355,44 +376,50 @@ fn reciprocal_rank_fusion( ) -> Vec { // Build a map of memory ID to RRF score let mut rrf_scores: HashMap = HashMap::new(); - + // Add vector results for (rank, scored) in vector_results.iter().enumerate() { let rrf_score = 1.0 / (k + (rank as f64 + 1.0)); - let entry = rrf_scores.entry(scored.memory.id.clone()) + let entry = rrf_scores + .entry(scored.memory.id.clone()) .or_insert((0.0, scored.memory.clone())); entry.0 += rrf_score; } - + // Add FTS results for (rank, scored) in fts_results.iter().enumerate() { let rrf_score = 1.0 / (k + (rank as f64 + 1.0)); - let entry = rrf_scores.entry(scored.memory.id.clone()) + let entry = rrf_scores + .entry(scored.memory.id.clone()) .or_insert((0.0, scored.memory.clone())); entry.0 += rrf_score; } - + // Add graph results for (rank, scored) in graph_results.iter().enumerate() { let rrf_score = 1.0 / (k + (rank as f64 + 1.0)); - let entry = rrf_scores.entry(scored.memory.id.clone()) + let entry = rrf_scores + .entry(scored.memory.id.clone()) .or_insert((0.0, scored.memory.clone())); entry.0 += rrf_score; } - + // Convert to vec and sort by RRF score let mut fused: Vec = rrf_scores .into_iter() .map(|(_, (score, memory))| ScoredMemory { memory, score }) .collect(); - + fused.sort_by(|a, b| b.score.total_cmp(&a.score)); - + fused } /// Curate search results to return only the most relevant. -pub fn curate_results(results: &[MemorySearchResult], max_results: usize) -> Vec<&MemorySearchResult> { +pub fn curate_results( + results: &[MemorySearchResult], + max_results: usize, +) -> Vec<&MemorySearchResult> { results.iter().take(max_results).collect() } @@ -482,11 +509,36 @@ mod tests { let mut memories = Vec::new(); let types_and_importance = [ - ("user identity info", MemoryType::Identity, 1.0, now - Duration::days(30)), - ("recent event", MemoryType::Event, 0.4, now - Duration::hours(1)), - ("important decision", MemoryType::Decision, 0.9, now - Duration::days(2)), - ("casual observation", MemoryType::Observation, 0.2, now - Duration::days(7)), - ("user preference", MemoryType::Preference, 0.7, now - Duration::days(1)), + ( + "user identity info", + MemoryType::Identity, + 1.0, + now - Duration::days(30), + ), + ( + "recent event", + MemoryType::Event, + 0.4, + now - Duration::hours(1), + ), + ( + "important decision", + MemoryType::Decision, + 0.9, + now - Duration::days(2), + ), + ( + "casual observation", + MemoryType::Observation, + 0.2, + now - Duration::days(7), + ), + ( + "user preference", + MemoryType::Preference, + 0.7, + now - Duration::days(1), + ), ]; for (content, memory_type, importance, created_at) in types_and_importance { diff --git a/src/memory/store.rs b/src/memory/store.rs index 695c4a1aa..ca6f2cff4 100644 --- a/src/memory/store.rs +++ b/src/memory/store.rs @@ -27,12 +27,12 @@ impl MemoryStore { pub fn new(pool: SqlitePool) -> Arc { Arc::new(Self { pool }) } - + /// Get a reference to the SQLite pool. pub(crate) fn pool(&self) -> &SqlitePool { &self.pool } - + /// Save a new memory to the store. pub async fn save(&self, memory: &Memory) -> Result<()> { sqlx::query( @@ -40,7 +40,7 @@ impl MemoryStore { INSERT INTO memories (id, content, memory_type, importance, created_at, updated_at, last_accessed_at, access_count, source, channel_id, forgotten) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - "# + "#, ) .bind(&memory.id) .bind(&memory.content) @@ -56,10 +56,10 @@ impl MemoryStore { .execute(&self.pool) .await .with_context(|| format!("failed to save memory {}", memory.id))?; - + Ok(()) } - + /// Load a memory by ID. pub async fn load(&self, id: &str) -> Result> { let row = sqlx::query( @@ -68,16 +68,16 @@ impl MemoryStore { last_accessed_at, access_count, source, channel_id, forgotten FROM memories WHERE id = ? - "# + "#, ) .bind(id) .fetch_optional(&self.pool) .await .with_context(|| format!("failed to load memory {}", id))?; - + Ok(row.map(|row| row_to_memory(&row))) } - + /// Update an existing memory. pub async fn update(&self, memory: &Memory) -> Result<()> { sqlx::query( @@ -87,7 +87,7 @@ impl MemoryStore { last_accessed_at = ?, access_count = ?, source = ?, channel_id = ?, forgotten = ? WHERE id = ? - "# + "#, ) .bind(&memory.content) .bind(memory.memory_type.to_string()) @@ -102,10 +102,10 @@ impl MemoryStore { .execute(&self.pool) .await .with_context(|| format!("failed to update memory {}", memory.id))?; - + Ok(()) } - + /// Delete a memory by ID. pub async fn delete(&self, id: &str) -> Result<()> { sqlx::query("DELETE FROM memories WHERE id = ?") @@ -113,35 +113,35 @@ impl MemoryStore { .execute(&self.pool) .await .with_context(|| format!("failed to delete memory {}", id))?; - + Ok(()) } - + /// Record access to a memory, updating last_accessed_at and access_count. pub async fn record_access(&self, id: &str) -> Result<()> { let now = chrono::Utc::now(); - + sqlx::query( r#" UPDATE memories SET last_accessed_at = ?, access_count = access_count + 1 WHERE id = ? - "# + "#, ) .bind(now) .bind(id) .execute(&self.pool) .await .with_context(|| format!("failed to record access for memory {}", id))?; - + Ok(()) } - + /// Mark a memory as forgotten. The memory stays in the database but is /// excluded from search results and recall. pub async fn forget(&self, id: &str) -> Result { let result = sqlx::query( - "UPDATE memories SET forgotten = 1, updated_at = ? WHERE id = ? AND forgotten = 0" + "UPDATE memories SET forgotten = 1, updated_at = ? WHERE id = ? AND forgotten = 0", ) .bind(chrono::Utc::now()) .bind(id) @@ -160,7 +160,7 @@ impl MemoryStore { VALUES (?, ?, ?, ?, ?, ?) ON CONFLICT(source_id, target_id, relation_type) DO UPDATE SET weight = excluded.weight - "# + "#, ) .bind(&association.id) .bind(&association.source_id) @@ -176,10 +176,10 @@ impl MemoryStore { association.source_id, association.target_id ) })?; - + Ok(()) } - + /// Get all associations for a memory (both incoming and outgoing). pub async fn get_associations(&self, memory_id: &str) -> Result> { let rows = sqlx::query( @@ -187,25 +187,28 @@ impl MemoryStore { SELECT id, source_id, target_id, relation_type, weight, created_at FROM associations WHERE source_id = ? OR target_id = ? - "# + "#, ) .bind(memory_id) .bind(memory_id) .fetch_all(&self.pool) .await .with_context(|| format!("failed to get associations for memory {}", memory_id))?; - + let associations = rows .into_iter() .map(|row| row_to_association(&row)) .collect(); - + Ok(associations) } /// Get all associations where both source and target are in the provided set. /// Used by the graph view to fetch edges between a known set of visible nodes. - pub async fn get_associations_between(&self, memory_ids: &[String]) -> Result> { + pub async fn get_associations_between( + &self, + memory_ids: &[String], + ) -> Result> { if memory_ids.is_empty() { return Ok(Vec::new()); } @@ -233,7 +236,10 @@ impl MemoryStore { .await .context("failed to get associations between memory set")?; - Ok(rows.into_iter().map(|row| row_to_association(&row)).collect()) + Ok(rows + .into_iter() + .map(|row| row_to_association(&row)) + .collect()) } /// Get neighbors of a memory: all associations plus the connected memories. @@ -296,11 +302,11 @@ impl MemoryStore { Ok((neighbors, all_associations)) } - + /// Get memories by type. pub async fn get_by_type(&self, memory_type: MemoryType, limit: i64) -> Result> { let type_str = memory_type.to_string(); - + let rows = sqlx::query( r#" SELECT id, content, memory_type, importance, created_at, updated_at, @@ -309,17 +315,17 @@ impl MemoryStore { WHERE memory_type = ? AND forgotten = 0 ORDER BY importance DESC, updated_at DESC LIMIT ? - "# + "#, ) .bind(&type_str) .bind(limit) .fetch_all(&self.pool) .await .with_context(|| format!("failed to get memories by type {:?}", memory_type))?; - + Ok(rows.into_iter().map(|row| row_to_memory(&row)).collect()) } - + /// Get high-importance memories for injection into context. pub async fn get_high_importance(&self, threshold: f32, limit: i64) -> Result> { let rows = sqlx::query( @@ -330,14 +336,14 @@ impl MemoryStore { WHERE importance >= ? AND forgotten = 0 ORDER BY importance DESC, updated_at DESC LIMIT ? - "# + "#, ) .bind(threshold) .bind(limit) .fetch_all(&self.pool) .await .with_context(|| "failed to get high importance memories")?; - + Ok(rows.into_iter().map(|row| row_to_memory(&row)).collect()) } @@ -422,17 +428,23 @@ impl MemoryStore { fn row_to_memory(row: &sqlx::sqlite::SqliteRow) -> Memory { let mem_type_str: String = row.try_get("memory_type").unwrap_or_default(); let memory_type = parse_memory_type(&mem_type_str); - + let channel_id: Option = row.try_get("channel_id").ok(); - + Memory { id: row.try_get("id").unwrap_or_default(), content: row.try_get("content").unwrap_or_default(), memory_type, importance: row.try_get("importance").unwrap_or(0.5), - created_at: row.try_get("created_at").unwrap_or_else(|_| chrono::Utc::now()), - updated_at: row.try_get("updated_at").unwrap_or_else(|_| chrono::Utc::now()), - last_accessed_at: row.try_get("last_accessed_at").unwrap_or_else(|_| chrono::Utc::now()), + created_at: row + .try_get("created_at") + .unwrap_or_else(|_| chrono::Utc::now()), + updated_at: row + .try_get("updated_at") + .unwrap_or_else(|_| chrono::Utc::now()), + last_accessed_at: row + .try_get("last_accessed_at") + .unwrap_or_else(|_| chrono::Utc::now()), access_count: row.try_get("access_count").unwrap_or(0), source: row.try_get("source").ok(), channel_id: channel_id.map(|id| Arc::from(id) as crate::ChannelId), @@ -459,14 +471,16 @@ fn parse_memory_type(s: &str) -> MemoryType { fn row_to_association(row: &sqlx::sqlite::SqliteRow) -> Association { let relation_type_str: String = row.try_get("relation_type").unwrap_or_default(); let relation_type = parse_relation_type(&relation_type_str); - + Association { id: row.try_get("id").unwrap_or_default(), source_id: row.try_get("source_id").unwrap_or_default(), target_id: row.try_get("target_id").unwrap_or_default(), relation_type, weight: row.try_get("weight").unwrap_or(0.5), - created_at: row.try_get("created_at").unwrap_or_else(|_| chrono::Utc::now()), + created_at: row + .try_get("created_at") + .unwrap_or_else(|_| chrono::Utc::now()), } } @@ -520,11 +534,28 @@ mod tests { let store = MemoryStore::connect_in_memory().await; let now = Utc::now(); - let old = insert_memory_at(&store, "old", MemoryType::Fact, 0.5, now - Duration::hours(3)).await; - let mid = insert_memory_at(&store, "mid", MemoryType::Fact, 0.5, now - Duration::hours(1)).await; + let old = insert_memory_at( + &store, + "old", + MemoryType::Fact, + 0.5, + now - Duration::hours(3), + ) + .await; + let mid = insert_memory_at( + &store, + "mid", + MemoryType::Fact, + 0.5, + now - Duration::hours(1), + ) + .await; let new = insert_memory_at(&store, "new", MemoryType::Fact, 0.5, now).await; - let results = store.get_sorted(SearchSort::Recent, 10, None).await.unwrap(); + let results = store + .get_sorted(SearchSort::Recent, 10, None) + .await + .unwrap(); assert_eq!(results.len(), 3); assert_eq!(results[0].id, new.id); assert_eq!(results[1].id, mid.id); @@ -540,7 +571,10 @@ mod tests { let high = insert_memory_at(&store, "high", MemoryType::Fact, 0.9, now).await; let medium = insert_memory_at(&store, "medium", MemoryType::Fact, 0.5, now).await; - let results = store.get_sorted(SearchSort::Importance, 10, None).await.unwrap(); + let results = store + .get_sorted(SearchSort::Importance, 10, None) + .await + .unwrap(); assert_eq!(results[0].id, high.id); assert_eq!(results[1].id, medium.id); assert_eq!(results[2].id, low.id); @@ -560,7 +594,10 @@ mod tests { store.record_access(&b.id).await.unwrap(); } - let results = store.get_sorted(SearchSort::MostAccessed, 10, None).await.unwrap(); + let results = store + .get_sorted(SearchSort::MostAccessed, 10, None) + .await + .unwrap(); assert_eq!(results[0].id, b.id); assert_eq!(results[1].id, a.id); } @@ -612,7 +649,10 @@ mod tests { let forgotten = insert_memory_at(&store, "forgotten", MemoryType::Fact, 0.5, now).await; store.forget(&forgotten.id).await.unwrap(); - let results = store.get_sorted(SearchSort::Recent, 10, None).await.unwrap(); + let results = store + .get_sorted(SearchSort::Recent, 10, None) + .await + .unwrap(); assert_eq!(results.len(), 1); assert_eq!(results[0].id, visible.id); } From 92533cc2658ee41ba2f338dba9b621f93752bf9e Mon Sep 17 00:00:00 2001 From: nmb Date: Tue, 17 Feb 2026 16:31:20 +0100 Subject: [PATCH 07/11] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor(messaging):?= =?UTF-8?q?=20improve=20platform=20integrations?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/messaging.rs | 6 ++-- src/messaging/discord.rs | 72 +++++++++++++++++++++++++-------------- src/messaging/manager.rs | 18 +++++++--- src/messaging/slack.rs | 44 +++++++++++------------- src/messaging/telegram.rs | 32 ++++++++--------- src/messaging/traits.rs | 61 +++++++++++++++++++-------------- src/messaging/webhook.rs | 2 +- 7 files changed, 135 insertions(+), 100 deletions(-) diff --git a/src/messaging.rs b/src/messaging.rs index ded286bc6..76ecc7a2f 100644 --- a/src/messaging.rs +++ b/src/messaging.rs @@ -1,11 +1,11 @@ //! Messaging adapters (Discord, Slack, Telegram, Webhook). -pub mod traits; -pub mod manager; pub mod discord; +pub mod manager; pub mod slack; pub mod telegram; +pub mod traits; pub mod webhook; -pub use traits::Messaging; pub use manager::MessagingManager; +pub use traits::Messaging; diff --git a/src/messaging/discord.rs b/src/messaging/discord.rs index 98f0f510c..10b3c2b88 100644 --- a/src/messaging/discord.rs +++ b/src/messaging/discord.rs @@ -9,7 +9,7 @@ use arc_swap::ArcSwap; use async_trait::async_trait; use serenity::all::{ ChannelId, ChannelType, Context, CreateAttachment, CreateMessage, CreateThread, EditMessage, - EventHandler, GatewayIntents, GetMessages, Http, Message, MessageId, Ready, ReactionType, + EventHandler, GatewayIntents, GetMessages, Http, Message, MessageId, ReactionType, Ready, ShardManager, User, UserId, }; use std::collections::HashMap; @@ -30,10 +30,7 @@ pub struct DiscordAdapter { } impl DiscordAdapter { - pub fn new( - token: impl Into, - permissions: Arc>, - ) -> Self { + pub fn new(token: impl Into, permissions: Arc>) -> Self { Self { token: token.into(), permissions, @@ -138,15 +135,15 @@ impl Messaging for DiscordAdapter { let thread_result = match message_id { Some(source_message_id) => { - let builder = CreateThread::new(&thread_name) - .kind(ChannelType::PublicThread); + let builder = + CreateThread::new(&thread_name).kind(ChannelType::PublicThread); channel_id .create_thread_from_message(&*http, source_message_id, builder) .await } None => { - let builder = CreateThread::new(&thread_name) - .kind(ChannelType::PublicThread); + let builder = + CreateThread::new(&thread_name).kind(ChannelType::PublicThread); channel_id.create_thread(&*http, builder).await } }; @@ -178,7 +175,12 @@ impl Messaging for DiscordAdapter { } } } - OutboundResponse::File { filename, data, mime_type: _, caption } => { + OutboundResponse::File { + filename, + data, + mime_type: _, + caption, + } => { self.stop_typing(&message.id).await; let attachment = CreateAttachment::bytes(data, &filename); @@ -200,7 +202,11 @@ impl Messaging for DiscordAdapter { .context("missing discord_message_id for reaction")?; channel_id - .create_reaction(&*http, MessageId::new(message_id), ReactionType::Unicode(emoji)) + .create_reaction( + &*http, + MessageId::new(message_id), + ReactionType::Unicode(emoji), + ) .await .context("failed to add reaction")?; } @@ -267,11 +273,7 @@ impl Messaging for DiscordAdapter { Ok(()) } - async fn broadcast( - &self, - target: &str, - response: OutboundResponse, - ) -> crate::Result<()> { + async fn broadcast(&self, target: &str, response: OutboundResponse) -> crate::Result<()> { let http = self.get_http().await?; let channel_id = ChannelId::new( @@ -330,12 +332,18 @@ impl Messaging for DiscordAdapter { let resolved_content = resolve_mentions(&message.content, &message.mentions); - let display_name = message.author.global_name.as_deref() + let display_name = message + .author + .global_name + .as_deref() .unwrap_or(&message.author.name); // Include reply-to attribution if this message is a reply let author = if let Some(referenced) = &message.referenced_message { - let reply_author = referenced.author.global_name.as_deref() + let reply_author = referenced + .author + .global_name + .as_deref() .unwrap_or(&referenced.author.name); format!("{display_name} (replying to {reply_author})") } else { @@ -417,7 +425,9 @@ impl EventHandler for Handler { // DM filter: if no guild_id, it's a DM — only allow listed users if message.guild_id.is_none() { if permissions.dm_allowed_users.is_empty() - || !permissions.dm_allowed_users.contains(&message.author.id.get()) + || !permissions + .dm_allowed_users + .contains(&message.author.id.get()) { return; } @@ -444,8 +454,8 @@ impl EventHandler for Handler { .and_then(|v| v.as_u64()); let direct_match = allowed_channels.contains(&message.channel_id.get()); - let parent_match = parent_channel_id - .is_some_and(|pid| allowed_channels.contains(&pid)); + let parent_match = + parent_channel_id.is_some_and(|pid| allowed_channels.contains(&pid)); if !direct_match && !parent_match { return; @@ -540,11 +550,17 @@ async fn build_metadata(ctx: &Context, message: &Message) -> HashMap global display name > username let display_name = if let Some(member) = &message.member { member.nick.clone().unwrap_or_else(|| { - message.author.global_name.clone() + message + .author + .global_name + .clone() .unwrap_or_else(|| message.author.name.clone()) }) } else { - message.author.global_name.clone() + message + .author + .global_name + .clone() .unwrap_or_else(|| message.author.name.clone()) }; metadata.insert("sender_display_name".into(), display_name.into()); @@ -565,7 +581,10 @@ async fn build_metadata(ctx: &Context, message: &Message) -> HashMap HashMap crate::Result { let adapters = self.adapters.read().await; for (name, adapter) in adapters.iter() { - let stream = adapter.start().await + let stream = adapter + .start() + .await .with_context(|| format!("failed to start adapter '{name}'"))?; Self::spawn_forwarder(name.clone(), stream, self.fan_in_tx.clone()); } drop(adapters); - let receiver = self.fan_in_rx.write().await.take() + let receiver = self + .fan_in_rx + .write() + .await + .take() .context("start() already called")?; - Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(receiver))) + Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new( + receiver, + ))) } /// Register and start a new adapter at runtime. @@ -78,7 +86,9 @@ impl MessagingManager { let adapter: Arc = Arc::new(adapter); - let stream = adapter.start().await + let stream = adapter + .start() + .await .with_context(|| format!("failed to start adapter '{name}'"))?; Self::spawn_forwarder(name.clone(), stream, self.fan_in_tx.clone()); diff --git a/src/messaging/slack.rs b/src/messaging/slack.rs index 1d8889eb7..e652899ac 100644 --- a/src/messaging/slack.rs +++ b/src/messaging/slack.rs @@ -244,8 +244,7 @@ impl Messaging for SlackAdapter { bot_token: self.bot_token.clone(), }); - let callbacks = SlackSocketModeListenerCallbacks::new() - .with_push_events(handle_push_event); + let callbacks = SlackSocketModeListenerCallbacks::new().with_push_events(handle_push_event); let listener_environment = Arc::new( SlackClientEventsListenerEnvironment::new(client.clone()) @@ -331,12 +330,18 @@ impl Messaging for SlackAdapter { .context("failed to send slack thread reply")?; } } - OutboundResponse::File { filename, data, mime_type, caption } => { + OutboundResponse::File { + filename, + data, + mime_type, + caption, + } => { // Slack's v2 upload flow: get upload URL, upload bytes, complete let upload_url_response = session - .get_upload_url_external( - &SlackApiFilesGetUploadUrlExternalRequest::new(filename.clone(), data.len()), - ) + .get_upload_url_external(&SlackApiFilesGetUploadUrlExternalRequest::new( + filename.clone(), + data.len(), + )) .await .context("failed to get slack upload URL")?; @@ -350,13 +355,12 @@ impl Messaging for SlackAdapter { .context("failed to upload file to slack")?; let thread_ts = extract_thread_ts(message); - let file_complete = SlackApiFilesComplete::new(upload_url_response.file_id) - .with_title(filename); + let file_complete = + SlackApiFilesComplete::new(upload_url_response.file_id).with_title(filename); - let mut complete_request = SlackApiFilesCompleteUploadExternalRequest::new( - vec![file_complete], - ) - .with_channel_id(channel_id.clone()); + let mut complete_request = + SlackApiFilesCompleteUploadExternalRequest::new(vec![file_complete]) + .with_channel_id(channel_id.clone()); complete_request = complete_request.opt_initial_comment(caption); complete_request = complete_request.opt_thread_ts(thread_ts); @@ -367,8 +371,8 @@ impl Messaging for SlackAdapter { .context("failed to complete slack file upload")?; } OutboundResponse::Reaction(emoji) => { - let ts = extract_message_ts(message) - .context("missing slack_message_ts for reaction")?; + let ts = + extract_message_ts(message).context("missing slack_message_ts for reaction")?; let reaction_name = sanitize_reaction_name(&emoji); @@ -470,11 +474,7 @@ impl Messaging for SlackAdapter { Ok(()) } - async fn broadcast( - &self, - target: &str, - response: OutboundResponse, - ) -> crate::Result<()> { + async fn broadcast(&self, target: &str, response: OutboundResponse) -> crate::Result<()> { let (client, token) = self.create_session()?; let session = client.open_session(&token); let channel_id = SlackChannelId(target.to_string()); @@ -539,11 +539,7 @@ impl Messaging for SlackAdapter { HistoryMessage { author, - content: msg - .content - .text - .clone() - .unwrap_or_default(), + content: msg.content.text.clone().unwrap_or_default(), is_bot, } }) diff --git a/src/messaging/telegram.rs b/src/messaging/telegram.rs index 7af00b18b..d83545c3f 100644 --- a/src/messaging/telegram.rs +++ b/src/messaging/telegram.rs @@ -6,13 +6,13 @@ use crate::{Attachment, InboundMessage, MessageContent, OutboundResponse, Status use anyhow::Context as _; use arc_swap::ArcSwap; +use teloxide::Bot; use teloxide::payloads::setters::*; use teloxide::requests::{Request, Requester}; use teloxide::types::{ - ChatAction, ChatId, InputFile, MediaKind, MessageId, MessageKind, ReactionType, ReplyParameters, - UpdateKind, UserId, + ChatAction, ChatId, InputFile, MediaKind, MessageId, MessageKind, ReactionType, + ReplyParameters, UpdateKind, UserId, }; -use teloxide::Bot; use std::collections::HashMap; use std::sync::Arc; @@ -48,10 +48,7 @@ const MAX_MESSAGE_LENGTH: usize = 4096; const STREAM_EDIT_INTERVAL: std::time::Duration = std::time::Duration::from_millis(1000); impl TelegramAdapter { - pub fn new( - token: impl Into, - permissions: Arc>, - ) -> Self { + pub fn new(token: impl Into, permissions: Arc>) -> Self { let token = token.into(); let bot = Bot::new(&token); Self { @@ -249,7 +246,10 @@ impl Messaging for TelegramAdapter { .context("failed to send telegram message")?; } } - OutboundResponse::ThreadReply { thread_name: _, text } => { + OutboundResponse::ThreadReply { + thread_name: _, + text, + } => { self.stop_typing(&message.conversation_id).await; // Telegram doesn't have named threads. Reply to the source message instead. @@ -380,8 +380,10 @@ impl Messaging for TelegramAdapter { // Send one immediately, then repeat every 4 seconds. let handle = tokio::spawn(async move { loop { - if let Err(error) = - bot.send_chat_action(chat_id, ChatAction::Typing).send().await + if let Err(error) = bot + .send_chat_action(chat_id, ChatAction::Typing) + .send() + .await { tracing::debug!(%error, "failed to send typing indicator"); break; @@ -463,10 +465,7 @@ fn has_attachments(message: &teloxide::types::Message) -> bool { } /// Build `MessageContent` from a Telegram message. -fn build_content( - message: &teloxide::types::Message, - text: &Option, -) -> MessageContent { +fn build_content(message: &teloxide::types::Message, text: &Option) -> MessageContent { let attachments = extract_attachments(message); if attachments.is_empty() { @@ -635,10 +634,7 @@ fn build_metadata( metadata.insert("reply_to_text".into(), truncated.into()); } if let Some(from) = &reply.from { - metadata.insert( - "reply_to_author".into(), - build_display_name(from).into(), - ); + metadata.insert("reply_to_author".into(), build_display_name(from).into()); } } diff --git a/src/messaging/traits.rs b/src/messaging/traits.rs index c42147553..56631bfca 100644 --- a/src/messaging/traits.rs +++ b/src/messaging/traits.rs @@ -2,8 +2,8 @@ use crate::error::Result; use crate::{InboundMessage, OutboundResponse, StatusUpdate}; -use std::pin::Pin; use futures::Stream; +use std::pin::Pin; /// Message stream type. pub type InboundStream = Pin + Send>>; @@ -21,17 +21,17 @@ pub struct HistoryMessage { pub trait Messaging: Send + Sync + 'static { /// Unique name for this adapter. fn name(&self) -> &str; - + /// Start the adapter and return inbound message stream. fn start(&self) -> impl std::future::Future> + Send; - + /// Send a response to a message. fn respond( &self, message: &InboundMessage, response: OutboundResponse, ) -> impl std::future::Future> + Send; - + /// Send a status update. fn send_status( &self, @@ -40,7 +40,7 @@ pub trait Messaging: Send + Sync + 'static { ) -> impl std::future::Future> + Send { async { Ok(()) } } - + /// Broadcast a message. fn broadcast( &self, @@ -61,10 +61,10 @@ pub trait Messaging: Send + Sync + 'static { let _ = (message, limit); async { Ok(Vec::new()) } } - + /// Health check. fn health_check(&self) -> impl std::future::Future> + Send; - + /// Graceful shutdown. fn shutdown(&self) -> impl std::future::Future> + Send { async { Ok(()) } @@ -75,21 +75,23 @@ pub trait Messaging: Send + Sync + 'static { /// Use this when you need `Arc` for storing different adapters. pub trait MessagingDyn: Send + Sync + 'static { fn name(&self) -> &str; - - fn start<'a>(&'a self) -> Pin> + Send + 'a>>; - + + fn start<'a>( + &'a self, + ) -> Pin> + Send + 'a>>; + fn respond<'a>( &'a self, message: &'a InboundMessage, response: OutboundResponse, ) -> Pin> + Send + 'a>>; - + fn send_status<'a>( &'a self, message: &'a InboundMessage, status: StatusUpdate, ) -> Pin> + Send + 'a>>; - + fn broadcast<'a>( &'a self, target: &'a str, @@ -101,10 +103,13 @@ pub trait MessagingDyn: Send + Sync + 'static { message: &'a InboundMessage, limit: usize, ) -> Pin>> + Send + 'a>>; - - fn health_check<'a>(&'a self) -> Pin> + Send + 'a>>; - - fn shutdown<'a>(&'a self) -> Pin> + Send + 'a>>; + + fn health_check<'a>( + &'a self, + ) -> Pin> + Send + 'a>>; + + fn shutdown<'a>(&'a self) + -> Pin> + Send + 'a>>; } /// Blanket implementation: any type implementing Messaging automatically implements MessagingDyn. @@ -112,11 +117,13 @@ impl MessagingDyn for T { fn name(&self) -> &str { Messaging::name(self) } - - fn start<'a>(&'a self) -> Pin> + Send + 'a>> { + + fn start<'a>( + &'a self, + ) -> Pin> + Send + 'a>> { Box::pin(Messaging::start(self)) } - + fn respond<'a>( &'a self, message: &'a InboundMessage, @@ -124,7 +131,7 @@ impl MessagingDyn for T { ) -> Pin> + Send + 'a>> { Box::pin(Messaging::respond(self, message, response)) } - + fn send_status<'a>( &'a self, message: &'a InboundMessage, @@ -132,7 +139,7 @@ impl MessagingDyn for T { ) -> Pin> + Send + 'a>> { Box::pin(Messaging::send_status(self, message, status)) } - + fn broadcast<'a>( &'a self, target: &'a str, @@ -148,12 +155,16 @@ impl MessagingDyn for T { ) -> Pin>> + Send + 'a>> { Box::pin(Messaging::fetch_history(self, message, limit)) } - - fn health_check<'a>(&'a self) -> Pin> + Send + 'a>> { + + fn health_check<'a>( + &'a self, + ) -> Pin> + Send + 'a>> { Box::pin(Messaging::health_check(self)) } - - fn shutdown<'a>(&'a self) -> Pin> + Send + 'a>> { + + fn shutdown<'a>( + &'a self, + ) -> Pin> + Send + 'a>> { Box::pin(Messaging::shutdown(self)) } } diff --git a/src/messaging/webhook.rs b/src/messaging/webhook.rs index 431e899e0..49fc57660 100644 --- a/src/messaging/webhook.rs +++ b/src/messaging/webhook.rs @@ -9,9 +9,9 @@ use crate::messaging::traits::{InboundStream, Messaging}; use crate::{InboundMessage, MessageContent, OutboundResponse}; use anyhow::Context as _; +use axum::Router; use axum::extract::{Json, State}; use axum::http::StatusCode; -use axum::Router; use axum::routing::{get, post}; use serde::{Deserialize, Serialize}; From a0d4ffb2a4e8d23e248fce1c12356b4ab11d663d Mon Sep 17 00:00:00 2001 From: nmb Date: Tue, 17 Feb 2026 18:40:00 +0100 Subject: [PATCH 08/11] refactor(runtime): align daemon, hooks, and model plumbing - update startup/runtime wiring across daemon, main, hooks, and LLM integration\n- improve opencode worker/server flow and related test coverage\n- include matching docs updates for config and roadmap --- build.rs | 12 +- docs/content/docs/(configuration)/config.mdx | 2 + docs/content/docs/(deployment)/roadmap.mdx | 2 +- src/conversation.rs | 2 +- src/cron/scheduler.rs | 18 +- src/cron/store.rs | 6 +- src/daemon.rs | 44 ++-- src/db.rs | 24 +- src/hooks.rs | 4 +- src/hooks/cortex.rs | 6 +- src/hooks/spacebot.rs | 17 +- src/identity/files.rs | 21 +- src/lib.rs | 38 +++- src/llm/manager.rs | 68 ++++-- src/llm/model.rs | 193 ++++++++++------ src/main.rs | 159 +++++++++---- src/opencode/server.rs | 92 ++++---- src/opencode/worker.rs | 139 ++++++++---- src/prompts/engine.rs | 2 +- src/settings.rs | 2 +- src/skills.rs | 94 +++++--- src/update.rs | 11 +- tests/bulletin.rs | 59 +++-- tests/context_dump.rs | 227 ++++++++++++++----- tests/opencode_stream.rs | 21 +- 25 files changed, 834 insertions(+), 429 deletions(-) diff --git a/build.rs b/build.rs index ff9f77571..2507c62f5 100644 --- a/build.rs +++ b/build.rs @@ -15,7 +15,9 @@ fn main() { // Skip if bun isn't installed or node_modules is missing (CI without frontend deps) if !interface_dir.join("node_modules").exists() { - eprintln!("cargo:warning=interface/node_modules not found, skipping frontend build. Run `bun install` in interface/"); + eprintln!( + "cargo:warning=interface/node_modules not found, skipping frontend build. Run `bun install` in interface/" + ); ensure_dist_dir(); return; } @@ -28,10 +30,14 @@ fn main() { match status { Ok(s) if s.success() => {} Ok(s) => { - eprintln!("cargo:warning=frontend build exited with {s}, the binary will serve a stale or empty UI"); + eprintln!( + "cargo:warning=frontend build exited with {s}, the binary will serve a stale or empty UI" + ); } Err(e) => { - eprintln!("cargo:warning=failed to run `bun run build`: {e}. Install bun to build the frontend."); + eprintln!( + "cargo:warning=failed to run `bun run build`: {e}. Install bun to build the frontend." + ); ensure_dist_dir(); } } diff --git a/docs/content/docs/(configuration)/config.mdx b/docs/content/docs/(configuration)/config.mdx index 88ca4630c..334f02f4b 100644 --- a/docs/content/docs/(configuration)/config.mdx +++ b/docs/content/docs/(configuration)/config.mdx @@ -24,6 +24,7 @@ spacebot --config /path/to.toml # CLI override anthropic_key = "env:ANTHROPIC_API_KEY" openai_key = "env:OPENAI_API_KEY" openrouter_key = "env:OPENROUTER_API_KEY" +ollama_key = "env:OLLAMA_API_KEY" # --- Instance Defaults --- # All agents inherit these. Individual agents can override any field. @@ -168,6 +169,7 @@ Model names include the provider as a prefix: | Anthropic | `anthropic/` | `anthropic/claude-sonnet-4-20250514` | | OpenAI | `openai/` | `openai/gpt-4o` | | OpenRouter | `openrouter//` | `openrouter/anthropic/claude-sonnet-4-20250514` | +| Ollama Cloud | `ollama/` | `ollama/gpt-oss:20b` | You can mix providers across process types. See [Routing](/docs/routing) for the full routing system. diff --git a/docs/content/docs/(deployment)/roadmap.mdx b/docs/content/docs/(deployment)/roadmap.mdx index 6711bdd2d..52401c0a8 100644 --- a/docs/content/docs/(deployment)/roadmap.mdx +++ b/docs/content/docs/(deployment)/roadmap.mdx @@ -18,7 +18,7 @@ The full message-in → LLM → response-out pipeline is wired end-to-end across - **Config** — hierarchical TOML with `Config`, `AgentConfig`, `ResolvedAgentConfig`, `Binding`, `MessagingConfig`. File watcher with event filtering and content hash debounce for hot-reload. - **Multi-agent** — per-agent database isolation, `Agent` struct bundles all dependencies - **Database connections** — SQLite + LanceDB + redb per-agent, migrations for all tables -- **LLM** — `SpacebotModel` implements Rig's `CompletionModel`, routes through `LlmManager` via HTTP with retries and fallback chains across 11 providers (Anthropic, OpenAI, OpenRouter, Z.ai, Groq, Together, Fireworks, DeepSeek, xAI, Mistral, OpenCode Zen) +- **LLM** — `SpacebotModel` implements Rig's `CompletionModel`, routes through `LlmManager` via HTTP with retries and fallback chains across 12 providers (Anthropic, OpenAI, OpenRouter, Ollama Cloud, Z.ai, Groq, Together, Fireworks, DeepSeek, xAI, Mistral, OpenCode Zen) - **Model routing** — `RoutingConfig` with process-type defaults, task overrides, fallback chains - **Memory** — full stack: types, SQLite store (CRUD + graph), LanceDB (embeddings + vector + FTS), fastembed, hybrid search (RRF fusion). `memory_type` filter wired end-to-end through SearchConfig. `total_cmp` for safe sorting. - **Memory maintenance** — decay + prune implemented diff --git a/src/conversation.rs b/src/conversation.rs index 47a3dae34..0097a7d42 100644 --- a/src/conversation.rs +++ b/src/conversation.rs @@ -1,8 +1,8 @@ //! Conversation history and context management. pub mod channels; -pub mod history; pub mod context; +pub mod history; pub use channels::ChannelStore; pub use history::{ConversationLogger, ProcessRunLogger, TimelineItem}; diff --git a/src/cron/scheduler.rs b/src/cron/scheduler.rs index 4e1b277ff..ebe33bd7d 100644 --- a/src/cron/scheduler.rs +++ b/src/cron/scheduler.rs @@ -14,7 +14,7 @@ use chrono::Timelike; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::RwLock; -use tokio::time::{interval, Duration}; +use tokio::time::{Duration, interval}; /// A cron job definition loaded from the database. #[derive(Debug, Clone)] @@ -165,9 +165,7 @@ impl Scheduler { // Look up interval before entering the loop let interval_secs = { let j = jobs.read().await; - j.get(&job_id) - .map(|j| j.interval_secs) - .unwrap_or(3600) + j.get(&job_id).map(|j| j.interval_secs).unwrap_or(3600) }; let mut ticker = interval(Duration::from_secs(interval_secs)); @@ -323,9 +321,9 @@ impl Scheduler { if let Some(job) = job { if !job.enabled { - return Err(crate::error::Error::Other( - anyhow::anyhow!("cron job is disabled"), - )); + return Err(crate::error::Error::Other(anyhow::anyhow!( + "cron job is disabled" + ))); } tracing::info!(cron_id = %job_id, "cron job triggered manually"); @@ -453,11 +451,7 @@ async fn run_cron_job(job: &CronJob, context: &CronContext) -> Result<()> { } else { None }; - if let Err(error) = context - .store - .log_execution(&job.id, true, summary) - .await - { + if let Err(error) = context.store.log_execution(&job.id, true, summary).await { tracing::warn!(%error, "failed to log cron execution"); } diff --git a/src/cron/store.rs b/src/cron/store.rs index f66882bf9..e61fd8fce 100644 --- a/src/cron/store.rs +++ b/src/cron/store.rs @@ -170,7 +170,11 @@ impl CronStore { } /// Load execution history for a specific cron job. - pub async fn load_executions(&self, cron_id: &str, limit: i64) -> Result> { + pub async fn load_executions( + &self, + cron_id: &str, + limit: i64, + ) -> Result> { let rows = sqlx::query( r#" SELECT id, executed_at, success, result_summary diff --git a/src/daemon.rs b/src/daemon.rs index 61a4d9732..cf6bdb15a 100644 --- a/src/daemon.rs +++ b/src/daemon.rs @@ -24,13 +24,8 @@ pub enum IpcCommand { #[serde(tag = "result", rename_all = "snake_case")] pub enum IpcResponse { Ok, - Status { - pid: u32, - uptime_seconds: u64, - }, - Error { - message: String, - }, + Status { pid: u32, uptime_seconds: u64 }, + Error { message: String }, } /// Paths for daemon runtime files, all derived from the instance directory. @@ -84,8 +79,12 @@ pub fn is_running(paths: &DaemonPaths) -> Option { /// Daemonize the current process. Returns in the child; the parent prints /// a message and exits. pub fn daemonize(paths: &DaemonPaths) -> anyhow::Result<()> { - std::fs::create_dir_all(&paths.log_dir) - .with_context(|| format!("failed to create log directory: {}", paths.log_dir.display()))?; + std::fs::create_dir_all(&paths.log_dir).with_context(|| { + format!( + "failed to create log directory: {}", + paths.log_dir.display() + ) + })?; let stdout = std::fs::OpenOptions::new() .create(true) @@ -105,7 +104,9 @@ pub fn daemonize(paths: &DaemonPaths) -> anyhow::Result<()> { .stdout(stdout) .stderr(stderr); - daemonize.start().map_err(|error| anyhow!("failed to daemonize: {error}"))?; + daemonize + .start() + .map_err(|error| anyhow!("failed to daemonize: {error}"))?; Ok(()) } @@ -140,9 +141,7 @@ pub fn init_foreground_tracing(debug: bool) { tracing_subscriber::EnvFilter::new("info") }; - tracing_subscriber::fmt() - .with_env_filter(filter) - .init(); + tracing_subscriber::fmt().with_env_filter(filter).init(); } /// Start the IPC server. Returns a shutdown receiver that the main event @@ -152,8 +151,9 @@ pub async fn start_ipc_server( ) -> anyhow::Result<(watch::Receiver, tokio::task::JoinHandle<()>)> { // Clean up any stale socket file if paths.socket.exists() { - std::fs::remove_file(&paths.socket) - .with_context(|| format!("failed to remove stale socket: {}", paths.socket.display()))?; + std::fs::remove_file(&paths.socket).with_context(|| { + format!("failed to remove stale socket: {}", paths.socket.display()) + })?; } let listener = UnixListener::bind(&paths.socket) @@ -170,7 +170,9 @@ pub async fn start_ipc_server( let shutdown_tx = shutdown_tx.clone(); let uptime = start_time.elapsed(); tokio::spawn(async move { - if let Err(error) = handle_ipc_connection(stream, &shutdown_tx, uptime).await { + if let Err(error) = + handle_ipc_connection(stream, &shutdown_tx, uptime).await + { tracing::warn!(%error, "IPC connection handler failed"); } }); @@ -213,12 +215,10 @@ async fn handle_ipc_connection( shutdown_tx.send(true).ok(); IpcResponse::Ok } - IpcCommand::Status => { - IpcResponse::Status { - pid: std::process::id(), - uptime_seconds: uptime.as_secs(), - } - } + IpcCommand::Status => IpcResponse::Status { + pid: std::process::id(), + uptime_seconds: uptime.as_secs(), + }, }; let mut response_bytes = serde_json::to_vec(&response)?; diff --git a/src/db.rs b/src/db.rs index 6e5032d60..5a1a3c54e 100644 --- a/src/db.rs +++ b/src/db.rs @@ -9,10 +9,10 @@ use std::path::Path; pub struct Db { /// SQLite pool for relational data. pub sqlite: SqlitePool, - + /// LanceDB connection for vector storage. pub lance: lancedb::Connection, - + /// Redb database for key-value config. pub redb: Arc, } @@ -27,35 +27,39 @@ impl Db { let sqlite = SqlitePool::connect(&sqlite_url) .await .with_context(|| "failed to connect to SQLite")?; - + // Run migrations sqlx::migrate!("./migrations") .run(&sqlite) .await .with_context(|| "failed to run database migrations")?; - + // LanceDB let lance_path = data_dir.join("lancedb"); - std::fs::create_dir_all(&lance_path) - .with_context(|| format!("failed to create LanceDB directory: {}", lance_path.display()))?; - + std::fs::create_dir_all(&lance_path).with_context(|| { + format!( + "failed to create LanceDB directory: {}", + lance_path.display() + ) + })?; + let lance = lancedb::connect(lance_path.to_str().unwrap_or("./lancedb")) .execute() .await .map_err(|e| DbError::LanceConnect(e.to_string()))?; - + // Redb let redb_path = data_dir.join("config.redb"); let redb = redb::Database::create(&redb_path) .with_context(|| format!("failed to create redb at: {}", redb_path.display()))?; - + Ok(Self { sqlite, lance, redb: Arc::new(redb), }) } - + /// Close all database connections gracefully. pub async fn close(self) { self.sqlite.close().await; diff --git a/src/hooks.rs b/src/hooks.rs index cc46ae59b..b1b02c0fc 100644 --- a/src/hooks.rs +++ b/src/hooks.rs @@ -1,7 +1,7 @@ //! Prompt hooks for observing and controlling agent behavior. -pub mod spacebot; pub mod cortex; +pub mod spacebot; -pub use spacebot::SpacebotHook; pub use cortex::CortexHook; +pub use spacebot::SpacebotHook; diff --git a/src/hooks/cortex.rs b/src/hooks/cortex.rs index 564a0ec49..538762f69 100644 --- a/src/hooks/cortex.rs +++ b/src/hooks/cortex.rs @@ -24,11 +24,7 @@ impl PromptHook for CortexHook where M: CompletionModel, { - async fn on_completion_call( - &self, - _prompt: &Message, - _history: &[Message], - ) -> HookAction { + async fn on_completion_call(&self, _prompt: &Message, _history: &[Message]) -> HookAction { // Cortex observes but doesn't intervene tracing::trace!("cortex: completion call observed"); HookAction::Continue diff --git a/src/hooks/spacebot.rs b/src/hooks/spacebot.rs index 55b6bab73..6beeb3660 100644 --- a/src/hooks/spacebot.rs +++ b/src/hooks/spacebot.rs @@ -63,7 +63,8 @@ impl SpacebotHook { // Google API keys Regex::new(r"AIza[0-9A-Za-z_-]{35}").expect("hardcoded regex"), // Discord bot tokens (base64 user ID . timestamp . HMAC) - Regex::new(r"[MN][A-Za-z0-9]{23,}\.[A-Za-z0-9_-]{6}\.[A-Za-z0-9_-]{27,}").expect("hardcoded regex"), + Regex::new(r"[MN][A-Za-z0-9]{23,}\.[A-Za-z0-9_-]{6}\.[A-Za-z0-9_-]{27,}") + .expect("hardcoded regex"), // Slack bot tokens Regex::new(r"xoxb-[0-9]{10,}-[0-9A-Za-z-]+").expect("hardcoded regex"), // Slack app tokens @@ -89,11 +90,7 @@ impl PromptHook for SpacebotHook where M: CompletionModel, { - async fn on_completion_call( - &self, - _prompt: &Message, - _history: &[Message], - ) -> HookAction { + async fn on_completion_call(&self, _prompt: &Message, _history: &[Message]) -> HookAction { // Log the completion call but don't block it tracing::debug!( process_id = %self.process_id, @@ -178,12 +175,16 @@ where leak_prefix = %&leak[..leak.len().min(8)], "secret leak detected in tool output, terminating agent" ); - return HookAction::Terminate { reason: "Tool output contained a secret. Agent terminated to prevent exfiltration.".into() }; + return HookAction::Terminate { + reason: "Tool output contained a secret. Agent terminated to prevent exfiltration." + .into(), + }; } // Cap the result stored in the broadcast event to avoid blowing up // event subscribers with multi-MB tool results. - let capped_result = crate::tools::truncate_output(result, crate::tools::MAX_TOOL_OUTPUT_BYTES); + let capped_result = + crate::tools::truncate_output(result, crate::tools::MAX_TOOL_OUTPUT_BYTES); let event = ProcessEvent::ToolCompleted { agent_id: self.agent_id.clone(), process_id: self.process_id.clone(), diff --git a/src/identity/files.rs b/src/identity/files.rs index 29ceb843f..e5b5c8d2e 100644 --- a/src/identity/files.rs +++ b/src/identity/files.rs @@ -47,9 +47,18 @@ impl Identity { /// Default identity file templates for new agents. const DEFAULT_IDENTITY_FILES: &[(&str, &str)] = &[ - ("SOUL.md", "\n"), - ("IDENTITY.md", "\n"), - ("USER.md", "\n"), + ( + "SOUL.md", + "\n", + ), + ( + "IDENTITY.md", + "\n", + ), + ( + "USER.md", + "\n", + ), ]; /// Write template identity files into an agent's workspace if they don't already exist. @@ -59,9 +68,9 @@ pub async fn scaffold_identity_files(workspace: &Path) -> crate::error::Result<( for (filename, content) in DEFAULT_IDENTITY_FILES { let target = workspace.join(filename); if !target.exists() { - tokio::fs::write(&target, content) - .await - .with_context(|| format!("failed to write identity template: {}", target.display()))?; + tokio::fs::write(&target, content).await.with_context(|| { + format!("failed to write identity template: {}", target.display()) + })?; tracing::info!(path = %target.display(), "wrote identity template"); } } diff --git a/src/lib.rs b/src/lib.rs index d97ef5901..6f6ff7717 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,10 +4,10 @@ pub mod agent; pub mod api; pub mod config; pub mod conversation; +pub mod cron; pub mod daemon; pub mod db; pub mod error; -pub mod cron; pub mod hooks; pub mod identity; pub mod llm; @@ -103,6 +103,12 @@ pub enum ProcessEvent { channel_id: ChannelId, conclusion: String, }, + BranchFailed { + agent_id: AgentId, + branch_id: BranchId, + channel_id: ChannelId, + error: String, + }, WorkerStarted { agent_id: AgentId, worker_id: WorkerId, @@ -180,8 +186,12 @@ pub struct AgentDeps { } impl AgentDeps { - pub fn memory_search(&self) -> &Arc { &self.memory_search } - pub fn llm_manager(&self) -> &Arc { &self.llm_manager } + pub fn memory_search(&self) -> &Arc { + &self.memory_search + } + pub fn llm_manager(&self) -> &Arc { + &self.llm_manager + } /// Load the current routing config snapshot. pub fn routing(&self) -> arc_swap::Guard> { @@ -297,9 +307,21 @@ pub enum StatusUpdate { Thinking, /// Cancel the typing indicator (e.g. when the skip tool fires). StopTyping, - ToolStarted { tool_name: String }, - ToolCompleted { tool_name: String }, - BranchStarted { branch_id: BranchId }, - WorkerStarted { worker_id: WorkerId, task: String }, - WorkerCompleted { worker_id: WorkerId, result: String }, + ToolStarted { + tool_name: String, + }, + ToolCompleted { + tool_name: String, + }, + BranchStarted { + branch_id: BranchId, + }, + WorkerStarted { + worker_id: WorkerId, + task: String, + }, + WorkerCompleted { + worker_id: WorkerId, + result: String, + }, } diff --git a/src/llm/manager.rs b/src/llm/manager.rs index 80c476e2e..40356acc5 100644 --- a/src/llm/manager.rs +++ b/src/llm/manager.rs @@ -38,27 +38,65 @@ impl LlmManager { /// Get the appropriate API key for a provider. pub fn get_api_key(&self, provider: &str) -> Result { match provider { - "anthropic" => self.config.anthropic_key.clone() + "anthropic" => self + .config + .anthropic_key + .clone() .ok_or_else(|| LlmError::MissingProviderKey("anthropic".into()).into()), - "openai" => self.config.openai_key.clone() + "openai" => self + .config + .openai_key + .clone() .ok_or_else(|| LlmError::MissingProviderKey("openai".into()).into()), - "openrouter" => self.config.openrouter_key.clone() + "openrouter" => self + .config + .openrouter_key + .clone() .ok_or_else(|| LlmError::MissingProviderKey("openrouter".into()).into()), - "zhipu" => self.config.zhipu_key.clone() + "ollama" => self + .config + .ollama_key + .clone() + .ok_or_else(|| LlmError::MissingProviderKey("ollama".into()).into()), + "zhipu" => self + .config + .zhipu_key + .clone() .ok_or_else(|| LlmError::MissingProviderKey("zhipu".into()).into()), - "groq" => self.config.groq_key.clone() + "groq" => self + .config + .groq_key + .clone() .ok_or_else(|| LlmError::MissingProviderKey("groq".into()).into()), - "together" => self.config.together_key.clone() + "together" => self + .config + .together_key + .clone() .ok_or_else(|| LlmError::MissingProviderKey("together".into()).into()), - "fireworks" => self.config.fireworks_key.clone() + "fireworks" => self + .config + .fireworks_key + .clone() .ok_or_else(|| LlmError::MissingProviderKey("fireworks".into()).into()), - "deepseek" => self.config.deepseek_key.clone() + "deepseek" => self + .config + .deepseek_key + .clone() .ok_or_else(|| LlmError::MissingProviderKey("deepseek".into()).into()), - "xai" => self.config.xai_key.clone() + "xai" => self + .config + .xai_key + .clone() .ok_or_else(|| LlmError::MissingProviderKey("xai".into()).into()), - "mistral" => self.config.mistral_key.clone() + "mistral" => self + .config + .mistral_key + .clone() .ok_or_else(|| LlmError::MissingProviderKey("mistral".into()).into()), - "opencode-zen" => self.config.opencode_zen_key.clone() + "opencode-zen" => self + .config + .opencode_zen_key + .clone() .ok_or_else(|| LlmError::MissingProviderKey("opencode-zen".into()).into()), _ => Err(LlmError::UnknownProvider(provider.into()).into()), } @@ -81,7 +119,9 @@ impl LlmManager { /// Record that a model hit a rate limit. pub async fn record_rate_limit(&self, model_name: &str) { - self.rate_limited.write().await + self.rate_limited + .write() + .await .insert(model_name.to_string(), Instant::now()); tracing::warn!(model = %model_name, "model rate limited, entering cooldown"); } @@ -98,7 +138,9 @@ impl LlmManager { /// Clean up expired rate limit entries. pub async fn cleanup_rate_limits(&self, cooldown_secs: u64) { - self.rate_limited.write().await + self.rate_limited + .write() + .await .retain(|_, limited_at| limited_at.elapsed().as_secs() < cooldown_secs); } } diff --git a/src/llm/model.rs b/src/llm/model.rs index 6394a4354..e8e2c156a 100644 --- a/src/llm/model.rs +++ b/src/llm/model.rs @@ -2,12 +2,10 @@ use crate::llm::manager::LlmManager; use crate::llm::routing::{ - self, RoutingConfig, MAX_FALLBACK_ATTEMPTS, MAX_RETRIES_PER_MODEL, RETRY_BASE_DELAY_MS, + self, MAX_FALLBACK_ATTEMPTS, MAX_RETRIES_PER_MODEL, RETRY_BASE_DELAY_MS, RoutingConfig, }; -use rig::completion::{ - self, CompletionError, CompletionModel, CompletionRequest, GetTokenUsage, -}; +use rig::completion::{self, CompletionError, CompletionModel, CompletionRequest, GetTokenUsage}; use rig::message::{ AssistantContent, DocumentSourceKind, Image, Message, MimeType, Text, ToolCall, ToolFunction, ToolResult, UserContent, @@ -50,9 +48,15 @@ pub struct SpacebotModel { } impl SpacebotModel { - pub fn provider(&self) -> &str { &self.provider } - pub fn model_name(&self) -> &str { &self.model_name } - pub fn full_model_name(&self) -> &str { &self.full_model_name } + pub fn provider(&self) -> &str { + &self.provider + } + pub fn model_name(&self) -> &str { + &self.model_name + } + pub fn full_model_name(&self) -> &str { + &self.full_model_name + } /// Attach routing config for fallback behavior. pub fn with_routing(mut self, routing: RoutingConfig) -> Self { @@ -69,6 +73,7 @@ impl SpacebotModel { "anthropic" => self.call_anthropic(request).await, "openai" => self.call_openai(request).await, "openrouter" => self.call_openrouter(request).await, + "ollama" => self.call_ollama(request).await, "zhipu" => self.call_zhipu(request).await, "groq" => self.call_groq(request).await, "together" => self.call_together(request).await, @@ -93,10 +98,7 @@ impl SpacebotModel { &self, model_name: &str, request: &CompletionRequest, - ) -> Result< - completion::CompletionResponse, - (CompletionError, bool), - > { + ) -> Result, (CompletionError, bool)> { let model = if model_name == self.full_model_name { self.clone() } else { @@ -203,11 +205,16 @@ impl CompletionModel for SpacebotModel { "primary model in rate-limit cooldown, skipping to fallbacks" ); } else { - match self.attempt_with_retries(&self.full_model_name, &request).await { + match self + .attempt_with_retries(&self.full_model_name, &request) + .await + { Ok(response) => return Ok(response), Err((error, was_rate_limit)) => { if was_rate_limit { - self.llm_manager.record_rate_limit(&self.full_model_name).await; + self.llm_manager + .record_rate_limit(&self.full_model_name) + .await; } if fallbacks.is_empty() { // No fallbacks — this is the final error @@ -224,7 +231,11 @@ impl CompletionModel for SpacebotModel { // Try fallback chain, each with their own retry loop for (index, fallback_name) in fallbacks.iter().take(MAX_FALLBACK_ATTEMPTS).enumerate() { - if self.llm_manager.is_rate_limited(fallback_name, cooldown).await { + if self + .llm_manager + .is_rate_limited(fallback_name, cooldown) + .await + { tracing::debug!( fallback = %fallback_name, "fallback model in cooldown, skipping" @@ -324,15 +335,17 @@ impl SpacebotModel { .map_err(|e| CompletionError::ProviderError(e.to_string()))?; let status = response.status(); - let response_text = response - .text() - .await - .map_err(|e| CompletionError::ProviderError(format!("failed to read response body: {e}")))?; - - let response_body: serde_json::Value = serde_json::from_str(&response_text) - .map_err(|e| CompletionError::ProviderError(format!( - "Anthropic response ({status}) is not valid JSON: {e}\nBody: {}", truncate_body(&response_text) - )))?; + let response_text = response.text().await.map_err(|e| { + CompletionError::ProviderError(format!("failed to read response body: {e}")) + })?; + + let response_body: serde_json::Value = + serde_json::from_str(&response_text).map_err(|e| { + CompletionError::ProviderError(format!( + "Anthropic response ({status}) is not valid JSON: {e}\nBody: {}", + truncate_body(&response_text) + )) + })?; if !status.is_success() { let message = response_body["error"]["message"] @@ -409,15 +422,17 @@ impl SpacebotModel { .map_err(|e| CompletionError::ProviderError(e.to_string()))?; let status = response.status(); - let response_text = response - .text() - .await - .map_err(|e| CompletionError::ProviderError(format!("failed to read response body: {e}")))?; - - let response_body: serde_json::Value = serde_json::from_str(&response_text) - .map_err(|e| CompletionError::ProviderError(format!( - "OpenAI response ({status}) is not valid JSON: {e}\nBody: {}", truncate_body(&response_text) - )))?; + let response_text = response.text().await.map_err(|e| { + CompletionError::ProviderError(format!("failed to read response body: {e}")) + })?; + + let response_body: serde_json::Value = + serde_json::from_str(&response_text).map_err(|e| { + CompletionError::ProviderError(format!( + "OpenAI response ({status}) is not valid JSON: {e}\nBody: {}", + truncate_body(&response_text) + )) + })?; if !status.is_success() { let message = response_body["error"]["message"] @@ -496,15 +511,17 @@ impl SpacebotModel { .map_err(|e| CompletionError::ProviderError(e.to_string()))?; let status = response.status(); - let response_text = response - .text() - .await - .map_err(|e| CompletionError::ProviderError(format!("failed to read response body: {e}")))?; - - let response_body: serde_json::Value = serde_json::from_str(&response_text) - .map_err(|e| CompletionError::ProviderError(format!( - "OpenRouter response ({status}) is not valid JSON: {e}\nBody: {}", truncate_body(&response_text) - )))?; + let response_text = response.text().await.map_err(|e| { + CompletionError::ProviderError(format!("failed to read response body: {e}")) + })?; + + let response_body: serde_json::Value = + serde_json::from_str(&response_text).map_err(|e| { + CompletionError::ProviderError(format!( + "OpenRouter response ({status}) is not valid JSON: {e}\nBody: {}", + truncate_body(&response_text) + )) + })?; if !status.is_success() { let message = response_body["error"]["message"] @@ -582,15 +599,17 @@ impl SpacebotModel { .map_err(|e| CompletionError::ProviderError(e.to_string()))?; let status = response.status(); - let response_text = response - .text() - .await - .map_err(|e| CompletionError::ProviderError(format!("failed to read response body: {e}")))?; - - let response_body: serde_json::Value = serde_json::from_str(&response_text) - .map_err(|e| CompletionError::ProviderError(format!( - "Z.ai response ({status}) is not valid JSON: {e}\nBody: {}", truncate_body(&response_text) - )))?; + let response_text = response.text().await.map_err(|e| { + CompletionError::ProviderError(format!("failed to read response body: {e}")) + })?; + + let response_body: serde_json::Value = + serde_json::from_str(&response_text).map_err(|e| { + CompletionError::ProviderError(format!( + "Z.ai response ({status}) is not valid JSON: {e}\nBody: {}", + truncate_body(&response_text) + )) + })?; if !status.is_success() { let message = response_body["error"]["message"] @@ -604,6 +623,19 @@ impl SpacebotModel { parse_openai_response(response_body, "Z.ai") } + async fn call_ollama( + &self, + request: CompletionRequest, + ) -> Result, CompletionError> { + self.call_openai_compatible( + request, + "ollama", + "Ollama", + "https://ollama.com/v1/chat/completions", + ) + .await + } + /// Generic OpenAI-compatible API call. /// Used by providers that implement the OpenAI chat completions format. async fn call_openai_compatible( @@ -672,15 +704,17 @@ impl SpacebotModel { .map_err(|e| CompletionError::ProviderError(e.to_string()))?; let status = response.status(); - let response_text = response - .text() - .await - .map_err(|e| CompletionError::ProviderError(format!("failed to read response body: {e}")))?; - - let response_body: serde_json::Value = serde_json::from_str(&response_text) - .map_err(|e| CompletionError::ProviderError(format!( - "{provider_display_name} response ({status}) is not valid JSON: {e}\nBody: {}", truncate_body(&response_text) - )))?; + let response_text = response.text().await.map_err(|e| { + CompletionError::ProviderError(format!("failed to read response body: {e}")) + })?; + + let response_body: serde_json::Value = + serde_json::from_str(&response_text).map_err(|e| { + CompletionError::ProviderError(format!( + "{provider_display_name} response ({status}) is not valid JSON: {e}\nBody: {}", + truncate_body(&response_text) + )) + })?; if !status.is_success() { let message = response_body["error"]["message"] @@ -703,7 +737,8 @@ impl SpacebotModel { "groq", "Groq", "https://api.groq.com/openai/v1/chat/completions", - ).await + ) + .await } async fn call_together( @@ -715,7 +750,8 @@ impl SpacebotModel { "together", "Together AI", "https://api.together.xyz/v1/chat/completions", - ).await + ) + .await } async fn call_fireworks( @@ -727,7 +763,8 @@ impl SpacebotModel { "fireworks", "Fireworks AI", "https://api.fireworks.ai/inference/v1/chat/completions", - ).await + ) + .await } async fn call_deepseek( @@ -739,7 +776,8 @@ impl SpacebotModel { "deepseek", "DeepSeek", "https://api.deepseek.com/v1/chat/completions", - ).await + ) + .await } async fn call_xai( @@ -751,7 +789,8 @@ impl SpacebotModel { "xai", "xAI", "https://api.x.ai/v1/chat/completions", - ).await + ) + .await } async fn call_mistral( @@ -763,7 +802,8 @@ impl SpacebotModel { "mistral", "Mistral AI", "https://api.mistral.ai/v1/chat/completions", - ).await + ) + .await } async fn call_opencode_zen( @@ -775,7 +815,8 @@ impl SpacebotModel { "opencode-zen", "OpenCode Zen", "https://opencode.ai/zen/v1/chat/completions", - ).await + ) + .await } } @@ -1003,7 +1044,10 @@ fn make_tool_call(id: String, name: String, arguments: serde_json::Value) -> Too ToolCall { id, call_id: None, - function: ToolFunction { name: name.trim().to_string(), arguments }, + function: ToolFunction { + name: name.trim().to_string(), + arguments, + }, signature: None, additional_params: None, } @@ -1028,8 +1072,9 @@ fn parse_anthropic_response( let id = block["id"].as_str().unwrap_or("").to_string(); let name = block["name"].as_str().unwrap_or("").to_string(); let arguments = block["input"].clone(); - assistant_content - .push(AssistantContent::ToolCall(make_tool_call(id, name, arguments))); + assistant_content.push(AssistantContent::ToolCall(make_tool_call( + id, name, arguments, + ))); } _ => {} } @@ -1081,13 +1126,15 @@ fn parse_openai_response( .as_str() .and_then(|s| serde_json::from_str(s).ok()) .unwrap_or(serde_json::json!({})); - assistant_content - .push(AssistantContent::ToolCall(make_tool_call(id, name, arguments))); + assistant_content.push(AssistantContent::ToolCall(make_tool_call( + id, name, arguments, + ))); } } - let result_choice = OneOrMany::many(assistant_content) - .map_err(|_| CompletionError::ResponseError(format!("empty response from {provider_label}")))?; + let result_choice = OneOrMany::many(assistant_content).map_err(|_| { + CompletionError::ResponseError(format!("empty response from {provider_label}")) + })?; let input_tokens = body["usage"]["prompt_tokens"].as_u64().unwrap_or(0); let output_tokens = body["usage"]["completion_tokens"].as_u64().unwrap_or(0); diff --git a/src/main.rs b/src/main.rs index 8e8af3a91..efeffb891 100644 --- a/src/main.rs +++ b/src/main.rs @@ -89,8 +89,7 @@ fn cmd_start( config_path.clone() } else if spacebot::config::Config::needs_onboarding() { // Returns Some(path) if CLI wizard ran, None if user chose the UI. - spacebot::config::run_onboarding() - .with_context(|| "onboarding failed")? + spacebot::config::run_onboarding().with_context(|| "onboarding failed")? } else { None }; @@ -199,7 +198,10 @@ fn cmd_status() -> anyhow::Result<()> { runtime.block_on(async { match spacebot::daemon::send_command(&paths, spacebot::daemon::IpcCommand::Status).await { - Ok(spacebot::daemon::IpcResponse::Status { pid, uptime_seconds }) => { + Ok(spacebot::daemon::IpcResponse::Status { + pid, + uptime_seconds, + }) => { let hours = uptime_seconds / 3600; let minutes = (uptime_seconds % 3600) / 60; let seconds = uptime_seconds % 60; @@ -225,20 +227,18 @@ fn cmd_status() -> anyhow::Result<()> { Ok(()) } -fn load_config(config_path: &Option) -> anyhow::Result { +fn load_config( + config_path: &Option, +) -> anyhow::Result { if let Some(path) = config_path { spacebot::config::Config::load_from_path(path) .with_context(|| format!("failed to load config from {}", path.display())) } else { - spacebot::config::Config::load() - .with_context(|| "failed to load configuration") + spacebot::config::Config::load().with_context(|| "failed to load configuration") } } -async fn run( - config: spacebot::config::Config, - foreground: bool, -) -> anyhow::Result<()> { +async fn run(config: spacebot::config::Config, foreground: bool) -> anyhow::Result<()> { let paths = spacebot::daemon::DaemonPaths::new(&config.instance_dir); tracing::info!("starting spacebot"); @@ -253,7 +253,9 @@ async fn run( let (provider_tx, mut provider_rx) = mpsc::channel::(1); // Start HTTP API server if enabled - let api_state = Arc::new(spacebot::api::ApiState::new_with_provider_sender(provider_tx)); + let api_state = Arc::new(spacebot::api::ApiState::new_with_provider_sender( + provider_tx, + )); // Start background update checker spacebot::update::spawn_update_checker(api_state.update_status.clone()); @@ -279,8 +281,10 @@ async fn run( tracing::info!("No LLM provider keys configured. Starting in setup mode."); if foreground { eprintln!("No LLM provider keys configured."); - eprintln!("Please add a provider key via the web UI at http://{}:{}", - config.api.bind, config.api.port); + eprintln!( + "Please add a provider key via the web UI at http://{}:{}", + config.api.bind, config.api.port + ); } } @@ -289,21 +293,20 @@ async fn run( let llm_manager = Arc::new( spacebot::llm::LlmManager::new(config.llm.clone()) .await - .with_context(|| "failed to initialize LLM manager")? + .with_context(|| "failed to initialize LLM manager")?, ); // Shared embedding model (stateless, agent-agnostic) let embedding_cache_dir = config.instance_dir.join("embedding_cache"); let embedding_model = Arc::new( spacebot::memory::EmbeddingModel::new(&embedding_cache_dir) - .context("failed to initialize embedding model")? + .context("failed to initialize embedding model")?, ); tracing::info!("shared resources initialized"); // Initialize the language for all text lookups (must happen before PromptEngine/tools) - spacebot::prompts::text::init("en") - .with_context(|| "failed to initialize language")?; + spacebot::prompts::text::init("en").with_context(|| "failed to initialize language")?; // Create the PromptEngine with bundled templates (no file watching, no user overrides) let prompt_engine = spacebot::prompts::PromptEngine::new("en") @@ -387,7 +390,10 @@ async fn run( } if foreground { - eprintln!("spacebot running in foreground (pid {})", std::process::id()); + eprintln!( + "spacebot running in foreground (pid {})", + std::process::id() + ); } else { tracing::info!(pid = std::process::id(), "spacebot daemon started"); } @@ -722,7 +728,11 @@ async fn initialize_agents( cron_schedulers_for_shutdown: &mut Vec>, ingestion_handles: &mut Vec>, cortex_handles: &mut Vec>, - watcher_agents: &mut Vec<(String, std::path::PathBuf, Arc)>, + watcher_agents: &mut Vec<( + String, + std::path::PathBuf, + Arc, + )>, discord_permissions: &mut Option>>, slack_permissions: &mut Option>>, telegram_permissions: &mut Option>>, @@ -733,34 +743,65 @@ async fn initialize_agents( tracing::info!(agent_id = %agent_config.id, "initializing agent"); // Ensure agent directories exist - std::fs::create_dir_all(&agent_config.workspace) - .with_context(|| format!("failed to create workspace: {}", agent_config.workspace.display()))?; - std::fs::create_dir_all(&agent_config.data_dir) - .with_context(|| format!("failed to create data dir: {}", agent_config.data_dir.display()))?; - std::fs::create_dir_all(&agent_config.archives_dir) - .with_context(|| format!("failed to create archives dir: {}", agent_config.archives_dir.display()))?; - std::fs::create_dir_all(&agent_config.ingest_dir()) - .with_context(|| format!("failed to create ingest dir: {}", agent_config.ingest_dir().display()))?; - std::fs::create_dir_all(&agent_config.logs_dir()) - .with_context(|| format!("failed to create logs dir: {}", agent_config.logs_dir().display()))?; + std::fs::create_dir_all(&agent_config.workspace).with_context(|| { + format!( + "failed to create workspace: {}", + agent_config.workspace.display() + ) + })?; + std::fs::create_dir_all(&agent_config.data_dir).with_context(|| { + format!( + "failed to create data dir: {}", + agent_config.data_dir.display() + ) + })?; + std::fs::create_dir_all(&agent_config.archives_dir).with_context(|| { + format!( + "failed to create archives dir: {}", + agent_config.archives_dir.display() + ) + })?; + std::fs::create_dir_all(&agent_config.ingest_dir()).with_context(|| { + format!( + "failed to create ingest dir: {}", + agent_config.ingest_dir().display() + ) + })?; + std::fs::create_dir_all(&agent_config.logs_dir()).with_context(|| { + format!( + "failed to create logs dir: {}", + agent_config.logs_dir().display() + ) + })?; // Per-agent database connections let db = spacebot::db::Db::connect(&agent_config.data_dir) .await - .with_context(|| format!("failed to connect databases for agent '{}'", agent_config.id))?; + .with_context(|| { + format!( + "failed to connect databases for agent '{}'", + agent_config.id + ) + })?; // Per-agent settings store (redb-backed) let settings_path = agent_config.data_dir.join("settings.redb"); let settings_store = Arc::new( - spacebot::settings::SettingsStore::new(&settings_path) - .with_context(|| format!("failed to initialize settings store for agent '{}'", agent_config.id))? + spacebot::settings::SettingsStore::new(&settings_path).with_context(|| { + format!( + "failed to initialize settings store for agent '{}'", + agent_config.id + ) + })?, ); // Per-agent memory system let memory_store = spacebot::memory::MemoryStore::new(db.sqlite.clone()); let embedding_table = spacebot::memory::EmbeddingTable::open_or_create(&db.lance) .await - .with_context(|| format!("failed to init embeddings for agent '{}'", agent_config.id))?; + .with_context(|| { + format!("failed to init embeddings for agent '{}'", agent_config.id) + })?; // Ensure FTS index exists for full-text search queries if let Err(error) = embedding_table.ensure_fts_index().await { @@ -781,14 +822,18 @@ async fn initialize_agents( // Scaffold identity templates if missing, then load spacebot::identity::scaffold_identity_files(&agent_config.workspace) .await - .with_context(|| format!("failed to scaffold identity files for agent '{}'", agent_config.id))?; + .with_context(|| { + format!( + "failed to scaffold identity files for agent '{}'", + agent_config.id + ) + })?; let identity = spacebot::identity::Identity::load(&agent_config.workspace).await; // Load skills (instance-level, then workspace overrides) - let skills = spacebot::skills::SkillSet::load( - &config.skills_dir(), - &agent_config.skills_dir(), - ).await; + let skills = + spacebot::skills::SkillSet::load(&config.skills_dir(), &agent_config.skills_dir()) + .await; // Build the RuntimeConfig with all hot-reloadable values let runtime_config = Arc::new(spacebot::config::RuntimeConfig::new( @@ -870,20 +915,21 @@ async fn initialize_agents( // Shared Discord permissions (hot-reloadable via file watcher) *discord_permissions = config.messaging.discord.as_ref().map(|discord_config| { - let perms = spacebot::config::DiscordPermissions::from_config(discord_config, &config.bindings); + let perms = + spacebot::config::DiscordPermissions::from_config(discord_config, &config.bindings); Arc::new(ArcSwap::from_pointee(perms)) }); if let Some(perms) = &*discord_permissions { api_state.set_discord_permissions(perms.clone()).await; } - - if let Some(discord_config) = &config.messaging.discord { if discord_config.enabled { let adapter = spacebot::messaging::discord::DiscordAdapter::new( &discord_config.token, - discord_permissions.clone().expect("discord permissions initialized when discord is enabled"), + discord_permissions + .clone() + .expect("discord permissions initialized when discord is enabled"), ); new_messaging_manager.register(adapter).await; } @@ -903,7 +949,9 @@ async fn initialize_agents( let adapter = spacebot::messaging::slack::SlackAdapter::new( &slack_config.bot_token, &slack_config.app_token, - slack_permissions.clone().expect("slack permissions initialized when slack is enabled"), + slack_permissions + .clone() + .expect("slack permissions initialized when slack is enabled"), ); new_messaging_manager.register(adapter).await; } @@ -911,7 +959,8 @@ async fn initialize_agents( // Shared Telegram permissions (hot-reloadable via file watcher) *telegram_permissions = config.messaging.telegram.as_ref().map(|telegram_config| { - let perms = spacebot::config::TelegramPermissions::from_config(telegram_config, &config.bindings); + let perms = + spacebot::config::TelegramPermissions::from_config(telegram_config, &config.bindings); Arc::new(ArcSwap::from_pointee(perms)) }); @@ -919,7 +968,9 @@ async fn initialize_agents( if telegram_config.enabled { let adapter = spacebot::messaging::telegram::TelegramAdapter::new( &telegram_config.token, - telegram_permissions.clone().expect("telegram permissions initialized when telegram is enabled"), + telegram_permissions + .clone() + .expect("telegram permissions initialized when telegram is enabled"), ); new_messaging_manager.register(adapter).await; } @@ -936,7 +987,9 @@ async fn initialize_agents( } *messaging_manager = Arc::new(new_messaging_manager); - api_state.set_messaging_manager(messaging_manager.clone()).await; + api_state + .set_messaging_manager(messaging_manager.clone()) + .await; // Start all messaging adapters and get the merged inbound stream let new_inbound = messaging_manager @@ -986,7 +1039,10 @@ async fn initialize_agents( let scheduler = Arc::new(spacebot::cron::Scheduler::new(cron_context)); // Make cron store and scheduler available via RuntimeConfig - agent.deps.runtime_config.set_cron(store.clone(), scheduler.clone()); + agent + .deps + .runtime_config + .set_cron(store.clone(), scheduler.clone()); match store.load_all().await { Ok(configs) => { @@ -1032,11 +1088,13 @@ async fn initialize_agents( // Start cortex bulletin loops and association loops for each agent for (agent_id, agent) in agents.iter() { let cortex_logger = spacebot::agent::cortex::CortexLogger::new(agent.db.sqlite.clone()); - let bulletin_handle = spacebot::agent::cortex::spawn_bulletin_loop(agent.deps.clone(), cortex_logger.clone()); + let bulletin_handle = + spacebot::agent::cortex::spawn_bulletin_loop(agent.deps.clone(), cortex_logger.clone()); cortex_handles.push(bulletin_handle); tracing::info!(agent_id = %agent_id, "cortex bulletin loop started"); - let association_handle = spacebot::agent::cortex::spawn_association_loop(agent.deps.clone(), cortex_logger); + let association_handle = + spacebot::agent::cortex::spawn_association_loop(agent.deps.clone(), cortex_logger); cortex_handles.push(association_handle); tracing::info!(agent_id = %agent_id, "cortex association loop started"); } @@ -1047,7 +1105,8 @@ async fn initialize_agents( for (agent_id, agent) in agents.iter() { let browser_config = (**agent.deps.runtime_config.browser_config.load()).clone(); let brave_search_key = (**agent.deps.runtime_config.brave_search_key.load()).clone(); - let conversation_logger = spacebot::conversation::history::ConversationLogger::new(agent.db.sqlite.clone()); + let conversation_logger = + spacebot::conversation::history::ConversationLogger::new(agent.db.sqlite.clone()); let channel_store = spacebot::conversation::ChannelStore::new(agent.db.sqlite.clone()); let tool_server = spacebot::tools::create_cortex_chat_tool_server( agent.deps.memory_search.clone(), diff --git a/src/opencode/server.rs b/src/opencode/server.rs index 5ad07c5ed..c047c418c 100644 --- a/src/opencode/server.rs +++ b/src/opencode/server.rs @@ -54,8 +54,8 @@ impl OpenCodeServer { let base_url = format!("http://127.0.0.1:{port}"); let env_config = OpenCodeEnvConfig::new(permissions); - let config_json = serde_json::to_string(&env_config) - .context("failed to serialize OpenCode config")?; + let config_json = + serde_json::to_string(&env_config).context("failed to serialize OpenCode config")?; tracing::info!( directory = %directory.display(), @@ -73,10 +73,13 @@ impl OpenCodeServer { .env("OPENCODE_PORT", port.to_string()) .kill_on_drop(true) .spawn() - .with_context(|| format!( - "failed to spawn OpenCode at '{}' for directory '{}'", - opencode_path, directory.display() - ))?; + .with_context(|| { + format!( + "failed to spawn OpenCode at '{}' for directory '{}'", + opencode_path, + directory.display() + ) + })?; let client = Client::builder() .timeout(std::time::Duration::from_secs(300)) @@ -185,7 +188,8 @@ impl OpenCodeServer { /// Check if the server is healthy. async fn health_check(&self) -> anyhow::Result { let url = format!("{}/global/health", self.base_url); - let response = self.client + let response = self + .client .get(&url) .timeout(std::time::Duration::from_secs(5)) .send() @@ -197,7 +201,8 @@ impl OpenCodeServer { // Fallback let url = format!("{}/api/health", self.base_url); - let response = self.client + let response = self + .client .get(&url) .timeout(std::time::Duration::from_secs(5)) .send() @@ -260,10 +265,12 @@ impl OpenCodeServer { .env("OPENCODE_PORT", port.to_string()) .kill_on_drop(true) .spawn() - .with_context(|| format!( - "failed to restart OpenCode server for '{}'", - self.directory.display() - ))?; + .with_context(|| { + format!( + "failed to restart OpenCode server for '{}'", + self.directory.display() + ) + })?; self.port = port; self.base_url = base_url; @@ -311,7 +318,8 @@ impl OpenCodeServer { let url = format!("{}/session", self.base_url); let body = CreateSessionRequest { title }; - let response = self.client + let response = self + .client .post(&url) .query(&[("directory", self.directory.to_str().unwrap_or("."))]) .json(&body) @@ -325,7 +333,9 @@ impl OpenCodeServer { bail!("create session failed ({status}): {text}"); } - response.json::().await + response + .json::() + .await .context("failed to parse session response") } @@ -337,7 +347,8 @@ impl OpenCodeServer { ) -> anyhow::Result { let url = format!("{}/session/{}/message", self.base_url, session_id); - let response = self.client + let response = self + .client .post(&url) .query(&[("directory", self.directory.to_str().unwrap_or("."))]) .json(request) @@ -351,7 +362,9 @@ impl OpenCodeServer { bail!("send prompt failed ({status}): {text}"); } - response.json::().await + response + .json::() + .await .context("failed to parse prompt response") } @@ -363,7 +376,8 @@ impl OpenCodeServer { ) -> anyhow::Result<()> { let url = format!("{}/session/{}/prompt_async", self.base_url, session_id); - let response = self.client + let response = self + .client .post(&url) .query(&[("directory", self.directory.to_str().unwrap_or("."))]) .json(request) @@ -384,7 +398,8 @@ impl OpenCodeServer { pub async fn abort_session(&self, session_id: &str) -> anyhow::Result<()> { let url = format!("{}/session/{}/abort", self.base_url, session_id); - let response = self.client + let response = self + .client .post(&url) .query(&[("directory", self.directory.to_str().unwrap_or("."))]) .send() @@ -412,7 +427,8 @@ impl OpenCodeServer { message: None, }; - let response = self.client + let response = self + .client .post(&url) .query(&[("directory", self.directory.to_str().unwrap_or("."))]) .json(&body) @@ -438,7 +454,8 @@ impl OpenCodeServer { let url = format!("{}/question/{}/reply", self.base_url, request_id); let body = QuestionReplyRequest { answers }; - let response = self.client + let response = self + .client .post(&url) .query(&[("directory", self.directory.to_str().unwrap_or("."))]) .json(&body) @@ -460,7 +477,8 @@ impl OpenCodeServer { pub async fn subscribe_events(&self) -> anyhow::Result { let url = format!("{}/event", self.base_url); - let response = self.client + let response = self + .client .get(&url) .query(&[("directory", self.directory.to_str().unwrap_or("."))]) .header("Accept", "text/event-stream") @@ -479,13 +497,11 @@ impl OpenCodeServer { } /// Get messages for a session (for reading final results). - pub async fn get_messages( - &self, - session_id: &str, - ) -> anyhow::Result> { + pub async fn get_messages(&self, session_id: &str) -> anyhow::Result> { let url = format!("{}/session/{}/message", self.base_url, session_id); - let response = self.client + let response = self + .client .get(&url) .query(&[("directory", self.directory.to_str().unwrap_or("."))]) .send() @@ -498,7 +514,9 @@ impl OpenCodeServer { bail!("get messages failed ({status}): {text}"); } - response.json::>().await + response + .json::>() + .await .context("failed to parse messages response") } } @@ -552,7 +570,8 @@ impl OpenCodeServerPool { &self, directory: &Path, ) -> anyhow::Result>> { - let canonical = directory.canonicalize() + let canonical = directory + .canonicalize() .with_context(|| format!("directory '{}' does not exist", directory.display()))?; let mut servers = self.servers.lock().await; @@ -574,11 +593,10 @@ impl OpenCodeServerPool { // Not in pool yet. Try reattaching to an existing server on the // deterministic port (left over from a previous spacebot run). - if let Some(reattached) = OpenCodeServer::reattach( - canonical.clone(), - &self.opencode_path, - &self.permissions, - ).await { + if let Some(reattached) = + OpenCodeServer::reattach(canonical.clone(), &self.opencode_path, &self.permissions) + .await + { let server = Arc::new(Mutex::new(reattached)); servers.insert(canonical, Arc::clone(&server)); return Ok(server); @@ -592,11 +610,9 @@ impl OpenCodeServerPool { ); } - let server = OpenCodeServer::spawn( - canonical.clone(), - &self.opencode_path, - &self.permissions, - ).await?; + let server = + OpenCodeServer::spawn(canonical.clone(), &self.opencode_path, &self.permissions) + .await?; let server = Arc::new(Mutex::new(server)); servers.insert(canonical, Arc::clone(&server)); diff --git a/src/opencode/worker.rs b/src/opencode/worker.rs index f421e9c72..1466c5292 100644 --- a/src/opencode/worker.rs +++ b/src/opencode/worker.rs @@ -12,7 +12,7 @@ use anyhow::{Context as _, bail}; use futures::StreamExt as _; use std::path::PathBuf; use std::sync::Arc; -use tokio::sync::{broadcast, mpsc, Mutex}; +use tokio::sync::{Mutex, broadcast, mpsc}; use uuid::Uuid; /// An OpenCode-backed worker that drives a coding session via subprocess. @@ -95,20 +95,25 @@ impl OpenCodeWorker { self.send_status("starting OpenCode server"); // Get or create server for this directory - let server = self.server_pool + let server = self + .server_pool .get_or_create(&self.directory) .await - .with_context(|| format!( - "failed to get OpenCode server for '{}'", - self.directory.display() - ))?; + .with_context(|| { + format!( + "failed to get OpenCode server for '{}'", + self.directory.display() + ) + })?; self.send_status("creating session"); // Create a session let session = { let guard = server.lock().await; - guard.create_session(Some(format!("spacebot-worker-{}", self.id))).await? + guard + .create_session(Some(format!("spacebot-worker-{}", self.id))) + .await? }; let session_id = session.id.clone(); @@ -141,15 +146,15 @@ impl OpenCodeWorker { self.send_status("sending task to OpenCode"); { let guard = server.lock().await; - guard.send_prompt_async(&session_id, &prompt_request).await?; + guard + .send_prompt_async(&session_id, &prompt_request) + .await?; } // Process SSE events until session goes idle or errors - let result_text = self.process_events( - event_response, - &session_id, - &server, - ).await?; + let result_text = self + .process_events(event_response, &session_id, &server) + .await?; // Interactive follow-up loop if let Some(mut input_rx) = self.input_rx.take() { @@ -176,10 +181,15 @@ impl OpenCodeWorker { { let guard = server.lock().await; - guard.send_prompt_async(&session_id, &follow_up_request).await?; + guard + .send_prompt_async(&session_id, &follow_up_request) + .await?; } - match self.process_events(event_response, &session_id, &server).await { + match self + .process_events(event_response, &session_id, &server) + .await + { Ok(_) => { self.send_status("waiting for follow-up"); } @@ -247,15 +257,18 @@ impl OpenCodeWorker { // Parse SSE lines from buffer while let Some(event) = extract_sse_event(&mut buffer) { - match self.handle_sse_event( - &event, - session_id, - server, - &mut last_text, - &mut current_tool, - &mut has_received_event, - &mut has_assistant_message, - ).await { + match self + .handle_sse_event( + &event, + session_id, + server, + &mut last_text, + &mut current_tool, + &mut has_received_event, + &mut has_assistant_message, + ) + .await + { EventAction::Continue => {} EventAction::Complete => return Ok(last_text.clone()), EventAction::Error(message) => bail!("OpenCode session error: {message}"), @@ -294,7 +307,11 @@ impl OpenCodeWorker { SseEvent::MessagePartUpdated { part, .. } => { *has_received_event = true; match part { - Part::Text { text, session_id: part_session, .. } => { + Part::Text { + text, + session_id: part_session, + .. + } => { if let Some(sid) = part_session { if sid != session_id { return EventAction::Continue; @@ -303,7 +320,12 @@ impl OpenCodeWorker { *has_assistant_message = true; *last_text = text.clone(); } - Part::Tool { tool, state, session_id: part_session, .. } => { + Part::Tool { + tool, + state, + session_id: part_session, + .. + } => { if let Some(sid) = part_session { if sid != session_id { return EventAction::Continue; @@ -326,7 +348,9 @@ impl OpenCodeWorker { } ToolState::Error { error, .. } => { let description = error.as_deref().unwrap_or("unknown"); - self.send_status(&format!("tool error: {tool_name}: {description}")); + self.send_status(&format!( + "tool error: {tool_name}: {description}" + )); } ToolState::Pending { .. } => { // Tool queued, no status update needed @@ -340,7 +364,9 @@ impl OpenCodeWorker { EventAction::Continue } - SseEvent::SessionIdle { session_id: event_session_id } => { + SseEvent::SessionIdle { + session_id: event_session_id, + } => { if event_session_id != session_id { return EventAction::Continue; } @@ -360,7 +386,10 @@ impl OpenCodeWorker { EventAction::Complete } - SseEvent::SessionError { session_id: event_session_id, error } => { + SseEvent::SessionError { + session_id: event_session_id, + error, + } => { if event_session_id.as_deref() != Some(session_id) { return EventAction::Continue; } @@ -400,7 +429,10 @@ impl OpenCodeWorker { // Auto-allow (OPENCODE_CONFIG_CONTENT should prevent most prompts) let guard = server.lock().await; - if let Err(error) = guard.reply_permission(&permission.id, PermissionReply::Once).await { + if let Err(error) = guard + .reply_permission(&permission.id, PermissionReply::Once) + .await + { tracing::warn!( worker_id = %self.id, permission_id = %permission.id, @@ -429,29 +461,35 @@ impl OpenCodeWorker { worker_id: self.id, channel_id: self.channel_id.clone(), question_id: question.id.clone(), - questions: question.questions.iter().map(|q| { - QuestionInfo { + questions: question + .questions + .iter() + .map(|q| QuestionInfo { question: q.question.clone(), header: q.header.clone(), options: q.options.clone(), - } - }).collect(), + }) + .collect(), }); // Auto-select first option - let answers: Vec = question.questions.iter().map(|q| { - if let Some(first_option) = q.options.first() { - QuestionAnswer { - label: first_option.label.clone(), - description: first_option.description.clone(), - } - } else { - QuestionAnswer { - label: "continue".to_string(), - description: None, + let answers: Vec = question + .questions + .iter() + .map(|q| { + if let Some(first_option) = q.options.first() { + QuestionAnswer { + label: first_option.label.clone(), + description: first_option.description.clone(), + } + } else { + QuestionAnswer { + label: "continue".to_string(), + description: None, + } } - } - }).collect(); + }) + .collect(); let guard = server.lock().await; if let Err(error) = guard.reply_question(&question.id, answers).await { @@ -466,12 +504,17 @@ impl OpenCodeWorker { EventAction::Continue } - SseEvent::SessionStatus { session_id: event_session_id, status } => { + SseEvent::SessionStatus { + session_id: event_session_id, + status, + } => { if event_session_id != session_id { return EventAction::Continue; } match status { - SessionStatusPayload::Retry { attempt, message, .. } => { + SessionStatusPayload::Retry { + attempt, message, .. + } => { let description = message.as_deref().unwrap_or("rate limited"); self.send_status(&format!("retry attempt {attempt}: {description}")); } diff --git a/src/prompts/engine.rs b/src/prompts/engine.rs index 480291b86..07e94439e 100644 --- a/src/prompts/engine.rs +++ b/src/prompts/engine.rs @@ -1,6 +1,6 @@ use crate::error::Result; use anyhow::Context; -use minijinja::{context, Environment, Value}; +use minijinja::{Environment, Value, context}; use std::collections::HashMap; use std::sync::Arc; diff --git a/src/settings.rs b/src/settings.rs index 4a2b4fb61..8c6d73490 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -2,4 +2,4 @@ pub mod store; -pub use store::{SettingsStore, WorkerLogMode, WORKER_LOG_MODE_KEY}; +pub use store::{SettingsStore, WORKER_LOG_MODE_KEY, WorkerLogMode}; diff --git a/src/skills.rs b/src/skills.rs index 7b6946e48..f8411400a 100644 --- a/src/skills.rs +++ b/src/skills.rs @@ -58,7 +58,9 @@ impl SkillSet { // Instance skills (lowest precedence) if instance_skills_dir.is_dir() { - if let Ok(skills) = load_skills_from_dir(instance_skills_dir, SkillSource::Instance).await { + if let Ok(skills) = + load_skills_from_dir(instance_skills_dir, SkillSource::Instance).await + { for skill in skills { set.skills.insert(skill.name.to_lowercase(), skill); } @@ -67,7 +69,9 @@ impl SkillSet { // Workspace skills (highest precedence, overrides instance) if workspace_skills_dir.is_dir() { - if let Ok(skills) = load_skills_from_dir(workspace_skills_dir, SkillSource::Workspace).await { + if let Ok(skills) = + load_skills_from_dir(workspace_skills_dir, SkillSource::Workspace).await + { for skill in skills { set.skills.insert(skill.name.to_lowercase(), skill); } @@ -194,26 +198,27 @@ async fn load_skills_from_dir(dir: &Path, source: SkillSource) -> anyhow::Result } /// Load a single skill from its SKILL.md file. -async fn load_skill(file_path: &Path, base_dir: &Path, source: SkillSource) -> anyhow::Result { +async fn load_skill( + file_path: &Path, + base_dir: &Path, + source: SkillSource, +) -> anyhow::Result { let raw = tokio::fs::read_to_string(file_path) .await .with_context(|| format!("failed to read {}", file_path.display()))?; let (frontmatter, body) = parse_frontmatter(&raw)?; - let name = frontmatter.get("name") - .cloned() - .unwrap_or_else(|| { - // Fall back to directory name if no name in frontmatter - base_dir.file_name() - .and_then(|n| n.to_str()) - .unwrap_or("unknown") - .to_string() - }); + let name = frontmatter.get("name").cloned().unwrap_or_else(|| { + // Fall back to directory name if no name in frontmatter + base_dir + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or("unknown") + .to_string() + }); - let description = frontmatter.get("description") - .cloned() - .unwrap_or_default(); + let description = frontmatter.get("description").cloned().unwrap_or_default(); // Resolve {baseDir} template variable in the body let base_dir_str = base_dir.to_string_lossy(); @@ -277,8 +282,10 @@ fn parse_frontmatter(content: &str) -> anyhow::Result<(HashMap, // Strip surrounding quotes let value = value - .trim_start_matches('"').trim_end_matches('"') - .trim_start_matches('\'').trim_end_matches('\'') + .trim_start_matches('"') + .trim_end_matches('"') + .trim_start_matches('\'') + .trim_end_matches('\'') .to_string(); map.insert(key, value); @@ -308,7 +315,10 @@ mod tests { let (fm, body) = parse_frontmatter(content).unwrap(); assert_eq!(fm.get("name").unwrap(), "weather"); - assert_eq!(fm.get("description").unwrap(), "Get current weather and forecasts (no API key required)."); + assert_eq!( + fm.get("description").unwrap(), + "Get current weather and forecasts (no API key required)." + ); assert_eq!(fm.get("homepage").unwrap(), "https://wttr.in/:help"); assert!(body.starts_with("# Weather")); } @@ -327,7 +337,10 @@ mod tests { let (fm, body) = parse_frontmatter(content).unwrap(); assert_eq!(fm.get("name").unwrap(), "github"); - assert_eq!(fm.get("description").unwrap(), "Interact with GitHub using the gh CLI."); + assert_eq!( + fm.get("description").unwrap(), + "Interact with GitHub using the gh CLI." + ); // metadata line is skipped (starts with {) assert!(!fm.contains_key("metadata")); assert!(body.starts_with("# GitHub Skill")); @@ -353,7 +366,10 @@ mod tests { "#}; let (fm, _body) = parse_frontmatter(content).unwrap(); - assert_eq!(fm.get("description").unwrap(), "A skill with 'quotes' inside"); + assert_eq!( + fm.get("description").unwrap(), + "A skill with 'quotes' inside" + ); } #[test] @@ -366,14 +382,17 @@ mod tests { #[test] fn test_skill_set_channel_prompt() { let mut set = SkillSet::default(); - set.skills.insert("weather".into(), Skill { - name: "weather".into(), - description: "Get weather forecasts".into(), - file_path: PathBuf::from("/skills/weather/SKILL.md"), - base_dir: PathBuf::from("/skills/weather"), - content: "# Weather\n\nUse curl.".into(), - source: SkillSource::Instance, - }); + set.skills.insert( + "weather".into(), + Skill { + name: "weather".into(), + description: "Get weather forecasts".into(), + file_path: PathBuf::from("/skills/weather/SKILL.md"), + base_dir: PathBuf::from("/skills/weather"), + content: "# Weather\n\nUse curl.".into(), + source: SkillSource::Instance, + }, + ); let engine = crate::prompts::PromptEngine::new("en").unwrap(); let prompt = set.render_channel_prompt(&engine); @@ -385,14 +404,17 @@ mod tests { #[test] fn test_skill_set_worker_prompt() { let mut set = SkillSet::default(); - set.skills.insert("weather".into(), Skill { - name: "weather".into(), - description: "Get weather forecasts".into(), - file_path: PathBuf::from("/skills/weather/SKILL.md"), - base_dir: PathBuf::from("/skills/weather"), - content: "# Weather\n\nUse curl.".into(), - source: SkillSource::Instance, - }); + set.skills.insert( + "weather".into(), + Skill { + name: "weather".into(), + description: "Get weather forecasts".into(), + file_path: PathBuf::from("/skills/weather/SKILL.md"), + base_dir: PathBuf::from("/skills/weather"), + content: "# Weather\n\nUse curl.".into(), + source: SkillSource::Instance, + }, + ); let engine = crate::prompts::PromptEngine::new("en").unwrap(); let prompt = set.render_worker_prompt("weather", &engine).unwrap(); diff --git a/src/update.rs b/src/update.rs index 7b810f84e..b782fe8dd 100644 --- a/src/update.rs +++ b/src/update.rs @@ -103,7 +103,10 @@ pub async fn check_for_update(status: &SharedUpdateStatus) { match result { Ok(release) => { - let tag = release.tag_name.strip_prefix('v').unwrap_or(&release.tag_name); + let tag = release + .tag_name + .strip_prefix('v') + .unwrap_or(&release.tag_name); let is_newer = is_newer_version(tag, CURRENT_VERSION); next.latest_version = Some(tag.to_string()); @@ -417,11 +420,7 @@ fn resolve_target_image(current_image: &str, new_version: &str) -> String { }; // Determine the variant suffix (slim, full, or default to slim) - let variant = if tag.contains("full") { - "full" - } else { - "slim" - }; + let variant = if tag.contains("full") { "full" } else { "slim" }; format!("{}:v{}-{}", base, new_version, variant) } diff --git a/tests/bulletin.rs b/tests/bulletin.rs index 815e660e6..aa4ebc3a4 100644 --- a/tests/bulletin.rs +++ b/tests/bulletin.rs @@ -12,8 +12,8 @@ use std::sync::Arc; /// Bootstrap an AgentDeps from the real ~/.spacebot config, using the first /// (default) agent's databases and config. async fn bootstrap_deps() -> anyhow::Result { - let config = spacebot::config::Config::load() - .context("failed to load ~/.spacebot/config.toml")?; + let config = + spacebot::config::Config::load().context("failed to load ~/.spacebot/config.toml")?; let llm_manager = Arc::new( spacebot::llm::LlmManager::new(config.llm.clone()) @@ -28,9 +28,7 @@ async fn bootstrap_deps() -> anyhow::Result { ); let resolved_agents = config.resolve_agents(); - let agent_config = resolved_agents - .first() - .context("no agents configured")?; + let agent_config = resolved_agents.first().context("no agents configured")?; let db = spacebot::db::Db::connect(&agent_config.data_dir) .await @@ -38,10 +36,9 @@ async fn bootstrap_deps() -> anyhow::Result { let memory_store = spacebot::memory::MemoryStore::new(db.sqlite.clone()); - let embedding_table = - spacebot::memory::EmbeddingTable::open_or_create(&db.lance) - .await - .context("failed to init embedding table")?; + let embedding_table = spacebot::memory::EmbeddingTable::open_or_create(&db.lance) + .await + .context("failed to init embedding table")?; if let Err(error) = embedding_table.ensure_fts_index().await { eprintln!("warning: FTS index creation failed: {error}"); @@ -54,16 +51,18 @@ async fn bootstrap_deps() -> anyhow::Result { )); let identity = spacebot::identity::Identity::load(&agent_config.workspace).await; - let prompts = spacebot::prompts::PromptEngine::new("en") - .context("failed to init prompt engine")?; - let skills = spacebot::skills::SkillSet::load( - &config.skills_dir(), - &agent_config.skills_dir(), - ) - .await; + let prompts = + spacebot::prompts::PromptEngine::new("en").context("failed to init prompt engine")?; + let skills = + spacebot::skills::SkillSet::load(&config.skills_dir(), &agent_config.skills_dir()).await; let runtime_config = Arc::new(spacebot::config::RuntimeConfig::new( - &config.instance_dir, agent_config, &config.defaults, prompts, identity, skills, + &config.instance_dir, + agent_config, + &config.defaults, + prompts, + identity, + skills, )); let (event_tx, _) = tokio::sync::broadcast::channel(16); @@ -87,7 +86,16 @@ async fn bootstrap_deps() -> anyhow::Result { fn test_bulletin_prompts_cover_all_memory_types() { // The cortex user prompt in cortex.rs lists types inline. Check the same // set against the canonical list so drift is caught at compile time. - let cortex_user_prompt_types = ["identity", "fact", "decision", "event", "preference", "observation", "goal", "todo"]; + let cortex_user_prompt_types = [ + "identity", + "fact", + "decision", + "event", + "preference", + "observation", + "goal", + "todo", + ]; for memory_type in spacebot::memory::types::MemoryType::ALL { let type_str = memory_type.to_string(); @@ -128,7 +136,10 @@ async fn test_memory_recall_returns_results() { ); } - assert!(!results.is_empty(), "hybrid_search should return results from a populated database"); + assert!( + !results.is_empty(), + "hybrid_search should return results from a populated database" + ); } #[tokio::test] @@ -145,7 +156,10 @@ async fn test_bulletin_generation() { // Verify the bulletin was stored let bulletin = deps.runtime_config.memory_bulletin.load(); - assert!(!bulletin.is_empty(), "bulletin should not be empty after generation"); + assert!( + !bulletin.is_empty(), + "bulletin should not be empty after generation" + ); let word_count = bulletin.split_whitespace().count(); println!("bulletin generated: {word_count} words"); @@ -153,5 +167,8 @@ async fn test_bulletin_generation() { println!("{bulletin}"); println!("---"); - assert!(word_count > 50, "bulletin should have meaningful content (got {word_count} words)"); + assert!( + word_count > 50, + "bulletin should have meaningful content (got {word_count} words)" + ); } diff --git a/tests/context_dump.rs b/tests/context_dump.rs index 69a88ee9f..3e00c92ed 100644 --- a/tests/context_dump.rs +++ b/tests/context_dump.rs @@ -12,8 +12,8 @@ use std::sync::Arc; /// Bootstrap AgentDeps from the real ~/.spacebot config (same as bulletin test). async fn bootstrap_deps() -> anyhow::Result<(spacebot::AgentDeps, spacebot::config::Config)> { - let config = spacebot::config::Config::load() - .context("failed to load ~/.spacebot/config.toml")?; + let config = + spacebot::config::Config::load().context("failed to load ~/.spacebot/config.toml")?; let llm_manager = Arc::new( spacebot::llm::LlmManager::new(config.llm.clone()) @@ -28,9 +28,7 @@ async fn bootstrap_deps() -> anyhow::Result<(spacebot::AgentDeps, spacebot::conf ); let resolved_agents = config.resolve_agents(); - let agent_config = resolved_agents - .first() - .context("no agents configured")?; + let agent_config = resolved_agents.first().context("no agents configured")?; let db = spacebot::db::Db::connect(&agent_config.data_dir) .await @@ -38,10 +36,9 @@ async fn bootstrap_deps() -> anyhow::Result<(spacebot::AgentDeps, spacebot::conf let memory_store = spacebot::memory::MemoryStore::new(db.sqlite.clone()); - let embedding_table = - spacebot::memory::EmbeddingTable::open_or_create(&db.lance) - .await - .context("failed to init embedding table")?; + let embedding_table = spacebot::memory::EmbeddingTable::open_or_create(&db.lance) + .await + .context("failed to init embedding table")?; if let Err(error) = embedding_table.ensure_fts_index().await { eprintln!("warning: FTS index creation failed: {error}"); @@ -54,16 +51,18 @@ async fn bootstrap_deps() -> anyhow::Result<(spacebot::AgentDeps, spacebot::conf )); let identity = spacebot::identity::Identity::load(&agent_config.workspace).await; - let prompts = spacebot::prompts::PromptEngine::new("en") - .context("failed to init prompt engine")?; - let skills = spacebot::skills::SkillSet::load( - &config.skills_dir(), - &agent_config.skills_dir(), - ) - .await; + let prompts = + spacebot::prompts::PromptEngine::new("en").context("failed to init prompt engine")?; + let skills = + spacebot::skills::SkillSet::load(&config.skills_dir(), &agent_config.skills_dir()).await; let runtime_config = Arc::new(spacebot::config::RuntimeConfig::new( - &config.instance_dir, agent_config, &config.defaults, prompts, identity, skills, + &config.instance_dir, + agent_config, + &config.defaults, + prompts, + identity, + skills, )); let (event_tx, _) = tokio::sync::broadcast::channel(16); @@ -162,10 +161,13 @@ async fn dump_channel_context() { print_stats("System prompt", &prompt); // Build the actual channel tool server with real tools registered - let conversation_logger = spacebot::conversation::ConversationLogger::new(deps.sqlite_pool.clone()); + let conversation_logger = + spacebot::conversation::ConversationLogger::new(deps.sqlite_pool.clone()); let channel_store = spacebot::conversation::ChannelStore::new(deps.sqlite_pool.clone()); let channel_id: spacebot::ChannelId = Arc::from("test-channel"); - let status_block = Arc::new(tokio::sync::RwLock::new(spacebot::agent::status::StatusBlock::new())); + let status_block = Arc::new(tokio::sync::RwLock::new( + spacebot::agent::status::StatusBlock::new(), + )); let (response_tx, _response_rx) = tokio::sync::mpsc::channel(16); let state = spacebot::agent::channel::ChannelState { @@ -212,7 +214,10 @@ async fn dump_channel_context() { println!("\n--- TOTAL CHANNEL CONTEXT: ~{} tokens ---", total / 4); let routing = rc.routing.load(); - println!("Model: {}", routing.resolve(spacebot::ProcessType::Channel, None)); + println!( + "Model: {}", + routing.resolve(spacebot::ProcessType::Channel, None) + ); println!("Max turns: {}", **rc.max_turns.load()); assert!(!prompt.is_empty()); @@ -236,7 +241,8 @@ async fn dump_branch_context() { print_stats("System prompt", &branch_prompt); // Build the actual branch tool server - let conversation_logger = spacebot::conversation::ConversationLogger::new(deps.sqlite_pool.clone()); + let conversation_logger = + spacebot::conversation::ConversationLogger::new(deps.sqlite_pool.clone()); let channel_store = spacebot::conversation::ChannelStore::new(deps.sqlite_pool.clone()); let branch_tool_server = spacebot::tools::create_branch_tool_server( deps.memory_search.clone(), @@ -260,7 +266,10 @@ async fn dump_branch_context() { println!("\n--- TOTAL BRANCH CONTEXT: ~{} tokens ---", total / 4); let routing = rc.routing.load(); - println!("Model: {}", routing.resolve(spacebot::ProcessType::Branch, None)); + println!( + "Model: {}", + routing.resolve(spacebot::ProcessType::Branch, None) + ); println!("Max turns: {}", **rc.branch_max_turns.load()); println!("History: cloned from channel at fork time (full conversation context)"); @@ -319,7 +328,10 @@ async fn dump_worker_context() { println!("\n--- TOTAL WORKER CONTEXT: ~{} tokens ---", total / 4); let routing = rc.routing.load(); - println!("Model: {}", routing.resolve(spacebot::ProcessType::Worker, None)); + println!( + "Model: {}", + routing.resolve(spacebot::ProcessType::Worker, None) + ); println!("Turns per segment: 25"); println!("History: fresh (empty). Workers have no channel context."); @@ -341,12 +353,16 @@ async fn dump_all_contexts() { let bulletin_success = spacebot::agent::cortex::generate_bulletin(&deps).await; if bulletin_success { let bulletin = rc.memory_bulletin.load(); - println!("Bulletin generated: {} words", bulletin.split_whitespace().count()); + println!( + "Bulletin generated: {} words", + bulletin.split_whitespace().count() + ); } else { println!("Bulletin generation failed (may not have memories or LLM keys)"); } - let conversation_logger = spacebot::conversation::ConversationLogger::new(deps.sqlite_pool.clone()); + let conversation_logger = + spacebot::conversation::ConversationLogger::new(deps.sqlite_pool.clone()); let channel_store = spacebot::conversation::ChannelStore::new(deps.sqlite_pool.clone()); // ── Channel ── @@ -360,7 +376,9 @@ async fn dump_all_contexts() { active_branches: Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())), active_workers: Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())), worker_inputs: Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())), - status_block: Arc::new(tokio::sync::RwLock::new(spacebot::agent::status::StatusBlock::new())), + status_block: Arc::new(tokio::sync::RwLock::new( + spacebot::agent::status::StatusBlock::new(), + )), deps: deps.clone(), conversation_logger: conversation_logger.clone(), channel_store: channel_store.clone(), @@ -370,14 +388,24 @@ async fn dump_all_contexts() { let channel_tool_server = rig::tool::server::ToolServer::new().run(); let skip_flag = spacebot::tools::new_skip_flag(); spacebot::tools::add_channel_tools( - &channel_tool_server, state, response_tx, "test", skip_flag, None, - ).await.expect("failed to add channel tools"); + &channel_tool_server, + state, + response_tx, + "test", + skip_flag, + None, + ) + .await + .expect("failed to add channel tools"); let channel_tool_defs = channel_tool_server.get_tool_defs(None).await.unwrap(); let channel_tools_text = format_tool_defs(&channel_tool_defs); print_section("CHANNEL SYSTEM PROMPT (with bulletin)", &channel_prompt); print_stats("System prompt", &channel_prompt); - print_section(&format!("CHANNEL TOOLS ({} tools)", channel_tool_defs.len()), &channel_tools_text); + print_section( + &format!("CHANNEL TOOLS ({} tools)", channel_tool_defs.len()), + &channel_tools_text, + ); print_stats("Tool definitions", &channel_tools_text); let channel_total = channel_prompt.len() + channel_tools_text.len(); println!("--- TOTAL CHANNEL: ~{} tokens ---", channel_total / 4); @@ -396,7 +424,10 @@ async fn dump_all_contexts() { print_section("BRANCH SYSTEM PROMPT", &branch_prompt); print_stats("System prompt", &branch_prompt); - print_section(&format!("BRANCH TOOLS ({} tools)", branch_tool_defs.len()), &branch_tools_text); + print_section( + &format!("BRANCH TOOLS ({} tools)", branch_tool_defs.len()), + &branch_tools_text, + ); print_stats("Tool definitions", &branch_tools_text); let branch_total = branch_prompt.len() + branch_tools_text.len(); println!("--- TOTAL BRANCH: ~{} tokens ---", branch_total / 4); @@ -421,7 +452,10 @@ async fn dump_all_contexts() { print_section("WORKER SYSTEM PROMPT", &worker_prompt); print_stats("System prompt", &worker_prompt); - print_section(&format!("WORKER TOOLS ({} tools)", worker_tool_defs.len()), &worker_tools_text); + print_section( + &format!("WORKER TOOLS ({} tools)", worker_tool_defs.len()), + &worker_tools_text, + ); print_stats("Tool definitions", &worker_tools_text); let worker_total = worker_prompt.len() + worker_tools_text.len(); println!("--- TOTAL WORKER: ~{} tokens ---", worker_total / 4); @@ -433,23 +467,53 @@ async fn dump_all_contexts() { let routing = rc.routing.load(); println!("\nRouting:"); - println!(" Channel: {}", routing.resolve(spacebot::ProcessType::Channel, None)); - println!(" Branch: {}", routing.resolve(spacebot::ProcessType::Branch, None)); - println!(" Worker: {}", routing.resolve(spacebot::ProcessType::Worker, None)); + println!( + " Channel: {}", + routing.resolve(spacebot::ProcessType::Channel, None) + ); + println!( + " Branch: {}", + routing.resolve(spacebot::ProcessType::Branch, None) + ); + println!( + " Worker: {}", + routing.resolve(spacebot::ProcessType::Worker, None) + ); println!("\nContext budget (initial turn, before any history):"); - println!(" Channel: ~{} tokens (prompt: ~{}, tools: ~{})", - channel_total / 4, channel_prompt.len() / 4, channel_tools_text.len() / 4); - println!(" Branch: ~{} tokens (prompt: ~{}, tools: ~{}) + cloned channel history", - branch_total / 4, branch_prompt.len() / 4, branch_tools_text.len() / 4); - println!(" Worker: ~{} tokens (prompt: ~{}, tools: ~{})", - worker_total / 4, worker_prompt.len() / 4, worker_tools_text.len() / 4); + println!( + " Channel: ~{} tokens (prompt: ~{}, tools: ~{})", + channel_total / 4, + channel_prompt.len() / 4, + channel_tools_text.len() / 4 + ); + println!( + " Branch: ~{} tokens (prompt: ~{}, tools: ~{}) + cloned channel history", + branch_total / 4, + branch_prompt.len() / 4, + branch_tools_text.len() / 4 + ); + println!( + " Worker: ~{} tokens (prompt: ~{}, tools: ~{})", + worker_total / 4, + worker_prompt.len() / 4, + worker_tools_text.len() / 4 + ); let context_window = **rc.context_window.load(); println!("\nContext window: {} tokens", context_window); - println!(" Channel headroom: ~{} tokens for history", context_window - channel_total / 4); - println!(" Branch headroom: ~{} tokens for history", context_window - branch_total / 4); - println!(" Worker headroom: ~{} tokens for history", context_window - worker_total / 4); + println!( + " Channel headroom: ~{} tokens for history", + context_window - channel_total / 4 + ); + println!( + " Branch headroom: ~{} tokens for history", + context_window - branch_total / 4 + ); + println!( + " Worker headroom: ~{} tokens for history", + context_window - worker_total / 4 + ); println!("\nTurn limits:"); println!(" Channel: {} max turns", **rc.max_turns.load()); @@ -458,26 +522,74 @@ async fn dump_all_contexts() { let compaction = rc.compaction.load(); println!("\nCompaction thresholds:"); - println!(" Background: {:.0}%", compaction.background_threshold * 100.0); - println!(" Aggressive: {:.0}%", compaction.aggressive_threshold * 100.0); - println!(" Emergency: {:.0}%", compaction.emergency_threshold * 100.0); + println!( + " Background: {:.0}%", + compaction.background_threshold * 100.0 + ); + println!( + " Aggressive: {:.0}%", + compaction.aggressive_threshold * 100.0 + ); + println!( + " Emergency: {:.0}%", + compaction.emergency_threshold * 100.0 + ); println!("\nTool counts:"); - println!(" Channel: {} tools ({})", + println!( + " Channel: {} tools ({})", channel_tool_defs.len(), - channel_tool_defs.iter().map(|d| d.name.as_str()).collect::>().join(", ")); - println!(" Branch: {} tools ({})", + channel_tool_defs + .iter() + .map(|d| d.name.as_str()) + .collect::>() + .join(", ") + ); + println!( + " Branch: {} tools ({})", branch_tool_defs.len(), - branch_tool_defs.iter().map(|d| d.name.as_str()).collect::>().join(", ")); - println!(" Worker: {} tools ({})", + branch_tool_defs + .iter() + .map(|d| d.name.as_str()) + .collect::>() + .join(", ") + ); + println!( + " Worker: {} tools ({})", worker_tool_defs.len(), - worker_tool_defs.iter().map(|d| d.name.as_str()).collect::>().join(", ")); + worker_tool_defs + .iter() + .map(|d| d.name.as_str()) + .collect::>() + .join(", ") + ); let identity = rc.identity.load(); println!("\nIdentity files:"); - println!(" SOUL.md: {}", if identity.soul.is_some() { "loaded" } else { "empty" }); - println!(" IDENTITY.md: {}", if identity.identity.is_some() { "loaded" } else { "empty" }); - println!(" USER.md: {}", if identity.user.is_some() { "loaded" } else { "empty" }); + println!( + " SOUL.md: {}", + if identity.soul.is_some() { + "loaded" + } else { + "empty" + } + ); + println!( + " IDENTITY.md: {}", + if identity.identity.is_some() { + "loaded" + } else { + "empty" + } + ); + println!( + " USER.md: {}", + if identity.user.is_some() { + "loaded" + } else { + "empty" + } + ); let skills = rc.skills.load(); if skills.is_empty() { @@ -488,5 +600,8 @@ async fn dump_all_contexts() { } let bulletin = rc.memory_bulletin.load(); - println!("\nMemory bulletin: {} words", bulletin.split_whitespace().count()); + println!( + "\nMemory bulletin: {} words", + bulletin.split_whitespace().count() + ); } diff --git a/tests/opencode_stream.rs b/tests/opencode_stream.rs index 7b38fb335..c3f2b6f89 100644 --- a/tests/opencode_stream.rs +++ b/tests/opencode_stream.rs @@ -57,7 +57,10 @@ async fn stream_events_from_live_server() { .await .expect("failed to subscribe to events"); - assert!(event_response.status().is_success(), "event subscription failed"); + assert!( + event_response.status().is_success(), + "event subscription failed" + ); // 2. Create a session let session: Session = client @@ -124,10 +127,7 @@ async fn stream_events_from_live_server() { let bytes = match chunk { Ok(b) => b, Err(e) => { - panic!( - "bytes_stream error after {} events: {e}", - events.len() - ); + panic!("bytes_stream error after {} events: {e}", events.len()); } }; @@ -159,7 +159,10 @@ async fn stream_events_from_live_server() { let envelope: SseEventEnvelope = match serde_json::from_str(&json_str) { Ok(e) => e, Err(err) => { - eprintln!("parse error: {err} on: {}", &json_str[..json_str.len().min(200)]); + eprintln!( + "parse error: {err} on: {}", + &json_str[..json_str.len().min(200)] + ); continue; } }; @@ -178,7 +181,11 @@ async fn stream_events_from_live_server() { } } SseEvent::MessagePartUpdated { part, .. } => { - if let Part::Text { session_id: Some(sid), .. } = part { + if let Part::Text { + session_id: Some(sid), + .. + } = part + { if sid == &session_id { saw_text = true; } From 27ae74d93d925fa33375585ba4f181be3574a76e Mon Sep 17 00:00:00 2001 From: nmb Date: Tue, 17 Feb 2026 16:30:46 +0100 Subject: [PATCH 09/11] =?UTF-8?q?=E2=9C=A8=20feat(channel):=20add=20live?= =?UTF-8?q?=20state=20tracking=20and=20enhance=20channel=20management?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- interface/src/hooks/useChannelLiveState.ts | 27 ++ src/agent/channel.rs | 525 ++++++++++++++------- src/agent/worker.rs | 103 ++-- src/api/state.rs | 184 +++++--- src/conversation/channels.rs | 52 +- 5 files changed, 605 insertions(+), 286 deletions(-) diff --git a/interface/src/hooks/useChannelLiveState.ts b/interface/src/hooks/useChannelLiveState.ts index fa14e9cae..a63721995 100644 --- a/interface/src/hooks/useChannelLiveState.ts +++ b/interface/src/hooks/useChannelLiveState.ts @@ -1,6 +1,7 @@ import { useCallback, useEffect, useRef, useState } from "react"; import { api, + type BranchFailedEvent, type BranchCompletedEvent, type BranchStartedEvent, type InboundMessageEvent, @@ -397,6 +398,31 @@ export function useChannelLiveState(channels: ChannelInfo[]) { }); }, [updateItem]); + const handleBranchFailed = useCallback((data: unknown) => { + const event = data as BranchFailedEvent; + + // Remove from active branches + setLiveStates((prev) => { + const state = prev[event.channel_id]; + if (!state?.branches[event.branch_id]) return prev; + const { [event.branch_id]: _, ...remainingBranches } = state.branches; + return { + ...prev, + [event.channel_id]: { ...state, branches: remainingBranches }, + }; + }); + + // Update timeline item with failure summary + updateItem(event.channel_id, event.branch_id, (item) => { + if (item.type !== "branch_run") return item; + return { + ...item, + conclusion: `Branch failed: ${event.error}`, + completed_at: new Date().toISOString(), + }; + }); + }, [updateItem]); + const handleToolStarted = useCallback((data: unknown) => { const event = data as ToolStartedEvent; const channelId = event.channel_id; @@ -612,6 +638,7 @@ export function useChannelLiveState(channels: ChannelInfo[]) { worker_completed: handleWorkerCompleted, branch_started: handleBranchStarted, branch_completed: handleBranchCompleted, + branch_failed: handleBranchFailed, tool_started: handleToolStarted, tool_completed: handleToolCompleted, }; diff --git a/src/agent/channel.rs b/src/agent/channel.rs index af132b954..863bc1681 100644 --- a/src/agent/channel.rs +++ b/src/agent/channel.rs @@ -1,24 +1,27 @@ //! Channel: User-facing conversation process. +use crate::agent::branch::Branch; use crate::agent::compactor::Compactor; -use crate::error::{AgentError, Result}; -use crate::llm::SpacebotModel; -use crate::conversation::{ChannelStore, ConversationLogger, ProcessRunLogger}; -use crate::{ChannelId, WorkerId, BranchId, ProcessId, ProcessType, AgentDeps, InboundMessage, ProcessEvent, OutboundResponse}; -use crate::hooks::SpacebotHook; use crate::agent::status::StatusBlock; use crate::agent::worker::Worker; -use crate::agent::branch::Branch; +use crate::conversation::{ChannelStore, ConversationLogger, ProcessRunLogger}; +use crate::error::{AgentError, Result}; +use crate::hooks::SpacebotHook; +use crate::llm::SpacebotModel; +use crate::{ + AgentDeps, BranchId, ChannelId, InboundMessage, OutboundResponse, ProcessEvent, ProcessId, + ProcessType, WorkerId, +}; use rig::agent::AgentBuilder; use rig::completion::{CompletionModel, Prompt}; use rig::message::{ImageMediaType, MimeType, UserContent}; use rig::one_or_many::OneOrMany; use rig::tool::server::ToolServer; +use std::collections::HashMap; use std::collections::HashSet; use std::sync::Arc; -use tokio::sync::{mpsc, RwLock}; use tokio::sync::broadcast; -use std::collections::HashMap; +use tokio::sync::{RwLock, mpsc}; /// Shared state that channel tools need to act on the channel. /// @@ -49,7 +52,12 @@ impl ChannelState { /// Returns an error message if the worker is not found. pub async fn cancel_worker(&self, worker_id: WorkerId) -> std::result::Result<(), String> { let handle = self.worker_handles.write().await.remove(&worker_id); - let removed = self.active_workers.write().await.remove(&worker_id).is_some(); + let removed = self + .active_workers + .write() + .await + .remove(&worker_id) + .is_some(); self.worker_inputs.write().await.remove(&worker_id); if let Some(handle) = handle { @@ -132,7 +140,13 @@ impl Channel { logs_dir: std::path::PathBuf, ) -> (Self, mpsc::Sender) { let process_id = ProcessId::Channel(id.clone()); - let hook = SpacebotHook::new(deps.agent_id.clone(), process_id, ProcessType::Channel, Some(id.clone()), deps.event_tx.clone()); + let hook = SpacebotHook::new( + deps.agent_id.clone(), + process_id, + ProcessType::Channel, + Some(id.clone()), + deps.event_tx.clone(), + ); let status_block = Arc::new(RwLock::new(StatusBlock::new())); let history = Arc::new(RwLock::new(Vec::new())); let active_branches = Arc::new(RwLock::new(HashMap::new())); @@ -143,11 +157,7 @@ impl Channel { let process_run_logger = ProcessRunLogger::new(deps.sqlite_pool.clone()); let channel_store = ChannelStore::new(deps.sqlite_pool.clone()); - let compactor = Compactor::new( - id.clone(), - deps.clone(), - history.clone(), - ); + let compactor = Compactor::new(id.clone(), deps.clone(), history.clone()); let state = ChannelState { channel_id: id.clone(), @@ -189,10 +199,10 @@ impl Channel { coalesce_buffer: Vec::new(), coalesce_deadline: None, }; - + (channel, message_tx) } - + /// Run the channel event loop. pub async fn run(mut self) -> Result<()> { tracing::info!(channel_id = %self.id, "channel started"); @@ -261,7 +271,11 @@ impl Channel { /// - System re-trigger messages (always process immediately) /// - Messages when coalescing is disabled /// - Messages in DMs when multi_user_only is true - fn should_coalesce(&self, message: &InboundMessage, config: &crate::config::CoalesceConfig) -> bool { + fn should_coalesce( + &self, + message: &InboundMessage, + config: &crate::config::CoalesceConfig, + ) -> bool { if !config.enabled { return false; } @@ -278,7 +292,9 @@ impl Channel { fn is_dm(&self) -> bool { // Check conversation_id pattern for DM indicators if let Some(ref conv_id) = self.conversation_id { - conv_id.contains(":dm:") || conv_id.starts_with("discord:dm:") || conv_id.starts_with("slack:dm:") + conv_id.contains(":dm:") + || conv_id.starts_with("discord:dm:") + || conv_id.starts_with("slack:dm:") } else { // If no conversation_id set yet, default to not DM (safer) false @@ -288,20 +304,21 @@ impl Channel { /// Update the coalesce deadline based on buffer size and config. async fn update_coalesce_deadline(&mut self, config: &crate::config::CoalesceConfig) { let now = tokio::time::Instant::now(); - + if let Some(first_message) = self.coalesce_buffer.first() { - let elapsed_since_first = chrono::Utc::now().signed_duration_since(first_message.timestamp); + let elapsed_since_first = + chrono::Utc::now().signed_duration_since(first_message.timestamp); let elapsed_millis = elapsed_since_first.num_milliseconds().max(0) as u64; - + let max_wait_ms = config.max_wait_ms; let debounce_ms = config.debounce_ms; - + // If we have enough messages to trigger coalescing (min_messages threshold) if self.coalesce_buffer.len() >= config.min_messages { // Cap at max_wait from the first message let remaining_wait_ms = max_wait_ms.saturating_sub(elapsed_millis); let max_deadline = now + std::time::Duration::from_millis(remaining_wait_ms); - + // If no deadline set yet, use debounce window // Otherwise, keep existing deadline (don't extend past max_wait) if self.coalesce_deadline.is_none() { @@ -327,11 +344,11 @@ impl Channel { if self.coalesce_buffer.is_empty() { return Ok(()); } - + self.coalesce_deadline = None; - + let messages: Vec = std::mem::take(&mut self.coalesce_buffer); - + if messages.len() == 1 { // Single message - process normally let message = messages.into_iter().next().unwrap(); @@ -349,38 +366,48 @@ impl Channel { /// with a coalesce hint telling the LLM this is a fast-moving conversation. async fn handle_message_batch(&mut self, messages: Vec) -> Result<()> { let message_count = messages.len(); - let first_timestamp = messages.first().map(|m| m.timestamp).unwrap_or_else(chrono::Utc::now); - let last_timestamp = messages.last().map(|m| m.timestamp).unwrap_or(first_timestamp); + let first_timestamp = messages + .first() + .map(|m| m.timestamp) + .unwrap_or_else(chrono::Utc::now); + let last_timestamp = messages + .last() + .map(|m| m.timestamp) + .unwrap_or(first_timestamp); let elapsed = last_timestamp.signed_duration_since(first_timestamp); let elapsed_secs = elapsed.num_milliseconds() as f64 / 1000.0; - + tracing::info!( channel_id = %self.id, message_count, elapsed_secs, "handling batched messages" ); - + // Count unique senders for the hint - let unique_senders: std::collections::HashSet<_> = messages - .iter() - .map(|m| &m.sender_id) - .collect(); + let unique_senders: std::collections::HashSet<_> = + messages.iter().map(|m| &m.sender_id).collect(); let unique_sender_count = unique_senders.len(); - + // Track conversation_id from the first message if self.conversation_id.is_none() { if let Some(first) = messages.first() { self.conversation_id = Some(first.conversation_id.clone()); } } - + // Capture conversation context from the first message if self.conversation_context.is_none() { if let Some(first) = messages.first() { let prompt_engine = self.deps.runtime_config.prompts.load(); - let server_name = first.metadata.get("discord_guild_name").and_then(|v| v.as_str()); - let channel_name = first.metadata.get("discord_channel_name").and_then(|v| v.as_str()); + let server_name = first + .metadata + .get("discord_guild_name") + .and_then(|v| v.as_str()); + let channel_name = first + .metadata + .get("discord_channel_name") + .and_then(|v| v.as_str()); self.conversation_context = Some( prompt_engine .render_conversation_context(&first.source, server_name, channel_name) @@ -388,25 +415,26 @@ impl Channel { ); } } - + // Persist each message to conversation log (individual audit trail) let mut user_contents: Vec = Vec::new(); let mut conversation_id = String::new(); - + for message in &messages { if message.source != "system" { - let sender_name = message.metadata + let sender_name = message + .metadata .get("sender_display_name") .and_then(|v| v.as_str()) .unwrap_or(&message.sender_id); - + let (raw_text, attachments) = match &message.content { crate::MessageContent::Text(text) => (text.clone(), Vec::new()), crate::MessageContent::Media { text, attachments } => { (text.clone().unwrap_or_default(), attachments.clone()) } }; - + self.state.conversation_logger.log_user_message( &self.state.channel_id, sender_name, @@ -414,15 +442,17 @@ impl Channel { &raw_text, &message.metadata, ); - self.state.channel_store.upsert( - &message.conversation_id, - &message.metadata, - ); - + self.state + .channel_store + .upsert(&message.conversation_id, &message.metadata); + conversation_id = message.conversation_id.clone(); - + // Format with relative timestamp - let relative_secs = message.timestamp.signed_duration_since(first_timestamp).num_seconds(); + let relative_secs = message + .timestamp + .signed_duration_since(first_timestamp) + .num_seconds(); let relative_text = if relative_secs < 1 { "just now".to_string() } else if relative_secs < 60 { @@ -430,14 +460,16 @@ impl Channel { } else { format!("{}m ago", relative_secs / 60) }; - - let display_name = message.metadata + + let display_name = message + .metadata .get("sender_display_name") .and_then(|v| v.as_str()) .unwrap_or(&message.sender_id); - - let formatted_text = format!("[{}] ({}): {}", display_name, relative_text, raw_text); - + + let formatted_text = + format!("[{}] ({}): {}", display_name, relative_text, raw_text); + // Download attachments for this message if !attachments.is_empty() { let attachment_content = download_attachments(&self.deps, &attachments).await; @@ -445,16 +477,17 @@ impl Channel { user_contents.push(content); } } - + user_contents.push(UserContent::text(formatted_text)); } } - + // Combine all user content into a single text let combined_text = format!( "[{} messages arrived rapidly in this channel]\n\n{}", message_count, - user_contents.iter() + user_contents + .iter() .filter_map(|c| match c { UserContent::Text(t) => Some(t.text.clone()), _ => None, @@ -462,33 +495,33 @@ impl Channel { .collect::>() .join("\n") ); - + // Build system prompt with coalesce hint - let system_prompt = self.build_system_prompt_with_coalesce( - message_count, - elapsed_secs, - unique_sender_count, - ).await; - + let system_prompt = self + .build_system_prompt_with_coalesce(message_count, elapsed_secs, unique_sender_count) + .await; + // Run agent turn - let (result, skip_flag) = self.run_agent_turn( - &combined_text, - &system_prompt, - &conversation_id, - Vec::new(), // Attachments already formatted into text - ).await?; - + let (result, skip_flag) = self + .run_agent_turn( + &combined_text, + &system_prompt, + &conversation_id, + Vec::new(), // Attachments already formatted into text + ) + .await?; + self.handle_agent_result(result, &skip_flag).await; - + // Check compaction if let Err(error) = self.compactor.check_and_compact().await { tracing::warn!(channel_id = %self.id, %error, "compaction check failed"); } - + // Increment message counter for memory persistence self.message_count += message_count; self.check_memory_persistence().await; - + Ok(()) } @@ -501,32 +534,32 @@ impl Channel { ) -> String { let rc = &self.deps.runtime_config; let prompt_engine = rc.prompts.load(); - + let identity_context = rc.identity.load().render(); let memory_bulletin = rc.memory_bulletin.load(); let skills = rc.skills.load(); let skills_prompt = skills.render_channel_prompt(&prompt_engine); - + let browser_enabled = rc.browser_config.load().enabled; let web_search_enabled = rc.brave_search_key.load().is_some(); let opencode_enabled = rc.opencode.load().enabled; let worker_capabilities = prompt_engine .render_worker_capabilities(browser_enabled, web_search_enabled, opencode_enabled) .expect("failed to render worker capabilities"); - + let status_text = { let status = self.state.status_block.read().await; status.render() }; - + // Render coalesce hint let elapsed_str = format!("{:.1}s", elapsed_secs); let coalesce_hint = prompt_engine .render_coalesce_hint(message_count, &elapsed_str, unique_senders) .ok(); - + let empty_to_none = |s: String| if s.is_empty() { None } else { Some(s) }; - + prompt_engine .render_channel_prompt( empty_to_none(identity_context), @@ -556,7 +589,7 @@ impl Channel { if self.conversation_id.is_none() { self.conversation_id = Some(message.conversation_id.clone()); } - + let (raw_text, attachments) = match &message.content { crate::MessageContent::Text(text) => (text.clone(), Vec::new()), crate::MessageContent::Media { text, attachments } => { @@ -574,7 +607,8 @@ impl Channel { // Persist user messages (skip system re-triggers) if message.source != "system" { - let sender_name = message.metadata + let sender_name = message + .metadata .get("sender_display_name") .and_then(|v| v.as_str()) .unwrap_or(&message.sender_id); @@ -585,17 +619,22 @@ impl Channel { &raw_text, &message.metadata, ); - self.state.channel_store.upsert( - &message.conversation_id, - &message.metadata, - ); + self.state + .channel_store + .upsert(&message.conversation_id, &message.metadata); } // Capture conversation context from the first message (platform, channel, server) if self.conversation_context.is_none() { let prompt_engine = self.deps.runtime_config.prompts.load(); - let server_name = message.metadata.get("discord_guild_name").and_then(|v| v.as_str()); - let channel_name = message.metadata.get("discord_channel_name").and_then(|v| v.as_str()); + let server_name = message + .metadata + .get("discord_guild_name") + .and_then(|v| v.as_str()); + let channel_name = message + .metadata + .get("discord_channel_name") + .and_then(|v| v.as_str()); self.conversation_context = Some( prompt_engine .render_conversation_context(&message.source, server_name, channel_name) @@ -605,12 +644,14 @@ impl Channel { let system_prompt = self.build_system_prompt().await; - let (result, skip_flag) = self.run_agent_turn( - &user_text, - &system_prompt, - &message.conversation_id, - attachment_content, - ).await?; + let (result, skip_flag) = self + .run_agent_turn( + &user_text, + &system_prompt, + &message.conversation_id, + attachment_content, + ) + .await?; self.handle_agent_result(result, &skip_flag).await; @@ -624,7 +665,7 @@ impl Channel { self.message_count += 1; self.check_memory_persistence().await; } - + Ok(()) } @@ -674,7 +715,10 @@ impl Channel { system_prompt: &str, conversation_id: &str, attachment_content: Vec, - ) -> Result<(std::result::Result, crate::tools::SkipFlag)> { + ) -> Result<( + std::result::Result, + crate::tools::SkipFlag, + )> { let skip_flag = crate::tools::new_skip_flag(); if let Err(error) = crate::tools::add_channel_tools( @@ -684,7 +728,9 @@ impl Channel { conversation_id, skip_flag.clone(), self.deps.cron_tool.clone(), - ).await { + ) + .await + { tracing::error!(%error, "failed to add channel tools"); return Err(AgentError::Other(error.into()).into()); } @@ -702,13 +748,17 @@ impl Channel { .tool_server_handle(self.tool_server.clone()) .build(); - let _ = self.response_tx.send(OutboundResponse::Status(crate::StatusUpdate::Thinking)).await; + let _ = self + .response_tx + .send(OutboundResponse::Status(crate::StatusUpdate::Thinking)) + .await; // Inject attachments as a user message before the text prompt if !attachment_content.is_empty() { let mut history = self.state.history.write().await; - let content = OneOrMany::many(attachment_content) - .unwrap_or_else(|_| OneOrMany::one(UserContent::text("[attachment processing failed]"))); + let content = OneOrMany::many(attachment_content).unwrap_or_else(|_| { + OneOrMany::one(UserContent::text("[attachment processing failed]")) + }); history.push(rig::message::Message::User { content }); drop(history); } @@ -721,7 +771,8 @@ impl Channel { guard.clone() }; - let result = agent.prompt(user_text) + let result = agent + .prompt(user_text) .with_history(&mut history) .with_hook(self.hook.clone()) .await; @@ -756,8 +807,14 @@ impl Channel { // directly. Some models respond with text instead of tool calls. let text = response.trim(); if !text.is_empty() { - self.state.conversation_logger.log_bot_message(&self.state.channel_id, text); - if let Err(error) = self.response_tx.send(OutboundResponse::Text(text.to_string())).await { + self.state + .conversation_logger + .log_bot_message(&self.state.channel_id, text); + if let Err(error) = self + .response_tx + .send(OutboundResponse::Text(text.to_string())) + .await + { tracing::error!(%error, channel_id = %self.id, "failed to send fallback reply"); } } @@ -777,9 +834,12 @@ impl Channel { } // Ensure typing indicator is always cleaned up, even on error paths - let _ = self.response_tx.send(OutboundResponse::Status(crate::StatusUpdate::StopTyping)).await; + let _ = self + .response_tx + .send(OutboundResponse::Status(crate::StatusUpdate::StopTyping)) + .await; } - + /// Handle a process event (branch results, worker completions, status updates). async fn handle_event(&mut self, event: ProcessEvent) -> Result<()> { // Only process events targeted at this channel @@ -797,10 +857,19 @@ impl Channel { let run_logger = &self.state.process_run_logger; match &event { - ProcessEvent::BranchStarted { branch_id, channel_id, description, .. } => { + ProcessEvent::BranchStarted { + branch_id, + channel_id, + description, + .. + } => { run_logger.log_branch_started(channel_id, *branch_id, description); } - ProcessEvent::BranchResult { branch_id, conclusion, .. } => { + ProcessEvent::BranchResult { + branch_id, + conclusion, + .. + } => { run_logger.log_branch_completed(*branch_id, conclusion); // Remove from active branches @@ -822,13 +891,42 @@ impl Channel { tracing::info!(branch_id = %branch_id, "branch result incorporated"); } } - ProcessEvent::WorkerStarted { worker_id, channel_id, task, .. } => { + ProcessEvent::BranchFailed { + branch_id, error, .. + } => { + run_logger.log_branch_failed(*branch_id, error); + + // Remove from active branches + let mut branches = self.state.active_branches.write().await; + branches.remove(branch_id); + + // Memory persistence branches complete silently on failure too. + // They are best-effort and should not re-trigger the channel. + if self.memory_persistence_branches.remove(branch_id) { + tracing::warn!(branch_id = %branch_id, %error, "memory persistence branch failed"); + } else { + tracing::warn!(branch_id = %branch_id, %error, "branch failed"); + } + } + ProcessEvent::WorkerStarted { + worker_id, + channel_id, + task, + .. + } => { run_logger.log_worker_started(channel_id.as_ref(), *worker_id, task); } - ProcessEvent::WorkerStatus { worker_id, status, .. } => { + ProcessEvent::WorkerStatus { + worker_id, status, .. + } => { run_logger.log_worker_status(*worker_id, status); } - ProcessEvent::WorkerComplete { worker_id, result, notify, .. } => { + ProcessEvent::WorkerComplete { + worker_id, + result, + notify, + .. + } => { run_logger.log_worker_completed(*worker_id, result); let mut workers = self.state.active_workers.write().await; @@ -844,7 +942,7 @@ impl Channel { history.push(rig::message::Message::from(worker_message)); should_retrigger = true; } - + tracing::info!(worker_id = %worker_id, "worker completed"); } _ => {} @@ -853,8 +951,11 @@ impl Channel { // Re-trigger the channel LLM so it can process the result and respond if should_retrigger { if let Some(conversation_id) = &self.conversation_id { - let retrigger_message = self.deps.runtime_config - .prompts.load() + let retrigger_message = self + .deps + .runtime_config + .prompts + .load() .render_system_retrigger() .expect("failed to render retrigger message"); @@ -873,10 +974,10 @@ impl Channel { } } } - + Ok(()) } - + /// Get the current status block as a string. pub async fn get_status(&self) -> String { let status = self.state.status_block.read().await; @@ -933,8 +1034,14 @@ pub async fn spawn_branch_from_state( ) .expect("failed to render branch prompt"); - spawn_branch(state, &description, &description, &system_prompt, &description) - .await + spawn_branch( + state, + &description, + &description, + &system_prompt, + &description, + ) + .await } /// Spawn a silent memory persistence branch. @@ -954,8 +1061,14 @@ async fn spawn_memory_persistence_branch( .render_system_memory_persistence() .expect("failed to render memory persistence prompt"); - spawn_branch(state, "memory persistence", &prompt, &system_prompt, "persisting memories...") - .await + spawn_branch( + state, + "memory persistence", + &prompt, + &system_prompt, + "persisting memories...", + ) + .await } /// Shared branch spawning logic. @@ -1004,9 +1117,25 @@ async fn spawn_branch( let branch_id = branch.id; let prompt = prompt.to_owned(); + let event_tx = state.deps.event_tx.clone(); + let agent_id = state.deps.agent_id.clone(); + let channel_id = state.channel_id.clone(); + let (start_tx, start_rx) = tokio::sync::oneshot::channel::<()>(); let handle = tokio::spawn(async move { + // Ensure the branch is registered in channel state before it can emit + // terminal events (success/failure), avoiding add/remove races. + let _ = start_rx.await; if let Err(error) = branch.run(&prompt).await { + let error_message = error.to_string(); + event_tx + .send(crate::ProcessEvent::BranchFailed { + agent_id, + branch_id, + channel_id, + error: error_message.clone(), + }) + .ok(); tracing::error!(branch_id = %branch_id, %error, "branch failed"); } }); @@ -1021,12 +1150,17 @@ async fn spawn_branch( status.add_branch(branch_id, status_label); } - state.deps.event_tx.send(crate::ProcessEvent::BranchStarted { - agent_id: state.deps.agent_id.clone(), - branch_id, - channel_id: state.channel_id.clone(), - description: status_label.to_string(), - }).ok(); + state + .deps + .event_tx + .send(crate::ProcessEvent::BranchStarted { + agent_id: state.deps.agent_id.clone(), + branch_id, + channel_id: state.channel_id.clone(), + description: status_label.to_string(), + }) + .ok(); + start_tx.send(()).ok(); tracing::info!(branch_id = %branch_id, description = %status_label, "branch spawned"); @@ -1079,7 +1213,7 @@ pub async fn spawn_worker_from_state( } else { worker_system_prompt }; - + let worker = if interactive { let (worker, input_tx) = Worker::new_interactive( Some(state.channel_id.clone()), @@ -1092,7 +1226,11 @@ pub async fn spawn_worker_from_state( state.logs_dir.clone(), ); let worker_id = worker.id; - state.worker_inputs.write().await.insert(worker_id, input_tx); + state + .worker_inputs + .write() + .await + .insert(worker_id, input_tx); worker } else { Worker::new( @@ -1106,7 +1244,7 @@ pub async fn spawn_worker_from_state( state.logs_dir.clone(), ) }; - + let worker_id = worker.id; let handle = spawn_worker_task( @@ -1124,15 +1262,19 @@ pub async fn spawn_worker_from_state( status.add_worker(worker_id, &task, false); } - state.deps.event_tx.send(crate::ProcessEvent::WorkerStarted { - agent_id: state.deps.agent_id.clone(), - worker_id, - channel_id: Some(state.channel_id.clone()), - task: task.clone(), - }).ok(); + state + .deps + .event_tx + .send(crate::ProcessEvent::WorkerStarted { + agent_id: state.deps.agent_id.clone(), + worker_id, + channel_id: Some(state.channel_id.clone()), + task: task.clone(), + }) + .ok(); tracing::info!(worker_id = %worker_id, task = %task, "worker spawned"); - + Ok(worker_id) } @@ -1153,6 +1295,8 @@ pub async fn spawn_opencode_worker_from_state( let rc = &state.deps.runtime_config; let opencode_config = rc.opencode.load(); + let routing = rc.routing.load(); + let model_name = routing.resolve(ProcessType::Worker, None).to_string(); if !opencode_config.enabled { return Err(AgentError::Other(anyhow::anyhow!( @@ -1172,8 +1316,12 @@ pub async fn spawn_opencode_worker_from_state( state.deps.event_tx.clone(), ); let worker_id = worker.id; - state.worker_inputs.write().await.insert(worker_id, input_tx); - worker + state + .worker_inputs + .write() + .await + .insert(worker_id, input_tx); + worker.with_model(model_name.clone()) } else { crate::opencode::OpenCodeWorker::new( Some(state.channel_id.clone()), @@ -1183,6 +1331,7 @@ pub async fn spawn_opencode_worker_from_state( server_pool, state.deps.event_tx.clone(), ) + .with_model(model_name) }; let worker_id = worker.id; @@ -1206,12 +1355,16 @@ pub async fn spawn_opencode_worker_from_state( status.add_worker(worker_id, &opencode_task, false); } - state.deps.event_tx.send(crate::ProcessEvent::WorkerStarted { - agent_id: state.deps.agent_id.clone(), - worker_id, - channel_id: Some(state.channel_id.clone()), - task: opencode_task, - }).ok(); + state + .deps + .event_tx + .send(crate::ProcessEvent::WorkerStarted { + agent_id: state.deps.agent_id.clone(), + worker_id, + channel_id: Some(state.channel_id.clone()), + task: opencode_task, + }) + .ok(); tracing::info!(worker_id = %worker_id, task = %task, "OpenCode worker spawned"); @@ -1261,22 +1414,30 @@ fn format_user_message(raw_text: &str, message: &InboundMessage) -> String { return raw_text.to_string(); } - let display_name = message.metadata + let display_name = message + .metadata .get("sender_display_name") .and_then(|v| v.as_str()) .unwrap_or(&message.sender_id); - let bot_tag = if message.metadata.get("sender_is_bot").and_then(|v| v.as_bool()).unwrap_or(false) { + let bot_tag = if message + .metadata + .get("sender_is_bot") + .and_then(|v| v.as_bool()) + .unwrap_or(false) + { " (bot)" } else { "" }; - let reply_context = message.metadata + let reply_context = message + .metadata .get("reply_to_author") .and_then(|v| v.as_str()) .map(|author| { - let content_preview = message.metadata + let content_preview = message + .metadata .get("reply_to_content") .and_then(|v| v.as_str()) .unwrap_or(""); @@ -1298,15 +1459,22 @@ fn format_user_message(raw_text: &str, message: &InboundMessage) -> String { /// channel's workers would leak into sibling channels (e.g. threads). fn event_is_for_channel(event: &ProcessEvent, channel_id: &ChannelId) -> bool { match event { - ProcessEvent::BranchResult { channel_id: event_channel, .. } => { - event_channel == channel_id - } - ProcessEvent::WorkerComplete { channel_id: event_channel, .. } => { - event_channel.as_ref() == Some(channel_id) - } - ProcessEvent::WorkerStatus { channel_id: event_channel, .. } => { - event_channel.as_ref() == Some(channel_id) - } + ProcessEvent::BranchResult { + channel_id: event_channel, + .. + } => event_channel == channel_id, + ProcessEvent::BranchFailed { + channel_id: event_channel, + .. + } => event_channel == channel_id, + ProcessEvent::WorkerComplete { + channel_id: event_channel, + .. + } => event_channel.as_ref() == Some(channel_id), + ProcessEvent::WorkerStatus { + channel_id: event_channel, + .. + } => event_channel.as_ref() == Some(channel_id), // Status block updates, tool events, etc. — match on agent_id which // is already filtered by the event bus subscription. Let them through. _ => true, @@ -1318,8 +1486,13 @@ const IMAGE_MIME_PREFIXES: &[&str] = &["image/jpeg", "image/png", "image/gif", " /// Text-based MIME types where we inline the content. const TEXT_MIME_PREFIXES: &[&str] = &[ - "text/", "application/json", "application/xml", "application/javascript", - "application/typescript", "application/toml", "application/yaml", + "text/", + "application/json", + "application/xml", + "application/javascript", + "application/typescript", + "application/toml", + "application/yaml", ]; /// Download attachments and convert them to LLM-ready UserContent parts. @@ -1334,15 +1507,20 @@ async fn download_attachments( let mut parts = Vec::new(); for attachment in attachments { - let is_image = IMAGE_MIME_PREFIXES.iter().any(|p| attachment.mime_type.starts_with(p)); - let is_text = TEXT_MIME_PREFIXES.iter().any(|p| attachment.mime_type.starts_with(p)); + let is_image = IMAGE_MIME_PREFIXES + .iter() + .any(|p| attachment.mime_type.starts_with(p)); + let is_text = TEXT_MIME_PREFIXES + .iter() + .any(|p| attachment.mime_type.starts_with(p)); let content = if is_image { download_image_attachment(http, attachment).await } else if is_text { download_text_attachment(http, attachment).await } else { - let size_str = attachment.size_bytes + let size_str = attachment + .size_bytes .map(|s| format!("{:.1} KB", s as f64 / 1024.0)) .unwrap_or_else(|| "unknown size".into()); UserContent::text(format!( @@ -1366,7 +1544,10 @@ async fn download_image_attachment( Ok(r) => r, Err(error) => { tracing::warn!(%error, filename = %attachment.filename, "failed to download image"); - return UserContent::text(format!("[Failed to download image: {}]", attachment.filename)); + return UserContent::text(format!( + "[Failed to download image: {}]", + attachment.filename + )); } }; @@ -1374,7 +1555,10 @@ async fn download_image_attachment( Ok(b) => b, Err(error) => { tracing::warn!(%error, filename = %attachment.filename, "failed to read image bytes"); - return UserContent::text(format!("[Failed to download image: {}]", attachment.filename)); + return UserContent::text(format!( + "[Failed to download image: {}]", + attachment.filename + )); } }; @@ -1401,7 +1585,10 @@ async fn download_text_attachment( Ok(r) => r, Err(error) => { tracing::warn!(%error, filename = %attachment.filename, "failed to download text file"); - return UserContent::text(format!("[Failed to download file: {}]", attachment.filename)); + return UserContent::text(format!( + "[Failed to download file: {}]", + attachment.filename + )); } }; @@ -1415,7 +1602,11 @@ async fn download_text_attachment( // Truncate very large files to avoid blowing up context let truncated = if content.len() > 50_000 { - format!("{}...\n[truncated — {} bytes total]", &content[..50_000], content.len()) + format!( + "{}...\n[truncated — {} bytes total]", + &content[..50_000], + content.len() + ) } else { content }; diff --git a/src/agent/worker.rs b/src/agent/worker.rs index 721d21f62..ef1c74bf6 100644 --- a/src/agent/worker.rs +++ b/src/agent/worker.rs @@ -3,10 +3,10 @@ use crate::agent::compactor::estimate_history_tokens; use crate::config::BrowserConfig; use crate::error::Result; -use crate::llm::routing::is_context_overflow_error; -use crate::llm::SpacebotModel; -use crate::{WorkerId, ChannelId, ProcessId, ProcessType, AgentDeps}; use crate::hooks::SpacebotHook; +use crate::llm::SpacebotModel; +use crate::llm::routing::is_context_overflow_error; +use crate::{AgentDeps, ChannelId, ProcessId, ProcessType, WorkerId}; use rig::agent::AgentBuilder; use rig::completion::{CompletionModel, Prompt}; use std::fmt::Write as _; @@ -73,9 +73,15 @@ impl Worker { ) -> Self { let id = Uuid::new_v4(); let process_id = ProcessId::Worker(id); - let hook = SpacebotHook::new(deps.agent_id.clone(), process_id, ProcessType::Worker, channel_id.clone(), deps.event_tx.clone()); + let hook = SpacebotHook::new( + deps.agent_id.clone(), + process_id, + ProcessType::Worker, + channel_id.clone(), + deps.event_tx.clone(), + ); let (status_tx, status_rx) = watch::channel("starting".to_string()); - + Self { id, channel_id, @@ -93,7 +99,7 @@ impl Worker { status_rx, } } - + /// Create a new interactive worker. pub fn new_interactive( channel_id: Option, @@ -107,10 +113,16 @@ impl Worker { ) -> (Self, mpsc::Sender) { let id = Uuid::new_v4(); let process_id = ProcessId::Worker(id); - let hook = SpacebotHook::new(deps.agent_id.clone(), process_id, ProcessType::Worker, channel_id.clone(), deps.event_tx.clone()); + let hook = SpacebotHook::new( + deps.agent_id.clone(), + process_id, + ProcessType::Worker, + channel_id.clone(), + deps.event_tx.clone(), + ); let (status_tx, status_rx) = watch::channel("starting".to_string()); let (input_tx, input_rx) = mpsc::channel(32); - + let worker = Self { id, channel_id, @@ -127,14 +139,14 @@ impl Worker { status_tx, status_rx, }; - + (worker, input_tx) } - + /// Check if the worker can transition to a new state. pub fn can_transition_to(&self, target: WorkerState) -> bool { use WorkerState::*; - + matches!( (self.state, target), (Running, WaitingForInput) @@ -144,19 +156,21 @@ impl Worker { | (WaitingForInput, Failed) ) } - + /// Transition to a new state. pub fn transition_to(&mut self, new_state: WorkerState) -> Result<()> { if !self.can_transition_to(new_state) { - return Err(crate::error::AgentError::InvalidStateTransition( - format!("can't transition from {:?} to {:?}", self.state, new_state) - ).into()); + return Err(crate::error::AgentError::InvalidStateTransition(format!( + "can't transition from {:?} to {:?}", + self.state, new_state + )) + .into()); } - + self.state = new_state; Ok(()) } - + /// Run the worker's LLM agent loop until completion. /// /// Runs in segments of 25 turns. After each segment, checks context usage @@ -166,7 +180,7 @@ impl Worker { pub async fn run(mut self) -> Result { self.status_tx.send_modify(|s| *s = "running".to_string()); self.hook.send_status("running"); - + tracing::info!(worker_id = %self.id, task = %self.task, "worker starting"); // Create per-worker ToolServer with task tools @@ -204,7 +218,8 @@ impl Worker { let result = loop { segments_run += 1; - match agent.prompt(&prompt) + match agent + .prompt(&prompt) .with_history(&mut history) .with_hook(self.hook.clone()) .await @@ -216,7 +231,8 @@ impl Worker { overflow_retries = 0; self.maybe_compact_history(&mut history).await; prompt = "Continue where you left off. Do not repeat completed work.".into(); - self.hook.send_status(&format!("working (segment {segments_run})")); + self.hook + .send_status(&format!("working (segment {segments_run})")); tracing::debug!( worker_id = %self.id, @@ -252,7 +268,8 @@ impl Worker { self.force_compact_history(&mut history).await; prompt = "Continue where you left off. Do not repeat completed work. \ Your previous attempt exceeded the context limit, so older history \ - has been compacted.".into(); + has been compacted." + .into(); } Err(error) => { self.state = WorkerState::Failed; @@ -280,7 +297,8 @@ impl Worker { let mut follow_up_overflow_retries = 0; let follow_up_ok = loop { - match agent.prompt(&follow_up_prompt) + match agent + .prompt(&follow_up_prompt) .with_history(&mut history) .with_hook(self.hook.clone()) .await @@ -328,13 +346,13 @@ impl Worker { self.state = WorkerState::Done; self.hook.send_status("completed"); - + // Write success log based on the worker log mode setting let log_mode = self.get_worker_log_mode(); if log_mode != crate::settings::WorkerLogMode::ErrorsOnly { self.write_success_log(&history); } - + tracing::info!(worker_id = %self.id, "worker completed"); Ok(result) } @@ -353,7 +371,8 @@ impl Worker { return; } - self.compact_history(history, 0.50, "worker history compacted").await; + self.compact_history(history, 0.50, "worker history compacted") + .await; } /// Aggressive compaction for context overflow recovery. @@ -362,7 +381,12 @@ impl Worker { /// usage and removes 75% of messages. Used when the provider has already /// rejected the request for exceeding context limits. async fn force_compact_history(&self, history: &mut Vec) { - self.compact_history(history, 0.75, "worker history force-compacted (overflow recovery)").await; + self.compact_history( + history, + 0.75, + "worker history force-compacted (overflow recovery)", + ) + .await; } /// Compact worker history by removing a fraction of the oldest messages. @@ -381,7 +405,9 @@ impl Worker { let estimated = estimate_history_tokens(history); let usage = estimated as f32 / context_window as f32; - let remove_count = ((total as f32 * fraction) as usize).max(1).min(total.saturating_sub(2)); + let remove_count = ((total as f32 * fraction) as usize) + .max(1) + .min(total.saturating_sub(2)); let removed: Vec = history.drain(..remove_count).collect(); let recap = build_worker_recap(&removed); @@ -399,12 +425,12 @@ impl Worker { "{log_message}" ); } - + /// Check if worker is in a terminal state. pub fn is_done(&self) -> bool { matches!(self.state, WorkerState::Done | WorkerState::Failed) } - + /// Check if worker is interactive. pub fn is_interactive(&self) -> bool { self.input_rx.is_some() @@ -427,7 +453,7 @@ impl Worker { /// For AllSeparate mode, uses "failed" or "successful" subdirectories. fn get_log_directory(&self, is_success: bool) -> PathBuf { let mode = self.get_worker_log_mode(); - + match mode { crate::settings::WorkerLogMode::AllSeparate => { let subdir = if is_success { "successful" } else { "failed" }; @@ -457,13 +483,13 @@ impl Worker { let _ = writeln!(log); let _ = writeln!(log, "--- Task ---"); let _ = writeln!(log, "{}", self.task); - + if let Some(err) = error { let _ = writeln!(log); let _ = writeln!(log, "--- Error ---"); let _ = writeln!(log, "{err}"); } - + let _ = writeln!(log); let _ = writeln!(log, "--- History ({} messages) ---", history.len()); @@ -541,8 +567,8 @@ impl Worker { let log = self.build_log_content(history, None); // Best-effort write - if let Err(write_error) = std::fs::create_dir_all(&log_dir) - .and_then(|()| std::fs::write(&path, &log)) + if let Err(write_error) = + std::fs::create_dir_all(&log_dir).and_then(|()| std::fs::write(&path, &log)) { tracing::warn!( worker_id = %self.id, @@ -571,8 +597,8 @@ impl Worker { let log = self.build_log_content(history, Some(error)); // Best-effort write - if let Err(write_error) = std::fs::create_dir_all(&log_dir) - .and_then(|()| std::fs::write(&path, &log)) + if let Err(write_error) = + std::fs::create_dir_all(&log_dir).and_then(|()| std::fs::write(&path, &log)) { tracing::warn!( worker_id = %self.id, @@ -596,7 +622,7 @@ impl Worker { /// retains full context of what it already did after compaction. fn build_worker_recap(messages: &[rig::message::Message]) -> String { let mut recap = String::new(); - + for message in messages { match message { rig::message::Message::Assistant { content, .. } => { @@ -637,7 +663,8 @@ fn build_worker_recap(messages: &[rig::message::Message]) -> String { fn extract_last_assistant_text(history: &[rig::message::Message]) -> Option { for message in history.iter().rev() { if let rig::message::Message::Assistant { content, .. } = message { - let texts: Vec = content.iter() + let texts: Vec = content + .iter() .filter_map(|c| { if let rig::message::AssistantContent::Text(t) = c { Some(t.text.clone()) diff --git a/src/api/state.rs b/src/api/state.rs index bd021cfd4..5f1faf932 100644 --- a/src/api/state.rs +++ b/src/api/state.rs @@ -17,7 +17,7 @@ use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; use std::time::Instant; -use tokio::sync::{broadcast, mpsc, RwLock}; +use tokio::sync::{RwLock, broadcast, mpsc}; /// Summary of an agent's configuration, exposed via the API. #[derive(Debug, Clone, Serialize)] @@ -130,6 +130,13 @@ pub enum ApiEvent { branch_id: String, conclusion: String, }, + /// A branch failed. + BranchFailed { + agent_id: String, + channel_id: String, + branch_id: String, + error: String, + }, /// A tool call started on a process. ToolStarted { agent_id: String, @@ -149,7 +156,9 @@ pub enum ApiEvent { } impl ApiState { - pub fn new_with_provider_sender(provider_setup_tx: mpsc::Sender) -> Self { + pub fn new_with_provider_sender( + provider_setup_tx: mpsc::Sender, + ) -> Self { let (event_tx, _) = broadcast::channel(512); Self { started_at: Instant::now(), @@ -188,10 +197,7 @@ impl ApiState { /// Remove a channel's status block when it's dropped. pub async fn unregister_channel_status(&self, channel_id: &str) { - self.channel_status_blocks - .write() - .await - .remove(channel_id); + self.channel_status_blocks.write().await.remove(channel_id); } /// Register a channel's state for API-driven cancellation. @@ -218,65 +224,129 @@ impl ApiState { Ok(event) => { // Translate ProcessEvents into typed ApiEvents match &event { - ProcessEvent::WorkerStarted { worker_id, channel_id, task, .. } => { - api_tx.send(ApiEvent::WorkerStarted { - agent_id: agent_id.clone(), - channel_id: channel_id.as_deref().map(|s| s.to_string()), - worker_id: worker_id.to_string(), - task: task.clone(), - }).ok(); + ProcessEvent::WorkerStarted { + worker_id, + channel_id, + task, + .. + } => { + api_tx + .send(ApiEvent::WorkerStarted { + agent_id: agent_id.clone(), + channel_id: channel_id.as_deref().map(|s| s.to_string()), + worker_id: worker_id.to_string(), + task: task.clone(), + }) + .ok(); + } + ProcessEvent::BranchStarted { + branch_id, + channel_id, + description, + .. + } => { + api_tx + .send(ApiEvent::BranchStarted { + agent_id: agent_id.clone(), + channel_id: channel_id.to_string(), + branch_id: branch_id.to_string(), + description: description.clone(), + }) + .ok(); } - ProcessEvent::BranchStarted { branch_id, channel_id, description, .. } => { - api_tx.send(ApiEvent::BranchStarted { - agent_id: agent_id.clone(), - channel_id: channel_id.to_string(), - branch_id: branch_id.to_string(), - description: description.clone(), - }).ok(); + ProcessEvent::WorkerStatus { + worker_id, + channel_id, + status, + .. + } => { + api_tx + .send(ApiEvent::WorkerStatusUpdate { + agent_id: agent_id.clone(), + channel_id: channel_id.as_deref().map(|s| s.to_string()), + worker_id: worker_id.to_string(), + status: status.clone(), + }) + .ok(); } - ProcessEvent::WorkerStatus { worker_id, channel_id, status, .. } => { - api_tx.send(ApiEvent::WorkerStatusUpdate { - agent_id: agent_id.clone(), - channel_id: channel_id.as_deref().map(|s| s.to_string()), - worker_id: worker_id.to_string(), - status: status.clone(), - }).ok(); + ProcessEvent::WorkerComplete { + worker_id, + channel_id, + result, + .. + } => { + api_tx + .send(ApiEvent::WorkerCompleted { + agent_id: agent_id.clone(), + channel_id: channel_id.as_deref().map(|s| s.to_string()), + worker_id: worker_id.to_string(), + result: result.clone(), + }) + .ok(); } - ProcessEvent::WorkerComplete { worker_id, channel_id, result, .. } => { - api_tx.send(ApiEvent::WorkerCompleted { - agent_id: agent_id.clone(), - channel_id: channel_id.as_deref().map(|s| s.to_string()), - worker_id: worker_id.to_string(), - result: result.clone(), - }).ok(); + ProcessEvent::BranchResult { + branch_id, + channel_id, + conclusion, + .. + } => { + api_tx + .send(ApiEvent::BranchCompleted { + agent_id: agent_id.clone(), + channel_id: channel_id.to_string(), + branch_id: branch_id.to_string(), + conclusion: conclusion.clone(), + }) + .ok(); } - ProcessEvent::BranchResult { branch_id, channel_id, conclusion, .. } => { - api_tx.send(ApiEvent::BranchCompleted { - agent_id: agent_id.clone(), - channel_id: channel_id.to_string(), - branch_id: branch_id.to_string(), - conclusion: conclusion.clone(), - }).ok(); + ProcessEvent::BranchFailed { + branch_id, + channel_id, + error, + .. + } => { + api_tx + .send(ApiEvent::BranchFailed { + agent_id: agent_id.clone(), + channel_id: channel_id.to_string(), + branch_id: branch_id.to_string(), + error: error.clone(), + }) + .ok(); } - ProcessEvent::ToolStarted { process_id, channel_id, tool_name, .. } => { + ProcessEvent::ToolStarted { + process_id, + channel_id, + tool_name, + .. + } => { let (process_type, id_str) = process_id_info(process_id); - api_tx.send(ApiEvent::ToolStarted { - agent_id: agent_id.clone(), - channel_id: channel_id.as_deref().map(|s| s.to_string()), - process_type, - process_id: id_str, - tool_name: tool_name.clone(), - }).ok(); + api_tx + .send(ApiEvent::ToolStarted { + agent_id: agent_id.clone(), + channel_id: channel_id.as_deref().map(|s| s.to_string()), + process_type, + process_id: id_str, + tool_name: tool_name.clone(), + }) + .ok(); } - ProcessEvent::ToolCompleted { process_id, channel_id, tool_name, .. } => { + ProcessEvent::ToolCompleted { + process_id, + channel_id, + tool_name, + .. + } => { let (process_type, id_str) = process_id_info(process_id); - api_tx.send(ApiEvent::ToolCompleted { - agent_id: agent_id.clone(), - channel_id: channel_id.as_deref().map(|s| s.to_string()), - process_type, - process_id: id_str, - tool_name: tool_name.clone(), - }).ok(); + api_tx + .send(ApiEvent::ToolCompleted { + agent_id: agent_id.clone(), + channel_id: channel_id.as_deref().map(|s| s.to_string()), + process_type, + process_id: id_str, + tool_name: tool_name.clone(), + }) + .ok(); } _ => {} } diff --git a/src/conversation/channels.rs b/src/conversation/channels.rs index e56074e0f..dffe54669 100644 --- a/src/conversation/channels.rs +++ b/src/conversation/channels.rs @@ -35,11 +35,7 @@ impl ChannelStore { /// Extracts platform from the channel ID prefix (e.g. "discord" from /// "discord:123:456"). Updates display_name and platform_meta if the /// channel already exists. Fire-and-forget. - pub fn upsert( - &self, - channel_id: &str, - metadata: &HashMap, - ) { + pub fn upsert(&self, channel_id: &str, metadata: &HashMap) { let pool = self.pool.clone(); let channel_id = channel_id.to_string(); let platform = extract_platform(&channel_id); @@ -73,12 +69,11 @@ impl ChannelStore { let channel_id = channel_id.to_string(); tokio::spawn(async move { - if let Err(error) = sqlx::query( - "UPDATE channels SET last_activity_at = CURRENT_TIMESTAMP WHERE id = ?" - ) - .bind(&channel_id) - .execute(&pool) - .await + if let Err(error) = + sqlx::query("UPDATE channels SET last_activity_at = CURRENT_TIMESTAMP WHERE id = ?") + .bind(&channel_id) + .execute(&pool) + .await { tracing::warn!(%error, %channel_id, "failed to touch channel"); } @@ -109,21 +104,27 @@ impl ChannelStore { // Exact name match if let Some(channel) = channels.iter().find(|c| { - c.display_name.as_ref().is_some_and(|n| n.to_lowercase() == name_lower) + c.display_name + .as_ref() + .is_some_and(|n| n.to_lowercase() == name_lower) }) { return Ok(Some(channel.clone())); } // Prefix match if let Some(channel) = channels.iter().find(|c| { - c.display_name.as_ref().is_some_and(|n| n.to_lowercase().starts_with(&name_lower)) + c.display_name + .as_ref() + .is_some_and(|n| n.to_lowercase().starts_with(&name_lower)) }) { return Ok(Some(channel.clone())); } // Contains match if let Some(channel) = channels.iter().find(|c| { - c.display_name.as_ref().is_some_and(|n| n.to_lowercase().contains(&name_lower)) + c.display_name + .as_ref() + .is_some_and(|n| n.to_lowercase().contains(&name_lower)) }) { return Ok(Some(channel.clone())); } @@ -153,14 +154,17 @@ impl ChannelStore { /// Resolve a channel's display name by ID. pub async fn resolve_name(&self, channel_id: &str) -> Option { - self.get(channel_id).await.ok().flatten().and_then(|c| c.display_name) + self.get(channel_id) + .await + .ok() + .flatten() + .and_then(|c| c.display_name) } } fn row_to_channel_info(row: sqlx::sqlite::SqliteRow) -> ChannelInfo { let platform_meta_str: Option = row.try_get("platform_meta").ok().flatten(); - let platform_meta = platform_meta_str - .and_then(|s| serde_json::from_str(&s).ok()); + let platform_meta = platform_meta_str.and_then(|s| serde_json::from_str(&s).ok()); ChannelInfo { id: row.try_get("id").unwrap_or_default(), @@ -168,8 +172,12 @@ fn row_to_channel_info(row: sqlx::sqlite::SqliteRow) -> ChannelInfo { display_name: row.try_get("display_name").ok().flatten(), platform_meta, is_active: row.try_get::("is_active").unwrap_or(1) == 1, - created_at: row.try_get("created_at").unwrap_or_else(|_| chrono::Utc::now()), - last_activity_at: row.try_get("last_activity_at").unwrap_or_else(|_| chrono::Utc::now()), + created_at: row + .try_get("created_at") + .unwrap_or_else(|_| chrono::Utc::now()), + last_activity_at: row + .try_get("last_activity_at") + .unwrap_or_else(|_| chrono::Utc::now()), } } @@ -224,11 +232,7 @@ fn extract_platform_meta( } } "slack" => { - for key in [ - "slack_workspace_id", - "slack_channel_id", - "slack_thread_ts", - ] { + for key in ["slack_workspace_id", "slack_channel_id", "slack_thread_ts"] { if let Some(value) = metadata.get(key) { meta.insert(key.to_string(), value.clone()); } From d85135d780bb524301e7b17cd6530bd2414b2647 Mon Sep 17 00:00:00 2001 From: nmb Date: Tue, 17 Feb 2026 16:30:51 +0100 Subject: [PATCH 10/11] =?UTF-8?q?=E2=9C=A8=20feat(api):=20improve=20server?= =?UTF-8?q?=20request=20handling=20and=20client=20methods?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- interface/src/api/client.ts | 10 + src/api/server.rs | 920 ++++++++++++++++++++++++------------ 2 files changed, 615 insertions(+), 315 deletions(-) diff --git a/interface/src/api/client.ts b/interface/src/api/client.ts index 306a2e076..69e4cf4d3 100644 --- a/interface/src/api/client.ts +++ b/interface/src/api/client.ts @@ -85,6 +85,14 @@ export interface BranchCompletedEvent { conclusion: string; } +export interface BranchFailedEvent { + type: "branch_failed"; + agent_id: string; + channel_id: string; + branch_id: string; + error: string; +} + export interface ToolStartedEvent { type: "tool_started"; agent_id: string; @@ -112,6 +120,7 @@ export type ApiEvent = | WorkerCompletedEvent | BranchStartedEvent | BranchCompletedEvent + | BranchFailedEvent | ToolStartedEvent | ToolCompletedEvent; @@ -645,6 +654,7 @@ export interface ProviderStatus { anthropic: boolean; openai: boolean; openrouter: boolean; + ollama: boolean; zhipu: boolean; groq: boolean; together: boolean; diff --git a/src/api/server.rs b/src/api/server.rs index 75753556f..1ff0d0d18 100644 --- a/src/api/server.rs +++ b/src/api/server.rs @@ -5,14 +5,14 @@ use crate::agent::cortex::{CortexEvent, CortexLogger}; use crate::agent::cortex_chat::{CortexChatEvent, CortexChatMessage, CortexChatStore}; use crate::conversation::channels::ChannelStore; use crate::conversation::history::{ProcessRunLogger, TimelineItem}; -use crate::memory::types::{Association, Memory, MemorySearchResult, MemoryType}; use crate::memory::search::{SearchConfig, SearchMode, SearchSort}; +use crate::memory::types::{Association, Memory, MemorySearchResult, MemoryType}; +use axum::Router; use axum::extract::{Query, State}; -use axum::http::{header, StatusCode, Uri}; +use axum::http::{StatusCode, Uri, header}; use axum::response::{Html, IntoResponse, Json, Response, Sse}; use axum::routing::{delete, get, post, put}; -use axum::Router; use futures::stream::Stream; use rust_embed::Embed; use serde::{Deserialize, Serialize}; @@ -447,27 +447,50 @@ pub async fn start_http_server( .route("/agents/memories", get(list_memories)) .route("/agents/memories/search", get(search_memories)) .route("/agents/memories/graph", get(memory_graph)) - .route("/agents/memories/graph/neighbors", get(memory_graph_neighbors)) + .route( + "/agents/memories/graph/neighbors", + get(memory_graph_neighbors), + ) .route("/cortex/events", get(cortex_events)) .route("/cortex-chat/messages", get(cortex_chat_messages)) .route("/cortex-chat/send", post(cortex_chat_send)) .route("/agents/profile", get(get_agent_profile)) .route("/agents/identity", get(get_identity).put(update_identity)) - .route("/agents/config", get(get_agent_config).put(update_agent_config)) - .route("/agents/cron", get(list_cron_jobs).post(create_or_update_cron).delete(delete_cron)) + .route( + "/agents/config", + get(get_agent_config).put(update_agent_config), + ) + .route( + "/agents/cron", + get(list_cron_jobs) + .post(create_or_update_cron) + .delete(delete_cron), + ) .route("/agents/cron/executions", get(cron_executions)) .route("/agents/cron/trigger", post(trigger_cron)) .route("/agents/cron/toggle", put(toggle_cron)) .route("/channels/cancel", post(cancel_process)) - .route("/agents/ingest/files", get(list_ingest_files).delete(delete_ingest_file)) + .route( + "/agents/ingest/files", + get(list_ingest_files).delete(delete_ingest_file), + ) .route("/agents/ingest/upload", post(upload_ingest_file)) .route("/providers", get(get_providers).put(update_provider)) .route("/providers/{provider}", delete(delete_provider)) .route("/models", get(get_models)) .route("/models/refresh", post(refresh_models)) .route("/messaging/status", get(messaging_status)) - .route("/bindings", get(list_bindings).post(create_binding).put(update_binding).delete(delete_binding)) - .route("/settings", get(get_global_settings).put(update_global_settings)) + .route( + "/bindings", + get(list_bindings) + .post(create_binding) + .put(update_binding) + .delete(delete_binding), + ) + .route( + "/settings", + get(get_global_settings).put(update_global_settings), + ) .route("/config/raw", get(get_raw_config).put(update_raw_config)) .route("/update/check", get(update_check).post(update_check_now)) .route("/update/apply", post(update_apply)); @@ -515,7 +538,9 @@ async fn status(State(state): State>) -> Json { /// List all configured agents with their config summaries. async fn list_agents(State(state): State>) -> Json { let agents = state.agent_configs.load(); - Json(AgentsResponse { agents: agents.as_ref().clone() }) + Json(AgentsResponse { + agents: agents.as_ref().clone(), + }) } /// Get overview stats for an agent: memory breakdown, channels, cron, cortex. @@ -595,7 +620,8 @@ async fn agent_overview( // Latest bulletin text let latest_bulletin = bulletin_events.first().and_then(|e| { e.details.as_ref().and_then(|d| { - d.get("bulletin_text").and_then(|v| v.as_str().map(|s| s.to_string())) + d.get("bulletin_text") + .and_then(|v| v.as_str().map(|s| s.to_string())) }) }); @@ -639,12 +665,24 @@ async fn agent_overview( for row in branch_activity { let date: String = row.get("date"); let count: i64 = row.get("count"); - map.entry(date.clone()).or_insert_with(|| ActivityDayCount { date, branches: 0, workers: 0 }).branches = count; + map.entry(date.clone()) + .or_insert_with(|| ActivityDayCount { + date, + branches: 0, + workers: 0, + }) + .branches = count; } for row in worker_activity { let date: String = row.get("date"); let count: i64 = row.get("count"); - map.entry(date.clone()).or_insert_with(|| ActivityDayCount { date, branches: 0, workers: 0 }).workers = count; + map.entry(date.clone()) + .or_insert_with(|| ActivityDayCount { + date, + branches: 0, + workers: 0, + }) + .workers = count; } let mut days: Vec<_> = map.into_values().collect(); days.sort_by(|a, b| a.date.cmp(&b.date)); @@ -688,7 +726,9 @@ struct AgentOverviewQuery { } /// Get instance-wide overview for the main dashboard. -async fn instance_overview(State(state): State>) -> Result, StatusCode> { +async fn instance_overview( + State(state): State>, +) -> Result, StatusCode> { let uptime = state.started_at.elapsed(); let pools = state.agent_pools.load(); let configs = state.agent_configs.load(); @@ -697,7 +737,7 @@ async fn instance_overview(State(state): State>) -> Result>) -> Result>) -> Result = Vec::with_capacity(14); for i in 0..14 { - let date = (chrono::Utc::now() - chrono::Duration::days(13 - i as i64)).format("%Y-%m-%d").to_string(); + let date = (chrono::Utc::now() - chrono::Duration::days(13 - i as i64)) + .format("%Y-%m-%d") + .to_string(); activity_sparkline.push(*activity_map.get(&date).unwrap_or(&0)); } @@ -804,6 +844,7 @@ async fn events_sse( ApiEvent::WorkerCompleted { .. } => "worker_completed", ApiEvent::BranchStarted { .. } => "branch_started", ApiEvent::BranchCompleted { .. } => "branch_completed", + ApiEvent::BranchFailed { .. } => "branch_failed", ApiEvent::ToolStarted { .. } => "tool_started", ApiEvent::ToolCompleted { .. } => "tool_completed", }; @@ -857,7 +898,9 @@ async fn list_channels(State(state): State>) -> Json { let has_more = items.len() as i64 > limit; - let items = if has_more { items[items.len() - limit as usize..].to_vec() } else { items }; + let items = if has_more { + items[items.len() - limit as usize..].to_vec() + } else { + items + }; return Json(MessagesResponse { items, has_more }); } Ok(_) => continue, @@ -899,7 +949,10 @@ async fn channel_messages( } } - Json(MessagesResponse { items: vec![], has_more: false }) + Json(MessagesResponse { + items: vec![], + has_more: false, + }) } /// Get live status (active workers, branches, completed items) for all channels. @@ -984,7 +1037,8 @@ async fn list_memories( // Fetch limit + offset so we can paginate, then slice let fetch_limit = limit + query.offset as i64; - let all = store.get_sorted(sort, fetch_limit, memory_type) + let all = store + .get_sorted(sort, fetch_limit, memory_type) .await .map_err(|error| { tracing::warn!(%error, agent_id = %query.agent_id, "failed to list memories"); @@ -1071,7 +1125,8 @@ async fn memory_graph( let memory_type = query.memory_type.as_deref().and_then(parse_memory_type); let fetch_limit = limit + query.offset as i64; - let all = store.get_sorted(sort, fetch_limit, memory_type) + let all = store + .get_sorted(sort, fetch_limit, memory_type) .await .map_err(|error| { tracing::warn!(%error, agent_id = %query.agent_id, "failed to load graph nodes"); @@ -1082,14 +1137,19 @@ async fn memory_graph( let nodes: Vec = all.into_iter().skip(query.offset).collect(); let node_ids: Vec = nodes.iter().map(|m| m.id.clone()).collect(); - let edges = store.get_associations_between(&node_ids) + let edges = store + .get_associations_between(&node_ids) .await .map_err(|error| { tracing::warn!(%error, agent_id = %query.agent_id, "failed to load graph edges"); StatusCode::INTERNAL_SERVER_ERROR })?; - Ok(Json(MemoryGraphResponse { nodes, edges, total })) + Ok(Json(MemoryGraphResponse { + nodes, + edges, + total, + })) } #[derive(Deserialize)] @@ -1118,7 +1178,8 @@ async fn memory_graph_neighbors( let store = memory_search.store(); let depth = query.depth.min(3); - let exclude_ids: Vec = query.exclude + let exclude_ids: Vec = query + .exclude .as_deref() .unwrap_or("") .split(',') @@ -1181,7 +1242,10 @@ async fn cortex_chat_messages( StatusCode::INTERNAL_SERVER_ERROR })?; - Ok(Json(CortexChatMessagesResponse { messages, thread_id })) + Ok(Json(CortexChatMessagesResponse { + messages, + thread_id, + })) } /// Send a message to cortex chat. Returns an SSE stream with activity events. @@ -1272,7 +1336,9 @@ async fn get_identity( Query(query): Query, ) -> Result, StatusCode> { let workspaces = state.agent_workspaces.load(); - let workspace = workspaces.get(&query.agent_id).ok_or(StatusCode::NOT_FOUND)?; + let workspace = workspaces + .get(&query.agent_id) + .ok_or(StatusCode::NOT_FOUND)?; let identity = crate::identity::Identity::load(workspace).await; @@ -1290,7 +1356,9 @@ async fn update_identity( axum::Json(request): axum::Json, ) -> Result, StatusCode> { let workspaces = state.agent_workspaces.load(); - let workspace = workspaces.get(&request.agent_id).ok_or(StatusCode::NOT_FOUND)?; + let workspace = workspaces + .get(&request.agent_id) + .ok_or(StatusCode::NOT_FOUND)?; if let Some(soul) = &request.soul { tokio::fs::write(workspace.join("SOUL.md"), soul) @@ -1438,7 +1506,8 @@ async fn update_agent_config( })?; // Parse with toml_edit to preserve formatting - let mut doc = config_content.parse::() + let mut doc = config_content + .parse::() .map_err(|error| { tracing::warn!(%error, "failed to parse config.toml"); StatusCode::INTERNAL_SERVER_ERROR @@ -1509,18 +1578,28 @@ async fn update_agent_config( } } - get_agent_config(State(state), Query(AgentConfigQuery { agent_id: request.agent_id })).await + get_agent_config( + State(state), + Query(AgentConfigQuery { + agent_id: request.agent_id, + }), + ) + .await } /// Find the index of an agent table in the [[agents]] array, or create a new one. -fn find_or_create_agent_table(doc: &mut toml_edit::DocumentMut, agent_id: &str) -> Result { +fn find_or_create_agent_table( + doc: &mut toml_edit::DocumentMut, + agent_id: &str, +) -> Result { // Create agents array if it doesn't exist if doc.get("agents").is_none() { doc["agents"] = toml_edit::Item::ArrayOfTables(toml_edit::ArrayOfTables::new()); } // Get the agents array - let agents = doc.get_mut("agents") + let agents = doc + .get_mut("agents") .and_then(|a| a.as_array_of_tables_mut()) .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; @@ -1542,7 +1621,10 @@ fn find_or_create_agent_table(doc: &mut toml_edit::DocumentMut, agent_id: &str) } /// Get a mutable reference to an agent's table in the [[agents]] array. -fn get_agent_table_mut(doc: &mut toml_edit::DocumentMut, agent_idx: usize) -> Result<&mut toml_edit::Table, StatusCode> { +fn get_agent_table_mut( + doc: &mut toml_edit::DocumentMut, + agent_idx: usize, +) -> Result<&mut toml_edit::Table, StatusCode> { doc.get_mut("agents") .and_then(|a| a.as_array_of_tables_mut()) .and_then(|arr| arr.get_mut(agent_idx)) @@ -1550,93 +1632,193 @@ fn get_agent_table_mut(doc: &mut toml_edit::DocumentMut, agent_idx: usize) -> Re } /// Get or create a subtable within an agent's config (e.g., [agents.routing]). -fn get_or_create_subtable<'a>(agent: &'a mut toml_edit::Table, key: &str) -> &'a mut toml_edit::Table { +fn get_or_create_subtable<'a>( + agent: &'a mut toml_edit::Table, + key: &str, +) -> &'a mut toml_edit::Table { if !agent.contains_key(key) { agent[key] = toml_edit::Item::Table(toml_edit::Table::new()); } agent[key].as_table_mut().expect("just created as table") } -fn update_routing_table(doc: &mut toml_edit::DocumentMut, agent_idx: usize, routing: &RoutingUpdate) -> Result<(), StatusCode> { +fn update_routing_table( + doc: &mut toml_edit::DocumentMut, + agent_idx: usize, + routing: &RoutingUpdate, +) -> Result<(), StatusCode> { let agent = get_agent_table_mut(doc, agent_idx)?; let table = get_or_create_subtable(agent, "routing"); - if let Some(ref v) = routing.channel { table["channel"] = toml_edit::value(v.as_str()); } - if let Some(ref v) = routing.branch { table["branch"] = toml_edit::value(v.as_str()); } - if let Some(ref v) = routing.worker { table["worker"] = toml_edit::value(v.as_str()); } - if let Some(ref v) = routing.compactor { table["compactor"] = toml_edit::value(v.as_str()); } - if let Some(ref v) = routing.cortex { table["cortex"] = toml_edit::value(v.as_str()); } - if let Some(v) = routing.rate_limit_cooldown_secs { table["rate_limit_cooldown_secs"] = toml_edit::value(v as i64); } + if let Some(ref v) = routing.channel { + table["channel"] = toml_edit::value(v.as_str()); + } + if let Some(ref v) = routing.branch { + table["branch"] = toml_edit::value(v.as_str()); + } + if let Some(ref v) = routing.worker { + table["worker"] = toml_edit::value(v.as_str()); + } + if let Some(ref v) = routing.compactor { + table["compactor"] = toml_edit::value(v.as_str()); + } + if let Some(ref v) = routing.cortex { + table["cortex"] = toml_edit::value(v.as_str()); + } + if let Some(v) = routing.rate_limit_cooldown_secs { + table["rate_limit_cooldown_secs"] = toml_edit::value(v as i64); + } Ok(()) } -fn update_tuning_table(doc: &mut toml_edit::DocumentMut, agent_idx: usize, tuning: &TuningUpdate) -> Result<(), StatusCode> { +fn update_tuning_table( + doc: &mut toml_edit::DocumentMut, + agent_idx: usize, + tuning: &TuningUpdate, +) -> Result<(), StatusCode> { let agent = get_agent_table_mut(doc, agent_idx)?; - if let Some(v) = tuning.max_concurrent_branches { agent["max_concurrent_branches"] = toml_edit::value(v as i64); } - if let Some(v) = tuning.max_concurrent_workers { agent["max_concurrent_workers"] = toml_edit::value(v as i64); } - if let Some(v) = tuning.max_turns { agent["max_turns"] = toml_edit::value(v as i64); } - if let Some(v) = tuning.branch_max_turns { agent["branch_max_turns"] = toml_edit::value(v as i64); } - if let Some(v) = tuning.context_window { agent["context_window"] = toml_edit::value(v as i64); } - if let Some(v) = tuning.history_backfill_count { agent["history_backfill_count"] = toml_edit::value(v as i64); } + if let Some(v) = tuning.max_concurrent_branches { + agent["max_concurrent_branches"] = toml_edit::value(v as i64); + } + if let Some(v) = tuning.max_concurrent_workers { + agent["max_concurrent_workers"] = toml_edit::value(v as i64); + } + if let Some(v) = tuning.max_turns { + agent["max_turns"] = toml_edit::value(v as i64); + } + if let Some(v) = tuning.branch_max_turns { + agent["branch_max_turns"] = toml_edit::value(v as i64); + } + if let Some(v) = tuning.context_window { + agent["context_window"] = toml_edit::value(v as i64); + } + if let Some(v) = tuning.history_backfill_count { + agent["history_backfill_count"] = toml_edit::value(v as i64); + } Ok(()) } -fn update_compaction_table(doc: &mut toml_edit::DocumentMut, agent_idx: usize, compaction: &CompactionUpdate) -> Result<(), StatusCode> { +fn update_compaction_table( + doc: &mut toml_edit::DocumentMut, + agent_idx: usize, + compaction: &CompactionUpdate, +) -> Result<(), StatusCode> { let agent = get_agent_table_mut(doc, agent_idx)?; let table = get_or_create_subtable(agent, "compaction"); - if let Some(v) = compaction.background_threshold { table["background_threshold"] = toml_edit::value(v as f64); } - if let Some(v) = compaction.aggressive_threshold { table["aggressive_threshold"] = toml_edit::value(v as f64); } - if let Some(v) = compaction.emergency_threshold { table["emergency_threshold"] = toml_edit::value(v as f64); } + if let Some(v) = compaction.background_threshold { + table["background_threshold"] = toml_edit::value(v as f64); + } + if let Some(v) = compaction.aggressive_threshold { + table["aggressive_threshold"] = toml_edit::value(v as f64); + } + if let Some(v) = compaction.emergency_threshold { + table["emergency_threshold"] = toml_edit::value(v as f64); + } Ok(()) } -fn update_cortex_table(doc: &mut toml_edit::DocumentMut, agent_idx: usize, cortex: &CortexUpdate) -> Result<(), StatusCode> { +fn update_cortex_table( + doc: &mut toml_edit::DocumentMut, + agent_idx: usize, + cortex: &CortexUpdate, +) -> Result<(), StatusCode> { let agent = get_agent_table_mut(doc, agent_idx)?; let table = get_or_create_subtable(agent, "cortex"); - if let Some(v) = cortex.tick_interval_secs { table["tick_interval_secs"] = toml_edit::value(v as i64); } - if let Some(v) = cortex.worker_timeout_secs { table["worker_timeout_secs"] = toml_edit::value(v as i64); } - if let Some(v) = cortex.branch_timeout_secs { table["branch_timeout_secs"] = toml_edit::value(v as i64); } - if let Some(v) = cortex.circuit_breaker_threshold { table["circuit_breaker_threshold"] = toml_edit::value(v as i64); } - if let Some(v) = cortex.bulletin_interval_secs { table["bulletin_interval_secs"] = toml_edit::value(v as i64); } - if let Some(v) = cortex.bulletin_max_words { table["bulletin_max_words"] = toml_edit::value(v as i64); } - if let Some(v) = cortex.bulletin_max_turns { table["bulletin_max_turns"] = toml_edit::value(v as i64); } + if let Some(v) = cortex.tick_interval_secs { + table["tick_interval_secs"] = toml_edit::value(v as i64); + } + if let Some(v) = cortex.worker_timeout_secs { + table["worker_timeout_secs"] = toml_edit::value(v as i64); + } + if let Some(v) = cortex.branch_timeout_secs { + table["branch_timeout_secs"] = toml_edit::value(v as i64); + } + if let Some(v) = cortex.circuit_breaker_threshold { + table["circuit_breaker_threshold"] = toml_edit::value(v as i64); + } + if let Some(v) = cortex.bulletin_interval_secs { + table["bulletin_interval_secs"] = toml_edit::value(v as i64); + } + if let Some(v) = cortex.bulletin_max_words { + table["bulletin_max_words"] = toml_edit::value(v as i64); + } + if let Some(v) = cortex.bulletin_max_turns { + table["bulletin_max_turns"] = toml_edit::value(v as i64); + } Ok(()) } -fn update_coalesce_table(doc: &mut toml_edit::DocumentMut, agent_idx: usize, coalesce: &CoalesceUpdate) -> Result<(), StatusCode> { +fn update_coalesce_table( + doc: &mut toml_edit::DocumentMut, + agent_idx: usize, + coalesce: &CoalesceUpdate, +) -> Result<(), StatusCode> { let agent = get_agent_table_mut(doc, agent_idx)?; let table = get_or_create_subtable(agent, "coalesce"); - if let Some(v) = coalesce.enabled { table["enabled"] = toml_edit::value(v); } - if let Some(v) = coalesce.debounce_ms { table["debounce_ms"] = toml_edit::value(v as i64); } - if let Some(v) = coalesce.max_wait_ms { table["max_wait_ms"] = toml_edit::value(v as i64); } - if let Some(v) = coalesce.min_messages { table["min_messages"] = toml_edit::value(v as i64); } - if let Some(v) = coalesce.multi_user_only { table["multi_user_only"] = toml_edit::value(v); } + if let Some(v) = coalesce.enabled { + table["enabled"] = toml_edit::value(v); + } + if let Some(v) = coalesce.debounce_ms { + table["debounce_ms"] = toml_edit::value(v as i64); + } + if let Some(v) = coalesce.max_wait_ms { + table["max_wait_ms"] = toml_edit::value(v as i64); + } + if let Some(v) = coalesce.min_messages { + table["min_messages"] = toml_edit::value(v as i64); + } + if let Some(v) = coalesce.multi_user_only { + table["multi_user_only"] = toml_edit::value(v); + } Ok(()) } -fn update_memory_persistence_table(doc: &mut toml_edit::DocumentMut, agent_idx: usize, memory_persistence: &MemoryPersistenceUpdate) -> Result<(), StatusCode> { +fn update_memory_persistence_table( + doc: &mut toml_edit::DocumentMut, + agent_idx: usize, + memory_persistence: &MemoryPersistenceUpdate, +) -> Result<(), StatusCode> { let agent = get_agent_table_mut(doc, agent_idx)?; let table = get_or_create_subtable(agent, "memory_persistence"); - if let Some(v) = memory_persistence.enabled { table["enabled"] = toml_edit::value(v); } - if let Some(v) = memory_persistence.message_interval { table["message_interval"] = toml_edit::value(v as i64); } + if let Some(v) = memory_persistence.enabled { + table["enabled"] = toml_edit::value(v); + } + if let Some(v) = memory_persistence.message_interval { + table["message_interval"] = toml_edit::value(v as i64); + } Ok(()) } -fn update_browser_table(doc: &mut toml_edit::DocumentMut, agent_idx: usize, browser: &BrowserUpdate) -> Result<(), StatusCode> { +fn update_browser_table( + doc: &mut toml_edit::DocumentMut, + agent_idx: usize, + browser: &BrowserUpdate, +) -> Result<(), StatusCode> { let agent = get_agent_table_mut(doc, agent_idx)?; let table = get_or_create_subtable(agent, "browser"); - if let Some(v) = browser.enabled { table["enabled"] = toml_edit::value(v); } - if let Some(v) = browser.headless { table["headless"] = toml_edit::value(v); } - if let Some(v) = browser.evaluate_enabled { table["evaluate_enabled"] = toml_edit::value(v); } + if let Some(v) = browser.enabled { + table["enabled"] = toml_edit::value(v); + } + if let Some(v) = browser.headless { + table["headless"] = toml_edit::value(v); + } + if let Some(v) = browser.evaluate_enabled { + table["evaluate_enabled"] = toml_edit::value(v); + } Ok(()) } /// Update instance-level Discord config at [messaging.discord]. -fn update_discord_table(doc: &mut toml_edit::DocumentMut, discord: &DiscordUpdate) -> Result<(), StatusCode> { - let messaging = doc.get_mut("messaging") +fn update_discord_table( + doc: &mut toml_edit::DocumentMut, + discord: &DiscordUpdate, +) -> Result<(), StatusCode> { + let messaging = doc + .get_mut("messaging") .and_then(|m| m.as_table_mut()) .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; - let discord_table = messaging.get_mut("discord") + let discord_table = messaging + .get_mut("discord") .and_then(|d| d.as_table_mut()) .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; @@ -1684,13 +1866,10 @@ async fn cortex_events( StatusCode::INTERNAL_SERVER_ERROR })?; - let total = logger - .count_events(event_type_ref) - .await - .map_err(|error| { - tracing::warn!(%error, agent_id = %query.agent_id, "failed to count cortex events"); - StatusCode::INTERNAL_SERVER_ERROR - })?; + let total = logger.count_events(event_type_ref).await.map_err(|error| { + tracing::warn!(%error, agent_id = %query.agent_id, "failed to count cortex events"); + StatusCode::INTERNAL_SERVER_ERROR + })?; Ok(Json(CortexEventsResponse { events, total })) } @@ -1795,13 +1974,10 @@ async fn list_cron_jobs( let stores = state.cron_stores.load(); let store = stores.get(&query.agent_id).ok_or(StatusCode::NOT_FOUND)?; - let configs = store - .load_all_unfiltered() - .await - .map_err(|error| { - tracing::warn!(%error, agent_id = %query.agent_id, "failed to load cron jobs"); - StatusCode::INTERNAL_SERVER_ERROR - })?; + let configs = store.load_all_unfiltered().await.map_err(|error| { + tracing::warn!(%error, agent_id = %query.agent_id, "failed to load cron jobs"); + StatusCode::INTERNAL_SERVER_ERROR + })?; let mut jobs = Vec::new(); for config in configs { @@ -1864,7 +2040,9 @@ async fn create_or_update_cron( let schedulers = state.cron_schedulers.load(); let store = stores.get(&request.agent_id).ok_or(StatusCode::NOT_FOUND)?; - let scheduler = schedulers.get(&request.agent_id).ok_or(StatusCode::NOT_FOUND)?; + let scheduler = schedulers + .get(&request.agent_id) + .ok_or(StatusCode::NOT_FOUND)?; let active_hours = match (request.active_start_hour, request.active_end_hour) { (Some(start), Some(end)) => Some((start, end)), @@ -1907,7 +2085,9 @@ async fn delete_cron( let store = stores.get(&query.agent_id).ok_or(StatusCode::NOT_FOUND)?; let schedulers = state.cron_schedulers.load(); - let scheduler = schedulers.get(&query.agent_id).ok_or(StatusCode::NOT_FOUND)?; + let scheduler = schedulers + .get(&query.agent_id) + .ok_or(StatusCode::NOT_FOUND)?; // Unregister from scheduler first scheduler.unregister(&query.cron_id).await; @@ -1930,7 +2110,9 @@ async fn trigger_cron( Json(request): Json, ) -> Result, StatusCode> { let schedulers = state.cron_schedulers.load(); - let scheduler = schedulers.get(&request.agent_id).ok_or(StatusCode::NOT_FOUND)?; + let scheduler = schedulers + .get(&request.agent_id) + .ok_or(StatusCode::NOT_FOUND)?; scheduler.trigger_now(&request.cron_id).await.map_err(|error| { tracing::warn!(%error, agent_id = %request.agent_id, cron_id = %request.cron_id, "failed to trigger cron job"); @@ -1952,7 +2134,9 @@ async fn toggle_cron( let store = stores.get(&request.agent_id).ok_or(StatusCode::NOT_FOUND)?; let schedulers = state.cron_schedulers.load(); - let scheduler = schedulers.get(&request.agent_id).ok_or(StatusCode::NOT_FOUND)?; + let scheduler = schedulers + .get(&request.agent_id) + .ok_or(StatusCode::NOT_FOUND)?; // Update in database first store.update_enabled(&request.cron_id, request.enabled).await.map_err(|error| { @@ -1966,7 +2150,11 @@ async fn toggle_cron( StatusCode::INTERNAL_SERVER_ERROR })?; - let status = if request.enabled { "enabled" } else { "disabled" }; + let status = if request.enabled { + "enabled" + } else { + "disabled" + }; Ok(Json(CronActionResponse { success: true, message: format!("Cron job '{}' {}", request.cron_id, status), @@ -1994,13 +2182,19 @@ async fn cancel_process( Json(request): Json, ) -> Result, StatusCode> { let states = state.channel_states.read().await; - let channel_state = states.get(&request.channel_id).ok_or(StatusCode::NOT_FOUND)?; + let channel_state = states + .get(&request.channel_id) + .ok_or(StatusCode::NOT_FOUND)?; match request.process_type.as_str() { "worker" => { - let worker_id: crate::WorkerId = request.process_id.parse() + let worker_id: crate::WorkerId = request + .process_id + .parse() .map_err(|_| StatusCode::BAD_REQUEST)?; - channel_state.cancel_worker(worker_id).await + channel_state + .cancel_worker(worker_id) + .await .map_err(|_| StatusCode::NOT_FOUND)?; Ok(Json(CancelProcessResponse { success: true, @@ -2008,9 +2202,13 @@ async fn cancel_process( })) } "branch" => { - let branch_id: crate::BranchId = request.process_id.parse() + let branch_id: crate::BranchId = request + .process_id + .parse() .map_err(|_| StatusCode::BAD_REQUEST)?; - channel_state.cancel_branch(branch_id).await + channel_state + .cancel_branch(branch_id) + .await .map_err(|_| StatusCode::NOT_FOUND)?; Ok(Json(CancelProcessResponse { success: true, @@ -2028,6 +2226,7 @@ struct ProviderStatus { anthropic: bool, openai: bool, openrouter: bool, + ollama: bool, zhipu: bool, groq: bool, together: bool, @@ -2062,7 +2261,20 @@ async fn get_providers( let config_path = state.config_path.read().await.clone(); // Check which providers have keys by reading the config - let (anthropic, openai, openrouter, zhipu, groq, together, fireworks, deepseek, xai, mistral, opencode_zen) = if config_path.exists() { + let ( + anthropic, + openai, + openrouter, + ollama, + zhipu, + groq, + together, + fireworks, + deepseek, + xai, + mistral, + opencode_zen, + ) = if config_path.exists() { let content = tokio::fs::read_to_string(&config_path) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; @@ -2091,6 +2303,7 @@ async fn get_providers( has_key("anthropic_key", "ANTHROPIC_API_KEY"), has_key("openai_key", "OPENAI_API_KEY"), has_key("openrouter_key", "OPENROUTER_API_KEY"), + has_key("ollama_key", "OLLAMA_API_KEY"), has_key("zhipu_key", "ZHIPU_API_KEY"), has_key("groq_key", "GROQ_API_KEY"), has_key("together_key", "TOGETHER_API_KEY"), @@ -2106,6 +2319,7 @@ async fn get_providers( std::env::var("ANTHROPIC_API_KEY").is_ok(), std::env::var("OPENAI_API_KEY").is_ok(), std::env::var("OPENROUTER_API_KEY").is_ok(), + std::env::var("OLLAMA_API_KEY").is_ok(), std::env::var("ZHIPU_API_KEY").is_ok(), std::env::var("GROQ_API_KEY").is_ok(), std::env::var("TOGETHER_API_KEY").is_ok(), @@ -2121,6 +2335,7 @@ async fn get_providers( anthropic, openai, openrouter, + ollama, zhipu, groq, together, @@ -2130,9 +2345,10 @@ async fn get_providers( mistral, opencode_zen, }; - let has_any = providers.anthropic - || providers.openai - || providers.openrouter + let has_any = providers.anthropic + || providers.openai + || providers.openrouter + || providers.ollama || providers.zhipu || providers.groq || providers.together @@ -2153,6 +2369,7 @@ async fn update_provider( "anthropic" => "anthropic_key", "openai" => "openai_key", "openrouter" => "openrouter_key", + "ollama" => "ollama_key", "zhipu" => "zhipu_key", "groq" => "groq_key", "together" => "together_key", @@ -2210,8 +2427,7 @@ async fn update_provider( .and_then(|v| v.as_str()) .unwrap_or("anthropic/claude-sonnet-4-20250514"); - let current_provider = - crate::llm::routing::provider_from_model(current_channel); + let current_provider = crate::llm::routing::provider_from_model(current_channel); // Check if the current routing provider has a key configured let has_key_for_current = match current_provider { @@ -2230,6 +2446,11 @@ async fn update_provider( .and_then(|l| l.get("openrouter_key")) .and_then(|v| v.as_str()) .is_some_and(|s| !s.is_empty()), + "ollama" => doc + .get("llm") + .and_then(|l| l.get("ollama_key")) + .and_then(|v| v.as_str()) + .is_some_and(|s| !s.is_empty()), "zhipu" => doc .get("llm") .and_then(|l| l.get("zhipu_key")) @@ -2247,19 +2468,19 @@ async fn update_provider( }; if should_set_routing { - let routing = - crate::llm::routing::defaults_for_provider(&request.provider); + let routing = crate::llm::routing::defaults_for_provider(&request.provider); if doc.get("defaults").is_none() { doc["defaults"] = toml_edit::Item::Table(toml_edit::Table::new()); } - + if let Some(defaults) = doc.get_mut("defaults").and_then(|d| d.as_table_mut()) { if defaults.get("routing").is_none() { defaults["routing"] = toml_edit::Item::Table(toml_edit::Table::new()); } - - if let Some(routing_table) = defaults.get_mut("routing").and_then(|r| r.as_table_mut()) { + + if let Some(routing_table) = defaults.get_mut("routing").and_then(|r| r.as_table_mut()) + { routing_table["channel"] = toml_edit::value(&routing.channel); routing_table["branch"] = toml_edit::value(&routing.branch); routing_table["worker"] = toml_edit::value(&routing.worker); @@ -2306,6 +2527,7 @@ async fn delete_provider( "anthropic" => "anthropic_key", "openai" => "openai_key", "openrouter" => "openrouter_key", + "ollama" => "ollama_key", "zhipu" => "zhipu_key", "groq" => "groq_key", "together" => "together_key", @@ -2363,7 +2585,7 @@ struct ModelInfo { id: String, /// Human-readable name name: String, - /// Provider ID ("anthropic", "openrouter", "openai", "zhipu") + /// Provider ID ("anthropic", "openrouter", "openai", "ollama", "zhipu") provider: String, /// Context window size in tokens, if known context_window: Option, @@ -2536,6 +2758,35 @@ fn curated_models() -> Vec { context_window: Some(200_000), curated: true, }, + // Ollama Cloud (OpenAI-compatible) + ModelInfo { + id: "ollama/gpt-oss:120b".into(), + name: "gpt-oss 120B".into(), + provider: "ollama".into(), + context_window: None, + curated: true, + }, + ModelInfo { + id: "ollama/gpt-oss:20b".into(), + name: "gpt-oss 20B".into(), + provider: "ollama".into(), + context_window: None, + curated: true, + }, + ModelInfo { + id: "ollama/qwen3:30b-a3b".into(), + name: "Qwen3 30B A3B".into(), + provider: "ollama".into(), + context_window: None, + curated: true, + }, + ModelInfo { + id: "ollama/devstral:24b".into(), + name: "Devstral 24B".into(), + provider: "ollama".into(), + context_window: None, + curated: true, + }, // Z.ai (GLM) ModelInfo { id: "zhipu/glm-4-plus".into(), @@ -2795,9 +3046,7 @@ fn curated_models() -> Vec { /// In-memory cache for dynamically fetched models. static DYNAMIC_MODELS: std::sync::LazyLock< tokio::sync::RwLock<(Vec, std::time::Instant)>, -> = std::sync::LazyLock::new(|| { - tokio::sync::RwLock::new((Vec::new(), std::time::Instant::now())) -}); +> = std::sync::LazyLock::new(|| tokio::sync::RwLock::new((Vec::new(), std::time::Instant::now()))); const MODEL_CACHE_TTL: std::time::Duration = std::time::Duration::from_secs(3600); @@ -2892,6 +3141,9 @@ async fn configured_providers(config_path: &std::path::Path) -> Vec<&'static str if has_key("openrouter_key", "OPENROUTER_API_KEY") { providers.push("openrouter"); } + if has_key("ollama_key", "OLLAMA_API_KEY") { + providers.push("ollama"); + } if has_key("zhipu_key", "ZHIPU_API_KEY") { providers.push("zhipu"); } @@ -2921,9 +3173,7 @@ async fn configured_providers(config_path: &std::path::Path) -> Vec<&'static str } /// Fetch available models from OpenRouter's API. -async fn fetch_openrouter_models( - config_path: &std::path::Path, -) -> anyhow::Result> { +async fn fetch_openrouter_models(config_path: &std::path::Path) -> anyhow::Result> { let content = tokio::fs::read_to_string(config_path).await?; let doc: toml_edit::DocumentMut = content.parse()?; @@ -3046,13 +3296,17 @@ async fn upload_ingest_file( mut multipart: axum::extract::Multipart, ) -> Result, StatusCode> { let workspaces = state.agent_workspaces.load(); - let workspace = workspaces.get(&query.agent_id).ok_or(StatusCode::NOT_FOUND)?; + let workspace = workspaces + .get(&query.agent_id) + .ok_or(StatusCode::NOT_FOUND)?; let ingest_dir = workspace.join("ingest"); - tokio::fs::create_dir_all(&ingest_dir).await.map_err(|error| { - tracing::warn!(%error, "failed to create ingest directory"); - StatusCode::INTERNAL_SERVER_ERROR - })?; + tokio::fs::create_dir_all(&ingest_dir) + .await + .map_err(|error| { + tracing::warn!(%error, "failed to create ingest directory"); + StatusCode::INTERNAL_SERVER_ERROR + })?; let mut uploaded = Vec::new(); @@ -3089,7 +3343,12 @@ async fn upload_ingest_file( .extension() .and_then(|e| e.to_str()) .unwrap_or("txt"); - let unique = format!("{}-{}.{}", stem, &uuid::Uuid::new_v4().to_string()[..8], ext); + let unique = format!( + "{}-{}.{}", + stem, + &uuid::Uuid::new_v4().to_string()[..8], + ext + ); ingest_dir.join(unique) } else { target @@ -3194,10 +3453,7 @@ async fn messaging_status( .get("token") .and_then(|v| v.as_str()) .is_some_and(|s| !s.is_empty()); - let enabled = d - .get("enabled") - .and_then(|v| v.as_bool()) - .unwrap_or(false); + let enabled = d.get("enabled").and_then(|v| v.as_bool()).unwrap_or(false); PlatformStatus { configured: has_token, enabled: has_token && enabled, @@ -3220,10 +3476,7 @@ async fn messaging_status( .get("app_token") .and_then(|v| v.as_str()) .is_some_and(|t| !t.is_empty()); - let enabled = s - .get("enabled") - .and_then(|v| v.as_bool()) - .unwrap_or(false); + let enabled = s.get("enabled").and_then(|v| v.as_bool()).unwrap_or(false); PlatformStatus { configured: has_bot_token && has_app_token, enabled: has_bot_token && has_app_token && enabled, @@ -3238,10 +3491,7 @@ async fn messaging_status( .get("messaging") .and_then(|m| m.get("webhook")) .map(|w| { - let enabled = w - .get("enabled") - .and_then(|v| v.as_bool()) - .unwrap_or(false); + let enabled = w.get("enabled").and_then(|v| v.as_bool()).unwrap_or(false); PlatformStatus { configured: true, enabled, @@ -3307,12 +3557,7 @@ async fn list_bindings( let filtered: Vec = bindings .into_iter() - .filter(|b| { - query - .agent_id - .as_ref() - .map_or(true, |id| &b.agent_id == id) - }) + .filter(|b| query.agent_id.as_ref().map_or(true, |id| &b.agent_id == id)) .map(|b| BindingResponse { agent_id: b.agent_id, channel: b.channel, @@ -3381,10 +3626,12 @@ async fn create_binding( } let content = if config_path.exists() { - tokio::fs::read_to_string(&config_path).await.map_err(|error| { - tracing::warn!(%error, "failed to read config.toml"); - StatusCode::INTERNAL_SERVER_ERROR - })? + tokio::fs::read_to_string(&config_path) + .await + .map_err(|error| { + tracing::warn!(%error, "failed to read config.toml"); + StatusCode::INTERNAL_SERVER_ERROR + })? } else { String::new() }; @@ -3406,11 +3653,15 @@ async fn create_binding( if doc.get("messaging").is_none() { doc["messaging"] = toml_edit::Item::Table(toml_edit::Table::new()); } - let messaging = doc["messaging"].as_table_mut().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; + let messaging = doc["messaging"] + .as_table_mut() + .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; if !messaging.contains_key("discord") { messaging["discord"] = toml_edit::Item::Table(toml_edit::Table::new()); } - let discord = messaging["discord"].as_table_mut().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; + let discord = messaging["discord"] + .as_table_mut() + .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; discord["enabled"] = toml_edit::value(true); discord["token"] = toml_edit::value(token.as_str()); new_discord_token = Some(token.clone()); @@ -3422,11 +3673,15 @@ async fn create_binding( if doc.get("messaging").is_none() { doc["messaging"] = toml_edit::Item::Table(toml_edit::Table::new()); } - let messaging = doc["messaging"].as_table_mut().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; + let messaging = doc["messaging"] + .as_table_mut() + .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; if !messaging.contains_key("slack") { messaging["slack"] = toml_edit::Item::Table(toml_edit::Table::new()); } - let slack = messaging["slack"].as_table_mut().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; + let slack = messaging["slack"] + .as_table_mut() + .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; slack["enabled"] = toml_edit::value(true); slack["bot_token"] = toml_edit::value(bot_token.as_str()); slack["app_token"] = toml_edit::value(app_token); @@ -3510,8 +3765,10 @@ async fn create_binding( // Rebuild Discord permissions if let Some(discord_config) = &new_config.messaging.discord { - let new_perms = - crate::config::DiscordPermissions::from_config(discord_config, &new_config.bindings); + let new_perms = crate::config::DiscordPermissions::from_config( + discord_config, + &new_config.bindings, + ); let perms = state.discord_permissions.read().await; if let Some(arc_swap) = perms.as_ref() { arc_swap.store(std::sync::Arc::new(new_perms)); @@ -3540,10 +3797,15 @@ async fn create_binding( None => { drop(perms_guard); let perms = crate::config::DiscordPermissions::from_config( - new_config.messaging.discord.as_ref().expect("discord config exists when token is provided"), + new_config + .messaging + .discord + .as_ref() + .expect("discord config exists when token is provided"), &new_config.bindings, ); - let arc_swap = std::sync::Arc::new(arc_swap::ArcSwap::from_pointee(perms)); + let arc_swap = + std::sync::Arc::new(arc_swap::ArcSwap::from_pointee(perms)); state.set_discord_permissions(arc_swap.clone()).await; arc_swap } @@ -3563,16 +3825,22 @@ async fn create_binding( None => { drop(perms_guard); let perms = crate::config::SlackPermissions::from_config( - new_config.messaging.slack.as_ref().expect("slack config exists when tokens are provided"), + new_config + .messaging + .slack + .as_ref() + .expect("slack config exists when tokens are provided"), &new_config.bindings, ); - let arc_swap = std::sync::Arc::new(arc_swap::ArcSwap::from_pointee(perms)); + let arc_swap = + std::sync::Arc::new(arc_swap::ArcSwap::from_pointee(perms)); state.set_slack_permissions(arc_swap.clone()).await; arc_swap } } }; - let adapter = crate::messaging::slack::SlackAdapter::new(&bot_token, &app_token, slack_perms); + let adapter = + crate::messaging::slack::SlackAdapter::new(&bot_token, &app_token, slack_perms); if let Err(error) = manager.register_and_start(adapter).await { tracing::error!(%error, "failed to hot-start slack adapter"); } @@ -3631,7 +3899,7 @@ struct UpdateBindingRequest { original_workspace_id: Option, #[serde(default)] original_chat_id: Option, - + // New values agent_id: String, channel: String, @@ -3662,10 +3930,12 @@ async fn update_binding( return Err(StatusCode::NOT_FOUND); } - let content = tokio::fs::read_to_string(&config_path).await.map_err(|error| { - tracing::warn!(%error, "failed to read config.toml"); - StatusCode::INTERNAL_SERVER_ERROR - })?; + let content = tokio::fs::read_to_string(&config_path) + .await + .map_err(|error| { + tracing::warn!(%error, "failed to read config.toml"); + StatusCode::INTERNAL_SERVER_ERROR + })?; let mut doc: toml_edit::DocumentMut = content.parse().map_err(|error| { tracing::warn!(%error, "failed to parse config.toml"); @@ -3723,16 +3993,18 @@ async fn update_binding( }; // Update the binding in place - let binding = bindings_array.get_mut(idx).ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; - + let binding = bindings_array + .get_mut(idx) + .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; + binding["agent_id"] = toml_edit::value(&request.agent_id); binding["channel"] = toml_edit::value(&request.channel); - + // Clear and set optional fields binding.remove("guild_id"); binding.remove("workspace_id"); binding.remove("chat_id"); - + if let Some(ref guild_id) = request.guild_id { if !guild_id.is_empty() { binding["guild_id"] = toml_edit::value(guild_id); @@ -3748,7 +4020,7 @@ async fn update_binding( binding["chat_id"] = toml_edit::value(chat_id); } } - + // Update arrays if !request.channel_ids.is_empty() { let mut arr = toml_edit::Array::new(); @@ -3759,7 +4031,7 @@ async fn update_binding( } else { binding.remove("channel_ids"); } - + if !request.dm_allowed_users.is_empty() { let mut arr = toml_edit::Array::new(); for id in &request.dm_allowed_users { @@ -3792,8 +4064,10 @@ async fn update_binding( drop(bindings_guard); if let Some(discord_config) = &new_config.messaging.discord { - let new_perms = - crate::config::DiscordPermissions::from_config(discord_config, &new_config.bindings); + let new_perms = crate::config::DiscordPermissions::from_config( + discord_config, + &new_config.bindings, + ); let perms = state.discord_permissions.read().await; if let Some(arc_swap) = perms.as_ref() { arc_swap.store(std::sync::Arc::new(new_perms)); @@ -3826,10 +4100,12 @@ async fn delete_binding( return Err(StatusCode::NOT_FOUND); } - let content = tokio::fs::read_to_string(&config_path).await.map_err(|error| { - tracing::warn!(%error, "failed to read config.toml"); - StatusCode::INTERNAL_SERVER_ERROR - })?; + let content = tokio::fs::read_to_string(&config_path) + .await + .map_err(|error| { + tracing::warn!(%error, "failed to read config.toml"); + StatusCode::INTERNAL_SERVER_ERROR + })?; let mut doc: toml_edit::DocumentMut = content.parse().map_err(|error| { tracing::warn!(%error, "failed to parse config.toml"); @@ -3910,8 +4186,10 @@ async fn delete_binding( drop(bindings_guard); if let Some(discord_config) = &new_config.messaging.discord { - let new_perms = - crate::config::DiscordPermissions::from_config(discord_config, &new_config.bindings); + let new_perms = crate::config::DiscordPermissions::from_config( + discord_config, + &new_config.bindings, + ); let perms = state.discord_permissions.read().await; if let Some(arc_swap) = perms.as_ref() { arc_swap.store(std::sync::Arc::new(new_perms)); @@ -4001,117 +4279,132 @@ async fn get_global_settings( State(state): State>, ) -> Result, StatusCode> { let config_path = state.config_path.read().await.clone(); - - let (brave_search_key, api_enabled, api_port, api_bind, worker_log_mode, opencode) = if config_path.exists() { - let content = tokio::fs::read_to_string(&config_path) - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - let doc: toml_edit::DocumentMut = content - .parse() - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - - let brave_search = doc - .get("defaults") - .and_then(|d| d.get("brave_search_key")) - .and_then(|v| v.as_str()) - .map(|s| { - if let Some(var) = s.strip_prefix("env:") { - std::env::var(var).ok() - } else { - Some(s.to_string()) - } - }) - .flatten(); - - let api_enabled = doc - .get("api") - .and_then(|a| a.get("enabled")) - .and_then(|v| v.as_bool()) - .unwrap_or(true); - - let api_port = doc - .get("api") - .and_then(|a| a.get("port")) - .and_then(|v| v.as_integer()) - .and_then(|i| u16::try_from(i).ok()) - .unwrap_or(19898); - - let api_bind = doc - .get("api") - .and_then(|a| a.get("bind")) - .and_then(|v| v.as_str()) - .unwrap_or("127.0.0.1") - .to_string(); - - let worker_log_mode = doc - .get("defaults") - .and_then(|d| d.get("worker_log_mode")) - .and_then(|v| v.as_str()) - .unwrap_or("errors_only") - .to_string(); - - let opencode_table = doc.get("defaults").and_then(|d| d.get("opencode")); - let opencode_perms = opencode_table.and_then(|o| o.get("permissions")); - let opencode = OpenCodeSettingsResponse { - enabled: opencode_table - .and_then(|o| o.get("enabled")) - .and_then(|v| v.as_bool()) - .unwrap_or(false), - path: opencode_table - .and_then(|o| o.get("path")) + + let (brave_search_key, api_enabled, api_port, api_bind, worker_log_mode, opencode) = + if config_path.exists() { + let content = tokio::fs::read_to_string(&config_path) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + let doc: toml_edit::DocumentMut = content + .parse() + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + let brave_search = doc + .get("defaults") + .and_then(|d| d.get("brave_search_key")) .and_then(|v| v.as_str()) - .unwrap_or("opencode") - .to_string(), - max_servers: opencode_table - .and_then(|o| o.get("max_servers")) - .and_then(|v| v.as_integer()) - .and_then(|i| usize::try_from(i).ok()) - .unwrap_or(5), - server_startup_timeout_secs: opencode_table - .and_then(|o| o.get("server_startup_timeout_secs")) - .and_then(|v| v.as_integer()) - .and_then(|i| u64::try_from(i).ok()) - .unwrap_or(30), - max_restart_retries: opencode_table - .and_then(|o| o.get("max_restart_retries")) + .map(|s| { + if let Some(var) = s.strip_prefix("env:") { + std::env::var(var).ok() + } else { + Some(s.to_string()) + } + }) + .flatten(); + + let api_enabled = doc + .get("api") + .and_then(|a| a.get("enabled")) + .and_then(|v| v.as_bool()) + .unwrap_or(true); + + let api_port = doc + .get("api") + .and_then(|a| a.get("port")) .and_then(|v| v.as_integer()) - .and_then(|i| u32::try_from(i).ok()) - .unwrap_or(5), - permissions: OpenCodePermissionsResponse { - edit: opencode_perms - .and_then(|p| p.get("edit")) - .and_then(|v| v.as_str()) - .unwrap_or("allow") - .to_string(), - bash: opencode_perms - .and_then(|p| p.get("bash")) - .and_then(|v| v.as_str()) - .unwrap_or("allow") - .to_string(), - webfetch: opencode_perms - .and_then(|p| p.get("webfetch")) + .and_then(|i| u16::try_from(i).ok()) + .unwrap_or(19898); + + let api_bind = doc + .get("api") + .and_then(|a| a.get("bind")) + .and_then(|v| v.as_str()) + .unwrap_or("127.0.0.1") + .to_string(); + + let worker_log_mode = doc + .get("defaults") + .and_then(|d| d.get("worker_log_mode")) + .and_then(|v| v.as_str()) + .unwrap_or("errors_only") + .to_string(); + + let opencode_table = doc.get("defaults").and_then(|d| d.get("opencode")); + let opencode_perms = opencode_table.and_then(|o| o.get("permissions")); + let opencode = OpenCodeSettingsResponse { + enabled: opencode_table + .and_then(|o| o.get("enabled")) + .and_then(|v| v.as_bool()) + .unwrap_or(false), + path: opencode_table + .and_then(|o| o.get("path")) .and_then(|v| v.as_str()) - .unwrap_or("allow") + .unwrap_or("opencode") .to_string(), - }, + max_servers: opencode_table + .and_then(|o| o.get("max_servers")) + .and_then(|v| v.as_integer()) + .and_then(|i| usize::try_from(i).ok()) + .unwrap_or(5), + server_startup_timeout_secs: opencode_table + .and_then(|o| o.get("server_startup_timeout_secs")) + .and_then(|v| v.as_integer()) + .and_then(|i| u64::try_from(i).ok()) + .unwrap_or(30), + max_restart_retries: opencode_table + .and_then(|o| o.get("max_restart_retries")) + .and_then(|v| v.as_integer()) + .and_then(|i| u32::try_from(i).ok()) + .unwrap_or(5), + permissions: OpenCodePermissionsResponse { + edit: opencode_perms + .and_then(|p| p.get("edit")) + .and_then(|v| v.as_str()) + .unwrap_or("allow") + .to_string(), + bash: opencode_perms + .and_then(|p| p.get("bash")) + .and_then(|v| v.as_str()) + .unwrap_or("allow") + .to_string(), + webfetch: opencode_perms + .and_then(|p| p.get("webfetch")) + .and_then(|v| v.as_str()) + .unwrap_or("allow") + .to_string(), + }, + }; + + ( + brave_search, + api_enabled, + api_port, + api_bind, + worker_log_mode, + opencode, + ) + } else { + ( + None, + true, + 19898, + "127.0.0.1".to_string(), + "errors_only".to_string(), + OpenCodeSettingsResponse { + enabled: false, + path: "opencode".to_string(), + max_servers: 5, + server_startup_timeout_secs: 30, + max_restart_retries: 5, + permissions: OpenCodePermissionsResponse { + edit: "allow".to_string(), + bash: "allow".to_string(), + webfetch: "allow".to_string(), + }, + }, + ) }; - (brave_search, api_enabled, api_port, api_bind, worker_log_mode, opencode) - } else { - (None, true, 19898, "127.0.0.1".to_string(), "errors_only".to_string(), OpenCodeSettingsResponse { - enabled: false, - path: "opencode".to_string(), - max_servers: 5, - server_startup_timeout_secs: 30, - max_restart_retries: 5, - permissions: OpenCodePermissionsResponse { - edit: "allow".to_string(), - bash: "allow".to_string(), - webfetch: "allow".to_string(), - }, - }) - }; - Ok(Json(GlobalSettingsResponse { brave_search_key: brave_search_key, api_enabled, @@ -4127,7 +4420,7 @@ async fn update_global_settings( Json(request): Json, ) -> Result, StatusCode> { let config_path = state.config_path.read().await.clone(); - + let content = if config_path.exists() { tokio::fs::read_to_string(&config_path) .await @@ -4135,13 +4428,13 @@ async fn update_global_settings( } else { String::new() }; - + let mut doc: toml_edit::DocumentMut = content .parse() .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - + let mut requires_restart = false; - + // Update brave_search_key if let Some(key) = request.brave_search_key { if doc.get("defaults").is_none() { @@ -4155,15 +4448,15 @@ async fn update_global_settings( doc["defaults"]["brave_search_key"] = toml_edit::value(key); } } - + // Update API settings (requires restart) if request.api_enabled.is_some() || request.api_port.is_some() || request.api_bind.is_some() { requires_restart = true; - + if doc.get("api").is_none() { doc["api"] = toml_edit::Item::Table(toml_edit::Table::new()); } - + if let Some(enabled) = request.api_enabled { doc["api"]["enabled"] = toml_edit::value(enabled); } @@ -4174,7 +4467,7 @@ async fn update_global_settings( doc["api"]["bind"] = toml_edit::value(bind); } } - + // Update worker_log_mode if let Some(mode) = request.worker_log_mode { // Validate the mode @@ -4185,13 +4478,13 @@ async fn update_global_settings( requires_restart: false, })); } - + if doc.get("defaults").is_none() { doc["defaults"] = toml_edit::Item::Table(toml_edit::Table::new()); } doc["defaults"]["worker_log_mode"] = toml_edit::value(mode); } - + // Update OpenCode settings if let Some(opencode) = request.opencode { if doc.get("defaults").is_none() { @@ -4200,7 +4493,7 @@ async fn update_global_settings( if doc["defaults"].get("opencode").is_none() { doc["defaults"]["opencode"] = toml_edit::Item::Table(toml_edit::Table::new()); } - + if let Some(enabled) = opencode.enabled { doc["defaults"]["opencode"]["enabled"] = toml_edit::value(enabled); } @@ -4211,14 +4504,16 @@ async fn update_global_settings( doc["defaults"]["opencode"]["max_servers"] = toml_edit::value(max_servers as i64); } if let Some(timeout) = opencode.server_startup_timeout_secs { - doc["defaults"]["opencode"]["server_startup_timeout_secs"] = toml_edit::value(timeout as i64); + doc["defaults"]["opencode"]["server_startup_timeout_secs"] = + toml_edit::value(timeout as i64); } if let Some(retries) = opencode.max_restart_retries { doc["defaults"]["opencode"]["max_restart_retries"] = toml_edit::value(retries as i64); } if let Some(permissions) = opencode.permissions { if doc["defaults"]["opencode"].get("permissions").is_none() { - doc["defaults"]["opencode"]["permissions"] = toml_edit::Item::Table(toml_edit::Table::new()); + doc["defaults"]["opencode"]["permissions"] = + toml_edit::Item::Table(toml_edit::Table::new()); } if let Some(edit) = permissions.edit { doc["defaults"]["opencode"]["permissions"]["edit"] = toml_edit::value(edit); @@ -4235,13 +4530,13 @@ async fn update_global_settings( tokio::fs::write(&config_path, doc.to_string()) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - + let message = if requires_restart { "Settings updated. API server changes require a restart to take effect.".to_string() } else { "Settings updated successfully.".to_string() }; - + Ok(Json(GlobalSettingsUpdateResponse { success: true, message, @@ -4385,12 +4680,7 @@ async fn static_handler(uri: Uri) -> Response { // SPA fallback if let Some(content) = InterfaceAssets::get("index.html") { - return Html( - std::str::from_utf8(&content.data) - .unwrap_or("") - .to_string(), - ) - .into_response(); + return Html(std::str::from_utf8(&content.data).unwrap_or("").to_string()).into_response(); } (StatusCode::NOT_FOUND, "not found").into_response() From a5435a1e011977976c06f28d1afeb91827be0041 Mon Sep 17 00:00:00 2001 From: nmb Date: Tue, 17 Feb 2026 18:39:41 +0100 Subject: [PATCH 11/11] feat(interface): add workers live panel - replace the workers placeholder route with a live worker monitoring panel\n- show active worker status, current tool usage, runtime, and cancel actions\n- add recent worker result cards for quick progress visibility --- interface/src/router.tsx | 8 +- interface/src/routes/AgentWorkers.tsx | 294 ++++++++++++++++++++++++++ 2 files changed, 298 insertions(+), 4 deletions(-) create mode 100644 interface/src/routes/AgentWorkers.tsx diff --git a/interface/src/router.tsx b/interface/src/router.tsx index f9e2713d9..d686ffbc4 100644 --- a/interface/src/router.tsx +++ b/interface/src/router.tsx @@ -18,6 +18,7 @@ import {AgentMemories} from "@/routes/AgentMemories"; import {AgentConfig} from "@/routes/AgentConfig"; import {AgentCron} from "@/routes/AgentCron"; import {AgentIngest} from "@/routes/AgentIngest"; +import {AgentWorkers} from "@/routes/AgentWorkers"; import {Settings} from "@/routes/Settings"; import {useLiveContext} from "@/hooks/useLiveContext"; import {AgentTabs} from "@/components/AgentTabs"; @@ -170,13 +171,12 @@ const agentWorkersRoute = createRoute({ path: "/agents/$agentId/workers", component: function AgentWorkersPage() { const {agentId} = agentWorkersRoute.useParams(); + const {liveStates, channels} = useLiveContext(); return (
-
-

- Workers control interface coming soon -

+
+
); diff --git a/interface/src/routes/AgentWorkers.tsx b/interface/src/routes/AgentWorkers.tsx new file mode 100644 index 000000000..f03f8cb2d --- /dev/null +++ b/interface/src/routes/AgentWorkers.tsx @@ -0,0 +1,294 @@ +import { useMemo, useState } from "react"; +import { Link } from "@tanstack/react-router"; +import { AnimatePresence, motion } from "framer-motion"; +import { api, type ChannelInfo, type TimelineWorkerRun } from "@/api/client"; +import { LiveDuration } from "@/components/LiveDuration"; +import type { ActiveWorker, ChannelLiveState } from "@/hooks/useChannelLiveState"; +import { formatTimeAgo, formatTimestamp } from "@/lib/format"; +import { Badge, Button } from "@/ui"; + +interface AgentWorkersProps { + agentId: string; + channels: ChannelInfo[]; + liveStates: Record; +} + +interface ActiveWorkerView extends ActiveWorker { + channel: ChannelInfo; +} + +interface CompletedWorkerView extends TimelineWorkerRun { + channel: ChannelInfo; +} + +function statusVariant(status: string): "default" | "outline" | "amber" | "red" | "green" { + const normalized = status.toLowerCase(); + if (normalized.includes("fail") || normalized.includes("error") || normalized.includes("killed")) { + return "red"; + } + if (normalized.includes("done") || normalized.includes("complete") || normalized.includes("success")) { + return "green"; + } + if (normalized.includes("run") || normalized.includes("start") || normalized.includes("tool")) { + return "amber"; + } + if (normalized.includes("wait") || normalized.includes("queue")) { + return "outline"; + } + return "default"; +} + +function summarizeResult(result: string | null): string { + if (!result) return "No result text captured."; + const compact = result.replace(/\s+/g, " ").trim(); + if (compact.length <= 220) return compact; + return `${compact.slice(0, 220)}...`; +} + +export function AgentWorkers({ agentId, channels, liveStates }: AgentWorkersProps) { + const [cancellingIds, setCancellingIds] = useState>({}); + + const agentChannels = useMemo( + () => channels.filter((channel) => channel.agent_id === agentId), + [agentId, channels], + ); + + const activeWorkers = useMemo(() => { + return agentChannels + .flatMap((channel) => + Object.values(liveStates[channel.id]?.workers ?? {}).map((worker) => ({ + ...worker, + channel, + })), + ) + .sort((a, b) => a.startedAt - b.startedAt); + }, [agentChannels, liveStates]); + + const completedWorkers = useMemo(() => { + return agentChannels + .flatMap((channel) => { + const timeline = liveStates[channel.id]?.timeline ?? []; + return timeline + .filter((item): item is TimelineWorkerRun => item.type === "worker_run") + .filter((item) => Boolean(item.completed_at || item.status === "done" || item.result)) + .map((item) => ({ ...item, channel })); + }) + .sort((a, b) => { + const aTime = new Date(a.completed_at ?? a.started_at).getTime(); + const bTime = new Date(b.completed_at ?? b.started_at).getTime(); + return bTime - aTime; + }) + .slice(0, 24); + }, [agentChannels, liveStates]); + + const activeChannelCount = useMemo( + () => new Set(activeWorkers.map((worker) => worker.channel.id)).size, + [activeWorkers], + ); + const workersUsingTools = useMemo( + () => activeWorkers.filter((worker) => worker.currentTool !== null).length, + [activeWorkers], + ); + const totalToolCalls = useMemo( + () => activeWorkers.reduce((sum, worker) => sum + worker.toolCalls, 0), + [activeWorkers], + ); + + const handleCancel = (worker: ActiveWorkerView) => { + setCancellingIds((previous) => ({ ...previous, [worker.id]: true })); + api.cancelProcess(worker.channel.id, "worker", worker.id) + .catch((error) => { + console.warn("Failed to cancel worker:", error); + }) + .finally(() => { + setCancellingIds((previous) => { + const { [worker.id]: _, ...remaining } = previous; + return remaining; + }); + }); + }; + + return ( +
+
+
+
+

Workers Live Panel

+

+ Monitor every active worker, current tool usage, and recent outputs across this agent. +

+
+
+
+
Active
+
{activeWorkers.length}
+
+
+
Channels
+
{activeChannelCount}
+
+
+
Tool Calls
+
{totalToolCalls}
+
+
+
+
+ +
+
+
+
+

Active Workers

+ 0 ? "amber" : "outline"} size="sm"> + {workersUsingTools} using tools + +
+ + {activeWorkers.length === 0 ? ( +
+

No workers are active right now.

+

+ Spawn a worker from a channel and it will appear here with live status updates. +

+
+ ) : ( +
+ + {activeWorkers.map((worker) => ( + +
+
+
+
+
+
{worker.task}
+
{worker.id}
+
+ +
+ +
+ {worker.status} + + + + {worker.toolCalls} tool calls +
+ +
+
Current Tool
+
+ + + {worker.currentTool ?? "Waiting for next tool call"} + +
+
+ +
+ + {worker.channel.display_name ?? worker.channel.id} + + {formatTimestamp(worker.startedAt)} +
+
+ + ))} + +
+ )} +
+ +
+
+

Live Telemetry

+

+ Active workers currently inside a tool call. +

+
+ {activeWorkers.filter((worker) => worker.currentTool).length === 0 ? ( +
+ No tool calls in-flight. +
+ ) : ( + activeWorkers + .filter((worker) => worker.currentTool) + .map((worker) => ( +
+
+
{worker.currentTool}
+
+ +
+
+
+ {worker.task} +
+
+ )) + )} +
+
+ +
+

Recent Worker Results

+

+ Most recent worker completions from loaded channel history. +

+
+ {completedWorkers.length === 0 ? ( +
+ No recent worker completions yet. +
+ ) : ( + completedWorkers.map((workerRun) => ( +
+
+ {workerRun.status} + + {workerRun.channel.display_name ?? workerRun.channel.id} + + + {formatTimeAgo(workerRun.completed_at ?? workerRun.started_at)} + +
+
{workerRun.task}
+
+ {summarizeResult(workerRun.result)} +
+
+ )) + )} +
+
+
+
+
+
+ ); +}