Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions crates/rmcp-macros/src/task_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ pub fn task_handler(attr: TokenStream, input: TokenStream) -> syn::Result<TokenS
use rmcp::task_manager::current_timestamp;
let task_id = request.task_id.clone();
let mut processor = (#processor).lock().await;
processor.collect_completed_results();

// Check completed results first
let completed = processor.peek_completed().iter().rev().find(|r| r.descriptor.operation_id == task_id);
Expand Down Expand Up @@ -200,7 +199,6 @@ pub fn task_handler(attr: TokenStream, input: TokenStream) -> syn::Result<TokenS
// Scope the lock so we can await outside if needed
{
let mut processor = (#processor).lock().await;
processor.collect_completed_results();

if let Some(task_result) = processor.take_completed_result(&task_id) {
match task_result.result {
Expand Down Expand Up @@ -256,7 +254,6 @@ pub fn task_handler(attr: TokenStream, input: TokenStream) -> syn::Result<TokenS
) -> Result<(), McpError> {
let task_id = request.task_id;
let mut processor = (#processor).lock().await;
processor.collect_completed_results();

if processor.cancel_task(&task_id) {
return Ok(());
Expand Down
31 changes: 18 additions & 13 deletions crates/rmcp/src/task_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ pub struct OperationProcessor {
running_tasks: HashMap<String, RunningTask>,
/// Completed results waiting to be collected
completed_results: Vec<TaskResult>,
task_result_receiver: Option<mpsc::UnboundedReceiver<TaskResult>>,
task_result_receiver: mpsc::UnboundedReceiver<TaskResult>,
task_result_sender: mpsc::UnboundedSender<TaskResult>,
}

Expand Down Expand Up @@ -138,7 +138,7 @@ impl OperationProcessor {
Self {
running_tasks: HashMap::new(),
completed_results: Vec::new(),
task_result_receiver: Some(task_result_receiver),
task_result_receiver,
task_result_sender,
}
}
Expand Down Expand Up @@ -195,18 +195,16 @@ impl OperationProcessor {
}

/// Collect completed results from running tasks and remove them from the running tasks map.
pub fn collect_completed_results(&mut self) -> Vec<TaskResult> {
if let Some(receiver) = &mut self.task_result_receiver {
while let Ok(result) = receiver.try_recv() {
self.running_tasks.remove(&result.descriptor.operation_id);
self.completed_results.push(result);
}
fn collect_completed_results(&mut self) {
while let Ok(result) = self.task_result_receiver.try_recv() {
self.running_tasks.remove(&result.descriptor.operation_id);
self.completed_results.push(result);
}
std::mem::take(&mut self.completed_results)
}

/// Check for tasks that have exceeded their timeout and handle them appropriately.
pub fn check_timeouts(&mut self) {
self.collect_completed_results();
let now = std::time::Instant::now();
let mut timed_out_tasks = Vec::new();

Expand All @@ -231,7 +229,8 @@ impl OperationProcessor {
}

/// Get the number of running tasks.
pub fn running_task_count(&self) -> usize {
pub fn running_task_count(&mut self) -> usize {
self.collect_completed_results();
self.running_tasks.len()
}

Expand All @@ -240,15 +239,19 @@ impl OperationProcessor {
for (_, task) in self.running_tasks.drain() {
task.task_handle.abort();
}
while self.task_result_receiver.try_recv().is_ok() {}
self.completed_results.clear();
}

/// List running task ids.
pub fn list_running(&self) -> Vec<String> {
pub fn list_running(&mut self) -> Vec<String> {
self.collect_completed_results();
self.running_tasks.keys().cloned().collect()
}

/// Note: collectors should call collect_completed_results; this provides a snapshot of queued results.
pub fn peek_completed(&self) -> &[TaskResult] {
/// Returns a snapshot of completed task results.
pub fn peek_completed(&mut self) -> &[TaskResult] {
self.collect_completed_results();
&self.completed_results
}

Expand All @@ -266,6 +269,7 @@ impl OperationProcessor {

/// Attempt to cancel a running task.
pub fn cancel_task(&mut self, task_id: &str) -> bool {
self.collect_completed_results();
if let Some(task) = self.running_tasks.remove(task_id) {
task.task_handle.abort();
// Insert a cancelled result so callers can observe the terminal state.
Expand All @@ -281,6 +285,7 @@ impl OperationProcessor {

/// Retrieve a completed task result if available.
pub fn take_completed_result(&mut self, task_id: &str) -> Option<TaskResult> {
self.collect_completed_results();
if let Some(position) = self
.completed_results
.iter()
Expand Down
2 changes: 1 addition & 1 deletion crates/rmcp/tests/test_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async fn executes_enqueued_future() {
.expect("submit operation");

tokio::time::sleep(Duration::from_millis(30)).await;
let results = processor.collect_completed_results();
let results = processor.peek_completed();
assert_eq!(results.len(), 1);
let payload = results[0]
.result
Expand Down