From a697f1c5b0b25f1fd52eff7804e76334ab7d332e Mon Sep 17 00:00:00 2001 From: nmb Date: Tue, 17 Feb 2026 16:30:29 +0100 Subject: [PATCH 1/8] =?UTF-8?q?=E2=9C=A8=20feat(llm):=20add=20Ollama=20Clo?= =?UTF-8?q?ud=20provider=20support?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 +- interface/src/lib/providerIcons.tsx | 1 + src/config.rs | 43 +++++++++++++++++++---------- src/llm/providers.rs | 14 ++++++++-- src/llm/routing.rs | 15 ++++++++++ 5 files changed, 57 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index fa7b8920c..b791364e4 100644 --- a/README.md +++ b/README.md @@ -311,7 +311,7 @@ Read the full vision in [docs/spacedrive.md](docs/spacedrive.md). ### Prerequisites - **Rust** 1.85+ ([rustup](https://rustup.rs/)) -- An LLM API key from any supported provider (Anthropic, OpenAI, OpenRouter, Z.ai, Groq, Together, Fireworks, DeepSeek, xAI, Mistral, or OpenCode Zen) +- An LLM API key from any supported provider (Anthropic, OpenAI, OpenRouter, Ollama Cloud, Z.ai, Groq, Together, Fireworks, DeepSeek, xAI, Mistral, or OpenCode Zen) ### Build and Run diff --git a/interface/src/lib/providerIcons.tsx b/interface/src/lib/providerIcons.tsx index 8d6666bd6..76ba6fdd1 100644 --- a/interface/src/lib/providerIcons.tsx +++ b/interface/src/lib/providerIcons.tsx @@ -60,6 +60,7 @@ export function ProviderIcon({ provider, className = "text-ink-faint", size = 24 anthropic: Anthropic, openai: OpenAI, openrouter: OpenRouter, + ollama: OpenRouter, groq: Groq, mistral: Mistral, deepseek: DeepSeek, diff --git a/src/config.rs b/src/config.rs index 94c289a86..d930df304 100644 --- a/src/config.rs +++ b/src/config.rs @@ -55,6 +55,7 @@ pub struct LlmConfig { pub anthropic_key: Option, pub openai_key: Option, pub openrouter_key: Option, + pub ollama_key: Option, pub zhipu_key: Option, pub groq_key: Option, pub together_key: Option, @@ -68,9 +69,10 @@ pub struct LlmConfig { impl LlmConfig { /// Check if any provider key is configured. pub fn has_any_key(&self) -> bool { - self.anthropic_key.is_some() - || self.openai_key.is_some() - || self.openrouter_key.is_some() + self.anthropic_key.is_some() + || self.openai_key.is_some() + || self.openrouter_key.is_some() + || self.ollama_key.is_some() || self.zhipu_key.is_some() || self.groq_key.is_some() || self.together_key.is_some() @@ -869,6 +871,7 @@ struct TomlLlmConfig { anthropic_key: Option, openai_key: Option, openrouter_key: Option, + ollama_key: Option, zhipu_key: Option, groq_key: Option, together_key: Option, @@ -1146,6 +1149,7 @@ impl Config { std::env::var("ANTHROPIC_API_KEY").is_err() && std::env::var("OPENAI_API_KEY").is_err() && std::env::var("OPENROUTER_API_KEY").is_err() + && std::env::var("OLLAMA_API_KEY").is_err() && std::env::var("OPENCODE_ZEN_API_KEY").is_err() } @@ -1183,6 +1187,7 @@ impl Config { anthropic_key: std::env::var("ANTHROPIC_API_KEY").ok(), openai_key: std::env::var("OPENAI_API_KEY").ok(), openrouter_key: std::env::var("OPENROUTER_API_KEY").ok(), + ollama_key: std::env::var("OLLAMA_API_KEY").ok(), zhipu_key: std::env::var("ZHIPU_API_KEY").ok(), groq_key: std::env::var("GROQ_API_KEY").ok(), together_key: std::env::var("TOGETHER_API_KEY").ok(), @@ -1239,8 +1244,8 @@ impl Config { /// Validate a raw TOML string as a valid Spacebot config. /// Returns Ok(()) if the config is structurally valid, or an error describing what's wrong. pub fn validate_toml(content: &str) -> Result<()> { - let toml_config: TomlConfig = toml::from_str(content) - .context("failed to parse config TOML")?; + let toml_config: TomlConfig = + toml::from_str(content).context("failed to parse config TOML")?; // Run full conversion to catch semantic errors (env resolution, defaults, etc.) let instance_dir = Self::default_instance_dir(); Self::from_toml(toml_config, instance_dir)?; @@ -1267,6 +1272,12 @@ impl Config { .as_deref() .and_then(resolve_env_value) .or_else(|| std::env::var("OPENROUTER_API_KEY").ok()), + ollama_key: toml + .llm + .ollama_key + .as_deref() + .and_then(resolve_env_value) + .or_else(|| std::env::var("OLLAMA_API_KEY").ok()), zhipu_key: toml .llm .zhipu_key @@ -1939,7 +1950,9 @@ pub fn spawn_file_watcher( // Only forward data modification events, not metadata/access changes use notify::EventKind; match &event.kind { - EventKind::Create(_) | EventKind::Modify(notify::event::ModifyKind::Data(_)) | EventKind::Remove(_) => { + EventKind::Create(_) + | EventKind::Modify(notify::event::ModifyKind::Data(_)) + | EventKind::Remove(_) => { let _ = tx.send(event); } // Also forward Any/Other modify events (some backends don't distinguish) @@ -2248,6 +2261,7 @@ pub fn run_onboarding() -> anyhow::Result> { "Anthropic", "OpenRouter", "OpenAI", + "Ollama Cloud", "Z.ai (GLM)", "Groq", "Together AI", @@ -2267,14 +2281,15 @@ pub fn run_onboarding() -> anyhow::Result> { 0 => ("Anthropic API key", "anthropic_key", "anthropic"), 1 => ("OpenRouter API key", "openrouter_key", "openrouter"), 2 => ("OpenAI API key", "openai_key", "openai"), - 3 => ("Z.ai (GLM) API key", "zhipu_key", "zhipu"), - 4 => ("Groq API key", "groq_key", "groq"), - 5 => ("Together AI API key", "together_key", "together"), - 6 => ("Fireworks AI API key", "fireworks_key", "fireworks"), - 7 => ("DeepSeek API key", "deepseek_key", "deepseek"), - 8 => ("xAI API key", "xai_key", "xai"), - 9 => ("Mistral AI API key", "mistral_key", "mistral"), - 10 => ("OpenCode Zen API key", "opencode_zen_key", "opencode-zen"), + 3 => ("Ollama Cloud API key", "ollama_key", "ollama"), + 4 => ("Z.ai (GLM) API key", "zhipu_key", "zhipu"), + 5 => ("Groq API key", "groq_key", "groq"), + 6 => ("Together AI API key", "together_key", "together"), + 7 => ("Fireworks AI API key", "fireworks_key", "fireworks"), + 8 => ("DeepSeek API key", "deepseek_key", "deepseek"), + 9 => ("xAI API key", "xai_key", "xai"), + 10 => ("Mistral AI API key", "mistral_key", "mistral"), + 11 => ("OpenCode Zen API key", "opencode_zen_key", "opencode-zen"), _ => unreachable!(), }; diff --git a/src/llm/providers.rs b/src/llm/providers.rs index e3ce47ccc..8e0672149 100644 --- a/src/llm/providers.rs +++ b/src/llm/providers.rs @@ -8,18 +8,26 @@ pub async fn init_providers(config: &LlmConfig) -> Result<()> { // Provider clients are initialized lazily through LlmManager // This module exists for any provider-specific setup that needs to happen // during system startup - + if config.anthropic_key.is_some() { tracing::info!("Anthropic provider configured"); } - + if config.openai_key.is_some() { tracing::info!("OpenAI provider configured"); } + if config.openrouter_key.is_some() { + tracing::info!("OpenRouter provider configured"); + } + + if config.ollama_key.is_some() { + tracing::info!("Ollama provider configured"); + } + if config.opencode_zen_key.is_some() { tracing::info!("OpenCode Zen provider configured"); } - + Ok(()) } diff --git a/src/llm/routing.rs b/src/llm/routing.rs index 5671516cb..d6d61dd8a 100644 --- a/src/llm/routing.rs +++ b/src/llm/routing.rs @@ -153,6 +153,20 @@ pub fn defaults_for_provider(provider: &str) -> RoutingConfig { rate_limit_cooldown_secs: 60, } } + "ollama" => { + let channel: String = "ollama/gpt-oss:120b".into(); + let worker: String = "ollama/gpt-oss:20b".into(); + RoutingConfig { + channel: channel.clone(), + branch: channel.clone(), + worker: worker.clone(), + compactor: worker.clone(), + cortex: worker.clone(), + task_overrides: HashMap::from([("coding".into(), channel.clone())]), + fallbacks: HashMap::from([(channel, vec![worker])]), + rate_limit_cooldown_secs: 60, + } + } "zhipu" => { let channel: String = "zhipu/glm-4-plus".into(); let worker: String = "zhipu/glm-4-flash".into(); @@ -277,6 +291,7 @@ pub fn provider_to_prefix(provider: &str) -> &str { match provider { "openrouter" => "openrouter/", "openai" => "openai/", + "ollama" => "ollama/", "anthropic" => "anthropic/", "zhipu" => "zhipu/", "groq" => "groq/", From 02d2a07eab94e61234bf182e580d322508d0dd2f Mon Sep 17 00:00:00 2001 From: nmb Date: Tue, 17 Feb 2026 16:30:57 +0100 Subject: [PATCH 2/8] =?UTF-8?q?=E2=9C=A8=20feat(ui):=20update=20settings?= =?UTF-8?q?=20for=20new=20provider=20and=20options?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- interface/src/routes/Settings.tsx | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/interface/src/routes/Settings.tsx b/interface/src/routes/Settings.tsx index 716e1cd4d..5fd3129ad 100644 --- a/interface/src/routes/Settings.tsx +++ b/interface/src/routes/Settings.tsx @@ -95,6 +95,13 @@ const PROVIDERS = [ placeholder: "sk-...", envVar: "OPENAI_API_KEY", }, + { + id: "ollama", + name: "Ollama Cloud", + description: "Hosted Ollama models via OpenAI-compatible API", + placeholder: "ollama_...", + envVar: "OLLAMA_API_KEY", + }, { id: "zhipu", name: "Z.ai (GLM)", From 96226ca02f9d2e0e6abc84860991e405595a5fdf Mon Sep 17 00:00:00 2001 From: nmb Date: Tue, 17 Feb 2026 16:30:54 +0100 Subject: [PATCH 3/8] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor(history):=20i?= =?UTF-8?q?mprove=20conversation=20pruning=20and=20management?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/conversation/history.rs | 101 ++++++++++++++++++++++++------------ 1 file changed, 69 insertions(+), 32 deletions(-) diff --git a/src/conversation/history.rs b/src/conversation/history.rs index aeb9b54ea..1916e0ef6 100644 --- a/src/conversation/history.rs +++ b/src/conversation/history.rs @@ -79,7 +79,7 @@ impl ConversationLogger { tokio::spawn(async move { if let Err(error) = sqlx::query( "INSERT INTO conversation_messages (id, channel_id, role, content) \ - VALUES (?, ?, 'assistant', ?)" + VALUES (?, ?, 'assistant', ?)", ) .bind(&id) .bind(&channel_id) @@ -103,7 +103,7 @@ impl ConversationLogger { FROM conversation_messages \ WHERE channel_id = ? \ ORDER BY created_at DESC \ - LIMIT ?" + LIMIT ?", ) .bind(channel_id.as_ref()) .bind(limit) @@ -121,7 +121,9 @@ impl ConversationLogger { sender_id: row.try_get("sender_id").ok(), content: row.try_get("content").unwrap_or_default(), metadata: row.try_get("metadata").ok(), - created_at: row.try_get("created_at").unwrap_or_else(|_| chrono::Utc::now()), + created_at: row + .try_get("created_at") + .unwrap_or_else(|_| chrono::Utc::now()), }) .collect(); @@ -142,7 +144,7 @@ impl ConversationLogger { FROM conversation_messages \ WHERE channel_id = ? \ ORDER BY created_at DESC \ - LIMIT ?" + LIMIT ?", ) .bind(channel_id) .bind(limit) @@ -160,14 +162,15 @@ impl ConversationLogger { sender_id: row.try_get("sender_id").ok(), content: row.try_get("content").unwrap_or_default(), metadata: row.try_get("metadata").ok(), - created_at: row.try_get("created_at").unwrap_or_else(|_| chrono::Utc::now()), + created_at: row + .try_get("created_at") + .unwrap_or_else(|_| chrono::Utc::now()), }) .collect(); messages.reverse(); Ok(messages) } - } /// A unified timeline item combining messages, branch runs, and worker runs. @@ -213,7 +216,12 @@ impl ProcessRunLogger { } /// Record a branch starting. Fire-and-forget. - pub fn log_branch_started(&self, channel_id: &ChannelId, branch_id: BranchId, description: &str) { + pub fn log_branch_started( + &self, + channel_id: &ChannelId, + branch_id: BranchId, + description: &str, + ) { let pool = self.pool.clone(); let id = branch_id.to_string(); let channel_id = channel_id.to_string(); @@ -221,7 +229,7 @@ impl ProcessRunLogger { tokio::spawn(async move { if let Err(error) = sqlx::query( - "INSERT INTO branch_runs (id, channel_id, description) VALUES (?, ?, ?)" + "INSERT INTO branch_runs (id, channel_id, description) VALUES (?, ?, ?)", ) .bind(&id) .bind(&channel_id) @@ -254,22 +262,46 @@ impl ProcessRunLogger { }); } - /// Record a worker starting. Fire-and-forget. - pub fn log_worker_started(&self, channel_id: Option<&ChannelId>, worker_id: WorkerId, task: &str) { + /// Record a branch failing and mark its run as completed. Fire-and-forget. + pub fn log_branch_failed(&self, branch_id: BranchId, error: &str) { let pool = self.pool.clone(); - let id = worker_id.to_string(); - let channel_id = channel_id.map(|c| c.to_string()); - let task = task.to_string(); + let id = branch_id.to_string(); + let summary = format!("Branch failed: {error}"); tokio::spawn(async move { if let Err(error) = sqlx::query( - "INSERT INTO worker_runs (id, channel_id, task) VALUES (?, ?, ?)" + "UPDATE branch_runs SET conclusion = ?, completed_at = CURRENT_TIMESTAMP WHERE id = ?" ) + .bind(&summary) .bind(&id) - .bind(&channel_id) - .bind(&task) .execute(&pool) .await + { + tracing::warn!(%error, branch_id = %id, "failed to persist branch failure"); + } + }); + } + + /// Record a worker starting. Fire-and-forget. + pub fn log_worker_started( + &self, + channel_id: Option<&ChannelId>, + worker_id: WorkerId, + task: &str, + ) { + let pool = self.pool.clone(); + let id = worker_id.to_string(); + let channel_id = channel_id.map(|c| c.to_string()); + let task = task.to_string(); + + tokio::spawn(async move { + if let Err(error) = + sqlx::query("INSERT INTO worker_runs (id, channel_id, task) VALUES (?, ?, ?)") + .bind(&id) + .bind(&channel_id) + .bind(&task) + .execute(&pool) + .await { tracing::warn!(%error, worker_id = %id, "failed to persist worker start"); } @@ -283,13 +315,11 @@ impl ProcessRunLogger { let status = status.to_string(); tokio::spawn(async move { - if let Err(error) = sqlx::query( - "UPDATE worker_runs SET status = ? WHERE id = ?" - ) - .bind(&status) - .bind(&id) - .execute(&pool) - .await + if let Err(error) = sqlx::query("UPDATE worker_runs SET status = ? WHERE id = ?") + .bind(&status) + .bind(&id) + .execute(&pool) + .await { tracing::warn!(%error, worker_id = %id, "failed to persist worker status"); } @@ -327,7 +357,11 @@ impl ProcessRunLogger { limit: i64, before: Option<&str>, ) -> crate::error::Result> { - let before_clause = if before.is_some() { "AND timestamp < ?3" } else { "" }; + let before_clause = if before.is_some() { + "AND timestamp < ?3" + } else { + "" + }; let query_str = format!( "SELECT * FROM ( \ @@ -348,9 +382,7 @@ impl ProcessRunLogger { ) WHERE 1=1 {before_clause} ORDER BY timestamp DESC LIMIT ?2" ); - let mut query = sqlx::query(&query_str) - .bind(channel_id) - .bind(limit); + let mut query = sqlx::query(&query_str).bind(channel_id).bind(limit); if let Some(before_ts) = before { query = query.bind(before_ts); @@ -372,7 +404,8 @@ impl ProcessRunLogger { sender_name: row.try_get("sender_name").ok(), sender_id: row.try_get("sender_id").ok(), content: row.try_get("content").unwrap_or_default(), - created_at: row.try_get::, _>("timestamp") + created_at: row + .try_get::, _>("timestamp") .map(|t| t.to_rfc3339()) .unwrap_or_default(), }), @@ -380,10 +413,12 @@ impl ProcessRunLogger { id: row.try_get("id").unwrap_or_default(), description: row.try_get("description").unwrap_or_default(), conclusion: row.try_get("conclusion").ok(), - started_at: row.try_get::, _>("timestamp") + started_at: row + .try_get::, _>("timestamp") .map(|t| t.to_rfc3339()) .unwrap_or_default(), - completed_at: row.try_get::, _>("completed_at") + completed_at: row + .try_get::, _>("completed_at") .ok() .map(|t| t.to_rfc3339()), }), @@ -392,10 +427,12 @@ impl ProcessRunLogger { task: row.try_get("task").unwrap_or_default(), result: row.try_get("result").ok(), status: row.try_get("status").unwrap_or_default(), - started_at: row.try_get::, _>("timestamp") + started_at: row + .try_get::, _>("timestamp") .map(|t| t.to_rfc3339()) .unwrap_or_default(), - completed_at: row.try_get::, _>("completed_at") + completed_at: row + .try_get::, _>("completed_at") .ok() .map(|t| t.to_rfc3339()), }), From 3c06b7382e54d9888f70b256ac8b4558327b69ab Mon Sep 17 00:00:00 2001 From: nmb Date: Tue, 17 Feb 2026 16:31:00 +0100 Subject: [PATCH 4/8] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor(agent):=20imp?= =?UTF-8?q?rove=20cortex,=20branching,=20and=20worker=20coordination?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/agent.rs | 4 +-- src/agent/branch.rs | 42 ++++++++++++++++---------- src/agent/compactor.rs | 35 +++++++++------------- src/agent/cortex.rs | 64 +++++++++++++++++++++++++--------------- src/agent/cortex_chat.rs | 8 +++-- src/agent/ingestion.rs | 60 ++++++++++++++++++++++--------------- src/agent/status.rs | 20 +++++++++++++ 7 files changed, 146 insertions(+), 87 deletions(-) diff --git a/src/agent.rs b/src/agent.rs index 1733f4ef7..dc277fad6 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -1,10 +1,10 @@ //! Agent processes: channels, branches, workers, compactor, cortex. -pub mod channel; pub mod branch; -pub mod worker; +pub mod channel; pub mod compactor; pub mod cortex; pub mod cortex_chat; pub mod ingestion; pub mod status; +pub mod worker; diff --git a/src/agent/branch.rs b/src/agent/branch.rs index 11d0ff986..7b8c4e2d2 100644 --- a/src/agent/branch.rs +++ b/src/agent/branch.rs @@ -2,10 +2,10 @@ use crate::agent::compactor::estimate_history_tokens; use crate::error::Result; -use crate::llm::routing::is_context_overflow_error; -use crate::llm::SpacebotModel; -use crate::{BranchId, ChannelId, ProcessId, ProcessType, AgentDeps, ProcessEvent}; use crate::hooks::SpacebotHook; +use crate::llm::SpacebotModel; +use crate::llm::routing::is_context_overflow_error; +use crate::{AgentDeps, BranchId, ChannelId, ProcessEvent, ProcessId, ProcessType}; use rig::agent::AgentBuilder; use rig::completion::{CompletionModel, Prompt}; use rig::tool::server::ToolServerHandle; @@ -44,8 +44,14 @@ impl Branch { ) -> Self { let id = Uuid::new_v4(); let process_id = ProcessId::Branch(id); - let hook = SpacebotHook::new(deps.agent_id.clone(), process_id, ProcessType::Branch, Some(channel_id.clone()), deps.event_tx.clone()); - + let hook = SpacebotHook::new( + deps.agent_id.clone(), + process_id, + ProcessType::Branch, + Some(channel_id.clone()), + deps.event_tx.clone(), + ); + Self { id, channel_id, @@ -58,7 +64,7 @@ impl Branch { max_turns, } } - + /// Run the branch's LLM agent loop and return a conclusion. /// /// Each branch has its own isolated ToolServer with `memory_save` and @@ -70,7 +76,7 @@ impl Branch { /// be large, making them susceptible to overflow on the first LLM call. pub async fn run(mut self, prompt: impl Into) -> Result { let prompt = prompt.into(); - + tracing::info!( branch_id = %self.id, channel_id = %self.channel_id, @@ -97,15 +103,17 @@ impl Branch { let mut overflow_retries = 0; let conclusion = loop { - match agent.prompt(¤t_prompt) + match agent + .prompt(¤t_prompt) .with_history(&mut self.history) .with_hook(self.hook.clone()) .await { Ok(response) => break response, Err(rig::completion::PromptError::MaxTurnsError { .. }) => { - let partial = extract_last_assistant_text(&self.history) - .unwrap_or_else(|| "Branch exhausted its turns without a final conclusion.".into()); + let partial = extract_last_assistant_text(&self.history).unwrap_or_else(|| { + "Branch exhausted its turns without a final conclusion.".into() + }); tracing::warn!(branch_id = %self.id, "branch hit max turns, returning partial result"); break partial; } @@ -133,7 +141,8 @@ impl Branch { "branch context overflow, compacting and retrying" ); self.force_compact_history(); - current_prompt = "Continue where you left off. Older context has been compacted.".into(); + current_prompt = + "Continue where you left off. Older context has been compacted.".into(); } Err(error) => { tracing::error!(branch_id = %self.id, %error, "branch LLM call failed"); @@ -149,9 +158,9 @@ impl Branch { channel_id: self.channel_id.clone(), conclusion: conclusion.clone(), }); - + tracing::info!(branch_id = %self.id, "branch completed"); - + Ok(conclusion) } @@ -192,7 +201,9 @@ impl Branch { return; } - let remove_count = ((total as f32 * fraction) as usize).max(1).min(total.saturating_sub(2)); + let remove_count = ((total as f32 * fraction) as usize) + .max(1) + .min(total.saturating_sub(2)); self.history.drain(..remove_count); let marker = format!( @@ -207,7 +218,8 @@ impl Branch { fn extract_last_assistant_text(history: &[rig::message::Message]) -> Option { for message in history.iter().rev() { if let rig::message::Message::Assistant { content, .. } = message { - let texts: Vec = content.iter() + let texts: Vec = content + .iter() .filter_map(|c| { if let rig::message::AssistantContent::Text(t) = c { Some(t.text.clone()) diff --git a/src/agent/compactor.rs b/src/agent/compactor.rs index 76b414187..e68eeda56 100644 --- a/src/agent/compactor.rs +++ b/src/agent/compactor.rs @@ -25,11 +25,7 @@ pub struct Compactor { impl Compactor { /// Create a new compactor for a channel. - pub fn new( - channel_id: ChannelId, - deps: AgentDeps, - history: Arc>>, - ) -> Self { + pub fn new(channel_id: ChannelId, deps: AgentDeps, history: Arc>>) -> Self { Self { channel_id, deps, @@ -117,12 +113,7 @@ impl Compactor { .expect("failed to render compactor prompt"); tokio::spawn(async move { - let result = run_compaction( - &deps, - &compactor_prompt, - &history, - fraction, - ).await; + let result = run_compaction(&deps, &compactor_prompt, &history, fraction).await; match result { Ok(turns_compacted) => { @@ -191,7 +182,9 @@ async fn run_compaction( let (removed_messages, remove_count) = { let mut hist = history.write().await; let total = hist.len(); - let remove_count = ((total as f32 * fraction) as usize).max(1).min(total.saturating_sub(2)); + let remove_count = ((total as f32 * fraction) as usize) + .max(1) + .min(total.saturating_sub(2)); if remove_count == 0 { return Ok(0); } @@ -205,12 +198,14 @@ async fn run_compaction( // 3. Run the compaction LLM to produce summary + extracted memories let routing = deps.runtime_config.routing.load(); let model_name = routing.resolve(ProcessType::Worker, None).to_string(); - let model = SpacebotModel::make(&deps.llm_manager, &model_name) - .with_routing((**routing).clone()); + let model = + SpacebotModel::make(&deps.llm_manager, &model_name).with_routing((**routing).clone()); // Give the compaction worker memory_save so it can directly persist memories let tool_server: ToolServerHandle = ToolServer::new() - .tool(crate::tools::MemorySaveTool::new(deps.memory_search.clone())) + .tool(crate::tools::MemorySaveTool::new( + deps.memory_search.clone(), + )) .run(); let agent = AgentBuilder::new(model) @@ -220,7 +215,8 @@ async fn run_compaction( .build(); let mut compaction_history = Vec::new(); - let response = agent.prompt(&transcript) + let response = agent + .prompt(&transcript) .with_history(&mut compaction_history) .await; @@ -294,9 +290,7 @@ fn estimate_assistant_content_chars(content: &AssistantContent) -> usize { AssistantContent::ToolCall(tc) => { tc.function.name.len() + tc.function.arguments.to_string().len() } - AssistantContent::Reasoning(r) => { - r.reasoning.iter().map(|s| s.len()).sum() - } + AssistantContent::Reasoning(r) => r.reasoning.iter().map(|s| s.len()).sum(), AssistantContent::Image(_) => 500, } } @@ -339,8 +333,7 @@ fn render_messages_as_transcript(messages: &[Message]) -> String { AssistantContent::ToolCall(tc) => { output.push_str(&format!( "[Tool Call: {}({})]\n", - tc.function.name, - tc.function.arguments + tc.function.name, tc.function.arguments )); } _ => {} diff --git a/src/agent/cortex.rs b/src/agent/cortex.rs index 0b0a20e68..c6220f9c0 100644 --- a/src/agent/cortex.rs +++ b/src/agent/cortex.rs @@ -10,11 +10,11 @@ //! health monitoring and memory consolidation. use crate::error::Result; +use crate::hooks::CortexHook; use crate::llm::SpacebotModel; use crate::memory::search::{SearchConfig, SearchMode, SearchSort}; use crate::memory::types::{Association, MemoryType, RelationType}; use crate::{AgentDeps, ProcessEvent, ProcessType}; -use crate::hooks::CortexHook; use rig::agent::AgentBuilder; use rig::completion::{CompletionModel, Prompt}; @@ -180,9 +180,7 @@ impl CortexEventRow { id: self.id, event_type: self.event_type, summary: self.summary, - details: self - .details - .and_then(|d| serde_json::from_str(&d).ok()), + details: self.details.and_then(|d| serde_json::from_str(&d).ok()), created_at: self.created_at.and_utc().to_rfc3339(), } } @@ -275,7 +273,11 @@ async fn run_bulletin_loop(deps: &AgentDeps, logger: &CortexLogger) -> anyhow::R ); logger.log( "bulletin_failed", - &format!("Bulletin generation failed, retrying (attempt {}/{})", attempt + 1, MAX_RETRIES), + &format!( + "Bulletin generation failed, retrying (attempt {}/{})", + attempt + 1, + MAX_RETRIES + ), Some(serde_json::json!({ "attempt": attempt + 1, "max_retries": MAX_RETRIES })), ); tokio::time::sleep(Duration::from_secs(RETRY_DELAY_SECS)).await; @@ -401,7 +403,12 @@ async fn gather_bulletin_sections(deps: &AgentDeps) -> String { "- [{}] (importance: {:.1}) {}\n", result.memory.memory_type, result.memory.importance, - result.memory.content.lines().next().unwrap_or(&result.memory.content), + result + .memory + .content + .lines() + .next() + .unwrap_or(&result.memory.content), )); } output.push('\n'); @@ -458,9 +465,7 @@ pub async fn generate_bulletin(deps: &AgentDeps, logger: &CortexLogger) -> bool SpacebotModel::make(&deps.llm_manager, &model_name).with_routing((**routing).clone()); // No tools needed — the LLM just synthesizes the pre-gathered data - let agent = AgentBuilder::new(model) - .preamble(&bulletin_prompt) - .build(); + let agent = AgentBuilder::new(model).preamble(&bulletin_prompt).build(); let synthesis_prompt = prompt_engine .render_system_cortex_synthesis(cortex_config.bulletin_max_words, &raw_sections) @@ -587,17 +592,24 @@ async fn generate_profile(deps: &AgentDeps, logger: &CortexLogger) { // Gather context: identity + current bulletin let identity_context = { let rendered = deps.runtime_config.identity.load().render(); - if rendered.is_empty() { None } else { Some(rendered) } + if rendered.is_empty() { + None + } else { + Some(rendered) + } }; let memory_bulletin = { let bulletin = deps.runtime_config.memory_bulletin.load(); - if bulletin.is_empty() { None } else { Some(bulletin.as_ref().clone()) } + if bulletin.is_empty() { + None + } else { + Some(bulletin.as_ref().clone()) + } }; - let synthesis_prompt = match prompt_engine.render_system_profile_synthesis( - identity_context.as_deref(), - memory_bulletin.as_deref(), - ) { + let synthesis_prompt = match prompt_engine + .render_system_profile_synthesis(identity_context.as_deref(), memory_bulletin.as_deref()) + { Ok(p) => p, Err(error) => { tracing::warn!(%error, "failed to render profile synthesis prompt"); @@ -610,9 +622,7 @@ async fn generate_profile(deps: &AgentDeps, logger: &CortexLogger) { let model = SpacebotModel::make(&deps.llm_manager, &model_name).with_routing((**routing).clone()); - let agent = AgentBuilder::new(model) - .preamble(&profile_prompt) - .build(); + let agent = AgentBuilder::new(model).preamble(&profile_prompt).build(); match agent.prompt(&synthesis_prompt).await { Ok(response) => { @@ -679,7 +689,9 @@ async fn generate_profile(deps: &AgentDeps, logger: &CortexLogger) { tracing::warn!(%error, raw = %cleaned, "failed to parse profile LLM response as JSON"); logger.log( "profile_failed", - &format!("Profile generation failed: could not parse LLM response — {error}"), + &format!( + "Profile generation failed: could not parse LLM response — {error}" + ), Some(serde_json::json!({ "error": error.to_string(), "raw_response": cleaned, @@ -711,7 +723,10 @@ async fn generate_profile(deps: &AgentDeps, logger: &CortexLogger) { /// Scans memories for embedding similarity and creates association edges /// between related memories. On first run, backfills all existing memories. /// Subsequent runs only process memories created since the last pass. -pub fn spawn_association_loop(deps: AgentDeps, logger: CortexLogger) -> tokio::task::JoinHandle<()> { +pub fn spawn_association_loop( + deps: AgentDeps, + logger: CortexLogger, +) -> tokio::task::JoinHandle<()> { tokio::spawn(async move { if let Err(error) = run_association_loop(&deps, &logger).await { tracing::error!(%error, "cortex association loop exited with error"); @@ -727,7 +742,10 @@ async fn run_association_loop(deps: &AgentDeps, logger: &CortexLogger) -> anyhow // Backfill: process all existing memories on first run let backfill_count = run_association_pass(deps, logger, None).await; - tracing::info!(associations_created = backfill_count, "association backfill complete"); + tracing::info!( + associations_created = backfill_count, + "association backfill complete" + ); let mut last_pass_at = chrono::Utc::now(); @@ -812,8 +830,8 @@ async fn run_association_pass( }; // Weight: map similarity range to 0.5-1.0 - let weight = 0.5 + (similarity - similarity_threshold) - / (1.0 - similarity_threshold) * 0.5; + let weight = + 0.5 + (similarity - similarity_threshold) / (1.0 - similarity_threshold) * 0.5; let association = Association::new(memory_id, &target_id, relation_type) .with_weight(weight.clamp(0.0, 1.0)); diff --git a/src/agent/cortex_chat.rs b/src/agent/cortex_chat.rs index a9e6edb62..6a95f1b4d 100644 --- a/src/agent/cortex_chat.rs +++ b/src/agent/cortex_chat.rs @@ -39,7 +39,10 @@ pub enum CortexChatEvent { /// A tool call started. ToolStarted { tool: String }, /// A tool call completed. - ToolCompleted { tool: String, result_preview: String }, + ToolCompleted { + tool: String, + result_preview: String, + }, /// The full response is ready. Done { full_text: String }, /// An error occurred. @@ -378,8 +381,7 @@ impl CortexChatSession { .. } => { if let Some(result) = result { - transcript - .push_str(&format!("*[Worker: {task}]*: {result}\n\n")); + transcript.push_str(&format!("*[Worker: {task}]*: {result}\n\n")); } } } diff --git a/src/agent/ingestion.rs b/src/agent/ingestion.rs index f59263cf4..97a887fbf 100644 --- a/src/agent/ingestion.rs +++ b/src/agent/ingestion.rs @@ -8,10 +8,10 @@ //! content. If the server restarts mid-file, already-completed chunks are //! skipped on the next run. -use crate::config::IngestionConfig; -use crate::llm::SpacebotModel; use crate::AgentDeps; use crate::ProcessType; +use crate::config::IngestionConfig; +use crate::llm::SpacebotModel; use anyhow::Context as _; use rig::agent::AgentBuilder; @@ -29,10 +29,7 @@ use std::time::Duration; /// Runs until the returned JoinHandle is dropped or aborted. Scans the ingest /// directory on a timer, processes any text files found, and deletes them after /// successful ingestion. -pub fn spawn_ingestion_loop( - ingest_dir: PathBuf, - deps: AgentDeps, -) -> tokio::task::JoinHandle<()> { +pub fn spawn_ingestion_loop(ingest_dir: PathBuf, deps: AgentDeps) -> tokio::task::JoinHandle<()> { tokio::spawn(async move { if let Err(error) = run_ingestion_loop(&ingest_dir, &deps).await { tracing::error!(%error, "ingestion loop exited with error"); @@ -129,13 +126,22 @@ fn is_text_file(path: &Path) -> bool { matches!( ext.to_lowercase().as_str(), - "txt" | "md" | "markdown" - | "json" | "jsonl" - | "csv" | "tsv" + "txt" + | "md" + | "markdown" + | "json" + | "jsonl" + | "csv" + | "tsv" | "log" - | "xml" | "yaml" | "yml" | "toml" - | "rst" | "org" - | "html" | "htm" + | "xml" + | "yaml" + | "yml" + | "toml" + | "rst" + | "org" + | "html" + | "htm" ) } @@ -182,7 +188,14 @@ async fn process_file( let remaining = total_chunks - completed.len(); // Record file-level tracking (idempotent — skips if already exists from a previous run) - upsert_ingestion_file(&deps.sqlite_pool, &hash, filename, file_size, total_chunks as i64).await?; + upsert_ingestion_file( + &deps.sqlite_pool, + &hash, + filename, + file_size, + total_chunks as i64, + ) + .await?; if !completed.is_empty() { tracing::info!( @@ -264,12 +277,9 @@ async fn process_file( // -- Progress tracking queries -------------------------------------------------- /// Load the set of chunk indices already completed for a given content hash. -async fn load_completed_chunks( - pool: &SqlitePool, - hash: &str, -) -> anyhow::Result> { +async fn load_completed_chunks(pool: &SqlitePool, hash: &str) -> anyhow::Result> { let rows = sqlx::query_scalar::<_, i64>( - "SELECT chunk_index FROM ingestion_progress WHERE content_hash = ?" + "SELECT chunk_index FROM ingestion_progress WHERE content_hash = ?", ) .bind(hash) .fetch_all(pool) @@ -417,13 +427,17 @@ async fn process_chunk( let routing = deps.runtime_config.routing.load(); let model_name = routing.resolve(ProcessType::Branch, None).to_string(); - let model = SpacebotModel::make(&deps.llm_manager, &model_name) - .with_routing((**routing).clone()); + let model = + SpacebotModel::make(&deps.llm_manager, &model_name).with_routing((**routing).clone()); - let conversation_logger = crate::conversation::history::ConversationLogger::new(deps.sqlite_pool.clone()); + let conversation_logger = + crate::conversation::history::ConversationLogger::new(deps.sqlite_pool.clone()); let channel_store = crate::conversation::ChannelStore::new(deps.sqlite_pool.clone()); - let tool_server: ToolServerHandle = - crate::tools::create_branch_tool_server(deps.memory_search.clone(), conversation_logger, channel_store); + let tool_server: ToolServerHandle = crate::tools::create_branch_tool_server( + deps.memory_search.clone(), + conversation_logger, + channel_store, + ); let agent = AgentBuilder::new(model) .preamble(&ingestion_prompt) diff --git a/src/agent/status.rs b/src/agent/status.rs index 1b8e35003..d316e48f9 100644 --- a/src/agent/status.rs +++ b/src/agent/status.rs @@ -118,6 +118,26 @@ impl StatusBlock { self.completed_items.remove(0); } } + ProcessEvent::BranchFailed { + branch_id, error, .. + } => { + // Remove from active branches, add to completed with error summary. + if let Some(pos) = self.active_branches.iter().position(|b| b.id == *branch_id) { + let branch = self.active_branches.remove(pos); + self.completed_items.push(CompletedItem { + id: branch_id.to_string(), + item_type: CompletedItemType::Branch, + description: branch.description, + completed_at: Utc::now(), + result_summary: format!("Branch failed: {error}"), + }); + } + + // Keep only last 10 completed items + if self.completed_items.len() > 10 { + self.completed_items.remove(0); + } + } _ => {} } } From 5a54ba900a9fa9f28671e90ac24db0ca7cb8bf73 Mon Sep 17 00:00:00 2001 From: nmb Date: Tue, 17 Feb 2026 16:31:06 +0100 Subject: [PATCH 5/8] =?UTF-8?q?=E2=9C=A8=20feat(tools):=20enhance=20error?= =?UTF-8?q?=20handling=20and=20add=20worker=20spawning?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/tools.rs | 101 +++++++++++++++++++++--------------- src/tools/branch_tool.rs | 7 ++- src/tools/browser.rs | 62 +++++++--------------- src/tools/cancel.rs | 21 ++++++-- src/tools/channel_recall.rs | 42 ++++++++++----- src/tools/cron.rs | 8 ++- src/tools/exec.rs | 36 ++++++++----- src/tools/file.rs | 12 +++-- src/tools/memory_recall.rs | 26 ++++++++-- src/tools/memory_save.rs | 19 ++++--- src/tools/reply.rs | 3 +- src/tools/route.rs | 13 +++-- src/tools/send_file.rs | 12 ++--- src/tools/shell.rs | 51 +++++++++++------- src/tools/skip.rs | 7 +-- src/tools/spawn_worker.rs | 20 ++++--- 16 files changed, 262 insertions(+), 178 deletions(-) diff --git a/src/tools.rs b/src/tools.rs index c0ec320fc..c08af5eaf 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -21,45 +21,56 @@ //! **Cortex ToolServer** (one per agent): //! - `memory_save` — registered at startup -pub mod reply; pub mod branch_tool; -pub mod spawn_worker; -pub mod route; +pub mod browser; pub mod cancel; -pub mod skip; -pub mod react; -pub mod memory_save; -pub mod memory_recall; +pub mod channel_recall; +pub mod cron; +pub mod exec; +pub mod file; pub mod memory_delete; +pub mod memory_recall; +pub mod memory_save; +pub mod react; +pub mod reply; +pub mod route; +pub mod send_file; pub mod set_status; pub mod shell; -pub mod file; -pub mod exec; -pub mod browser; +pub mod skip; +pub mod spawn_worker; pub mod web_search; -pub mod channel_recall; -pub mod cron; -pub mod send_file; -pub use reply::{ReplyTool, ReplyArgs, ReplyOutput, ReplyError}; -pub use branch_tool::{BranchTool, BranchArgs, BranchOutput, BranchError}; -pub use spawn_worker::{SpawnWorkerTool, SpawnWorkerArgs, SpawnWorkerOutput, SpawnWorkerError}; -pub use route::{RouteTool, RouteArgs, RouteOutput, RouteError}; -pub use cancel::{CancelTool, CancelArgs, CancelOutput, CancelError}; -pub use skip::{SkipTool, SkipArgs, SkipOutput, SkipError, SkipFlag, new_skip_flag}; -pub use react::{ReactTool, ReactArgs, ReactOutput, ReactError}; -pub use memory_save::{MemorySaveTool, MemorySaveArgs, MemorySaveOutput, MemorySaveError, AssociationInput}; -pub use memory_recall::{MemoryRecallTool, MemoryRecallArgs, MemoryRecallOutput, MemoryRecallError, MemoryOutput}; -pub use memory_delete::{MemoryDeleteTool, MemoryDeleteArgs, MemoryDeleteOutput, MemoryDeleteError}; -pub use set_status::{SetStatusTool, SetStatusArgs, SetStatusOutput, SetStatusError}; -pub use shell::{ShellTool, ShellArgs, ShellOutput, ShellError, ShellResult}; -pub use file::{FileTool, FileArgs, FileOutput, FileError, FileEntryOutput, FileEntry, FileType}; -pub use exec::{ExecTool, ExecArgs, ExecOutput, ExecError, ExecResult, EnvVar}; -pub use browser::{BrowserTool, BrowserArgs, BrowserOutput, BrowserError, BrowserAction, ActKind, ElementSummary, TabInfo}; -pub use web_search::{WebSearchTool, WebSearchArgs, WebSearchOutput, WebSearchError, SearchResult}; -pub use channel_recall::{ChannelRecallTool, ChannelRecallArgs, ChannelRecallOutput, ChannelRecallError}; -pub use cron::{CronTool, CronArgs, CronOutput, CronError}; -pub use send_file::{SendFileTool, SendFileArgs, SendFileOutput, SendFileError}; +pub use branch_tool::{BranchArgs, BranchError, BranchOutput, BranchTool}; +pub use browser::{ + ActKind, BrowserAction, BrowserArgs, BrowserError, BrowserOutput, BrowserTool, ElementSummary, + TabInfo, +}; +pub use cancel::{CancelArgs, CancelError, CancelOutput, CancelTool}; +pub use channel_recall::{ + ChannelRecallArgs, ChannelRecallError, ChannelRecallOutput, ChannelRecallTool, +}; +pub use cron::{CronArgs, CronError, CronOutput, CronTool}; +pub use exec::{EnvVar, ExecArgs, ExecError, ExecOutput, ExecResult, ExecTool}; +pub use file::{FileArgs, FileEntry, FileEntryOutput, FileError, FileOutput, FileTool, FileType}; +pub use memory_delete::{ + MemoryDeleteArgs, MemoryDeleteError, MemoryDeleteOutput, MemoryDeleteTool, +}; +pub use memory_recall::{ + MemoryOutput, MemoryRecallArgs, MemoryRecallError, MemoryRecallOutput, MemoryRecallTool, +}; +pub use memory_save::{ + AssociationInput, MemorySaveArgs, MemorySaveError, MemorySaveOutput, MemorySaveTool, +}; +pub use react::{ReactArgs, ReactError, ReactOutput, ReactTool}; +pub use reply::{ReplyArgs, ReplyError, ReplyOutput, ReplyTool}; +pub use route::{RouteArgs, RouteError, RouteOutput, RouteTool}; +pub use send_file::{SendFileArgs, SendFileError, SendFileOutput, SendFileTool}; +pub use set_status::{SetStatusArgs, SetStatusError, SetStatusOutput, SetStatusTool}; +pub use shell::{ShellArgs, ShellError, ShellOutput, ShellResult, ShellTool}; +pub use skip::{SkipArgs, SkipError, SkipFlag, SkipOutput, SkipTool, new_skip_flag}; +pub use spawn_worker::{SpawnWorkerArgs, SpawnWorkerError, SpawnWorkerOutput, SpawnWorkerTool}; +pub use web_search::{SearchResult, WebSearchArgs, WebSearchError, WebSearchOutput, WebSearchTool}; use crate::agent::channel::ChannelState; use crate::config::BrowserConfig; @@ -116,18 +127,24 @@ pub async fn add_channel_tools( skip_flag: SkipFlag, cron_tool: Option, ) -> Result<(), rig::tool::server::ToolServerError> { - handle.add_tool(ReplyTool::new( - response_tx.clone(), - conversation_id, - state.conversation_logger.clone(), - state.channel_id.clone(), - )).await?; + handle + .add_tool(ReplyTool::new( + response_tx.clone(), + conversation_id, + state.conversation_logger.clone(), + state.channel_id.clone(), + )) + .await?; handle.add_tool(BranchTool::new(state.clone())).await?; handle.add_tool(SpawnWorkerTool::new(state.clone())).await?; handle.add_tool(RouteTool::new(state.clone())).await?; handle.add_tool(CancelTool::new(state)).await?; - handle.add_tool(SkipTool::new(skip_flag, response_tx.clone())).await?; - handle.add_tool(SendFileTool::new(response_tx.clone())).await?; + handle + .add_tool(SkipTool::new(skip_flag, response_tx.clone())) + .await?; + handle + .add_tool(SendFileTool::new(response_tx.clone())) + .await?; handle.add_tool(ReactTool::new(response_tx)).await?; if let Some(cron) = cron_tool { handle.add_tool(cron).await?; @@ -196,7 +213,9 @@ pub fn create_worker_tool_server( .tool(ShellTool::new(instance_dir.clone(), workspace.clone())) .tool(FileTool::new(workspace.clone())) .tool(ExecTool::new(instance_dir, workspace)) - .tool(SetStatusTool::new(agent_id, worker_id, channel_id, event_tx)); + .tool(SetStatusTool::new( + agent_id, worker_id, channel_id, event_tx, + )); if browser_config.enabled { server = server.tool(BrowserTool::new(browser_config, screenshot_dir)); diff --git a/src/tools/branch_tool.rs b/src/tools/branch_tool.rs index db543aa2d..6f19aeed0 100644 --- a/src/tools/branch_tool.rs +++ b/src/tools/branch_tool.rs @@ -1,7 +1,7 @@ //! Branch tool for forking context and thinking (channel only). -use crate::agent::channel::{ChannelState, spawn_branch_from_state}; use crate::BranchId; +use crate::agent::channel::{ChannelState, spawn_branch_from_state}; use rig::completion::ToolDefinition; use rig::tool::Tool; use schemars::JsonSchema; @@ -75,7 +75,10 @@ impl Tool for BranchTool { Ok(BranchOutput { branch_id, spawned: true, - message: format!("Branch {branch_id} spawned. It will investigate: {}", args.description), + message: format!( + "Branch {branch_id} spawned. It will investigate: {}", + args.description + ), }) } } diff --git a/src/tools/browser.rs b/src/tools/browser.rs index 932fd1c0f..150bd87a6 100644 --- a/src/tools/browser.rs +++ b/src/tools/browser.rs @@ -384,9 +384,9 @@ impl BrowserTool { builder = builder.chrome_executable(path); } - let chrome_config = builder - .build() - .map_err(|error| BrowserError::new(format!("failed to build browser config: {error}")))?; + let chrome_config = builder.build().map_err(|error| { + BrowserError::new(format!("failed to build browser config: {error}")) + })?; tracing::info!( headless = self.config.headless, @@ -398,9 +398,7 @@ impl BrowserTool { .await .map_err(|error| BrowserError::new(format!("failed to launch browser: {error}")))?; - let handler_task = tokio::spawn(async move { - while handler.next().await.is_some() {} - }); + let handler_task = tokio::spawn(async move { while handler.next().await.is_some() {} }); state.browser = Some(browser); state._handler_task = Some(handler_task); @@ -428,7 +426,10 @@ impl BrowserTool { state.element_refs.clear(); state.next_ref = 0; - Ok(BrowserOutput::success(format!("Navigated to {url}")).with_page_info(title, current_url)) + Ok( + BrowserOutput::success(format!("Navigated to {url}")) + .with_page_info(title, current_url), + ) } async fn handle_open(&self, url: Option) -> Result { @@ -500,10 +501,7 @@ impl BrowserTool { }) } - async fn handle_focus( - &self, - target_id: Option, - ) -> Result { + async fn handle_focus(&self, target_id: Option) -> Result { let Some(target_id) = target_id else { return Err(BrowserError::new("target_id is required for focus action")); }; @@ -520,9 +518,7 @@ impl BrowserTool { state.element_refs.clear(); state.next_ref = 0; - Ok(BrowserOutput::success(format!( - "Focused tab {target_id}" - ))) + Ok(BrowserOutput::success(format!("Focused tab {target_id}"))) } async fn handle_close_tab( @@ -652,9 +648,7 @@ impl BrowserTool { key: Option, ) -> Result { let Some(act_kind) = act_kind else { - return Err(BrowserError::new( - "act_kind is required for act action", - )); + return Err(BrowserError::new("act_kind is required for act action")); }; let state = self.state.lock().await; @@ -692,8 +686,7 @@ impl BrowserTool { return Err(BrowserError::new("key is required for act:press_key")); }; if element_ref.is_some() { - let element = - self.resolve_element_ref(&state, page, element_ref).await?; + let element = self.resolve_element_ref(&state, page, element_ref).await?; element .press_key(&key) .await @@ -713,12 +706,9 @@ impl BrowserTool { } ActKind::ScrollIntoView => { let element = self.resolve_element_ref(&state, page, element_ref).await?; - element - .scroll_into_view() - .await - .map_err(|error| { - BrowserError::new(format!("scroll_into_view failed: {error}")) - })?; + element.scroll_into_view().await.map_err(|error| { + BrowserError::new(format!("scroll_into_view failed: {error}")) + })?; Ok(BrowserOutput::success("Scrolled element into view")) } ActKind::Focus => { @@ -745,9 +735,7 @@ impl BrowserTool { element .screenshot(CaptureScreenshotFormat::Png) .await - .map_err(|error| { - BrowserError::new(format!("element screenshot failed: {error}")) - })? + .map_err(|error| BrowserError::new(format!("element screenshot failed: {error}")))? } else { let params = ScreenshotParams::builder() .format(CaptureScreenshotFormat::Png) @@ -793,10 +781,7 @@ impl BrowserTool { }) } - async fn handle_evaluate( - &self, - script: Option, - ) -> Result { + async fn handle_evaluate(&self, script: Option) -> Result { if !self.config.evaluate_enabled { return Err(BrowserError::new( "JavaScript evaluation is disabled in browser config (set evaluate_enabled = true)", @@ -804,9 +789,7 @@ impl BrowserTool { } let Some(script) = script else { - return Err(BrowserError::new( - "script is required for evaluate action", - )); + return Err(BrowserError::new("script is required for evaluate action")); }; let state = self.state.lock().await; @@ -942,9 +925,7 @@ impl BrowserTool { element_ref: Option, ) -> Result { let Some(ref_id) = element_ref else { - return Err(BrowserError::new( - "element_ref is required for this action", - )); + return Err(BrowserError::new("element_ref is required for this action")); }; let elem_ref = state.element_refs.get(&ref_id).ok_or_else(|| { @@ -966,10 +947,7 @@ impl BrowserTool { } /// Dispatch a key press event to the page via CDP Input domain. -async fn dispatch_key_press( - page: &chromiumoxide::Page, - key: &str, -) -> Result<(), BrowserError> { +async fn dispatch_key_press(page: &chromiumoxide::Page, key: &str) -> Result<(), BrowserError> { let key_down = DispatchKeyEventParams::builder() .r#type(DispatchKeyEventType::KeyDown) .key(key) diff --git a/src/tools/cancel.rs b/src/tools/cancel.rs index c80177a27..ea1fd950c 100644 --- a/src/tools/cancel.rs +++ b/src/tools/cancel.rs @@ -85,22 +85,33 @@ impl Tool for CancelTool { async fn call(&self, args: Self::Args) -> Result { match args.process_type.as_str() { "branch" => { - let branch_id = args.process_id.parse::() + let branch_id = args + .process_id + .parse::() .map_err(|e| CancelError(format!("Invalid branch ID: {e}")))?; - self.state.cancel_branch(branch_id).await + self.state + .cancel_branch(branch_id) + .await .map_err(CancelError)?; } "worker" => { - let worker_id = args.process_id.parse::() + let worker_id = args + .process_id + .parse::() .map_err(|e| CancelError(format!("Invalid worker ID: {e}")))?; - self.state.cancel_worker(worker_id).await + self.state + .cancel_worker(worker_id) + .await .map_err(CancelError)?; } other => return Err(CancelError(format!("Unknown process type: {other}"))), } let message = if let Some(reason) = &args.reason { - format!("{} {} cancelled: {reason}", args.process_type, args.process_id) + format!( + "{} {} cancelled: {reason}", + args.process_type, args.process_id + ) } else { format!("{} {} cancelled.", args.process_type, args.process_id) }; diff --git a/src/tools/channel_recall.rs b/src/tools/channel_recall.rs index 7d3751467..3759082bd 100644 --- a/src/tools/channel_recall.rs +++ b/src/tools/channel_recall.rs @@ -20,7 +20,10 @@ pub struct ChannelRecallTool { impl ChannelRecallTool { pub fn new(conversation_logger: ConversationLogger, channel_store: ChannelStore) -> Self { - Self { conversation_logger, channel_store } + Self { + conversation_logger, + channel_store, + } } } @@ -118,7 +121,8 @@ impl Tool for ChannelRecallTool { let limit = args.limit.min(MAX_TRANSCRIPT_MESSAGES).max(1); // Resolve channel name to ID - let found = self.channel_store + let found = self + .channel_store .find_by_name(&channel_query) .await .map_err(|e| ChannelRecallError(format!("Failed to search channels: {e}")))?; @@ -133,19 +137,21 @@ impl Tool for ChannelRecallTool { }; // Load transcript - let messages = self.conversation_logger + let messages = self + .conversation_logger .load_channel_transcript(&channel.id, limit) .await .map_err(|e| ChannelRecallError(format!("Failed to load transcript: {e}")))?; - let transcript: Vec = messages.iter().map(|message| { - TranscriptMessage { + let transcript: Vec = messages + .iter() + .map(|message| TranscriptMessage { role: message.role.clone(), sender: message.sender_name.clone(), content: message.content.clone(), timestamp: message.created_at.to_rfc3339(), - } - }).collect(); + }) + .collect(); let summary = format_transcript(&channel.display_name, &channel.id, &transcript); @@ -162,18 +168,20 @@ impl Tool for ChannelRecallTool { impl ChannelRecallTool { async fn list_channels(&self) -> std::result::Result { - let channels = self.channel_store + let channels = self + .channel_store .list_active() .await .map_err(|e| ChannelRecallError(format!("Failed to list channels: {e}")))?; - let entries: Vec = channels.iter().map(|channel| { - ChannelListEntry { + let entries: Vec = channels + .iter() + .map(|channel| ChannelListEntry { channel_id: channel.id.clone(), channel_name: channel.display_name.clone(), last_activity: channel.last_activity_at.to_rfc3339(), - } - }).collect(); + }) + .collect(); let summary = format_channel_list(&entries); @@ -201,14 +209,20 @@ fn format_transcript( } let label = channel_name.as_deref().unwrap_or(channel_id); - let mut output = format!("## Transcript from #{label} ({} messages)\n\n", messages.len()); + let mut output = format!( + "## Transcript from #{label} ({} messages)\n\n", + messages.len() + ); for message in messages { let sender = match &message.sender { Some(name) => name.as_str(), None => "assistant", }; - output.push_str(&format!("**{}** ({}): {}\n\n", sender, message.role, message.content)); + output.push_str(&format!( + "**{}** ({}): {}\n\n", + sender, message.role, message.content + )); } output diff --git a/src/tools/cron.rs b/src/tools/cron.rs index c8d56e5bc..20e62d673 100644 --- a/src/tools/cron.rs +++ b/src/tools/cron.rs @@ -139,7 +139,9 @@ impl Tool for CronTool { impl CronTool { async fn create(&self, args: CronArgs) -> Result { - let id = args.id.ok_or_else(|| CronError("'id' is required for create".into()))?; + let id = args + .id + .ok_or_else(|| CronError("'id' is required for create".into()))?; let prompt = args .prompt .ok_or_else(|| CronError("'prompt' is required for create".into()))?; @@ -205,7 +207,9 @@ impl CronTool { prompt: config.prompt, interval_secs: config.interval_secs, delivery_target: config.delivery_target, - active_hours: config.active_hours.map(|(s, e)| format!("{s:02}:00-{e:02}:00")), + active_hours: config + .active_hours + .map(|(s, e)| format!("{s:02}:00-{e:02}:00")), }) .collect(); diff --git a/src/tools/exec.rs b/src/tools/exec.rs index 0a66cae55..f74b7d052 100644 --- a/src/tools/exec.rs +++ b/src/tools/exec.rs @@ -19,7 +19,10 @@ pub struct ExecTool { impl ExecTool { /// Create a new exec tool with the given instance directory for path blocking. pub fn new(instance_dir: PathBuf, workspace: PathBuf) -> Self { - Self { instance_dir, workspace } + Self { + instance_dir, + workspace, + } } /// Check if program arguments reference sensitive instance paths. @@ -42,9 +45,13 @@ impl ExecTool { // Block references to sensitive files by name for file in super::shell::SENSITIVE_FILES { if all_args.contains(file) { - if all_args.contains(instance_str.as_ref()) || !all_args.contains(workspace_str.as_ref()) { + if all_args.contains(instance_str.as_ref()) + || !all_args.contains(workspace_str.as_ref()) + { return Err(ExecError { - message: format!("Cannot access {file} — instance configuration is protected."), + message: format!( + "Cannot access {file} — instance configuration is protected." + ), exit_code: -1, }); } @@ -178,7 +185,10 @@ impl Tool for ExecTool { if let Some(ref dir) = args.working_dir { let path = std::path::Path::new(dir); let canonical = path.canonicalize().unwrap_or_else(|_| path.to_path_buf()); - let workspace_canonical = self.workspace.canonicalize().unwrap_or_else(|_| self.workspace.clone()); + let workspace_canonical = self + .workspace + .canonicalize() + .unwrap_or_else(|_| self.workspace.clone()); if !canonical.starts_with(&workspace_canonical) { return Err(ExecError { message: format!( @@ -216,8 +226,7 @@ impl Tool for ExecTool { cmd.env(env_var.key, env_var.value); } - cmd.stdout(Stdio::piped()) - .stderr(Stdio::piped()); + cmd.stdout(Stdio::piped()).stderr(Stdio::piped()); let timeout = tokio::time::Duration::from_secs(args.timeout_seconds); @@ -301,13 +310,14 @@ pub async fn exec( cmd.stdout(Stdio::piped()).stderr(Stdio::piped()); - let output = tokio::time::timeout( - tokio::time::Duration::from_secs(60), - cmd.output(), - ) - .await - .map_err(|_| crate::error::AgentError::Other(anyhow::anyhow!("Execution timed out").into()))? - .map_err(|e| crate::error::AgentError::Other(anyhow::anyhow!("Failed to execute: {e}").into()))?; + let output = tokio::time::timeout(tokio::time::Duration::from_secs(60), cmd.output()) + .await + .map_err(|_| { + crate::error::AgentError::Other(anyhow::anyhow!("Execution timed out").into()) + })? + .map_err(|e| { + crate::error::AgentError::Other(anyhow::anyhow!("Failed to execute: {e}").into()) + })?; Ok(ExecResult { success: output.status.success(), diff --git a/src/tools/file.rs b/src/tools/file.rs index 82897136a..b260ca95b 100644 --- a/src/tools/file.rs +++ b/src/tools/file.rs @@ -35,7 +35,10 @@ impl FileTool { // existing ancestor and append the remaining components. let canonical = best_effort_canonicalize(&resolved); - let workspace_canonical = self.workspace.canonicalize().unwrap_or_else(|_| self.workspace.clone()); + let workspace_canonical = self + .workspace + .canonicalize() + .unwrap_or_else(|_| self.workspace.clone()); if !canonical.starts_with(&workspace_canonical) { return Err(FileError(format!( @@ -285,7 +288,10 @@ async fn do_file_list(path: &Path) -> Result { if total_count > max_entries { entries.push(FileEntryOutput { - name: format!("... and {} more entries (listing capped at {max_entries})", total_count - max_entries), + name: format!( + "... and {} more entries (listing capped at {max_entries})", + total_count - max_entries + ), entry_type: "notice".to_string(), size: 0, }); @@ -301,8 +307,6 @@ async fn do_file_list(path: &Path) -> Result { }) } - - /// File entry metadata (legacy). #[derive(Debug, Clone)] pub struct FileEntry { diff --git a/src/tools/memory_recall.rs b/src/tools/memory_recall.rs index a2aa795a7..9739a7cfd 100644 --- a/src/tools/memory_recall.rs +++ b/src/tools/memory_recall.rs @@ -95,7 +95,11 @@ fn parse_memory_type(s: &str) -> std::result::Result Ok(MemoryType::Todo), other => Err(MemoryRecallError(format!( "unknown memory_type \"{other}\". Valid types: {}", - crate::memory::MemoryType::ALL.iter().map(|t| t.to_string()).collect::>().join(", ") + crate::memory::MemoryType::ALL + .iter() + .map(|t| t.to_string()) + .collect::>() + .join(", ") ))), } } @@ -296,7 +300,10 @@ pub async fn memory_recall( sort_by: None, }; - let output = tool.call(args).await.map_err(|e| crate::error::AgentError::Other(anyhow::anyhow!(e)))?; + let output = tool + .call(args) + .await + .map_err(|e| crate::error::AgentError::Other(anyhow::anyhow!(e)))?; // Convert back to Memory type for backward compatibility let store = memory_search.store(); @@ -319,7 +326,10 @@ mod tests { fn test_parse_search_mode_valid() { assert_eq!(parse_search_mode("hybrid").unwrap(), SearchMode::Hybrid); assert_eq!(parse_search_mode("recent").unwrap(), SearchMode::Recent); - assert_eq!(parse_search_mode("important").unwrap(), SearchMode::Important); + assert_eq!( + parse_search_mode("important").unwrap(), + SearchMode::Important + ); assert_eq!(parse_search_mode("typed").unwrap(), SearchMode::Typed); } @@ -332,8 +342,14 @@ mod tests { #[test] fn test_parse_search_sort_valid() { assert_eq!(parse_search_sort("recent").unwrap(), SearchSort::Recent); - assert_eq!(parse_search_sort("importance").unwrap(), SearchSort::Importance); - assert_eq!(parse_search_sort("most_accessed").unwrap(), SearchSort::MostAccessed); + assert_eq!( + parse_search_sort("importance").unwrap(), + SearchSort::Importance + ); + assert_eq!( + parse_search_sort("most_accessed").unwrap(), + SearchSort::MostAccessed + ); } #[test] diff --git a/src/tools/memory_save.rs b/src/tools/memory_save.rs index d965f1675..db2546da3 100644 --- a/src/tools/memory_save.rs +++ b/src/tools/memory_save.rs @@ -1,15 +1,14 @@ //! Memory save tool for channels and branches. use crate::error::Result; -use crate::memory::{Memory, MemorySearch, MemoryType}; use crate::memory::types::{Association, CreateAssociationInput}; +use crate::memory::{Memory, MemorySearch, MemoryType}; use rig::completion::ToolDefinition; use rig::tool::Tool; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::sync::Arc; - /// Tool for saving memories to the store. #[derive(Debug, Clone)] pub struct MemorySaveTool { @@ -207,8 +206,8 @@ impl Tool for MemorySaveTool { _ => crate::memory::types::RelationType::RelatedTo, }; - let association = - Association::new(&memory.id, &assoc.target_id, relation_type).with_weight(assoc.weight); + let association = Association::new(&memory.id, &assoc.target_id, relation_type) + .with_weight(assoc.weight); if let Err(error) = store.create_association(&association).await { tracing::warn!( @@ -236,7 +235,12 @@ impl Tool for MemorySaveTool { // Ensure the FTS index exists so full_text_search queries work. // Safe to call repeatedly — no-ops if the index already exists. - if let Err(error) = self.memory_search.embedding_table().ensure_fts_index().await { + if let Err(error) = self + .memory_search + .embedding_table() + .ensure_fts_index() + .await + { tracing::warn!(%error, "failed to ensure FTS index after memory save"); } @@ -264,6 +268,9 @@ pub async fn save_fact( associations: vec![], }; - let output = tool.call(args).await.map_err(|e| crate::error::AgentError::Other(anyhow::anyhow!(e)))?; + let output = tool + .call(args) + .await + .map_err(|e| crate::error::AgentError::Other(anyhow::anyhow!(e)))?; Ok(output.memory_id) } diff --git a/src/tools/reply.rs b/src/tools/reply.rs index 67b2fed2e..978b31c54 100644 --- a/src/tools/reply.rs +++ b/src/tools/reply.rs @@ -100,7 +100,8 @@ impl Tool for ReplyTool { "reply tool called" ); - self.conversation_logger.log_bot_message(&self.channel_id, &args.content); + self.conversation_logger + .log_bot_message(&self.channel_id, &args.content); let response = match args.thread_name { Some(ref name) => { diff --git a/src/tools/route.rs b/src/tools/route.rs index 8cbecd3b2..aa631d866 100644 --- a/src/tools/route.rs +++ b/src/tools/route.rs @@ -1,7 +1,7 @@ //! Route tool for sending follow-ups to active workers. -use crate::agent::channel::ChannelState; use crate::WorkerId; +use crate::agent::channel::ChannelState; use rig::completion::ToolDefinition; use rig::tool::Tool; use schemars::JsonSchema; @@ -74,7 +74,9 @@ impl Tool for RouteTool { } async fn call(&self, args: Self::Args) -> std::result::Result { - let worker_id = args.worker_id.parse::() + let worker_id = args + .worker_id + .parse::() .map_err(|e| RouteError(format!("Invalid worker ID: {e}")))?; // Look up the input sender for this worker @@ -87,10 +89,11 @@ impl Tool for RouteTool { drop(inputs); // Deliver the message - input_tx.send(args.message).await - .map_err(|_| RouteError(format!( + input_tx.send(args.message).await.map_err(|_| { + RouteError(format!( "Worker {worker_id} has stopped accepting input (channel closed)" - )))?; + )) + })?; tracing::info!( worker_id = %worker_id, diff --git a/src/tools/send_file.rs b/src/tools/send_file.rs index 7a2ea8756..cfe556f4b 100644 --- a/src/tools/send_file.rs +++ b/src/tools/send_file.rs @@ -85,9 +85,9 @@ impl Tool for SendFileTool { return Err(SendFileError("file_path must be an absolute path".into())); } - let metadata = tokio::fs::metadata(&path) - .await - .map_err(|error| SendFileError(format!("can't read file '{}': {error}", path.display())))?; + let metadata = tokio::fs::metadata(&path).await.map_err(|error| { + SendFileError(format!("can't read file '{}': {error}", path.display())) + })?; if !metadata.is_file() { return Err(SendFileError(format!("'{}' is not a file", path.display()))); @@ -101,9 +101,9 @@ impl Tool for SendFileTool { ))); } - let data = tokio::fs::read(&path) - .await - .map_err(|error| SendFileError(format!("failed to read '{}': {error}", path.display())))?; + let data = tokio::fs::read(&path).await.map_err(|error| { + SendFileError(format!("failed to read '{}': {error}", path.display())) + })?; let filename = path .file_name() diff --git a/src/tools/shell.rs b/src/tools/shell.rs index 91b714c05..5ae769c8f 100644 --- a/src/tools/shell.rs +++ b/src/tools/shell.rs @@ -22,6 +22,7 @@ pub const SECRET_ENV_VARS: &[&str] = &[ "ANTHROPIC_API_KEY", "OPENAI_API_KEY", "OPENROUTER_API_KEY", + "OLLAMA_API_KEY", "DISCORD_BOT_TOKEN", "SLACK_BOT_TOKEN", "SLACK_APP_TOKEN", @@ -40,7 +41,10 @@ pub struct ShellTool { impl ShellTool { /// Create a new shell tool with the given instance directory for path blocking. pub fn new(instance_dir: PathBuf, workspace: PathBuf) -> Self { - Self { instance_dir, workspace } + Self { + instance_dir, + workspace, + } } /// Check if a command references sensitive instance paths or secret env vars. @@ -58,7 +62,8 @@ impl ShellTool { return Err(ShellError { message: "ACCESS DENIED: Cannot access the instance directory — it contains \ protected configuration and data. Do not attempt to reproduce or \ - guess its contents. Inform the user that this path is restricted.".to_string(), + guess its contents. Inform the user that this path is restricted." + .to_string(), exit_code: -1, }); } @@ -101,9 +106,14 @@ impl ShellTool { // Block broad env dumps that would expose secrets if command.contains("printenv") { let trimmed = command.trim(); - if trimmed == "printenv" || trimmed.ends_with("| printenv") || trimmed.contains("printenv |") || trimmed.contains("printenv >") { + if trimmed == "printenv" + || trimmed.ends_with("| printenv") + || trimmed.contains("printenv |") + || trimmed.contains("printenv >") + { return Err(ShellError { - message: "Cannot dump all environment variables — they may contain secrets.".to_string(), + message: "Cannot dump all environment variables — they may contain secrets." + .to_string(), exit_code: -1, }); } @@ -112,7 +122,8 @@ impl ShellTool { let trimmed = command.trim(); if trimmed == "env" || trimmed.starts_with("env |") || trimmed.starts_with("env >") { return Err(ShellError { - message: "Cannot dump all environment variables — they may contain secrets.".to_string(), + message: "Cannot dump all environment variables — they may contain secrets." + .to_string(), exit_code: -1, }); } @@ -212,7 +223,10 @@ impl Tool for ShellTool { if let Some(ref dir) = args.working_dir { let path = std::path::Path::new(dir); let canonical = path.canonicalize().unwrap_or_else(|_| path.to_path_buf()); - let workspace_canonical = self.workspace.canonicalize().unwrap_or_else(|_| self.workspace.clone()); + let workspace_canonical = self + .workspace + .canonicalize() + .unwrap_or_else(|_| self.workspace.clone()); if !canonical.starts_with(&workspace_canonical) { return Err(ShellError { message: format!( @@ -241,8 +255,7 @@ impl Tool for ShellTool { cmd.current_dir(&self.workspace); } - cmd.stdout(Stdio::piped()) - .stderr(Stdio::piped()); + cmd.stdout(Stdio::piped()).stderr(Stdio::piped()); // Set timeout let timeout = tokio::time::Duration::from_secs(args.timeout_seconds); @@ -306,7 +319,10 @@ fn format_shell_output(exit_code: i32, stdout: &str, stderr: &str) -> String { /// System-internal shell execution that bypasses path restrictions. /// Used by the system itself, not LLM-facing. -pub async fn shell(command: &str, working_dir: Option<&std::path::Path>) -> crate::error::Result { +pub async fn shell( + command: &str, + working_dir: Option<&std::path::Path>, +) -> crate::error::Result { let mut cmd = if cfg!(target_os = "windows") { let mut c = Command::new("cmd"); c.arg("/C").arg(command); @@ -323,13 +339,14 @@ pub async fn shell(command: &str, working_dir: Option<&std::path::Path>) -> crat cmd.stdout(Stdio::piped()).stderr(Stdio::piped()); - let output = tokio::time::timeout( - tokio::time::Duration::from_secs(60), - cmd.output(), - ) - .await - .map_err(|_| crate::error::AgentError::Other(anyhow::anyhow!("Command timed out").into()))? - .map_err(|e| crate::error::AgentError::Other(anyhow::anyhow!("Failed to execute command: {e}").into()))?; + let output = tokio::time::timeout(tokio::time::Duration::from_secs(60), cmd.output()) + .await + .map_err(|_| crate::error::AgentError::Other(anyhow::anyhow!("Command timed out").into()))? + .map_err(|e| { + crate::error::AgentError::Other( + anyhow::anyhow!("Failed to execute command: {e}").into(), + ) + })?; Ok(ShellResult { success: output.status.success(), @@ -354,5 +371,3 @@ impl ShellResult { format_shell_output(self.exit_code, &self.stdout, &self.stderr) } } - - diff --git a/src/tools/skip.rs b/src/tools/skip.rs index 57c5a9258..481b5fb8b 100644 --- a/src/tools/skip.rs +++ b/src/tools/skip.rs @@ -84,9 +84,10 @@ impl Tool for SkipTool { self.flag.store(true, Ordering::Relaxed); // Cancel the typing indicator so it doesn't linger - let _ = self.response_tx.send( - OutboundResponse::Status(crate::StatusUpdate::StopTyping) - ).await; + let _ = self + .response_tx + .send(OutboundResponse::Status(crate::StatusUpdate::StopTyping)) + .await; let reason = args.reason.as_deref().unwrap_or("no reason given"); tracing::info!(reason, "skip tool called, suppressing response"); diff --git a/src/tools/spawn_worker.rs b/src/tools/spawn_worker.rs index 69047e8fc..35157df78 100644 --- a/src/tools/spawn_worker.rs +++ b/src/tools/spawn_worker.rs @@ -1,7 +1,9 @@ //! Spawn worker tool for creating new workers. -use crate::agent::channel::{ChannelState, spawn_worker_from_state, spawn_opencode_worker_from_state}; use crate::WorkerId; +use crate::agent::channel::{ + ChannelState, spawn_opencode_worker_from_state, spawn_worker_from_state, +}; use rig::completion::ToolDefinition; use rig::tool::Tool; use schemars::JsonSchema; @@ -144,17 +146,13 @@ impl Tool for SpawnWorkerTool { let is_opencode = args.worker_type.as_deref() == Some("opencode"); let worker_id = if is_opencode { - let directory = args.directory.as_deref() - .ok_or_else(|| SpawnWorkerError("directory is required for opencode workers".into()))?; + let directory = args.directory.as_deref().ok_or_else(|| { + SpawnWorkerError("directory is required for opencode workers".into()) + })?; - spawn_opencode_worker_from_state( - &self.state, - &args.task, - directory, - args.interactive, - ) - .await - .map_err(|e| SpawnWorkerError(format!("{e}")))? + spawn_opencode_worker_from_state(&self.state, &args.task, directory, args.interactive) + .await + .map_err(|e| SpawnWorkerError(format!("{e}")))? } else { spawn_worker_from_state( &self.state, From 40edd14ea31ecc987b38427cdc0fda9618a192ef Mon Sep 17 00:00:00 2001 From: nmb Date: Tue, 17 Feb 2026 16:31:09 +0100 Subject: [PATCH 6/8] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor(memory):=20im?= =?UTF-8?q?prove=20store=20operations=20and=20search?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/memory.rs | 16 ++-- src/memory/embedding.rs | 14 ++-- src/memory/lance.rs | 117 +++++++++++++------------- src/memory/maintenance.rs | 37 ++++----- src/memory/search.rs | 170 +++++++++++++++++++++++++------------- src/memory/store.rs | 132 ++++++++++++++++++----------- 6 files changed, 291 insertions(+), 195 deletions(-) diff --git a/src/memory.rs b/src/memory.rs index 937951ca1..5609fbc12 100644 --- a/src/memory.rs +++ b/src/memory.rs @@ -1,14 +1,14 @@ //! Memory storage and retrieval system. -pub mod store; -pub mod types; -pub mod search; -pub mod lance; pub mod embedding; +pub mod lance; pub mod maintenance; +pub mod search; +pub mod store; +pub mod types; -pub use store::MemoryStore; -pub use types::{Memory, MemoryType, Association, RelationType}; -pub use search::{MemorySearch, SearchConfig, SearchMode, SearchSort, curate_results}; -pub use lance::EmbeddingTable; pub use embedding::EmbeddingModel; +pub use lance::EmbeddingTable; +pub use search::{MemorySearch, SearchConfig, SearchMode, SearchSort, curate_results}; +pub use store::MemoryStore; +pub use types::{Association, Memory, MemoryType, RelationType}; diff --git a/src/memory/embedding.rs b/src/memory/embedding.rs index 4714248e5..bedb3c371 100644 --- a/src/memory/embedding.rs +++ b/src/memory/embedding.rs @@ -22,12 +22,15 @@ impl EmbeddingModel { let model = fastembed::TextEmbedding::try_new(options) .map_err(|e| LlmError::EmbeddingFailed(e.to_string()))?; - Ok(Self { model: Arc::new(model) }) + Ok(Self { + model: Arc::new(model), + }) } /// Generate embeddings for multiple texts (blocking). pub fn embed(&self, texts: Vec) -> Result>> { - self.model.embed(texts, None) + self.model + .embed(texts, None) .map_err(|e| LlmError::EmbeddingFailed(e.to_string()).into()) } @@ -42,8 +45,9 @@ impl EmbeddingModel { let text = text.to_string(); let model = self.model.clone(); let result = tokio::task::spawn_blocking(move || { - model.embed(vec![text], None) - .map_err(|e| crate::Error::Llm(crate::error::LlmError::EmbeddingFailed(e.to_string()))) + model.embed(vec![text], None).map_err(|e| { + crate::Error::Llm(crate::error::LlmError::EmbeddingFailed(e.to_string())) + }) }) .await .map_err(|e| crate::Error::Other(anyhow::anyhow!("embedding task failed: {}", e)))??; @@ -52,8 +56,6 @@ impl EmbeddingModel { } } - - /// Async function to embed text using a shared model. pub async fn embed_text(model: &Arc, text: &str) -> Result> { model.embed_one(text).await diff --git a/src/memory/lance.rs b/src/memory/lance.rs index 5338111ce..fb0fcc443 100644 --- a/src/memory/lance.rs +++ b/src/memory/lance.rs @@ -1,9 +1,9 @@ //! LanceDB table management and embedding storage with HNSW vector index and FTS. use crate::error::{DbError, Result}; -use arrow_array::{Array, RecordBatchIterator}; use arrow_array::cast::AsArray; use arrow_array::types::Float32Type; +use arrow_array::{Array, RecordBatchIterator}; use futures::TryStreamExt; use std::sync::Arc; @@ -33,54 +33,49 @@ impl EmbeddingTable { Err(_) => { // Create new table with empty batch let schema = Self::schema(); - + // Create empty RecordBatchIterator - let batches = RecordBatchIterator::new( - vec![].into_iter().map(Ok), - Arc::new(schema), - ); - + let batches = + RecordBatchIterator::new(vec![].into_iter().map(Ok), Arc::new(schema)); + let table = connection .create_table(TABLE_NAME, Box::new(batches)) .execute() .await .map_err(|e| DbError::LanceDb(e.to_string()))?; - + Ok(Self { table }) } } } - + /// Store an embedding with content for a memory. /// The content is stored for FTS search capability. - pub async fn store( - &self, - memory_id: &str, - content: &str, - embedding: &[f32], - ) -> Result<()> { + pub async fn store(&self, memory_id: &str, content: &str, embedding: &[f32]) -> Result<()> { if embedding.len() != EMBEDDING_DIM as usize { return Err(DbError::LanceDb(format!( "Embedding dimension mismatch: expected {}, got {}", EMBEDDING_DIM, embedding.len() - )).into()); + )) + .into()); } - + use arrow_array::{FixedSizeListArray, RecordBatch, StringArray}; - + let schema = Self::schema(); - + // Build arrays for the record batch let id_array = StringArray::from(vec![memory_id]); let content_array = StringArray::from(vec![content]); - + // Convert embedding to FixedSizeListArray - let embedding_array = arrow_array::FixedSizeListArray::from_iter_primitive::( - vec![Some(embedding.iter().map(|v| Some(*v)).collect::>())], - EMBEDDING_DIM, - ); - + let embedding_array = + arrow_array::FixedSizeListArray::from_iter_primitive::( + vec![Some(embedding.iter().map(|v| Some(*v)).collect::>())], + EMBEDDING_DIM, + ); + let batch = RecordBatch::try_new( Arc::new(schema), vec![ @@ -90,22 +85,19 @@ impl EmbeddingTable { ], ) .map_err(|e| DbError::LanceDb(e.to_string()))?; - + // Create iterator for IntoArrow trait - let batches = RecordBatchIterator::new( - vec![Ok(batch)], - Arc::new(Self::schema()), - ); - + let batches = RecordBatchIterator::new(vec![Ok(batch)], Arc::new(Self::schema())); + self.table .add(Box::new(batches)) .execute() .await .map_err(|e| DbError::LanceDb(e.to_string()))?; - + Ok(()) } - + /// Delete an embedding by memory ID. pub async fn delete(&self, memory_id: &str) -> Result<()> { let predicate = format!("id = '{}'", memory_id); @@ -113,10 +105,10 @@ impl EmbeddingTable { .delete(&predicate) .await .map_err(|e| DbError::LanceDb(e.to_string()))?; - + Ok(()) } - + /// Vector similarity search using cosine distance. /// Returns (memory_id, distance) pairs sorted by distance (ascending). pub async fn vector_search( @@ -129,11 +121,12 @@ impl EmbeddingTable { "Query embedding dimension mismatch: expected {}, got {}", EMBEDDING_DIM, query_embedding.len() - )).into()); + )) + .into()); } - + use lancedb::query::{ExecutableQuery, QueryBase}; - + // Use query() API with nearest_to for vector search let results: Vec = self .table @@ -147,13 +140,16 @@ impl EmbeddingTable { .try_collect() .await .map_err(|e| DbError::LanceDb(e.to_string()))?; - + let mut matches = Vec::new(); for batch in results { - if let (Some(id_col), Some(dist_col)) = (batch.column_by_name("id"), batch.column_by_name("_distance")) { + if let (Some(id_col), Some(dist_col)) = ( + batch.column_by_name("id"), + batch.column_by_name("_distance"), + ) { let ids: &arrow_array::StringArray = id_col.as_string::(); let dists: &arrow_array::PrimitiveArray = dist_col.as_primitive(); - + for i in 0..ids.len() { if ids.is_valid(i) && dists.is_valid(i) { let id = ids.value(i).to_string(); @@ -163,10 +159,10 @@ impl EmbeddingTable { } } } - + Ok(matches) } - + /// Find memories similar to a given memory by its embedding. /// Returns (memory_id, similarity) pairs where similarity = 1.0 - cosine_distance. /// Results exclude the source memory itself. @@ -235,12 +231,14 @@ impl EmbeddingTable { /// Returns (memory_id, score) pairs sorted by score (descending). pub async fn text_search(&self, query: &str, limit: usize) -> Result> { use lancedb::query::{ExecutableQuery, QueryBase}; - + // Use full_text_search on the content column let results: Vec = self .table .query() - .full_text_search(lance_index::scalar::FullTextSearchQuery::new(query.to_string())) + .full_text_search(lance_index::scalar::FullTextSearchQuery::new( + query.to_string(), + )) .select(lancedb::query::Select::columns(&["id", "_score"])) .limit(limit) .execute() @@ -249,13 +247,15 @@ impl EmbeddingTable { .try_collect() .await .map_err(|e| DbError::LanceDb(e.to_string()))?; - + let mut matches = Vec::new(); for batch in results { - if let (Some(id_col), Some(score_col)) = (batch.column_by_name("id"), batch.column_by_name("_score")) { + if let (Some(id_col), Some(score_col)) = + (batch.column_by_name("id"), batch.column_by_name("_score")) + { let ids: &arrow_array::StringArray = id_col.as_string::(); let scores: &arrow_array::PrimitiveArray = score_col.as_primitive(); - + for i in 0..ids.len() { if ids.is_valid(i) && scores.is_valid(i) { let id = ids.value(i).to_string(); @@ -265,10 +265,10 @@ impl EmbeddingTable { } } } - + Ok(matches) } - + /// Create HNSW vector index and FTS index for better performance. /// Should be called after enough data accumulates. pub async fn create_indexes(&self) -> Result<()> { @@ -278,19 +278,20 @@ impl EmbeddingTable { .execute() .await .map_err(|e| DbError::LanceDb(format!("Failed to create vector index: {}", e)))?; - + self.ensure_fts_index().await?; - + Ok(()) } - + /// Ensure the FTS index exists on the content column. /// /// LanceDB requires an inverted index for `full_text_search()` queries. /// This is safe to call multiple times — if the index already exists, the /// error is silently ignored. pub async fn ensure_fts_index(&self) -> Result<()> { - match self.table + match self + .table .create_index(&["content"], lancedb::index::Index::FTS(Default::default())) .execute() .await @@ -311,7 +312,7 @@ impl EmbeddingTable { } } } - + /// Get the Arrow schema for the embeddings table. fn schema() -> arrow_schema::Schema { arrow_schema::Schema::new(vec![ @@ -320,7 +321,11 @@ impl EmbeddingTable { arrow_schema::Field::new( "embedding", arrow_schema::DataType::FixedSizeList( - Arc::new(arrow_schema::Field::new("item", arrow_schema::DataType::Float32, true)), + Arc::new(arrow_schema::Field::new( + "item", + arrow_schema::DataType::Float32, + true, + )), EMBEDDING_DIM, ), false, diff --git a/src/memory/maintenance.rs b/src/memory/maintenance.rs index 6a8bbdf3b..2dc378ec9 100644 --- a/src/memory/maintenance.rs +++ b/src/memory/maintenance.rs @@ -35,16 +35,16 @@ pub async fn run_maintenance( config: &MaintenanceConfig, ) -> Result { let mut report = MaintenanceReport::default(); - + // Apply decay to all non-identity memories report.decayed = apply_decay(memory_store, config.decay_rate).await?; - + // Prune old, low-importance memories report.pruned = prune_memories(memory_store, config).await?; - + // Merge near-duplicate memories report.merged = merge_similar_memories(memory_store, config.merge_similarity_threshold).await?; - + Ok(report) } @@ -56,17 +56,17 @@ async fn apply_decay(memory_store: &MemoryStore, decay_rate: f32) -> Result Result 0.01 { memory.importance = new_importance.clamp(0.0, 1.0); memory.updated_at = now; @@ -87,19 +87,16 @@ async fn apply_decay(memory_store: &MemoryStore, decay_rate: f32) -> Result Result { +async fn prune_memories(memory_store: &MemoryStore, config: &MaintenanceConfig) -> Result { let now = chrono::Utc::now(); let min_age = chrono::Duration::days(config.min_age_days); let cutoff_date = now - min_age; - + // Get all memories below threshold that are old enough let candidates = sqlx::query( r#" @@ -107,21 +104,21 @@ async fn prune_memories( WHERE importance < ? AND memory_type != 'identity' AND created_at < ? - "# + "#, ) .bind(config.prune_threshold) .bind(cutoff_date) .fetch_all(memory_store.pool()) .await?; - + let mut pruned_count = 0; - + for row in candidates { let id: String = sqlx::Row::try_get(&row, "id")?; memory_store.delete(&id).await?; pruned_count += 1; } - + Ok(pruned_count) } diff --git a/src/memory/search.rs b/src/memory/search.rs index 571f7d7b9..b4477d34a 100644 --- a/src/memory/search.rs +++ b/src/memory/search.rs @@ -71,17 +71,17 @@ impl MemorySearch { embedding_model, } } - + /// Get a reference to the memory store. pub fn store(&self) -> &MemoryStore { &self.store } - + /// Get a reference to the embedding table. pub fn embedding_table(&self) -> &EmbeddingTable { &self.embedding_table } - + /// Get a reference to the embedding model. pub fn embedding_model(&self) -> &EmbeddingModel { &self.embedding_model @@ -91,7 +91,7 @@ impl MemorySearch { pub fn embedding_model_arc(&self) -> &Arc { &self.embedding_model } - + /// Unified search entry point. Dispatches to the appropriate strategy /// based on `config.mode`. pub async fn search( @@ -152,16 +152,23 @@ impl MemorySearch { let mut vector_results = Vec::new(); let mut fts_results = Vec::new(); let mut graph_results = Vec::new(); - + // 1. Full-text search via LanceDB // FTS requires an inverted index. If the index doesn't exist yet (empty // table, first run) this will fail — fall back to vector + graph search. - match self.embedding_table.text_search(query, config.max_results_per_source).await { + match self + .embedding_table + .text_search(query, config.max_results_per_source) + .await + { Ok(fts_matches) => { for (memory_id, score) in fts_matches { if let Some(memory) = self.store.load(&memory_id).await? { if !memory.forgotten { - fts_results.push(ScoredMemory { memory, score: score as f64 }); + fts_results.push(ScoredMemory { + memory, + score: score as f64, + }); } } } @@ -170,16 +177,23 @@ impl MemorySearch { tracing::debug!(%error, "FTS search unavailable, falling back to vector + graph"); } } - + // 2. Vector similarity search via LanceDB let query_embedding = self.embedding_model.embed_one(query).await?; - match self.embedding_table.vector_search(&query_embedding, config.max_results_per_source).await { + match self + .embedding_table + .vector_search(&query_embedding, config.max_results_per_source) + .await + { Ok(vector_matches) => { for (memory_id, distance) in vector_matches { let similarity = 1.0 - distance; if let Some(memory) = self.store.load(&memory_id).await? { if !memory.forgotten { - vector_results.push(ScoredMemory { memory, score: similarity as f64 }); + vector_results.push(ScoredMemory { + memory, + score: similarity as f64, + }); } } } @@ -188,39 +202,40 @@ impl MemorySearch { tracing::debug!(%error, "vector search unavailable, falling back to graph only"); } } - + // 3. Graph traversal from high-importance memories // Get identity and high-importance memories as starting points let seed_memories = self.store.get_high_importance(0.8, 20).await?; - + for seed in seed_memories { // Check if seed is semantically related to query via simple keyword matching - if query.to_lowercase().split_whitespace().any(|term| { - seed.content.to_lowercase().contains(term) - }) { - graph_results.push(ScoredMemory { - memory: seed.clone(), - score: seed.importance as f64 + if query + .to_lowercase() + .split_whitespace() + .any(|term| seed.content.to_lowercase().contains(term)) + { + graph_results.push(ScoredMemory { + memory: seed.clone(), + score: seed.importance as f64, }); - + // Traverse graph to find related memories - self.traverse_graph(&seed.id, config.max_graph_depth, &mut graph_results).await?; + self.traverse_graph(&seed.id, config.max_graph_depth, &mut graph_results) + .await?; } } - + // 4. Merge results using Reciprocal Rank Fusion (RRF) - let fused_results = reciprocal_rank_fusion( - &vector_results, - &fts_results, - &graph_results, - config.rrf_k, - ); - + let fused_results = + reciprocal_rank_fusion(&vector_results, &fts_results, &graph_results, config.rrf_k); + // Convert to MemorySearchResult with ranks, applying optional type filter let results: Vec = fused_results .into_iter() .filter(|scored| { - config.memory_type.is_none_or(|t| scored.memory.memory_type == t) + config + .memory_type + .is_none_or(|t| scored.memory.memory_type == t) }) .enumerate() .map(|(rank, scored)| MemorySearchResult { @@ -231,10 +246,10 @@ impl MemorySearch { .filter(|r| r.score >= config.min_score) .take(config.max_results_per_source) .collect(); - + Ok(results) } - + /// Traverse the memory graph to find related memories (iterative to avoid async recursion). async fn traverse_graph( &self, @@ -243,21 +258,21 @@ impl MemorySearch { results: &mut Vec, ) -> Result<()> { use std::collections::VecDeque; - + // Queue of (memory_id, current_depth) let mut queue: VecDeque<(String, usize)> = VecDeque::new(); let mut visited: std::collections::HashSet = std::collections::HashSet::new(); - + queue.push_back((start_id.to_string(), 0)); visited.insert(start_id.to_string()); - + while let Some((current_id, depth)) = queue.pop_front() { if depth > max_depth { continue; } - + let associations = self.store.get_associations(¤t_id).await?; - + for assoc in associations { // Get the related memory let related_id = if assoc.source_id == current_id { @@ -265,12 +280,12 @@ impl MemorySearch { } else { &assoc.source_id }; - + if visited.contains(related_id) { continue; } visited.insert(related_id.clone()); - + if let Some(memory) = self.store.load(related_id).await? { if memory.forgotten { continue; @@ -283,19 +298,25 @@ impl MemorySearch { RelationType::Contradicts => 0.5, RelationType::PartOf => 0.8, }; - + let score = memory.importance as f64 * assoc.weight as f64 * type_multiplier; - - results.push(ScoredMemory { memory: memory.clone(), score }); - + + results.push(ScoredMemory { + memory: memory.clone(), + score, + }); + // Add to queue for RelatedTo and PartOf relations - if matches!(assoc.relation_type, RelationType::RelatedTo | RelationType::PartOf) { + if matches!( + assoc.relation_type, + RelationType::RelatedTo | RelationType::PartOf + ) { queue.push_back((related_id.clone(), depth + 1)); } } } } - + Ok(()) } } @@ -355,44 +376,50 @@ fn reciprocal_rank_fusion( ) -> Vec { // Build a map of memory ID to RRF score let mut rrf_scores: HashMap = HashMap::new(); - + // Add vector results for (rank, scored) in vector_results.iter().enumerate() { let rrf_score = 1.0 / (k + (rank as f64 + 1.0)); - let entry = rrf_scores.entry(scored.memory.id.clone()) + let entry = rrf_scores + .entry(scored.memory.id.clone()) .or_insert((0.0, scored.memory.clone())); entry.0 += rrf_score; } - + // Add FTS results for (rank, scored) in fts_results.iter().enumerate() { let rrf_score = 1.0 / (k + (rank as f64 + 1.0)); - let entry = rrf_scores.entry(scored.memory.id.clone()) + let entry = rrf_scores + .entry(scored.memory.id.clone()) .or_insert((0.0, scored.memory.clone())); entry.0 += rrf_score; } - + // Add graph results for (rank, scored) in graph_results.iter().enumerate() { let rrf_score = 1.0 / (k + (rank as f64 + 1.0)); - let entry = rrf_scores.entry(scored.memory.id.clone()) + let entry = rrf_scores + .entry(scored.memory.id.clone()) .or_insert((0.0, scored.memory.clone())); entry.0 += rrf_score; } - + // Convert to vec and sort by RRF score let mut fused: Vec = rrf_scores .into_iter() .map(|(_, (score, memory))| ScoredMemory { memory, score }) .collect(); - + fused.sort_by(|a, b| b.score.total_cmp(&a.score)); - + fused } /// Curate search results to return only the most relevant. -pub fn curate_results(results: &[MemorySearchResult], max_results: usize) -> Vec<&MemorySearchResult> { +pub fn curate_results( + results: &[MemorySearchResult], + max_results: usize, +) -> Vec<&MemorySearchResult> { results.iter().take(max_results).collect() } @@ -482,11 +509,36 @@ mod tests { let mut memories = Vec::new(); let types_and_importance = [ - ("user identity info", MemoryType::Identity, 1.0, now - Duration::days(30)), - ("recent event", MemoryType::Event, 0.4, now - Duration::hours(1)), - ("important decision", MemoryType::Decision, 0.9, now - Duration::days(2)), - ("casual observation", MemoryType::Observation, 0.2, now - Duration::days(7)), - ("user preference", MemoryType::Preference, 0.7, now - Duration::days(1)), + ( + "user identity info", + MemoryType::Identity, + 1.0, + now - Duration::days(30), + ), + ( + "recent event", + MemoryType::Event, + 0.4, + now - Duration::hours(1), + ), + ( + "important decision", + MemoryType::Decision, + 0.9, + now - Duration::days(2), + ), + ( + "casual observation", + MemoryType::Observation, + 0.2, + now - Duration::days(7), + ), + ( + "user preference", + MemoryType::Preference, + 0.7, + now - Duration::days(1), + ), ]; for (content, memory_type, importance, created_at) in types_and_importance { diff --git a/src/memory/store.rs b/src/memory/store.rs index 695c4a1aa..ca6f2cff4 100644 --- a/src/memory/store.rs +++ b/src/memory/store.rs @@ -27,12 +27,12 @@ impl MemoryStore { pub fn new(pool: SqlitePool) -> Arc { Arc::new(Self { pool }) } - + /// Get a reference to the SQLite pool. pub(crate) fn pool(&self) -> &SqlitePool { &self.pool } - + /// Save a new memory to the store. pub async fn save(&self, memory: &Memory) -> Result<()> { sqlx::query( @@ -40,7 +40,7 @@ impl MemoryStore { INSERT INTO memories (id, content, memory_type, importance, created_at, updated_at, last_accessed_at, access_count, source, channel_id, forgotten) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - "# + "#, ) .bind(&memory.id) .bind(&memory.content) @@ -56,10 +56,10 @@ impl MemoryStore { .execute(&self.pool) .await .with_context(|| format!("failed to save memory {}", memory.id))?; - + Ok(()) } - + /// Load a memory by ID. pub async fn load(&self, id: &str) -> Result> { let row = sqlx::query( @@ -68,16 +68,16 @@ impl MemoryStore { last_accessed_at, access_count, source, channel_id, forgotten FROM memories WHERE id = ? - "# + "#, ) .bind(id) .fetch_optional(&self.pool) .await .with_context(|| format!("failed to load memory {}", id))?; - + Ok(row.map(|row| row_to_memory(&row))) } - + /// Update an existing memory. pub async fn update(&self, memory: &Memory) -> Result<()> { sqlx::query( @@ -87,7 +87,7 @@ impl MemoryStore { last_accessed_at = ?, access_count = ?, source = ?, channel_id = ?, forgotten = ? WHERE id = ? - "# + "#, ) .bind(&memory.content) .bind(memory.memory_type.to_string()) @@ -102,10 +102,10 @@ impl MemoryStore { .execute(&self.pool) .await .with_context(|| format!("failed to update memory {}", memory.id))?; - + Ok(()) } - + /// Delete a memory by ID. pub async fn delete(&self, id: &str) -> Result<()> { sqlx::query("DELETE FROM memories WHERE id = ?") @@ -113,35 +113,35 @@ impl MemoryStore { .execute(&self.pool) .await .with_context(|| format!("failed to delete memory {}", id))?; - + Ok(()) } - + /// Record access to a memory, updating last_accessed_at and access_count. pub async fn record_access(&self, id: &str) -> Result<()> { let now = chrono::Utc::now(); - + sqlx::query( r#" UPDATE memories SET last_accessed_at = ?, access_count = access_count + 1 WHERE id = ? - "# + "#, ) .bind(now) .bind(id) .execute(&self.pool) .await .with_context(|| format!("failed to record access for memory {}", id))?; - + Ok(()) } - + /// Mark a memory as forgotten. The memory stays in the database but is /// excluded from search results and recall. pub async fn forget(&self, id: &str) -> Result { let result = sqlx::query( - "UPDATE memories SET forgotten = 1, updated_at = ? WHERE id = ? AND forgotten = 0" + "UPDATE memories SET forgotten = 1, updated_at = ? WHERE id = ? AND forgotten = 0", ) .bind(chrono::Utc::now()) .bind(id) @@ -160,7 +160,7 @@ impl MemoryStore { VALUES (?, ?, ?, ?, ?, ?) ON CONFLICT(source_id, target_id, relation_type) DO UPDATE SET weight = excluded.weight - "# + "#, ) .bind(&association.id) .bind(&association.source_id) @@ -176,10 +176,10 @@ impl MemoryStore { association.source_id, association.target_id ) })?; - + Ok(()) } - + /// Get all associations for a memory (both incoming and outgoing). pub async fn get_associations(&self, memory_id: &str) -> Result> { let rows = sqlx::query( @@ -187,25 +187,28 @@ impl MemoryStore { SELECT id, source_id, target_id, relation_type, weight, created_at FROM associations WHERE source_id = ? OR target_id = ? - "# + "#, ) .bind(memory_id) .bind(memory_id) .fetch_all(&self.pool) .await .with_context(|| format!("failed to get associations for memory {}", memory_id))?; - + let associations = rows .into_iter() .map(|row| row_to_association(&row)) .collect(); - + Ok(associations) } /// Get all associations where both source and target are in the provided set. /// Used by the graph view to fetch edges between a known set of visible nodes. - pub async fn get_associations_between(&self, memory_ids: &[String]) -> Result> { + pub async fn get_associations_between( + &self, + memory_ids: &[String], + ) -> Result> { if memory_ids.is_empty() { return Ok(Vec::new()); } @@ -233,7 +236,10 @@ impl MemoryStore { .await .context("failed to get associations between memory set")?; - Ok(rows.into_iter().map(|row| row_to_association(&row)).collect()) + Ok(rows + .into_iter() + .map(|row| row_to_association(&row)) + .collect()) } /// Get neighbors of a memory: all associations plus the connected memories. @@ -296,11 +302,11 @@ impl MemoryStore { Ok((neighbors, all_associations)) } - + /// Get memories by type. pub async fn get_by_type(&self, memory_type: MemoryType, limit: i64) -> Result> { let type_str = memory_type.to_string(); - + let rows = sqlx::query( r#" SELECT id, content, memory_type, importance, created_at, updated_at, @@ -309,17 +315,17 @@ impl MemoryStore { WHERE memory_type = ? AND forgotten = 0 ORDER BY importance DESC, updated_at DESC LIMIT ? - "# + "#, ) .bind(&type_str) .bind(limit) .fetch_all(&self.pool) .await .with_context(|| format!("failed to get memories by type {:?}", memory_type))?; - + Ok(rows.into_iter().map(|row| row_to_memory(&row)).collect()) } - + /// Get high-importance memories for injection into context. pub async fn get_high_importance(&self, threshold: f32, limit: i64) -> Result> { let rows = sqlx::query( @@ -330,14 +336,14 @@ impl MemoryStore { WHERE importance >= ? AND forgotten = 0 ORDER BY importance DESC, updated_at DESC LIMIT ? - "# + "#, ) .bind(threshold) .bind(limit) .fetch_all(&self.pool) .await .with_context(|| "failed to get high importance memories")?; - + Ok(rows.into_iter().map(|row| row_to_memory(&row)).collect()) } @@ -422,17 +428,23 @@ impl MemoryStore { fn row_to_memory(row: &sqlx::sqlite::SqliteRow) -> Memory { let mem_type_str: String = row.try_get("memory_type").unwrap_or_default(); let memory_type = parse_memory_type(&mem_type_str); - + let channel_id: Option = row.try_get("channel_id").ok(); - + Memory { id: row.try_get("id").unwrap_or_default(), content: row.try_get("content").unwrap_or_default(), memory_type, importance: row.try_get("importance").unwrap_or(0.5), - created_at: row.try_get("created_at").unwrap_or_else(|_| chrono::Utc::now()), - updated_at: row.try_get("updated_at").unwrap_or_else(|_| chrono::Utc::now()), - last_accessed_at: row.try_get("last_accessed_at").unwrap_or_else(|_| chrono::Utc::now()), + created_at: row + .try_get("created_at") + .unwrap_or_else(|_| chrono::Utc::now()), + updated_at: row + .try_get("updated_at") + .unwrap_or_else(|_| chrono::Utc::now()), + last_accessed_at: row + .try_get("last_accessed_at") + .unwrap_or_else(|_| chrono::Utc::now()), access_count: row.try_get("access_count").unwrap_or(0), source: row.try_get("source").ok(), channel_id: channel_id.map(|id| Arc::from(id) as crate::ChannelId), @@ -459,14 +471,16 @@ fn parse_memory_type(s: &str) -> MemoryType { fn row_to_association(row: &sqlx::sqlite::SqliteRow) -> Association { let relation_type_str: String = row.try_get("relation_type").unwrap_or_default(); let relation_type = parse_relation_type(&relation_type_str); - + Association { id: row.try_get("id").unwrap_or_default(), source_id: row.try_get("source_id").unwrap_or_default(), target_id: row.try_get("target_id").unwrap_or_default(), relation_type, weight: row.try_get("weight").unwrap_or(0.5), - created_at: row.try_get("created_at").unwrap_or_else(|_| chrono::Utc::now()), + created_at: row + .try_get("created_at") + .unwrap_or_else(|_| chrono::Utc::now()), } } @@ -520,11 +534,28 @@ mod tests { let store = MemoryStore::connect_in_memory().await; let now = Utc::now(); - let old = insert_memory_at(&store, "old", MemoryType::Fact, 0.5, now - Duration::hours(3)).await; - let mid = insert_memory_at(&store, "mid", MemoryType::Fact, 0.5, now - Duration::hours(1)).await; + let old = insert_memory_at( + &store, + "old", + MemoryType::Fact, + 0.5, + now - Duration::hours(3), + ) + .await; + let mid = insert_memory_at( + &store, + "mid", + MemoryType::Fact, + 0.5, + now - Duration::hours(1), + ) + .await; let new = insert_memory_at(&store, "new", MemoryType::Fact, 0.5, now).await; - let results = store.get_sorted(SearchSort::Recent, 10, None).await.unwrap(); + let results = store + .get_sorted(SearchSort::Recent, 10, None) + .await + .unwrap(); assert_eq!(results.len(), 3); assert_eq!(results[0].id, new.id); assert_eq!(results[1].id, mid.id); @@ -540,7 +571,10 @@ mod tests { let high = insert_memory_at(&store, "high", MemoryType::Fact, 0.9, now).await; let medium = insert_memory_at(&store, "medium", MemoryType::Fact, 0.5, now).await; - let results = store.get_sorted(SearchSort::Importance, 10, None).await.unwrap(); + let results = store + .get_sorted(SearchSort::Importance, 10, None) + .await + .unwrap(); assert_eq!(results[0].id, high.id); assert_eq!(results[1].id, medium.id); assert_eq!(results[2].id, low.id); @@ -560,7 +594,10 @@ mod tests { store.record_access(&b.id).await.unwrap(); } - let results = store.get_sorted(SearchSort::MostAccessed, 10, None).await.unwrap(); + let results = store + .get_sorted(SearchSort::MostAccessed, 10, None) + .await + .unwrap(); assert_eq!(results[0].id, b.id); assert_eq!(results[1].id, a.id); } @@ -612,7 +649,10 @@ mod tests { let forgotten = insert_memory_at(&store, "forgotten", MemoryType::Fact, 0.5, now).await; store.forget(&forgotten.id).await.unwrap(); - let results = store.get_sorted(SearchSort::Recent, 10, None).await.unwrap(); + let results = store + .get_sorted(SearchSort::Recent, 10, None) + .await + .unwrap(); assert_eq!(results.len(), 1); assert_eq!(results[0].id, visible.id); } From 92533cc2658ee41ba2f338dba9b621f93752bf9e Mon Sep 17 00:00:00 2001 From: nmb Date: Tue, 17 Feb 2026 16:31:20 +0100 Subject: [PATCH 7/8] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor(messaging):?= =?UTF-8?q?=20improve=20platform=20integrations?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/messaging.rs | 6 ++-- src/messaging/discord.rs | 72 +++++++++++++++++++++++++-------------- src/messaging/manager.rs | 18 +++++++--- src/messaging/slack.rs | 44 +++++++++++------------- src/messaging/telegram.rs | 32 ++++++++--------- src/messaging/traits.rs | 61 +++++++++++++++++++-------------- src/messaging/webhook.rs | 2 +- 7 files changed, 135 insertions(+), 100 deletions(-) diff --git a/src/messaging.rs b/src/messaging.rs index ded286bc6..76ecc7a2f 100644 --- a/src/messaging.rs +++ b/src/messaging.rs @@ -1,11 +1,11 @@ //! Messaging adapters (Discord, Slack, Telegram, Webhook). -pub mod traits; -pub mod manager; pub mod discord; +pub mod manager; pub mod slack; pub mod telegram; +pub mod traits; pub mod webhook; -pub use traits::Messaging; pub use manager::MessagingManager; +pub use traits::Messaging; diff --git a/src/messaging/discord.rs b/src/messaging/discord.rs index 98f0f510c..10b3c2b88 100644 --- a/src/messaging/discord.rs +++ b/src/messaging/discord.rs @@ -9,7 +9,7 @@ use arc_swap::ArcSwap; use async_trait::async_trait; use serenity::all::{ ChannelId, ChannelType, Context, CreateAttachment, CreateMessage, CreateThread, EditMessage, - EventHandler, GatewayIntents, GetMessages, Http, Message, MessageId, Ready, ReactionType, + EventHandler, GatewayIntents, GetMessages, Http, Message, MessageId, ReactionType, Ready, ShardManager, User, UserId, }; use std::collections::HashMap; @@ -30,10 +30,7 @@ pub struct DiscordAdapter { } impl DiscordAdapter { - pub fn new( - token: impl Into, - permissions: Arc>, - ) -> Self { + pub fn new(token: impl Into, permissions: Arc>) -> Self { Self { token: token.into(), permissions, @@ -138,15 +135,15 @@ impl Messaging for DiscordAdapter { let thread_result = match message_id { Some(source_message_id) => { - let builder = CreateThread::new(&thread_name) - .kind(ChannelType::PublicThread); + let builder = + CreateThread::new(&thread_name).kind(ChannelType::PublicThread); channel_id .create_thread_from_message(&*http, source_message_id, builder) .await } None => { - let builder = CreateThread::new(&thread_name) - .kind(ChannelType::PublicThread); + let builder = + CreateThread::new(&thread_name).kind(ChannelType::PublicThread); channel_id.create_thread(&*http, builder).await } }; @@ -178,7 +175,12 @@ impl Messaging for DiscordAdapter { } } } - OutboundResponse::File { filename, data, mime_type: _, caption } => { + OutboundResponse::File { + filename, + data, + mime_type: _, + caption, + } => { self.stop_typing(&message.id).await; let attachment = CreateAttachment::bytes(data, &filename); @@ -200,7 +202,11 @@ impl Messaging for DiscordAdapter { .context("missing discord_message_id for reaction")?; channel_id - .create_reaction(&*http, MessageId::new(message_id), ReactionType::Unicode(emoji)) + .create_reaction( + &*http, + MessageId::new(message_id), + ReactionType::Unicode(emoji), + ) .await .context("failed to add reaction")?; } @@ -267,11 +273,7 @@ impl Messaging for DiscordAdapter { Ok(()) } - async fn broadcast( - &self, - target: &str, - response: OutboundResponse, - ) -> crate::Result<()> { + async fn broadcast(&self, target: &str, response: OutboundResponse) -> crate::Result<()> { let http = self.get_http().await?; let channel_id = ChannelId::new( @@ -330,12 +332,18 @@ impl Messaging for DiscordAdapter { let resolved_content = resolve_mentions(&message.content, &message.mentions); - let display_name = message.author.global_name.as_deref() + let display_name = message + .author + .global_name + .as_deref() .unwrap_or(&message.author.name); // Include reply-to attribution if this message is a reply let author = if let Some(referenced) = &message.referenced_message { - let reply_author = referenced.author.global_name.as_deref() + let reply_author = referenced + .author + .global_name + .as_deref() .unwrap_or(&referenced.author.name); format!("{display_name} (replying to {reply_author})") } else { @@ -417,7 +425,9 @@ impl EventHandler for Handler { // DM filter: if no guild_id, it's a DM — only allow listed users if message.guild_id.is_none() { if permissions.dm_allowed_users.is_empty() - || !permissions.dm_allowed_users.contains(&message.author.id.get()) + || !permissions + .dm_allowed_users + .contains(&message.author.id.get()) { return; } @@ -444,8 +454,8 @@ impl EventHandler for Handler { .and_then(|v| v.as_u64()); let direct_match = allowed_channels.contains(&message.channel_id.get()); - let parent_match = parent_channel_id - .is_some_and(|pid| allowed_channels.contains(&pid)); + let parent_match = + parent_channel_id.is_some_and(|pid| allowed_channels.contains(&pid)); if !direct_match && !parent_match { return; @@ -540,11 +550,17 @@ async fn build_metadata(ctx: &Context, message: &Message) -> HashMap global display name > username let display_name = if let Some(member) = &message.member { member.nick.clone().unwrap_or_else(|| { - message.author.global_name.clone() + message + .author + .global_name + .clone() .unwrap_or_else(|| message.author.name.clone()) }) } else { - message.author.global_name.clone() + message + .author + .global_name + .clone() .unwrap_or_else(|| message.author.name.clone()) }; metadata.insert("sender_display_name".into(), display_name.into()); @@ -565,7 +581,10 @@ async fn build_metadata(ctx: &Context, message: &Message) -> HashMap HashMap crate::Result { let adapters = self.adapters.read().await; for (name, adapter) in adapters.iter() { - let stream = adapter.start().await + let stream = adapter + .start() + .await .with_context(|| format!("failed to start adapter '{name}'"))?; Self::spawn_forwarder(name.clone(), stream, self.fan_in_tx.clone()); } drop(adapters); - let receiver = self.fan_in_rx.write().await.take() + let receiver = self + .fan_in_rx + .write() + .await + .take() .context("start() already called")?; - Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(receiver))) + Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new( + receiver, + ))) } /// Register and start a new adapter at runtime. @@ -78,7 +86,9 @@ impl MessagingManager { let adapter: Arc = Arc::new(adapter); - let stream = adapter.start().await + let stream = adapter + .start() + .await .with_context(|| format!("failed to start adapter '{name}'"))?; Self::spawn_forwarder(name.clone(), stream, self.fan_in_tx.clone()); diff --git a/src/messaging/slack.rs b/src/messaging/slack.rs index 1d8889eb7..e652899ac 100644 --- a/src/messaging/slack.rs +++ b/src/messaging/slack.rs @@ -244,8 +244,7 @@ impl Messaging for SlackAdapter { bot_token: self.bot_token.clone(), }); - let callbacks = SlackSocketModeListenerCallbacks::new() - .with_push_events(handle_push_event); + let callbacks = SlackSocketModeListenerCallbacks::new().with_push_events(handle_push_event); let listener_environment = Arc::new( SlackClientEventsListenerEnvironment::new(client.clone()) @@ -331,12 +330,18 @@ impl Messaging for SlackAdapter { .context("failed to send slack thread reply")?; } } - OutboundResponse::File { filename, data, mime_type, caption } => { + OutboundResponse::File { + filename, + data, + mime_type, + caption, + } => { // Slack's v2 upload flow: get upload URL, upload bytes, complete let upload_url_response = session - .get_upload_url_external( - &SlackApiFilesGetUploadUrlExternalRequest::new(filename.clone(), data.len()), - ) + .get_upload_url_external(&SlackApiFilesGetUploadUrlExternalRequest::new( + filename.clone(), + data.len(), + )) .await .context("failed to get slack upload URL")?; @@ -350,13 +355,12 @@ impl Messaging for SlackAdapter { .context("failed to upload file to slack")?; let thread_ts = extract_thread_ts(message); - let file_complete = SlackApiFilesComplete::new(upload_url_response.file_id) - .with_title(filename); + let file_complete = + SlackApiFilesComplete::new(upload_url_response.file_id).with_title(filename); - let mut complete_request = SlackApiFilesCompleteUploadExternalRequest::new( - vec![file_complete], - ) - .with_channel_id(channel_id.clone()); + let mut complete_request = + SlackApiFilesCompleteUploadExternalRequest::new(vec![file_complete]) + .with_channel_id(channel_id.clone()); complete_request = complete_request.opt_initial_comment(caption); complete_request = complete_request.opt_thread_ts(thread_ts); @@ -367,8 +371,8 @@ impl Messaging for SlackAdapter { .context("failed to complete slack file upload")?; } OutboundResponse::Reaction(emoji) => { - let ts = extract_message_ts(message) - .context("missing slack_message_ts for reaction")?; + let ts = + extract_message_ts(message).context("missing slack_message_ts for reaction")?; let reaction_name = sanitize_reaction_name(&emoji); @@ -470,11 +474,7 @@ impl Messaging for SlackAdapter { Ok(()) } - async fn broadcast( - &self, - target: &str, - response: OutboundResponse, - ) -> crate::Result<()> { + async fn broadcast(&self, target: &str, response: OutboundResponse) -> crate::Result<()> { let (client, token) = self.create_session()?; let session = client.open_session(&token); let channel_id = SlackChannelId(target.to_string()); @@ -539,11 +539,7 @@ impl Messaging for SlackAdapter { HistoryMessage { author, - content: msg - .content - .text - .clone() - .unwrap_or_default(), + content: msg.content.text.clone().unwrap_or_default(), is_bot, } }) diff --git a/src/messaging/telegram.rs b/src/messaging/telegram.rs index 7af00b18b..d83545c3f 100644 --- a/src/messaging/telegram.rs +++ b/src/messaging/telegram.rs @@ -6,13 +6,13 @@ use crate::{Attachment, InboundMessage, MessageContent, OutboundResponse, Status use anyhow::Context as _; use arc_swap::ArcSwap; +use teloxide::Bot; use teloxide::payloads::setters::*; use teloxide::requests::{Request, Requester}; use teloxide::types::{ - ChatAction, ChatId, InputFile, MediaKind, MessageId, MessageKind, ReactionType, ReplyParameters, - UpdateKind, UserId, + ChatAction, ChatId, InputFile, MediaKind, MessageId, MessageKind, ReactionType, + ReplyParameters, UpdateKind, UserId, }; -use teloxide::Bot; use std::collections::HashMap; use std::sync::Arc; @@ -48,10 +48,7 @@ const MAX_MESSAGE_LENGTH: usize = 4096; const STREAM_EDIT_INTERVAL: std::time::Duration = std::time::Duration::from_millis(1000); impl TelegramAdapter { - pub fn new( - token: impl Into, - permissions: Arc>, - ) -> Self { + pub fn new(token: impl Into, permissions: Arc>) -> Self { let token = token.into(); let bot = Bot::new(&token); Self { @@ -249,7 +246,10 @@ impl Messaging for TelegramAdapter { .context("failed to send telegram message")?; } } - OutboundResponse::ThreadReply { thread_name: _, text } => { + OutboundResponse::ThreadReply { + thread_name: _, + text, + } => { self.stop_typing(&message.conversation_id).await; // Telegram doesn't have named threads. Reply to the source message instead. @@ -380,8 +380,10 @@ impl Messaging for TelegramAdapter { // Send one immediately, then repeat every 4 seconds. let handle = tokio::spawn(async move { loop { - if let Err(error) = - bot.send_chat_action(chat_id, ChatAction::Typing).send().await + if let Err(error) = bot + .send_chat_action(chat_id, ChatAction::Typing) + .send() + .await { tracing::debug!(%error, "failed to send typing indicator"); break; @@ -463,10 +465,7 @@ fn has_attachments(message: &teloxide::types::Message) -> bool { } /// Build `MessageContent` from a Telegram message. -fn build_content( - message: &teloxide::types::Message, - text: &Option, -) -> MessageContent { +fn build_content(message: &teloxide::types::Message, text: &Option) -> MessageContent { let attachments = extract_attachments(message); if attachments.is_empty() { @@ -635,10 +634,7 @@ fn build_metadata( metadata.insert("reply_to_text".into(), truncated.into()); } if let Some(from) = &reply.from { - metadata.insert( - "reply_to_author".into(), - build_display_name(from).into(), - ); + metadata.insert("reply_to_author".into(), build_display_name(from).into()); } } diff --git a/src/messaging/traits.rs b/src/messaging/traits.rs index c42147553..56631bfca 100644 --- a/src/messaging/traits.rs +++ b/src/messaging/traits.rs @@ -2,8 +2,8 @@ use crate::error::Result; use crate::{InboundMessage, OutboundResponse, StatusUpdate}; -use std::pin::Pin; use futures::Stream; +use std::pin::Pin; /// Message stream type. pub type InboundStream = Pin + Send>>; @@ -21,17 +21,17 @@ pub struct HistoryMessage { pub trait Messaging: Send + Sync + 'static { /// Unique name for this adapter. fn name(&self) -> &str; - + /// Start the adapter and return inbound message stream. fn start(&self) -> impl std::future::Future> + Send; - + /// Send a response to a message. fn respond( &self, message: &InboundMessage, response: OutboundResponse, ) -> impl std::future::Future> + Send; - + /// Send a status update. fn send_status( &self, @@ -40,7 +40,7 @@ pub trait Messaging: Send + Sync + 'static { ) -> impl std::future::Future> + Send { async { Ok(()) } } - + /// Broadcast a message. fn broadcast( &self, @@ -61,10 +61,10 @@ pub trait Messaging: Send + Sync + 'static { let _ = (message, limit); async { Ok(Vec::new()) } } - + /// Health check. fn health_check(&self) -> impl std::future::Future> + Send; - + /// Graceful shutdown. fn shutdown(&self) -> impl std::future::Future> + Send { async { Ok(()) } @@ -75,21 +75,23 @@ pub trait Messaging: Send + Sync + 'static { /// Use this when you need `Arc` for storing different adapters. pub trait MessagingDyn: Send + Sync + 'static { fn name(&self) -> &str; - - fn start<'a>(&'a self) -> Pin> + Send + 'a>>; - + + fn start<'a>( + &'a self, + ) -> Pin> + Send + 'a>>; + fn respond<'a>( &'a self, message: &'a InboundMessage, response: OutboundResponse, ) -> Pin> + Send + 'a>>; - + fn send_status<'a>( &'a self, message: &'a InboundMessage, status: StatusUpdate, ) -> Pin> + Send + 'a>>; - + fn broadcast<'a>( &'a self, target: &'a str, @@ -101,10 +103,13 @@ pub trait MessagingDyn: Send + Sync + 'static { message: &'a InboundMessage, limit: usize, ) -> Pin>> + Send + 'a>>; - - fn health_check<'a>(&'a self) -> Pin> + Send + 'a>>; - - fn shutdown<'a>(&'a self) -> Pin> + Send + 'a>>; + + fn health_check<'a>( + &'a self, + ) -> Pin> + Send + 'a>>; + + fn shutdown<'a>(&'a self) + -> Pin> + Send + 'a>>; } /// Blanket implementation: any type implementing Messaging automatically implements MessagingDyn. @@ -112,11 +117,13 @@ impl MessagingDyn for T { fn name(&self) -> &str { Messaging::name(self) } - - fn start<'a>(&'a self) -> Pin> + Send + 'a>> { + + fn start<'a>( + &'a self, + ) -> Pin> + Send + 'a>> { Box::pin(Messaging::start(self)) } - + fn respond<'a>( &'a self, message: &'a InboundMessage, @@ -124,7 +131,7 @@ impl MessagingDyn for T { ) -> Pin> + Send + 'a>> { Box::pin(Messaging::respond(self, message, response)) } - + fn send_status<'a>( &'a self, message: &'a InboundMessage, @@ -132,7 +139,7 @@ impl MessagingDyn for T { ) -> Pin> + Send + 'a>> { Box::pin(Messaging::send_status(self, message, status)) } - + fn broadcast<'a>( &'a self, target: &'a str, @@ -148,12 +155,16 @@ impl MessagingDyn for T { ) -> Pin>> + Send + 'a>> { Box::pin(Messaging::fetch_history(self, message, limit)) } - - fn health_check<'a>(&'a self) -> Pin> + Send + 'a>> { + + fn health_check<'a>( + &'a self, + ) -> Pin> + Send + 'a>> { Box::pin(Messaging::health_check(self)) } - - fn shutdown<'a>(&'a self) -> Pin> + Send + 'a>> { + + fn shutdown<'a>( + &'a self, + ) -> Pin> + Send + 'a>> { Box::pin(Messaging::shutdown(self)) } } diff --git a/src/messaging/webhook.rs b/src/messaging/webhook.rs index 431e899e0..49fc57660 100644 --- a/src/messaging/webhook.rs +++ b/src/messaging/webhook.rs @@ -9,9 +9,9 @@ use crate::messaging::traits::{InboundStream, Messaging}; use crate::{InboundMessage, MessageContent, OutboundResponse}; use anyhow::Context as _; +use axum::Router; use axum::extract::{Json, State}; use axum::http::StatusCode; -use axum::Router; use axum::routing::{get, post}; use serde::{Deserialize, Serialize}; From a0d4ffb2a4e8d23e248fce1c12356b4ab11d663d Mon Sep 17 00:00:00 2001 From: nmb Date: Tue, 17 Feb 2026 18:40:00 +0100 Subject: [PATCH 8/8] refactor(runtime): align daemon, hooks, and model plumbing - update startup/runtime wiring across daemon, main, hooks, and LLM integration\n- improve opencode worker/server flow and related test coverage\n- include matching docs updates for config and roadmap --- build.rs | 12 +- docs/content/docs/(configuration)/config.mdx | 2 + docs/content/docs/(deployment)/roadmap.mdx | 2 +- src/conversation.rs | 2 +- src/cron/scheduler.rs | 18 +- src/cron/store.rs | 6 +- src/daemon.rs | 44 ++-- src/db.rs | 24 +- src/hooks.rs | 4 +- src/hooks/cortex.rs | 6 +- src/hooks/spacebot.rs | 17 +- src/identity/files.rs | 21 +- src/lib.rs | 38 +++- src/llm/manager.rs | 68 ++++-- src/llm/model.rs | 193 ++++++++++------ src/main.rs | 159 +++++++++---- src/opencode/server.rs | 92 ++++---- src/opencode/worker.rs | 139 ++++++++---- src/prompts/engine.rs | 2 +- src/settings.rs | 2 +- src/skills.rs | 94 +++++--- src/update.rs | 11 +- tests/bulletin.rs | 59 +++-- tests/context_dump.rs | 227 ++++++++++++++----- tests/opencode_stream.rs | 21 +- 25 files changed, 834 insertions(+), 429 deletions(-) diff --git a/build.rs b/build.rs index ff9f77571..2507c62f5 100644 --- a/build.rs +++ b/build.rs @@ -15,7 +15,9 @@ fn main() { // Skip if bun isn't installed or node_modules is missing (CI without frontend deps) if !interface_dir.join("node_modules").exists() { - eprintln!("cargo:warning=interface/node_modules not found, skipping frontend build. Run `bun install` in interface/"); + eprintln!( + "cargo:warning=interface/node_modules not found, skipping frontend build. Run `bun install` in interface/" + ); ensure_dist_dir(); return; } @@ -28,10 +30,14 @@ fn main() { match status { Ok(s) if s.success() => {} Ok(s) => { - eprintln!("cargo:warning=frontend build exited with {s}, the binary will serve a stale or empty UI"); + eprintln!( + "cargo:warning=frontend build exited with {s}, the binary will serve a stale or empty UI" + ); } Err(e) => { - eprintln!("cargo:warning=failed to run `bun run build`: {e}. Install bun to build the frontend."); + eprintln!( + "cargo:warning=failed to run `bun run build`: {e}. Install bun to build the frontend." + ); ensure_dist_dir(); } } diff --git a/docs/content/docs/(configuration)/config.mdx b/docs/content/docs/(configuration)/config.mdx index 88ca4630c..334f02f4b 100644 --- a/docs/content/docs/(configuration)/config.mdx +++ b/docs/content/docs/(configuration)/config.mdx @@ -24,6 +24,7 @@ spacebot --config /path/to.toml # CLI override anthropic_key = "env:ANTHROPIC_API_KEY" openai_key = "env:OPENAI_API_KEY" openrouter_key = "env:OPENROUTER_API_KEY" +ollama_key = "env:OLLAMA_API_KEY" # --- Instance Defaults --- # All agents inherit these. Individual agents can override any field. @@ -168,6 +169,7 @@ Model names include the provider as a prefix: | Anthropic | `anthropic/` | `anthropic/claude-sonnet-4-20250514` | | OpenAI | `openai/` | `openai/gpt-4o` | | OpenRouter | `openrouter//` | `openrouter/anthropic/claude-sonnet-4-20250514` | +| Ollama Cloud | `ollama/` | `ollama/gpt-oss:20b` | You can mix providers across process types. See [Routing](/docs/routing) for the full routing system. diff --git a/docs/content/docs/(deployment)/roadmap.mdx b/docs/content/docs/(deployment)/roadmap.mdx index 6711bdd2d..52401c0a8 100644 --- a/docs/content/docs/(deployment)/roadmap.mdx +++ b/docs/content/docs/(deployment)/roadmap.mdx @@ -18,7 +18,7 @@ The full message-in → LLM → response-out pipeline is wired end-to-end across - **Config** — hierarchical TOML with `Config`, `AgentConfig`, `ResolvedAgentConfig`, `Binding`, `MessagingConfig`. File watcher with event filtering and content hash debounce for hot-reload. - **Multi-agent** — per-agent database isolation, `Agent` struct bundles all dependencies - **Database connections** — SQLite + LanceDB + redb per-agent, migrations for all tables -- **LLM** — `SpacebotModel` implements Rig's `CompletionModel`, routes through `LlmManager` via HTTP with retries and fallback chains across 11 providers (Anthropic, OpenAI, OpenRouter, Z.ai, Groq, Together, Fireworks, DeepSeek, xAI, Mistral, OpenCode Zen) +- **LLM** — `SpacebotModel` implements Rig's `CompletionModel`, routes through `LlmManager` via HTTP with retries and fallback chains across 12 providers (Anthropic, OpenAI, OpenRouter, Ollama Cloud, Z.ai, Groq, Together, Fireworks, DeepSeek, xAI, Mistral, OpenCode Zen) - **Model routing** — `RoutingConfig` with process-type defaults, task overrides, fallback chains - **Memory** — full stack: types, SQLite store (CRUD + graph), LanceDB (embeddings + vector + FTS), fastembed, hybrid search (RRF fusion). `memory_type` filter wired end-to-end through SearchConfig. `total_cmp` for safe sorting. - **Memory maintenance** — decay + prune implemented diff --git a/src/conversation.rs b/src/conversation.rs index 47a3dae34..0097a7d42 100644 --- a/src/conversation.rs +++ b/src/conversation.rs @@ -1,8 +1,8 @@ //! Conversation history and context management. pub mod channels; -pub mod history; pub mod context; +pub mod history; pub use channels::ChannelStore; pub use history::{ConversationLogger, ProcessRunLogger, TimelineItem}; diff --git a/src/cron/scheduler.rs b/src/cron/scheduler.rs index 4e1b277ff..ebe33bd7d 100644 --- a/src/cron/scheduler.rs +++ b/src/cron/scheduler.rs @@ -14,7 +14,7 @@ use chrono::Timelike; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::RwLock; -use tokio::time::{interval, Duration}; +use tokio::time::{Duration, interval}; /// A cron job definition loaded from the database. #[derive(Debug, Clone)] @@ -165,9 +165,7 @@ impl Scheduler { // Look up interval before entering the loop let interval_secs = { let j = jobs.read().await; - j.get(&job_id) - .map(|j| j.interval_secs) - .unwrap_or(3600) + j.get(&job_id).map(|j| j.interval_secs).unwrap_or(3600) }; let mut ticker = interval(Duration::from_secs(interval_secs)); @@ -323,9 +321,9 @@ impl Scheduler { if let Some(job) = job { if !job.enabled { - return Err(crate::error::Error::Other( - anyhow::anyhow!("cron job is disabled"), - )); + return Err(crate::error::Error::Other(anyhow::anyhow!( + "cron job is disabled" + ))); } tracing::info!(cron_id = %job_id, "cron job triggered manually"); @@ -453,11 +451,7 @@ async fn run_cron_job(job: &CronJob, context: &CronContext) -> Result<()> { } else { None }; - if let Err(error) = context - .store - .log_execution(&job.id, true, summary) - .await - { + if let Err(error) = context.store.log_execution(&job.id, true, summary).await { tracing::warn!(%error, "failed to log cron execution"); } diff --git a/src/cron/store.rs b/src/cron/store.rs index f66882bf9..e61fd8fce 100644 --- a/src/cron/store.rs +++ b/src/cron/store.rs @@ -170,7 +170,11 @@ impl CronStore { } /// Load execution history for a specific cron job. - pub async fn load_executions(&self, cron_id: &str, limit: i64) -> Result> { + pub async fn load_executions( + &self, + cron_id: &str, + limit: i64, + ) -> Result> { let rows = sqlx::query( r#" SELECT id, executed_at, success, result_summary diff --git a/src/daemon.rs b/src/daemon.rs index 61a4d9732..cf6bdb15a 100644 --- a/src/daemon.rs +++ b/src/daemon.rs @@ -24,13 +24,8 @@ pub enum IpcCommand { #[serde(tag = "result", rename_all = "snake_case")] pub enum IpcResponse { Ok, - Status { - pid: u32, - uptime_seconds: u64, - }, - Error { - message: String, - }, + Status { pid: u32, uptime_seconds: u64 }, + Error { message: String }, } /// Paths for daemon runtime files, all derived from the instance directory. @@ -84,8 +79,12 @@ pub fn is_running(paths: &DaemonPaths) -> Option { /// Daemonize the current process. Returns in the child; the parent prints /// a message and exits. pub fn daemonize(paths: &DaemonPaths) -> anyhow::Result<()> { - std::fs::create_dir_all(&paths.log_dir) - .with_context(|| format!("failed to create log directory: {}", paths.log_dir.display()))?; + std::fs::create_dir_all(&paths.log_dir).with_context(|| { + format!( + "failed to create log directory: {}", + paths.log_dir.display() + ) + })?; let stdout = std::fs::OpenOptions::new() .create(true) @@ -105,7 +104,9 @@ pub fn daemonize(paths: &DaemonPaths) -> anyhow::Result<()> { .stdout(stdout) .stderr(stderr); - daemonize.start().map_err(|error| anyhow!("failed to daemonize: {error}"))?; + daemonize + .start() + .map_err(|error| anyhow!("failed to daemonize: {error}"))?; Ok(()) } @@ -140,9 +141,7 @@ pub fn init_foreground_tracing(debug: bool) { tracing_subscriber::EnvFilter::new("info") }; - tracing_subscriber::fmt() - .with_env_filter(filter) - .init(); + tracing_subscriber::fmt().with_env_filter(filter).init(); } /// Start the IPC server. Returns a shutdown receiver that the main event @@ -152,8 +151,9 @@ pub async fn start_ipc_server( ) -> anyhow::Result<(watch::Receiver, tokio::task::JoinHandle<()>)> { // Clean up any stale socket file if paths.socket.exists() { - std::fs::remove_file(&paths.socket) - .with_context(|| format!("failed to remove stale socket: {}", paths.socket.display()))?; + std::fs::remove_file(&paths.socket).with_context(|| { + format!("failed to remove stale socket: {}", paths.socket.display()) + })?; } let listener = UnixListener::bind(&paths.socket) @@ -170,7 +170,9 @@ pub async fn start_ipc_server( let shutdown_tx = shutdown_tx.clone(); let uptime = start_time.elapsed(); tokio::spawn(async move { - if let Err(error) = handle_ipc_connection(stream, &shutdown_tx, uptime).await { + if let Err(error) = + handle_ipc_connection(stream, &shutdown_tx, uptime).await + { tracing::warn!(%error, "IPC connection handler failed"); } }); @@ -213,12 +215,10 @@ async fn handle_ipc_connection( shutdown_tx.send(true).ok(); IpcResponse::Ok } - IpcCommand::Status => { - IpcResponse::Status { - pid: std::process::id(), - uptime_seconds: uptime.as_secs(), - } - } + IpcCommand::Status => IpcResponse::Status { + pid: std::process::id(), + uptime_seconds: uptime.as_secs(), + }, }; let mut response_bytes = serde_json::to_vec(&response)?; diff --git a/src/db.rs b/src/db.rs index 6e5032d60..5a1a3c54e 100644 --- a/src/db.rs +++ b/src/db.rs @@ -9,10 +9,10 @@ use std::path::Path; pub struct Db { /// SQLite pool for relational data. pub sqlite: SqlitePool, - + /// LanceDB connection for vector storage. pub lance: lancedb::Connection, - + /// Redb database for key-value config. pub redb: Arc, } @@ -27,35 +27,39 @@ impl Db { let sqlite = SqlitePool::connect(&sqlite_url) .await .with_context(|| "failed to connect to SQLite")?; - + // Run migrations sqlx::migrate!("./migrations") .run(&sqlite) .await .with_context(|| "failed to run database migrations")?; - + // LanceDB let lance_path = data_dir.join("lancedb"); - std::fs::create_dir_all(&lance_path) - .with_context(|| format!("failed to create LanceDB directory: {}", lance_path.display()))?; - + std::fs::create_dir_all(&lance_path).with_context(|| { + format!( + "failed to create LanceDB directory: {}", + lance_path.display() + ) + })?; + let lance = lancedb::connect(lance_path.to_str().unwrap_or("./lancedb")) .execute() .await .map_err(|e| DbError::LanceConnect(e.to_string()))?; - + // Redb let redb_path = data_dir.join("config.redb"); let redb = redb::Database::create(&redb_path) .with_context(|| format!("failed to create redb at: {}", redb_path.display()))?; - + Ok(Self { sqlite, lance, redb: Arc::new(redb), }) } - + /// Close all database connections gracefully. pub async fn close(self) { self.sqlite.close().await; diff --git a/src/hooks.rs b/src/hooks.rs index cc46ae59b..b1b02c0fc 100644 --- a/src/hooks.rs +++ b/src/hooks.rs @@ -1,7 +1,7 @@ //! Prompt hooks for observing and controlling agent behavior. -pub mod spacebot; pub mod cortex; +pub mod spacebot; -pub use spacebot::SpacebotHook; pub use cortex::CortexHook; +pub use spacebot::SpacebotHook; diff --git a/src/hooks/cortex.rs b/src/hooks/cortex.rs index 564a0ec49..538762f69 100644 --- a/src/hooks/cortex.rs +++ b/src/hooks/cortex.rs @@ -24,11 +24,7 @@ impl PromptHook for CortexHook where M: CompletionModel, { - async fn on_completion_call( - &self, - _prompt: &Message, - _history: &[Message], - ) -> HookAction { + async fn on_completion_call(&self, _prompt: &Message, _history: &[Message]) -> HookAction { // Cortex observes but doesn't intervene tracing::trace!("cortex: completion call observed"); HookAction::Continue diff --git a/src/hooks/spacebot.rs b/src/hooks/spacebot.rs index 55b6bab73..6beeb3660 100644 --- a/src/hooks/spacebot.rs +++ b/src/hooks/spacebot.rs @@ -63,7 +63,8 @@ impl SpacebotHook { // Google API keys Regex::new(r"AIza[0-9A-Za-z_-]{35}").expect("hardcoded regex"), // Discord bot tokens (base64 user ID . timestamp . HMAC) - Regex::new(r"[MN][A-Za-z0-9]{23,}\.[A-Za-z0-9_-]{6}\.[A-Za-z0-9_-]{27,}").expect("hardcoded regex"), + Regex::new(r"[MN][A-Za-z0-9]{23,}\.[A-Za-z0-9_-]{6}\.[A-Za-z0-9_-]{27,}") + .expect("hardcoded regex"), // Slack bot tokens Regex::new(r"xoxb-[0-9]{10,}-[0-9A-Za-z-]+").expect("hardcoded regex"), // Slack app tokens @@ -89,11 +90,7 @@ impl PromptHook for SpacebotHook where M: CompletionModel, { - async fn on_completion_call( - &self, - _prompt: &Message, - _history: &[Message], - ) -> HookAction { + async fn on_completion_call(&self, _prompt: &Message, _history: &[Message]) -> HookAction { // Log the completion call but don't block it tracing::debug!( process_id = %self.process_id, @@ -178,12 +175,16 @@ where leak_prefix = %&leak[..leak.len().min(8)], "secret leak detected in tool output, terminating agent" ); - return HookAction::Terminate { reason: "Tool output contained a secret. Agent terminated to prevent exfiltration.".into() }; + return HookAction::Terminate { + reason: "Tool output contained a secret. Agent terminated to prevent exfiltration." + .into(), + }; } // Cap the result stored in the broadcast event to avoid blowing up // event subscribers with multi-MB tool results. - let capped_result = crate::tools::truncate_output(result, crate::tools::MAX_TOOL_OUTPUT_BYTES); + let capped_result = + crate::tools::truncate_output(result, crate::tools::MAX_TOOL_OUTPUT_BYTES); let event = ProcessEvent::ToolCompleted { agent_id: self.agent_id.clone(), process_id: self.process_id.clone(), diff --git a/src/identity/files.rs b/src/identity/files.rs index 29ceb843f..e5b5c8d2e 100644 --- a/src/identity/files.rs +++ b/src/identity/files.rs @@ -47,9 +47,18 @@ impl Identity { /// Default identity file templates for new agents. const DEFAULT_IDENTITY_FILES: &[(&str, &str)] = &[ - ("SOUL.md", "\n"), - ("IDENTITY.md", "\n"), - ("USER.md", "\n"), + ( + "SOUL.md", + "\n", + ), + ( + "IDENTITY.md", + "\n", + ), + ( + "USER.md", + "\n", + ), ]; /// Write template identity files into an agent's workspace if they don't already exist. @@ -59,9 +68,9 @@ pub async fn scaffold_identity_files(workspace: &Path) -> crate::error::Result<( for (filename, content) in DEFAULT_IDENTITY_FILES { let target = workspace.join(filename); if !target.exists() { - tokio::fs::write(&target, content) - .await - .with_context(|| format!("failed to write identity template: {}", target.display()))?; + tokio::fs::write(&target, content).await.with_context(|| { + format!("failed to write identity template: {}", target.display()) + })?; tracing::info!(path = %target.display(), "wrote identity template"); } } diff --git a/src/lib.rs b/src/lib.rs index d97ef5901..6f6ff7717 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,10 +4,10 @@ pub mod agent; pub mod api; pub mod config; pub mod conversation; +pub mod cron; pub mod daemon; pub mod db; pub mod error; -pub mod cron; pub mod hooks; pub mod identity; pub mod llm; @@ -103,6 +103,12 @@ pub enum ProcessEvent { channel_id: ChannelId, conclusion: String, }, + BranchFailed { + agent_id: AgentId, + branch_id: BranchId, + channel_id: ChannelId, + error: String, + }, WorkerStarted { agent_id: AgentId, worker_id: WorkerId, @@ -180,8 +186,12 @@ pub struct AgentDeps { } impl AgentDeps { - pub fn memory_search(&self) -> &Arc { &self.memory_search } - pub fn llm_manager(&self) -> &Arc { &self.llm_manager } + pub fn memory_search(&self) -> &Arc { + &self.memory_search + } + pub fn llm_manager(&self) -> &Arc { + &self.llm_manager + } /// Load the current routing config snapshot. pub fn routing(&self) -> arc_swap::Guard> { @@ -297,9 +307,21 @@ pub enum StatusUpdate { Thinking, /// Cancel the typing indicator (e.g. when the skip tool fires). StopTyping, - ToolStarted { tool_name: String }, - ToolCompleted { tool_name: String }, - BranchStarted { branch_id: BranchId }, - WorkerStarted { worker_id: WorkerId, task: String }, - WorkerCompleted { worker_id: WorkerId, result: String }, + ToolStarted { + tool_name: String, + }, + ToolCompleted { + tool_name: String, + }, + BranchStarted { + branch_id: BranchId, + }, + WorkerStarted { + worker_id: WorkerId, + task: String, + }, + WorkerCompleted { + worker_id: WorkerId, + result: String, + }, } diff --git a/src/llm/manager.rs b/src/llm/manager.rs index 80c476e2e..40356acc5 100644 --- a/src/llm/manager.rs +++ b/src/llm/manager.rs @@ -38,27 +38,65 @@ impl LlmManager { /// Get the appropriate API key for a provider. pub fn get_api_key(&self, provider: &str) -> Result { match provider { - "anthropic" => self.config.anthropic_key.clone() + "anthropic" => self + .config + .anthropic_key + .clone() .ok_or_else(|| LlmError::MissingProviderKey("anthropic".into()).into()), - "openai" => self.config.openai_key.clone() + "openai" => self + .config + .openai_key + .clone() .ok_or_else(|| LlmError::MissingProviderKey("openai".into()).into()), - "openrouter" => self.config.openrouter_key.clone() + "openrouter" => self + .config + .openrouter_key + .clone() .ok_or_else(|| LlmError::MissingProviderKey("openrouter".into()).into()), - "zhipu" => self.config.zhipu_key.clone() + "ollama" => self + .config + .ollama_key + .clone() + .ok_or_else(|| LlmError::MissingProviderKey("ollama".into()).into()), + "zhipu" => self + .config + .zhipu_key + .clone() .ok_or_else(|| LlmError::MissingProviderKey("zhipu".into()).into()), - "groq" => self.config.groq_key.clone() + "groq" => self + .config + .groq_key + .clone() .ok_or_else(|| LlmError::MissingProviderKey("groq".into()).into()), - "together" => self.config.together_key.clone() + "together" => self + .config + .together_key + .clone() .ok_or_else(|| LlmError::MissingProviderKey("together".into()).into()), - "fireworks" => self.config.fireworks_key.clone() + "fireworks" => self + .config + .fireworks_key + .clone() .ok_or_else(|| LlmError::MissingProviderKey("fireworks".into()).into()), - "deepseek" => self.config.deepseek_key.clone() + "deepseek" => self + .config + .deepseek_key + .clone() .ok_or_else(|| LlmError::MissingProviderKey("deepseek".into()).into()), - "xai" => self.config.xai_key.clone() + "xai" => self + .config + .xai_key + .clone() .ok_or_else(|| LlmError::MissingProviderKey("xai".into()).into()), - "mistral" => self.config.mistral_key.clone() + "mistral" => self + .config + .mistral_key + .clone() .ok_or_else(|| LlmError::MissingProviderKey("mistral".into()).into()), - "opencode-zen" => self.config.opencode_zen_key.clone() + "opencode-zen" => self + .config + .opencode_zen_key + .clone() .ok_or_else(|| LlmError::MissingProviderKey("opencode-zen".into()).into()), _ => Err(LlmError::UnknownProvider(provider.into()).into()), } @@ -81,7 +119,9 @@ impl LlmManager { /// Record that a model hit a rate limit. pub async fn record_rate_limit(&self, model_name: &str) { - self.rate_limited.write().await + self.rate_limited + .write() + .await .insert(model_name.to_string(), Instant::now()); tracing::warn!(model = %model_name, "model rate limited, entering cooldown"); } @@ -98,7 +138,9 @@ impl LlmManager { /// Clean up expired rate limit entries. pub async fn cleanup_rate_limits(&self, cooldown_secs: u64) { - self.rate_limited.write().await + self.rate_limited + .write() + .await .retain(|_, limited_at| limited_at.elapsed().as_secs() < cooldown_secs); } } diff --git a/src/llm/model.rs b/src/llm/model.rs index 6394a4354..e8e2c156a 100644 --- a/src/llm/model.rs +++ b/src/llm/model.rs @@ -2,12 +2,10 @@ use crate::llm::manager::LlmManager; use crate::llm::routing::{ - self, RoutingConfig, MAX_FALLBACK_ATTEMPTS, MAX_RETRIES_PER_MODEL, RETRY_BASE_DELAY_MS, + self, MAX_FALLBACK_ATTEMPTS, MAX_RETRIES_PER_MODEL, RETRY_BASE_DELAY_MS, RoutingConfig, }; -use rig::completion::{ - self, CompletionError, CompletionModel, CompletionRequest, GetTokenUsage, -}; +use rig::completion::{self, CompletionError, CompletionModel, CompletionRequest, GetTokenUsage}; use rig::message::{ AssistantContent, DocumentSourceKind, Image, Message, MimeType, Text, ToolCall, ToolFunction, ToolResult, UserContent, @@ -50,9 +48,15 @@ pub struct SpacebotModel { } impl SpacebotModel { - pub fn provider(&self) -> &str { &self.provider } - pub fn model_name(&self) -> &str { &self.model_name } - pub fn full_model_name(&self) -> &str { &self.full_model_name } + pub fn provider(&self) -> &str { + &self.provider + } + pub fn model_name(&self) -> &str { + &self.model_name + } + pub fn full_model_name(&self) -> &str { + &self.full_model_name + } /// Attach routing config for fallback behavior. pub fn with_routing(mut self, routing: RoutingConfig) -> Self { @@ -69,6 +73,7 @@ impl SpacebotModel { "anthropic" => self.call_anthropic(request).await, "openai" => self.call_openai(request).await, "openrouter" => self.call_openrouter(request).await, + "ollama" => self.call_ollama(request).await, "zhipu" => self.call_zhipu(request).await, "groq" => self.call_groq(request).await, "together" => self.call_together(request).await, @@ -93,10 +98,7 @@ impl SpacebotModel { &self, model_name: &str, request: &CompletionRequest, - ) -> Result< - completion::CompletionResponse, - (CompletionError, bool), - > { + ) -> Result, (CompletionError, bool)> { let model = if model_name == self.full_model_name { self.clone() } else { @@ -203,11 +205,16 @@ impl CompletionModel for SpacebotModel { "primary model in rate-limit cooldown, skipping to fallbacks" ); } else { - match self.attempt_with_retries(&self.full_model_name, &request).await { + match self + .attempt_with_retries(&self.full_model_name, &request) + .await + { Ok(response) => return Ok(response), Err((error, was_rate_limit)) => { if was_rate_limit { - self.llm_manager.record_rate_limit(&self.full_model_name).await; + self.llm_manager + .record_rate_limit(&self.full_model_name) + .await; } if fallbacks.is_empty() { // No fallbacks — this is the final error @@ -224,7 +231,11 @@ impl CompletionModel for SpacebotModel { // Try fallback chain, each with their own retry loop for (index, fallback_name) in fallbacks.iter().take(MAX_FALLBACK_ATTEMPTS).enumerate() { - if self.llm_manager.is_rate_limited(fallback_name, cooldown).await { + if self + .llm_manager + .is_rate_limited(fallback_name, cooldown) + .await + { tracing::debug!( fallback = %fallback_name, "fallback model in cooldown, skipping" @@ -324,15 +335,17 @@ impl SpacebotModel { .map_err(|e| CompletionError::ProviderError(e.to_string()))?; let status = response.status(); - let response_text = response - .text() - .await - .map_err(|e| CompletionError::ProviderError(format!("failed to read response body: {e}")))?; - - let response_body: serde_json::Value = serde_json::from_str(&response_text) - .map_err(|e| CompletionError::ProviderError(format!( - "Anthropic response ({status}) is not valid JSON: {e}\nBody: {}", truncate_body(&response_text) - )))?; + let response_text = response.text().await.map_err(|e| { + CompletionError::ProviderError(format!("failed to read response body: {e}")) + })?; + + let response_body: serde_json::Value = + serde_json::from_str(&response_text).map_err(|e| { + CompletionError::ProviderError(format!( + "Anthropic response ({status}) is not valid JSON: {e}\nBody: {}", + truncate_body(&response_text) + )) + })?; if !status.is_success() { let message = response_body["error"]["message"] @@ -409,15 +422,17 @@ impl SpacebotModel { .map_err(|e| CompletionError::ProviderError(e.to_string()))?; let status = response.status(); - let response_text = response - .text() - .await - .map_err(|e| CompletionError::ProviderError(format!("failed to read response body: {e}")))?; - - let response_body: serde_json::Value = serde_json::from_str(&response_text) - .map_err(|e| CompletionError::ProviderError(format!( - "OpenAI response ({status}) is not valid JSON: {e}\nBody: {}", truncate_body(&response_text) - )))?; + let response_text = response.text().await.map_err(|e| { + CompletionError::ProviderError(format!("failed to read response body: {e}")) + })?; + + let response_body: serde_json::Value = + serde_json::from_str(&response_text).map_err(|e| { + CompletionError::ProviderError(format!( + "OpenAI response ({status}) is not valid JSON: {e}\nBody: {}", + truncate_body(&response_text) + )) + })?; if !status.is_success() { let message = response_body["error"]["message"] @@ -496,15 +511,17 @@ impl SpacebotModel { .map_err(|e| CompletionError::ProviderError(e.to_string()))?; let status = response.status(); - let response_text = response - .text() - .await - .map_err(|e| CompletionError::ProviderError(format!("failed to read response body: {e}")))?; - - let response_body: serde_json::Value = serde_json::from_str(&response_text) - .map_err(|e| CompletionError::ProviderError(format!( - "OpenRouter response ({status}) is not valid JSON: {e}\nBody: {}", truncate_body(&response_text) - )))?; + let response_text = response.text().await.map_err(|e| { + CompletionError::ProviderError(format!("failed to read response body: {e}")) + })?; + + let response_body: serde_json::Value = + serde_json::from_str(&response_text).map_err(|e| { + CompletionError::ProviderError(format!( + "OpenRouter response ({status}) is not valid JSON: {e}\nBody: {}", + truncate_body(&response_text) + )) + })?; if !status.is_success() { let message = response_body["error"]["message"] @@ -582,15 +599,17 @@ impl SpacebotModel { .map_err(|e| CompletionError::ProviderError(e.to_string()))?; let status = response.status(); - let response_text = response - .text() - .await - .map_err(|e| CompletionError::ProviderError(format!("failed to read response body: {e}")))?; - - let response_body: serde_json::Value = serde_json::from_str(&response_text) - .map_err(|e| CompletionError::ProviderError(format!( - "Z.ai response ({status}) is not valid JSON: {e}\nBody: {}", truncate_body(&response_text) - )))?; + let response_text = response.text().await.map_err(|e| { + CompletionError::ProviderError(format!("failed to read response body: {e}")) + })?; + + let response_body: serde_json::Value = + serde_json::from_str(&response_text).map_err(|e| { + CompletionError::ProviderError(format!( + "Z.ai response ({status}) is not valid JSON: {e}\nBody: {}", + truncate_body(&response_text) + )) + })?; if !status.is_success() { let message = response_body["error"]["message"] @@ -604,6 +623,19 @@ impl SpacebotModel { parse_openai_response(response_body, "Z.ai") } + async fn call_ollama( + &self, + request: CompletionRequest, + ) -> Result, CompletionError> { + self.call_openai_compatible( + request, + "ollama", + "Ollama", + "https://ollama.com/v1/chat/completions", + ) + .await + } + /// Generic OpenAI-compatible API call. /// Used by providers that implement the OpenAI chat completions format. async fn call_openai_compatible( @@ -672,15 +704,17 @@ impl SpacebotModel { .map_err(|e| CompletionError::ProviderError(e.to_string()))?; let status = response.status(); - let response_text = response - .text() - .await - .map_err(|e| CompletionError::ProviderError(format!("failed to read response body: {e}")))?; - - let response_body: serde_json::Value = serde_json::from_str(&response_text) - .map_err(|e| CompletionError::ProviderError(format!( - "{provider_display_name} response ({status}) is not valid JSON: {e}\nBody: {}", truncate_body(&response_text) - )))?; + let response_text = response.text().await.map_err(|e| { + CompletionError::ProviderError(format!("failed to read response body: {e}")) + })?; + + let response_body: serde_json::Value = + serde_json::from_str(&response_text).map_err(|e| { + CompletionError::ProviderError(format!( + "{provider_display_name} response ({status}) is not valid JSON: {e}\nBody: {}", + truncate_body(&response_text) + )) + })?; if !status.is_success() { let message = response_body["error"]["message"] @@ -703,7 +737,8 @@ impl SpacebotModel { "groq", "Groq", "https://api.groq.com/openai/v1/chat/completions", - ).await + ) + .await } async fn call_together( @@ -715,7 +750,8 @@ impl SpacebotModel { "together", "Together AI", "https://api.together.xyz/v1/chat/completions", - ).await + ) + .await } async fn call_fireworks( @@ -727,7 +763,8 @@ impl SpacebotModel { "fireworks", "Fireworks AI", "https://api.fireworks.ai/inference/v1/chat/completions", - ).await + ) + .await } async fn call_deepseek( @@ -739,7 +776,8 @@ impl SpacebotModel { "deepseek", "DeepSeek", "https://api.deepseek.com/v1/chat/completions", - ).await + ) + .await } async fn call_xai( @@ -751,7 +789,8 @@ impl SpacebotModel { "xai", "xAI", "https://api.x.ai/v1/chat/completions", - ).await + ) + .await } async fn call_mistral( @@ -763,7 +802,8 @@ impl SpacebotModel { "mistral", "Mistral AI", "https://api.mistral.ai/v1/chat/completions", - ).await + ) + .await } async fn call_opencode_zen( @@ -775,7 +815,8 @@ impl SpacebotModel { "opencode-zen", "OpenCode Zen", "https://opencode.ai/zen/v1/chat/completions", - ).await + ) + .await } } @@ -1003,7 +1044,10 @@ fn make_tool_call(id: String, name: String, arguments: serde_json::Value) -> Too ToolCall { id, call_id: None, - function: ToolFunction { name: name.trim().to_string(), arguments }, + function: ToolFunction { + name: name.trim().to_string(), + arguments, + }, signature: None, additional_params: None, } @@ -1028,8 +1072,9 @@ fn parse_anthropic_response( let id = block["id"].as_str().unwrap_or("").to_string(); let name = block["name"].as_str().unwrap_or("").to_string(); let arguments = block["input"].clone(); - assistant_content - .push(AssistantContent::ToolCall(make_tool_call(id, name, arguments))); + assistant_content.push(AssistantContent::ToolCall(make_tool_call( + id, name, arguments, + ))); } _ => {} } @@ -1081,13 +1126,15 @@ fn parse_openai_response( .as_str() .and_then(|s| serde_json::from_str(s).ok()) .unwrap_or(serde_json::json!({})); - assistant_content - .push(AssistantContent::ToolCall(make_tool_call(id, name, arguments))); + assistant_content.push(AssistantContent::ToolCall(make_tool_call( + id, name, arguments, + ))); } } - let result_choice = OneOrMany::many(assistant_content) - .map_err(|_| CompletionError::ResponseError(format!("empty response from {provider_label}")))?; + let result_choice = OneOrMany::many(assistant_content).map_err(|_| { + CompletionError::ResponseError(format!("empty response from {provider_label}")) + })?; let input_tokens = body["usage"]["prompt_tokens"].as_u64().unwrap_or(0); let output_tokens = body["usage"]["completion_tokens"].as_u64().unwrap_or(0); diff --git a/src/main.rs b/src/main.rs index 8e8af3a91..efeffb891 100644 --- a/src/main.rs +++ b/src/main.rs @@ -89,8 +89,7 @@ fn cmd_start( config_path.clone() } else if spacebot::config::Config::needs_onboarding() { // Returns Some(path) if CLI wizard ran, None if user chose the UI. - spacebot::config::run_onboarding() - .with_context(|| "onboarding failed")? + spacebot::config::run_onboarding().with_context(|| "onboarding failed")? } else { None }; @@ -199,7 +198,10 @@ fn cmd_status() -> anyhow::Result<()> { runtime.block_on(async { match spacebot::daemon::send_command(&paths, spacebot::daemon::IpcCommand::Status).await { - Ok(spacebot::daemon::IpcResponse::Status { pid, uptime_seconds }) => { + Ok(spacebot::daemon::IpcResponse::Status { + pid, + uptime_seconds, + }) => { let hours = uptime_seconds / 3600; let minutes = (uptime_seconds % 3600) / 60; let seconds = uptime_seconds % 60; @@ -225,20 +227,18 @@ fn cmd_status() -> anyhow::Result<()> { Ok(()) } -fn load_config(config_path: &Option) -> anyhow::Result { +fn load_config( + config_path: &Option, +) -> anyhow::Result { if let Some(path) = config_path { spacebot::config::Config::load_from_path(path) .with_context(|| format!("failed to load config from {}", path.display())) } else { - spacebot::config::Config::load() - .with_context(|| "failed to load configuration") + spacebot::config::Config::load().with_context(|| "failed to load configuration") } } -async fn run( - config: spacebot::config::Config, - foreground: bool, -) -> anyhow::Result<()> { +async fn run(config: spacebot::config::Config, foreground: bool) -> anyhow::Result<()> { let paths = spacebot::daemon::DaemonPaths::new(&config.instance_dir); tracing::info!("starting spacebot"); @@ -253,7 +253,9 @@ async fn run( let (provider_tx, mut provider_rx) = mpsc::channel::(1); // Start HTTP API server if enabled - let api_state = Arc::new(spacebot::api::ApiState::new_with_provider_sender(provider_tx)); + let api_state = Arc::new(spacebot::api::ApiState::new_with_provider_sender( + provider_tx, + )); // Start background update checker spacebot::update::spawn_update_checker(api_state.update_status.clone()); @@ -279,8 +281,10 @@ async fn run( tracing::info!("No LLM provider keys configured. Starting in setup mode."); if foreground { eprintln!("No LLM provider keys configured."); - eprintln!("Please add a provider key via the web UI at http://{}:{}", - config.api.bind, config.api.port); + eprintln!( + "Please add a provider key via the web UI at http://{}:{}", + config.api.bind, config.api.port + ); } } @@ -289,21 +293,20 @@ async fn run( let llm_manager = Arc::new( spacebot::llm::LlmManager::new(config.llm.clone()) .await - .with_context(|| "failed to initialize LLM manager")? + .with_context(|| "failed to initialize LLM manager")?, ); // Shared embedding model (stateless, agent-agnostic) let embedding_cache_dir = config.instance_dir.join("embedding_cache"); let embedding_model = Arc::new( spacebot::memory::EmbeddingModel::new(&embedding_cache_dir) - .context("failed to initialize embedding model")? + .context("failed to initialize embedding model")?, ); tracing::info!("shared resources initialized"); // Initialize the language for all text lookups (must happen before PromptEngine/tools) - spacebot::prompts::text::init("en") - .with_context(|| "failed to initialize language")?; + spacebot::prompts::text::init("en").with_context(|| "failed to initialize language")?; // Create the PromptEngine with bundled templates (no file watching, no user overrides) let prompt_engine = spacebot::prompts::PromptEngine::new("en") @@ -387,7 +390,10 @@ async fn run( } if foreground { - eprintln!("spacebot running in foreground (pid {})", std::process::id()); + eprintln!( + "spacebot running in foreground (pid {})", + std::process::id() + ); } else { tracing::info!(pid = std::process::id(), "spacebot daemon started"); } @@ -722,7 +728,11 @@ async fn initialize_agents( cron_schedulers_for_shutdown: &mut Vec>, ingestion_handles: &mut Vec>, cortex_handles: &mut Vec>, - watcher_agents: &mut Vec<(String, std::path::PathBuf, Arc)>, + watcher_agents: &mut Vec<( + String, + std::path::PathBuf, + Arc, + )>, discord_permissions: &mut Option>>, slack_permissions: &mut Option>>, telegram_permissions: &mut Option>>, @@ -733,34 +743,65 @@ async fn initialize_agents( tracing::info!(agent_id = %agent_config.id, "initializing agent"); // Ensure agent directories exist - std::fs::create_dir_all(&agent_config.workspace) - .with_context(|| format!("failed to create workspace: {}", agent_config.workspace.display()))?; - std::fs::create_dir_all(&agent_config.data_dir) - .with_context(|| format!("failed to create data dir: {}", agent_config.data_dir.display()))?; - std::fs::create_dir_all(&agent_config.archives_dir) - .with_context(|| format!("failed to create archives dir: {}", agent_config.archives_dir.display()))?; - std::fs::create_dir_all(&agent_config.ingest_dir()) - .with_context(|| format!("failed to create ingest dir: {}", agent_config.ingest_dir().display()))?; - std::fs::create_dir_all(&agent_config.logs_dir()) - .with_context(|| format!("failed to create logs dir: {}", agent_config.logs_dir().display()))?; + std::fs::create_dir_all(&agent_config.workspace).with_context(|| { + format!( + "failed to create workspace: {}", + agent_config.workspace.display() + ) + })?; + std::fs::create_dir_all(&agent_config.data_dir).with_context(|| { + format!( + "failed to create data dir: {}", + agent_config.data_dir.display() + ) + })?; + std::fs::create_dir_all(&agent_config.archives_dir).with_context(|| { + format!( + "failed to create archives dir: {}", + agent_config.archives_dir.display() + ) + })?; + std::fs::create_dir_all(&agent_config.ingest_dir()).with_context(|| { + format!( + "failed to create ingest dir: {}", + agent_config.ingest_dir().display() + ) + })?; + std::fs::create_dir_all(&agent_config.logs_dir()).with_context(|| { + format!( + "failed to create logs dir: {}", + agent_config.logs_dir().display() + ) + })?; // Per-agent database connections let db = spacebot::db::Db::connect(&agent_config.data_dir) .await - .with_context(|| format!("failed to connect databases for agent '{}'", agent_config.id))?; + .with_context(|| { + format!( + "failed to connect databases for agent '{}'", + agent_config.id + ) + })?; // Per-agent settings store (redb-backed) let settings_path = agent_config.data_dir.join("settings.redb"); let settings_store = Arc::new( - spacebot::settings::SettingsStore::new(&settings_path) - .with_context(|| format!("failed to initialize settings store for agent '{}'", agent_config.id))? + spacebot::settings::SettingsStore::new(&settings_path).with_context(|| { + format!( + "failed to initialize settings store for agent '{}'", + agent_config.id + ) + })?, ); // Per-agent memory system let memory_store = spacebot::memory::MemoryStore::new(db.sqlite.clone()); let embedding_table = spacebot::memory::EmbeddingTable::open_or_create(&db.lance) .await - .with_context(|| format!("failed to init embeddings for agent '{}'", agent_config.id))?; + .with_context(|| { + format!("failed to init embeddings for agent '{}'", agent_config.id) + })?; // Ensure FTS index exists for full-text search queries if let Err(error) = embedding_table.ensure_fts_index().await { @@ -781,14 +822,18 @@ async fn initialize_agents( // Scaffold identity templates if missing, then load spacebot::identity::scaffold_identity_files(&agent_config.workspace) .await - .with_context(|| format!("failed to scaffold identity files for agent '{}'", agent_config.id))?; + .with_context(|| { + format!( + "failed to scaffold identity files for agent '{}'", + agent_config.id + ) + })?; let identity = spacebot::identity::Identity::load(&agent_config.workspace).await; // Load skills (instance-level, then workspace overrides) - let skills = spacebot::skills::SkillSet::load( - &config.skills_dir(), - &agent_config.skills_dir(), - ).await; + let skills = + spacebot::skills::SkillSet::load(&config.skills_dir(), &agent_config.skills_dir()) + .await; // Build the RuntimeConfig with all hot-reloadable values let runtime_config = Arc::new(spacebot::config::RuntimeConfig::new( @@ -870,20 +915,21 @@ async fn initialize_agents( // Shared Discord permissions (hot-reloadable via file watcher) *discord_permissions = config.messaging.discord.as_ref().map(|discord_config| { - let perms = spacebot::config::DiscordPermissions::from_config(discord_config, &config.bindings); + let perms = + spacebot::config::DiscordPermissions::from_config(discord_config, &config.bindings); Arc::new(ArcSwap::from_pointee(perms)) }); if let Some(perms) = &*discord_permissions { api_state.set_discord_permissions(perms.clone()).await; } - - if let Some(discord_config) = &config.messaging.discord { if discord_config.enabled { let adapter = spacebot::messaging::discord::DiscordAdapter::new( &discord_config.token, - discord_permissions.clone().expect("discord permissions initialized when discord is enabled"), + discord_permissions + .clone() + .expect("discord permissions initialized when discord is enabled"), ); new_messaging_manager.register(adapter).await; } @@ -903,7 +949,9 @@ async fn initialize_agents( let adapter = spacebot::messaging::slack::SlackAdapter::new( &slack_config.bot_token, &slack_config.app_token, - slack_permissions.clone().expect("slack permissions initialized when slack is enabled"), + slack_permissions + .clone() + .expect("slack permissions initialized when slack is enabled"), ); new_messaging_manager.register(adapter).await; } @@ -911,7 +959,8 @@ async fn initialize_agents( // Shared Telegram permissions (hot-reloadable via file watcher) *telegram_permissions = config.messaging.telegram.as_ref().map(|telegram_config| { - let perms = spacebot::config::TelegramPermissions::from_config(telegram_config, &config.bindings); + let perms = + spacebot::config::TelegramPermissions::from_config(telegram_config, &config.bindings); Arc::new(ArcSwap::from_pointee(perms)) }); @@ -919,7 +968,9 @@ async fn initialize_agents( if telegram_config.enabled { let adapter = spacebot::messaging::telegram::TelegramAdapter::new( &telegram_config.token, - telegram_permissions.clone().expect("telegram permissions initialized when telegram is enabled"), + telegram_permissions + .clone() + .expect("telegram permissions initialized when telegram is enabled"), ); new_messaging_manager.register(adapter).await; } @@ -936,7 +987,9 @@ async fn initialize_agents( } *messaging_manager = Arc::new(new_messaging_manager); - api_state.set_messaging_manager(messaging_manager.clone()).await; + api_state + .set_messaging_manager(messaging_manager.clone()) + .await; // Start all messaging adapters and get the merged inbound stream let new_inbound = messaging_manager @@ -986,7 +1039,10 @@ async fn initialize_agents( let scheduler = Arc::new(spacebot::cron::Scheduler::new(cron_context)); // Make cron store and scheduler available via RuntimeConfig - agent.deps.runtime_config.set_cron(store.clone(), scheduler.clone()); + agent + .deps + .runtime_config + .set_cron(store.clone(), scheduler.clone()); match store.load_all().await { Ok(configs) => { @@ -1032,11 +1088,13 @@ async fn initialize_agents( // Start cortex bulletin loops and association loops for each agent for (agent_id, agent) in agents.iter() { let cortex_logger = spacebot::agent::cortex::CortexLogger::new(agent.db.sqlite.clone()); - let bulletin_handle = spacebot::agent::cortex::spawn_bulletin_loop(agent.deps.clone(), cortex_logger.clone()); + let bulletin_handle = + spacebot::agent::cortex::spawn_bulletin_loop(agent.deps.clone(), cortex_logger.clone()); cortex_handles.push(bulletin_handle); tracing::info!(agent_id = %agent_id, "cortex bulletin loop started"); - let association_handle = spacebot::agent::cortex::spawn_association_loop(agent.deps.clone(), cortex_logger); + let association_handle = + spacebot::agent::cortex::spawn_association_loop(agent.deps.clone(), cortex_logger); cortex_handles.push(association_handle); tracing::info!(agent_id = %agent_id, "cortex association loop started"); } @@ -1047,7 +1105,8 @@ async fn initialize_agents( for (agent_id, agent) in agents.iter() { let browser_config = (**agent.deps.runtime_config.browser_config.load()).clone(); let brave_search_key = (**agent.deps.runtime_config.brave_search_key.load()).clone(); - let conversation_logger = spacebot::conversation::history::ConversationLogger::new(agent.db.sqlite.clone()); + let conversation_logger = + spacebot::conversation::history::ConversationLogger::new(agent.db.sqlite.clone()); let channel_store = spacebot::conversation::ChannelStore::new(agent.db.sqlite.clone()); let tool_server = spacebot::tools::create_cortex_chat_tool_server( agent.deps.memory_search.clone(), diff --git a/src/opencode/server.rs b/src/opencode/server.rs index 5ad07c5ed..c047c418c 100644 --- a/src/opencode/server.rs +++ b/src/opencode/server.rs @@ -54,8 +54,8 @@ impl OpenCodeServer { let base_url = format!("http://127.0.0.1:{port}"); let env_config = OpenCodeEnvConfig::new(permissions); - let config_json = serde_json::to_string(&env_config) - .context("failed to serialize OpenCode config")?; + let config_json = + serde_json::to_string(&env_config).context("failed to serialize OpenCode config")?; tracing::info!( directory = %directory.display(), @@ -73,10 +73,13 @@ impl OpenCodeServer { .env("OPENCODE_PORT", port.to_string()) .kill_on_drop(true) .spawn() - .with_context(|| format!( - "failed to spawn OpenCode at '{}' for directory '{}'", - opencode_path, directory.display() - ))?; + .with_context(|| { + format!( + "failed to spawn OpenCode at '{}' for directory '{}'", + opencode_path, + directory.display() + ) + })?; let client = Client::builder() .timeout(std::time::Duration::from_secs(300)) @@ -185,7 +188,8 @@ impl OpenCodeServer { /// Check if the server is healthy. async fn health_check(&self) -> anyhow::Result { let url = format!("{}/global/health", self.base_url); - let response = self.client + let response = self + .client .get(&url) .timeout(std::time::Duration::from_secs(5)) .send() @@ -197,7 +201,8 @@ impl OpenCodeServer { // Fallback let url = format!("{}/api/health", self.base_url); - let response = self.client + let response = self + .client .get(&url) .timeout(std::time::Duration::from_secs(5)) .send() @@ -260,10 +265,12 @@ impl OpenCodeServer { .env("OPENCODE_PORT", port.to_string()) .kill_on_drop(true) .spawn() - .with_context(|| format!( - "failed to restart OpenCode server for '{}'", - self.directory.display() - ))?; + .with_context(|| { + format!( + "failed to restart OpenCode server for '{}'", + self.directory.display() + ) + })?; self.port = port; self.base_url = base_url; @@ -311,7 +318,8 @@ impl OpenCodeServer { let url = format!("{}/session", self.base_url); let body = CreateSessionRequest { title }; - let response = self.client + let response = self + .client .post(&url) .query(&[("directory", self.directory.to_str().unwrap_or("."))]) .json(&body) @@ -325,7 +333,9 @@ impl OpenCodeServer { bail!("create session failed ({status}): {text}"); } - response.json::().await + response + .json::() + .await .context("failed to parse session response") } @@ -337,7 +347,8 @@ impl OpenCodeServer { ) -> anyhow::Result { let url = format!("{}/session/{}/message", self.base_url, session_id); - let response = self.client + let response = self + .client .post(&url) .query(&[("directory", self.directory.to_str().unwrap_or("."))]) .json(request) @@ -351,7 +362,9 @@ impl OpenCodeServer { bail!("send prompt failed ({status}): {text}"); } - response.json::().await + response + .json::() + .await .context("failed to parse prompt response") } @@ -363,7 +376,8 @@ impl OpenCodeServer { ) -> anyhow::Result<()> { let url = format!("{}/session/{}/prompt_async", self.base_url, session_id); - let response = self.client + let response = self + .client .post(&url) .query(&[("directory", self.directory.to_str().unwrap_or("."))]) .json(request) @@ -384,7 +398,8 @@ impl OpenCodeServer { pub async fn abort_session(&self, session_id: &str) -> anyhow::Result<()> { let url = format!("{}/session/{}/abort", self.base_url, session_id); - let response = self.client + let response = self + .client .post(&url) .query(&[("directory", self.directory.to_str().unwrap_or("."))]) .send() @@ -412,7 +427,8 @@ impl OpenCodeServer { message: None, }; - let response = self.client + let response = self + .client .post(&url) .query(&[("directory", self.directory.to_str().unwrap_or("."))]) .json(&body) @@ -438,7 +454,8 @@ impl OpenCodeServer { let url = format!("{}/question/{}/reply", self.base_url, request_id); let body = QuestionReplyRequest { answers }; - let response = self.client + let response = self + .client .post(&url) .query(&[("directory", self.directory.to_str().unwrap_or("."))]) .json(&body) @@ -460,7 +477,8 @@ impl OpenCodeServer { pub async fn subscribe_events(&self) -> anyhow::Result { let url = format!("{}/event", self.base_url); - let response = self.client + let response = self + .client .get(&url) .query(&[("directory", self.directory.to_str().unwrap_or("."))]) .header("Accept", "text/event-stream") @@ -479,13 +497,11 @@ impl OpenCodeServer { } /// Get messages for a session (for reading final results). - pub async fn get_messages( - &self, - session_id: &str, - ) -> anyhow::Result> { + pub async fn get_messages(&self, session_id: &str) -> anyhow::Result> { let url = format!("{}/session/{}/message", self.base_url, session_id); - let response = self.client + let response = self + .client .get(&url) .query(&[("directory", self.directory.to_str().unwrap_or("."))]) .send() @@ -498,7 +514,9 @@ impl OpenCodeServer { bail!("get messages failed ({status}): {text}"); } - response.json::>().await + response + .json::>() + .await .context("failed to parse messages response") } } @@ -552,7 +570,8 @@ impl OpenCodeServerPool { &self, directory: &Path, ) -> anyhow::Result>> { - let canonical = directory.canonicalize() + let canonical = directory + .canonicalize() .with_context(|| format!("directory '{}' does not exist", directory.display()))?; let mut servers = self.servers.lock().await; @@ -574,11 +593,10 @@ impl OpenCodeServerPool { // Not in pool yet. Try reattaching to an existing server on the // deterministic port (left over from a previous spacebot run). - if let Some(reattached) = OpenCodeServer::reattach( - canonical.clone(), - &self.opencode_path, - &self.permissions, - ).await { + if let Some(reattached) = + OpenCodeServer::reattach(canonical.clone(), &self.opencode_path, &self.permissions) + .await + { let server = Arc::new(Mutex::new(reattached)); servers.insert(canonical, Arc::clone(&server)); return Ok(server); @@ -592,11 +610,9 @@ impl OpenCodeServerPool { ); } - let server = OpenCodeServer::spawn( - canonical.clone(), - &self.opencode_path, - &self.permissions, - ).await?; + let server = + OpenCodeServer::spawn(canonical.clone(), &self.opencode_path, &self.permissions) + .await?; let server = Arc::new(Mutex::new(server)); servers.insert(canonical, Arc::clone(&server)); diff --git a/src/opencode/worker.rs b/src/opencode/worker.rs index f421e9c72..1466c5292 100644 --- a/src/opencode/worker.rs +++ b/src/opencode/worker.rs @@ -12,7 +12,7 @@ use anyhow::{Context as _, bail}; use futures::StreamExt as _; use std::path::PathBuf; use std::sync::Arc; -use tokio::sync::{broadcast, mpsc, Mutex}; +use tokio::sync::{Mutex, broadcast, mpsc}; use uuid::Uuid; /// An OpenCode-backed worker that drives a coding session via subprocess. @@ -95,20 +95,25 @@ impl OpenCodeWorker { self.send_status("starting OpenCode server"); // Get or create server for this directory - let server = self.server_pool + let server = self + .server_pool .get_or_create(&self.directory) .await - .with_context(|| format!( - "failed to get OpenCode server for '{}'", - self.directory.display() - ))?; + .with_context(|| { + format!( + "failed to get OpenCode server for '{}'", + self.directory.display() + ) + })?; self.send_status("creating session"); // Create a session let session = { let guard = server.lock().await; - guard.create_session(Some(format!("spacebot-worker-{}", self.id))).await? + guard + .create_session(Some(format!("spacebot-worker-{}", self.id))) + .await? }; let session_id = session.id.clone(); @@ -141,15 +146,15 @@ impl OpenCodeWorker { self.send_status("sending task to OpenCode"); { let guard = server.lock().await; - guard.send_prompt_async(&session_id, &prompt_request).await?; + guard + .send_prompt_async(&session_id, &prompt_request) + .await?; } // Process SSE events until session goes idle or errors - let result_text = self.process_events( - event_response, - &session_id, - &server, - ).await?; + let result_text = self + .process_events(event_response, &session_id, &server) + .await?; // Interactive follow-up loop if let Some(mut input_rx) = self.input_rx.take() { @@ -176,10 +181,15 @@ impl OpenCodeWorker { { let guard = server.lock().await; - guard.send_prompt_async(&session_id, &follow_up_request).await?; + guard + .send_prompt_async(&session_id, &follow_up_request) + .await?; } - match self.process_events(event_response, &session_id, &server).await { + match self + .process_events(event_response, &session_id, &server) + .await + { Ok(_) => { self.send_status("waiting for follow-up"); } @@ -247,15 +257,18 @@ impl OpenCodeWorker { // Parse SSE lines from buffer while let Some(event) = extract_sse_event(&mut buffer) { - match self.handle_sse_event( - &event, - session_id, - server, - &mut last_text, - &mut current_tool, - &mut has_received_event, - &mut has_assistant_message, - ).await { + match self + .handle_sse_event( + &event, + session_id, + server, + &mut last_text, + &mut current_tool, + &mut has_received_event, + &mut has_assistant_message, + ) + .await + { EventAction::Continue => {} EventAction::Complete => return Ok(last_text.clone()), EventAction::Error(message) => bail!("OpenCode session error: {message}"), @@ -294,7 +307,11 @@ impl OpenCodeWorker { SseEvent::MessagePartUpdated { part, .. } => { *has_received_event = true; match part { - Part::Text { text, session_id: part_session, .. } => { + Part::Text { + text, + session_id: part_session, + .. + } => { if let Some(sid) = part_session { if sid != session_id { return EventAction::Continue; @@ -303,7 +320,12 @@ impl OpenCodeWorker { *has_assistant_message = true; *last_text = text.clone(); } - Part::Tool { tool, state, session_id: part_session, .. } => { + Part::Tool { + tool, + state, + session_id: part_session, + .. + } => { if let Some(sid) = part_session { if sid != session_id { return EventAction::Continue; @@ -326,7 +348,9 @@ impl OpenCodeWorker { } ToolState::Error { error, .. } => { let description = error.as_deref().unwrap_or("unknown"); - self.send_status(&format!("tool error: {tool_name}: {description}")); + self.send_status(&format!( + "tool error: {tool_name}: {description}" + )); } ToolState::Pending { .. } => { // Tool queued, no status update needed @@ -340,7 +364,9 @@ impl OpenCodeWorker { EventAction::Continue } - SseEvent::SessionIdle { session_id: event_session_id } => { + SseEvent::SessionIdle { + session_id: event_session_id, + } => { if event_session_id != session_id { return EventAction::Continue; } @@ -360,7 +386,10 @@ impl OpenCodeWorker { EventAction::Complete } - SseEvent::SessionError { session_id: event_session_id, error } => { + SseEvent::SessionError { + session_id: event_session_id, + error, + } => { if event_session_id.as_deref() != Some(session_id) { return EventAction::Continue; } @@ -400,7 +429,10 @@ impl OpenCodeWorker { // Auto-allow (OPENCODE_CONFIG_CONTENT should prevent most prompts) let guard = server.lock().await; - if let Err(error) = guard.reply_permission(&permission.id, PermissionReply::Once).await { + if let Err(error) = guard + .reply_permission(&permission.id, PermissionReply::Once) + .await + { tracing::warn!( worker_id = %self.id, permission_id = %permission.id, @@ -429,29 +461,35 @@ impl OpenCodeWorker { worker_id: self.id, channel_id: self.channel_id.clone(), question_id: question.id.clone(), - questions: question.questions.iter().map(|q| { - QuestionInfo { + questions: question + .questions + .iter() + .map(|q| QuestionInfo { question: q.question.clone(), header: q.header.clone(), options: q.options.clone(), - } - }).collect(), + }) + .collect(), }); // Auto-select first option - let answers: Vec = question.questions.iter().map(|q| { - if let Some(first_option) = q.options.first() { - QuestionAnswer { - label: first_option.label.clone(), - description: first_option.description.clone(), - } - } else { - QuestionAnswer { - label: "continue".to_string(), - description: None, + let answers: Vec = question + .questions + .iter() + .map(|q| { + if let Some(first_option) = q.options.first() { + QuestionAnswer { + label: first_option.label.clone(), + description: first_option.description.clone(), + } + } else { + QuestionAnswer { + label: "continue".to_string(), + description: None, + } } - } - }).collect(); + }) + .collect(); let guard = server.lock().await; if let Err(error) = guard.reply_question(&question.id, answers).await { @@ -466,12 +504,17 @@ impl OpenCodeWorker { EventAction::Continue } - SseEvent::SessionStatus { session_id: event_session_id, status } => { + SseEvent::SessionStatus { + session_id: event_session_id, + status, + } => { if event_session_id != session_id { return EventAction::Continue; } match status { - SessionStatusPayload::Retry { attempt, message, .. } => { + SessionStatusPayload::Retry { + attempt, message, .. + } => { let description = message.as_deref().unwrap_or("rate limited"); self.send_status(&format!("retry attempt {attempt}: {description}")); } diff --git a/src/prompts/engine.rs b/src/prompts/engine.rs index 480291b86..07e94439e 100644 --- a/src/prompts/engine.rs +++ b/src/prompts/engine.rs @@ -1,6 +1,6 @@ use crate::error::Result; use anyhow::Context; -use minijinja::{context, Environment, Value}; +use minijinja::{Environment, Value, context}; use std::collections::HashMap; use std::sync::Arc; diff --git a/src/settings.rs b/src/settings.rs index 4a2b4fb61..8c6d73490 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -2,4 +2,4 @@ pub mod store; -pub use store::{SettingsStore, WorkerLogMode, WORKER_LOG_MODE_KEY}; +pub use store::{SettingsStore, WORKER_LOG_MODE_KEY, WorkerLogMode}; diff --git a/src/skills.rs b/src/skills.rs index 7b6946e48..f8411400a 100644 --- a/src/skills.rs +++ b/src/skills.rs @@ -58,7 +58,9 @@ impl SkillSet { // Instance skills (lowest precedence) if instance_skills_dir.is_dir() { - if let Ok(skills) = load_skills_from_dir(instance_skills_dir, SkillSource::Instance).await { + if let Ok(skills) = + load_skills_from_dir(instance_skills_dir, SkillSource::Instance).await + { for skill in skills { set.skills.insert(skill.name.to_lowercase(), skill); } @@ -67,7 +69,9 @@ impl SkillSet { // Workspace skills (highest precedence, overrides instance) if workspace_skills_dir.is_dir() { - if let Ok(skills) = load_skills_from_dir(workspace_skills_dir, SkillSource::Workspace).await { + if let Ok(skills) = + load_skills_from_dir(workspace_skills_dir, SkillSource::Workspace).await + { for skill in skills { set.skills.insert(skill.name.to_lowercase(), skill); } @@ -194,26 +198,27 @@ async fn load_skills_from_dir(dir: &Path, source: SkillSource) -> anyhow::Result } /// Load a single skill from its SKILL.md file. -async fn load_skill(file_path: &Path, base_dir: &Path, source: SkillSource) -> anyhow::Result { +async fn load_skill( + file_path: &Path, + base_dir: &Path, + source: SkillSource, +) -> anyhow::Result { let raw = tokio::fs::read_to_string(file_path) .await .with_context(|| format!("failed to read {}", file_path.display()))?; let (frontmatter, body) = parse_frontmatter(&raw)?; - let name = frontmatter.get("name") - .cloned() - .unwrap_or_else(|| { - // Fall back to directory name if no name in frontmatter - base_dir.file_name() - .and_then(|n| n.to_str()) - .unwrap_or("unknown") - .to_string() - }); + let name = frontmatter.get("name").cloned().unwrap_or_else(|| { + // Fall back to directory name if no name in frontmatter + base_dir + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or("unknown") + .to_string() + }); - let description = frontmatter.get("description") - .cloned() - .unwrap_or_default(); + let description = frontmatter.get("description").cloned().unwrap_or_default(); // Resolve {baseDir} template variable in the body let base_dir_str = base_dir.to_string_lossy(); @@ -277,8 +282,10 @@ fn parse_frontmatter(content: &str) -> anyhow::Result<(HashMap, // Strip surrounding quotes let value = value - .trim_start_matches('"').trim_end_matches('"') - .trim_start_matches('\'').trim_end_matches('\'') + .trim_start_matches('"') + .trim_end_matches('"') + .trim_start_matches('\'') + .trim_end_matches('\'') .to_string(); map.insert(key, value); @@ -308,7 +315,10 @@ mod tests { let (fm, body) = parse_frontmatter(content).unwrap(); assert_eq!(fm.get("name").unwrap(), "weather"); - assert_eq!(fm.get("description").unwrap(), "Get current weather and forecasts (no API key required)."); + assert_eq!( + fm.get("description").unwrap(), + "Get current weather and forecasts (no API key required)." + ); assert_eq!(fm.get("homepage").unwrap(), "https://wttr.in/:help"); assert!(body.starts_with("# Weather")); } @@ -327,7 +337,10 @@ mod tests { let (fm, body) = parse_frontmatter(content).unwrap(); assert_eq!(fm.get("name").unwrap(), "github"); - assert_eq!(fm.get("description").unwrap(), "Interact with GitHub using the gh CLI."); + assert_eq!( + fm.get("description").unwrap(), + "Interact with GitHub using the gh CLI." + ); // metadata line is skipped (starts with {) assert!(!fm.contains_key("metadata")); assert!(body.starts_with("# GitHub Skill")); @@ -353,7 +366,10 @@ mod tests { "#}; let (fm, _body) = parse_frontmatter(content).unwrap(); - assert_eq!(fm.get("description").unwrap(), "A skill with 'quotes' inside"); + assert_eq!( + fm.get("description").unwrap(), + "A skill with 'quotes' inside" + ); } #[test] @@ -366,14 +382,17 @@ mod tests { #[test] fn test_skill_set_channel_prompt() { let mut set = SkillSet::default(); - set.skills.insert("weather".into(), Skill { - name: "weather".into(), - description: "Get weather forecasts".into(), - file_path: PathBuf::from("/skills/weather/SKILL.md"), - base_dir: PathBuf::from("/skills/weather"), - content: "# Weather\n\nUse curl.".into(), - source: SkillSource::Instance, - }); + set.skills.insert( + "weather".into(), + Skill { + name: "weather".into(), + description: "Get weather forecasts".into(), + file_path: PathBuf::from("/skills/weather/SKILL.md"), + base_dir: PathBuf::from("/skills/weather"), + content: "# Weather\n\nUse curl.".into(), + source: SkillSource::Instance, + }, + ); let engine = crate::prompts::PromptEngine::new("en").unwrap(); let prompt = set.render_channel_prompt(&engine); @@ -385,14 +404,17 @@ mod tests { #[test] fn test_skill_set_worker_prompt() { let mut set = SkillSet::default(); - set.skills.insert("weather".into(), Skill { - name: "weather".into(), - description: "Get weather forecasts".into(), - file_path: PathBuf::from("/skills/weather/SKILL.md"), - base_dir: PathBuf::from("/skills/weather"), - content: "# Weather\n\nUse curl.".into(), - source: SkillSource::Instance, - }); + set.skills.insert( + "weather".into(), + Skill { + name: "weather".into(), + description: "Get weather forecasts".into(), + file_path: PathBuf::from("/skills/weather/SKILL.md"), + base_dir: PathBuf::from("/skills/weather"), + content: "# Weather\n\nUse curl.".into(), + source: SkillSource::Instance, + }, + ); let engine = crate::prompts::PromptEngine::new("en").unwrap(); let prompt = set.render_worker_prompt("weather", &engine).unwrap(); diff --git a/src/update.rs b/src/update.rs index 7b810f84e..b782fe8dd 100644 --- a/src/update.rs +++ b/src/update.rs @@ -103,7 +103,10 @@ pub async fn check_for_update(status: &SharedUpdateStatus) { match result { Ok(release) => { - let tag = release.tag_name.strip_prefix('v').unwrap_or(&release.tag_name); + let tag = release + .tag_name + .strip_prefix('v') + .unwrap_or(&release.tag_name); let is_newer = is_newer_version(tag, CURRENT_VERSION); next.latest_version = Some(tag.to_string()); @@ -417,11 +420,7 @@ fn resolve_target_image(current_image: &str, new_version: &str) -> String { }; // Determine the variant suffix (slim, full, or default to slim) - let variant = if tag.contains("full") { - "full" - } else { - "slim" - }; + let variant = if tag.contains("full") { "full" } else { "slim" }; format!("{}:v{}-{}", base, new_version, variant) } diff --git a/tests/bulletin.rs b/tests/bulletin.rs index 815e660e6..aa4ebc3a4 100644 --- a/tests/bulletin.rs +++ b/tests/bulletin.rs @@ -12,8 +12,8 @@ use std::sync::Arc; /// Bootstrap an AgentDeps from the real ~/.spacebot config, using the first /// (default) agent's databases and config. async fn bootstrap_deps() -> anyhow::Result { - let config = spacebot::config::Config::load() - .context("failed to load ~/.spacebot/config.toml")?; + let config = + spacebot::config::Config::load().context("failed to load ~/.spacebot/config.toml")?; let llm_manager = Arc::new( spacebot::llm::LlmManager::new(config.llm.clone()) @@ -28,9 +28,7 @@ async fn bootstrap_deps() -> anyhow::Result { ); let resolved_agents = config.resolve_agents(); - let agent_config = resolved_agents - .first() - .context("no agents configured")?; + let agent_config = resolved_agents.first().context("no agents configured")?; let db = spacebot::db::Db::connect(&agent_config.data_dir) .await @@ -38,10 +36,9 @@ async fn bootstrap_deps() -> anyhow::Result { let memory_store = spacebot::memory::MemoryStore::new(db.sqlite.clone()); - let embedding_table = - spacebot::memory::EmbeddingTable::open_or_create(&db.lance) - .await - .context("failed to init embedding table")?; + let embedding_table = spacebot::memory::EmbeddingTable::open_or_create(&db.lance) + .await + .context("failed to init embedding table")?; if let Err(error) = embedding_table.ensure_fts_index().await { eprintln!("warning: FTS index creation failed: {error}"); @@ -54,16 +51,18 @@ async fn bootstrap_deps() -> anyhow::Result { )); let identity = spacebot::identity::Identity::load(&agent_config.workspace).await; - let prompts = spacebot::prompts::PromptEngine::new("en") - .context("failed to init prompt engine")?; - let skills = spacebot::skills::SkillSet::load( - &config.skills_dir(), - &agent_config.skills_dir(), - ) - .await; + let prompts = + spacebot::prompts::PromptEngine::new("en").context("failed to init prompt engine")?; + let skills = + spacebot::skills::SkillSet::load(&config.skills_dir(), &agent_config.skills_dir()).await; let runtime_config = Arc::new(spacebot::config::RuntimeConfig::new( - &config.instance_dir, agent_config, &config.defaults, prompts, identity, skills, + &config.instance_dir, + agent_config, + &config.defaults, + prompts, + identity, + skills, )); let (event_tx, _) = tokio::sync::broadcast::channel(16); @@ -87,7 +86,16 @@ async fn bootstrap_deps() -> anyhow::Result { fn test_bulletin_prompts_cover_all_memory_types() { // The cortex user prompt in cortex.rs lists types inline. Check the same // set against the canonical list so drift is caught at compile time. - let cortex_user_prompt_types = ["identity", "fact", "decision", "event", "preference", "observation", "goal", "todo"]; + let cortex_user_prompt_types = [ + "identity", + "fact", + "decision", + "event", + "preference", + "observation", + "goal", + "todo", + ]; for memory_type in spacebot::memory::types::MemoryType::ALL { let type_str = memory_type.to_string(); @@ -128,7 +136,10 @@ async fn test_memory_recall_returns_results() { ); } - assert!(!results.is_empty(), "hybrid_search should return results from a populated database"); + assert!( + !results.is_empty(), + "hybrid_search should return results from a populated database" + ); } #[tokio::test] @@ -145,7 +156,10 @@ async fn test_bulletin_generation() { // Verify the bulletin was stored let bulletin = deps.runtime_config.memory_bulletin.load(); - assert!(!bulletin.is_empty(), "bulletin should not be empty after generation"); + assert!( + !bulletin.is_empty(), + "bulletin should not be empty after generation" + ); let word_count = bulletin.split_whitespace().count(); println!("bulletin generated: {word_count} words"); @@ -153,5 +167,8 @@ async fn test_bulletin_generation() { println!("{bulletin}"); println!("---"); - assert!(word_count > 50, "bulletin should have meaningful content (got {word_count} words)"); + assert!( + word_count > 50, + "bulletin should have meaningful content (got {word_count} words)" + ); } diff --git a/tests/context_dump.rs b/tests/context_dump.rs index 69a88ee9f..3e00c92ed 100644 --- a/tests/context_dump.rs +++ b/tests/context_dump.rs @@ -12,8 +12,8 @@ use std::sync::Arc; /// Bootstrap AgentDeps from the real ~/.spacebot config (same as bulletin test). async fn bootstrap_deps() -> anyhow::Result<(spacebot::AgentDeps, spacebot::config::Config)> { - let config = spacebot::config::Config::load() - .context("failed to load ~/.spacebot/config.toml")?; + let config = + spacebot::config::Config::load().context("failed to load ~/.spacebot/config.toml")?; let llm_manager = Arc::new( spacebot::llm::LlmManager::new(config.llm.clone()) @@ -28,9 +28,7 @@ async fn bootstrap_deps() -> anyhow::Result<(spacebot::AgentDeps, spacebot::conf ); let resolved_agents = config.resolve_agents(); - let agent_config = resolved_agents - .first() - .context("no agents configured")?; + let agent_config = resolved_agents.first().context("no agents configured")?; let db = spacebot::db::Db::connect(&agent_config.data_dir) .await @@ -38,10 +36,9 @@ async fn bootstrap_deps() -> anyhow::Result<(spacebot::AgentDeps, spacebot::conf let memory_store = spacebot::memory::MemoryStore::new(db.sqlite.clone()); - let embedding_table = - spacebot::memory::EmbeddingTable::open_or_create(&db.lance) - .await - .context("failed to init embedding table")?; + let embedding_table = spacebot::memory::EmbeddingTable::open_or_create(&db.lance) + .await + .context("failed to init embedding table")?; if let Err(error) = embedding_table.ensure_fts_index().await { eprintln!("warning: FTS index creation failed: {error}"); @@ -54,16 +51,18 @@ async fn bootstrap_deps() -> anyhow::Result<(spacebot::AgentDeps, spacebot::conf )); let identity = spacebot::identity::Identity::load(&agent_config.workspace).await; - let prompts = spacebot::prompts::PromptEngine::new("en") - .context("failed to init prompt engine")?; - let skills = spacebot::skills::SkillSet::load( - &config.skills_dir(), - &agent_config.skills_dir(), - ) - .await; + let prompts = + spacebot::prompts::PromptEngine::new("en").context("failed to init prompt engine")?; + let skills = + spacebot::skills::SkillSet::load(&config.skills_dir(), &agent_config.skills_dir()).await; let runtime_config = Arc::new(spacebot::config::RuntimeConfig::new( - &config.instance_dir, agent_config, &config.defaults, prompts, identity, skills, + &config.instance_dir, + agent_config, + &config.defaults, + prompts, + identity, + skills, )); let (event_tx, _) = tokio::sync::broadcast::channel(16); @@ -162,10 +161,13 @@ async fn dump_channel_context() { print_stats("System prompt", &prompt); // Build the actual channel tool server with real tools registered - let conversation_logger = spacebot::conversation::ConversationLogger::new(deps.sqlite_pool.clone()); + let conversation_logger = + spacebot::conversation::ConversationLogger::new(deps.sqlite_pool.clone()); let channel_store = spacebot::conversation::ChannelStore::new(deps.sqlite_pool.clone()); let channel_id: spacebot::ChannelId = Arc::from("test-channel"); - let status_block = Arc::new(tokio::sync::RwLock::new(spacebot::agent::status::StatusBlock::new())); + let status_block = Arc::new(tokio::sync::RwLock::new( + spacebot::agent::status::StatusBlock::new(), + )); let (response_tx, _response_rx) = tokio::sync::mpsc::channel(16); let state = spacebot::agent::channel::ChannelState { @@ -212,7 +214,10 @@ async fn dump_channel_context() { println!("\n--- TOTAL CHANNEL CONTEXT: ~{} tokens ---", total / 4); let routing = rc.routing.load(); - println!("Model: {}", routing.resolve(spacebot::ProcessType::Channel, None)); + println!( + "Model: {}", + routing.resolve(spacebot::ProcessType::Channel, None) + ); println!("Max turns: {}", **rc.max_turns.load()); assert!(!prompt.is_empty()); @@ -236,7 +241,8 @@ async fn dump_branch_context() { print_stats("System prompt", &branch_prompt); // Build the actual branch tool server - let conversation_logger = spacebot::conversation::ConversationLogger::new(deps.sqlite_pool.clone()); + let conversation_logger = + spacebot::conversation::ConversationLogger::new(deps.sqlite_pool.clone()); let channel_store = spacebot::conversation::ChannelStore::new(deps.sqlite_pool.clone()); let branch_tool_server = spacebot::tools::create_branch_tool_server( deps.memory_search.clone(), @@ -260,7 +266,10 @@ async fn dump_branch_context() { println!("\n--- TOTAL BRANCH CONTEXT: ~{} tokens ---", total / 4); let routing = rc.routing.load(); - println!("Model: {}", routing.resolve(spacebot::ProcessType::Branch, None)); + println!( + "Model: {}", + routing.resolve(spacebot::ProcessType::Branch, None) + ); println!("Max turns: {}", **rc.branch_max_turns.load()); println!("History: cloned from channel at fork time (full conversation context)"); @@ -319,7 +328,10 @@ async fn dump_worker_context() { println!("\n--- TOTAL WORKER CONTEXT: ~{} tokens ---", total / 4); let routing = rc.routing.load(); - println!("Model: {}", routing.resolve(spacebot::ProcessType::Worker, None)); + println!( + "Model: {}", + routing.resolve(spacebot::ProcessType::Worker, None) + ); println!("Turns per segment: 25"); println!("History: fresh (empty). Workers have no channel context."); @@ -341,12 +353,16 @@ async fn dump_all_contexts() { let bulletin_success = spacebot::agent::cortex::generate_bulletin(&deps).await; if bulletin_success { let bulletin = rc.memory_bulletin.load(); - println!("Bulletin generated: {} words", bulletin.split_whitespace().count()); + println!( + "Bulletin generated: {} words", + bulletin.split_whitespace().count() + ); } else { println!("Bulletin generation failed (may not have memories or LLM keys)"); } - let conversation_logger = spacebot::conversation::ConversationLogger::new(deps.sqlite_pool.clone()); + let conversation_logger = + spacebot::conversation::ConversationLogger::new(deps.sqlite_pool.clone()); let channel_store = spacebot::conversation::ChannelStore::new(deps.sqlite_pool.clone()); // ── Channel ── @@ -360,7 +376,9 @@ async fn dump_all_contexts() { active_branches: Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())), active_workers: Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())), worker_inputs: Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())), - status_block: Arc::new(tokio::sync::RwLock::new(spacebot::agent::status::StatusBlock::new())), + status_block: Arc::new(tokio::sync::RwLock::new( + spacebot::agent::status::StatusBlock::new(), + )), deps: deps.clone(), conversation_logger: conversation_logger.clone(), channel_store: channel_store.clone(), @@ -370,14 +388,24 @@ async fn dump_all_contexts() { let channel_tool_server = rig::tool::server::ToolServer::new().run(); let skip_flag = spacebot::tools::new_skip_flag(); spacebot::tools::add_channel_tools( - &channel_tool_server, state, response_tx, "test", skip_flag, None, - ).await.expect("failed to add channel tools"); + &channel_tool_server, + state, + response_tx, + "test", + skip_flag, + None, + ) + .await + .expect("failed to add channel tools"); let channel_tool_defs = channel_tool_server.get_tool_defs(None).await.unwrap(); let channel_tools_text = format_tool_defs(&channel_tool_defs); print_section("CHANNEL SYSTEM PROMPT (with bulletin)", &channel_prompt); print_stats("System prompt", &channel_prompt); - print_section(&format!("CHANNEL TOOLS ({} tools)", channel_tool_defs.len()), &channel_tools_text); + print_section( + &format!("CHANNEL TOOLS ({} tools)", channel_tool_defs.len()), + &channel_tools_text, + ); print_stats("Tool definitions", &channel_tools_text); let channel_total = channel_prompt.len() + channel_tools_text.len(); println!("--- TOTAL CHANNEL: ~{} tokens ---", channel_total / 4); @@ -396,7 +424,10 @@ async fn dump_all_contexts() { print_section("BRANCH SYSTEM PROMPT", &branch_prompt); print_stats("System prompt", &branch_prompt); - print_section(&format!("BRANCH TOOLS ({} tools)", branch_tool_defs.len()), &branch_tools_text); + print_section( + &format!("BRANCH TOOLS ({} tools)", branch_tool_defs.len()), + &branch_tools_text, + ); print_stats("Tool definitions", &branch_tools_text); let branch_total = branch_prompt.len() + branch_tools_text.len(); println!("--- TOTAL BRANCH: ~{} tokens ---", branch_total / 4); @@ -421,7 +452,10 @@ async fn dump_all_contexts() { print_section("WORKER SYSTEM PROMPT", &worker_prompt); print_stats("System prompt", &worker_prompt); - print_section(&format!("WORKER TOOLS ({} tools)", worker_tool_defs.len()), &worker_tools_text); + print_section( + &format!("WORKER TOOLS ({} tools)", worker_tool_defs.len()), + &worker_tools_text, + ); print_stats("Tool definitions", &worker_tools_text); let worker_total = worker_prompt.len() + worker_tools_text.len(); println!("--- TOTAL WORKER: ~{} tokens ---", worker_total / 4); @@ -433,23 +467,53 @@ async fn dump_all_contexts() { let routing = rc.routing.load(); println!("\nRouting:"); - println!(" Channel: {}", routing.resolve(spacebot::ProcessType::Channel, None)); - println!(" Branch: {}", routing.resolve(spacebot::ProcessType::Branch, None)); - println!(" Worker: {}", routing.resolve(spacebot::ProcessType::Worker, None)); + println!( + " Channel: {}", + routing.resolve(spacebot::ProcessType::Channel, None) + ); + println!( + " Branch: {}", + routing.resolve(spacebot::ProcessType::Branch, None) + ); + println!( + " Worker: {}", + routing.resolve(spacebot::ProcessType::Worker, None) + ); println!("\nContext budget (initial turn, before any history):"); - println!(" Channel: ~{} tokens (prompt: ~{}, tools: ~{})", - channel_total / 4, channel_prompt.len() / 4, channel_tools_text.len() / 4); - println!(" Branch: ~{} tokens (prompt: ~{}, tools: ~{}) + cloned channel history", - branch_total / 4, branch_prompt.len() / 4, branch_tools_text.len() / 4); - println!(" Worker: ~{} tokens (prompt: ~{}, tools: ~{})", - worker_total / 4, worker_prompt.len() / 4, worker_tools_text.len() / 4); + println!( + " Channel: ~{} tokens (prompt: ~{}, tools: ~{})", + channel_total / 4, + channel_prompt.len() / 4, + channel_tools_text.len() / 4 + ); + println!( + " Branch: ~{} tokens (prompt: ~{}, tools: ~{}) + cloned channel history", + branch_total / 4, + branch_prompt.len() / 4, + branch_tools_text.len() / 4 + ); + println!( + " Worker: ~{} tokens (prompt: ~{}, tools: ~{})", + worker_total / 4, + worker_prompt.len() / 4, + worker_tools_text.len() / 4 + ); let context_window = **rc.context_window.load(); println!("\nContext window: {} tokens", context_window); - println!(" Channel headroom: ~{} tokens for history", context_window - channel_total / 4); - println!(" Branch headroom: ~{} tokens for history", context_window - branch_total / 4); - println!(" Worker headroom: ~{} tokens for history", context_window - worker_total / 4); + println!( + " Channel headroom: ~{} tokens for history", + context_window - channel_total / 4 + ); + println!( + " Branch headroom: ~{} tokens for history", + context_window - branch_total / 4 + ); + println!( + " Worker headroom: ~{} tokens for history", + context_window - worker_total / 4 + ); println!("\nTurn limits:"); println!(" Channel: {} max turns", **rc.max_turns.load()); @@ -458,26 +522,74 @@ async fn dump_all_contexts() { let compaction = rc.compaction.load(); println!("\nCompaction thresholds:"); - println!(" Background: {:.0}%", compaction.background_threshold * 100.0); - println!(" Aggressive: {:.0}%", compaction.aggressive_threshold * 100.0); - println!(" Emergency: {:.0}%", compaction.emergency_threshold * 100.0); + println!( + " Background: {:.0}%", + compaction.background_threshold * 100.0 + ); + println!( + " Aggressive: {:.0}%", + compaction.aggressive_threshold * 100.0 + ); + println!( + " Emergency: {:.0}%", + compaction.emergency_threshold * 100.0 + ); println!("\nTool counts:"); - println!(" Channel: {} tools ({})", + println!( + " Channel: {} tools ({})", channel_tool_defs.len(), - channel_tool_defs.iter().map(|d| d.name.as_str()).collect::>().join(", ")); - println!(" Branch: {} tools ({})", + channel_tool_defs + .iter() + .map(|d| d.name.as_str()) + .collect::>() + .join(", ") + ); + println!( + " Branch: {} tools ({})", branch_tool_defs.len(), - branch_tool_defs.iter().map(|d| d.name.as_str()).collect::>().join(", ")); - println!(" Worker: {} tools ({})", + branch_tool_defs + .iter() + .map(|d| d.name.as_str()) + .collect::>() + .join(", ") + ); + println!( + " Worker: {} tools ({})", worker_tool_defs.len(), - worker_tool_defs.iter().map(|d| d.name.as_str()).collect::>().join(", ")); + worker_tool_defs + .iter() + .map(|d| d.name.as_str()) + .collect::>() + .join(", ") + ); let identity = rc.identity.load(); println!("\nIdentity files:"); - println!(" SOUL.md: {}", if identity.soul.is_some() { "loaded" } else { "empty" }); - println!(" IDENTITY.md: {}", if identity.identity.is_some() { "loaded" } else { "empty" }); - println!(" USER.md: {}", if identity.user.is_some() { "loaded" } else { "empty" }); + println!( + " SOUL.md: {}", + if identity.soul.is_some() { + "loaded" + } else { + "empty" + } + ); + println!( + " IDENTITY.md: {}", + if identity.identity.is_some() { + "loaded" + } else { + "empty" + } + ); + println!( + " USER.md: {}", + if identity.user.is_some() { + "loaded" + } else { + "empty" + } + ); let skills = rc.skills.load(); if skills.is_empty() { @@ -488,5 +600,8 @@ async fn dump_all_contexts() { } let bulletin = rc.memory_bulletin.load(); - println!("\nMemory bulletin: {} words", bulletin.split_whitespace().count()); + println!( + "\nMemory bulletin: {} words", + bulletin.split_whitespace().count() + ); } diff --git a/tests/opencode_stream.rs b/tests/opencode_stream.rs index 7b38fb335..c3f2b6f89 100644 --- a/tests/opencode_stream.rs +++ b/tests/opencode_stream.rs @@ -57,7 +57,10 @@ async fn stream_events_from_live_server() { .await .expect("failed to subscribe to events"); - assert!(event_response.status().is_success(), "event subscription failed"); + assert!( + event_response.status().is_success(), + "event subscription failed" + ); // 2. Create a session let session: Session = client @@ -124,10 +127,7 @@ async fn stream_events_from_live_server() { let bytes = match chunk { Ok(b) => b, Err(e) => { - panic!( - "bytes_stream error after {} events: {e}", - events.len() - ); + panic!("bytes_stream error after {} events: {e}", events.len()); } }; @@ -159,7 +159,10 @@ async fn stream_events_from_live_server() { let envelope: SseEventEnvelope = match serde_json::from_str(&json_str) { Ok(e) => e, Err(err) => { - eprintln!("parse error: {err} on: {}", &json_str[..json_str.len().min(200)]); + eprintln!( + "parse error: {err} on: {}", + &json_str[..json_str.len().min(200)] + ); continue; } }; @@ -178,7 +181,11 @@ async fn stream_events_from_live_server() { } } SseEvent::MessagePartUpdated { part, .. } => { - if let Part::Text { session_id: Some(sid), .. } = part { + if let Part::Text { + session_id: Some(sid), + .. + } = part + { if sid == &session_id { saw_text = true; }