diff --git a/crates/rmcp-macros/src/task_handler.rs b/crates/rmcp-macros/src/task_handler.rs index f94cf130..09d43f96 100644 --- a/crates/rmcp-macros/src/task_handler.rs +++ b/crates/rmcp-macros/src/task_handler.rs @@ -132,7 +132,6 @@ pub fn task_handler(attr: TokenStream, input: TokenStream) -> syn::Result syn::Result syn::Result Result<(), McpError> { let task_id = request.task_id; let mut processor = (#processor).lock().await; - processor.collect_completed_results(); if processor.cancel_task(&task_id) { return Ok(()); diff --git a/crates/rmcp/src/task_manager.rs b/crates/rmcp/src/task_manager.rs index d8768902..774c542f 100644 --- a/crates/rmcp/src/task_manager.rs +++ b/crates/rmcp/src/task_manager.rs @@ -80,7 +80,7 @@ pub struct OperationProcessor { running_tasks: HashMap, /// Completed results waiting to be collected completed_results: Vec, - task_result_receiver: Option>, + task_result_receiver: mpsc::UnboundedReceiver, task_result_sender: mpsc::UnboundedSender, } @@ -138,7 +138,7 @@ impl OperationProcessor { Self { running_tasks: HashMap::new(), completed_results: Vec::new(), - task_result_receiver: Some(task_result_receiver), + task_result_receiver, task_result_sender, } } @@ -195,18 +195,16 @@ impl OperationProcessor { } /// Collect completed results from running tasks and remove them from the running tasks map. - pub fn collect_completed_results(&mut self) -> Vec { - if let Some(receiver) = &mut self.task_result_receiver { - while let Ok(result) = receiver.try_recv() { - self.running_tasks.remove(&result.descriptor.operation_id); - self.completed_results.push(result); - } + fn collect_completed_results(&mut self) { + while let Ok(result) = self.task_result_receiver.try_recv() { + self.running_tasks.remove(&result.descriptor.operation_id); + self.completed_results.push(result); } - std::mem::take(&mut self.completed_results) } /// Check for tasks that have exceeded their timeout and handle them appropriately. pub fn check_timeouts(&mut self) { + self.collect_completed_results(); let now = std::time::Instant::now(); let mut timed_out_tasks = Vec::new(); @@ -231,7 +229,8 @@ impl OperationProcessor { } /// Get the number of running tasks. - pub fn running_task_count(&self) -> usize { + pub fn running_task_count(&mut self) -> usize { + self.collect_completed_results(); self.running_tasks.len() } @@ -240,15 +239,19 @@ impl OperationProcessor { for (_, task) in self.running_tasks.drain() { task.task_handle.abort(); } + while self.task_result_receiver.try_recv().is_ok() {} self.completed_results.clear(); } + /// List running task ids. - pub fn list_running(&self) -> Vec { + pub fn list_running(&mut self) -> Vec { + self.collect_completed_results(); self.running_tasks.keys().cloned().collect() } - /// Note: collectors should call collect_completed_results; this provides a snapshot of queued results. - pub fn peek_completed(&self) -> &[TaskResult] { + /// Returns a snapshot of completed task results. + pub fn peek_completed(&mut self) -> &[TaskResult] { + self.collect_completed_results(); &self.completed_results } @@ -266,6 +269,7 @@ impl OperationProcessor { /// Attempt to cancel a running task. pub fn cancel_task(&mut self, task_id: &str) -> bool { + self.collect_completed_results(); if let Some(task) = self.running_tasks.remove(task_id) { task.task_handle.abort(); // Insert a cancelled result so callers can observe the terminal state. @@ -281,6 +285,7 @@ impl OperationProcessor { /// Retrieve a completed task result if available. pub fn take_completed_result(&mut self, task_id: &str) -> Option { + self.collect_completed_results(); if let Some(position) = self .completed_results .iter() diff --git a/crates/rmcp/tests/test_task.rs b/crates/rmcp/tests/test_task.rs index 31fc9a9b..9ad0b200 100644 --- a/crates/rmcp/tests/test_task.rs +++ b/crates/rmcp/tests/test_task.rs @@ -36,7 +36,7 @@ async fn executes_enqueued_future() { .expect("submit operation"); tokio::time::sleep(Duration::from_millis(30)).await; - let results = processor.collect_completed_results(); + let results = processor.peek_completed(); assert_eq!(results.len(), 1); let payload = results[0] .result