diff --git a/COMMIT_MESSAGE_ISSUE_378.txt b/COMMIT_MESSAGE_ISSUE_378.txt new file mode 100644 index 00000000000..e3e91b35070 --- /dev/null +++ b/COMMIT_MESSAGE_ISSUE_378.txt @@ -0,0 +1 @@ +fix(rmcp-client): label phase errors, retry transient init (#378) diff --git a/PR_BODY_ISSUE_378.md b/PR_BODY_ISSUE_378.md new file mode 100644 index 00000000000..8e2578fdfc4 --- /dev/null +++ b/PR_BODY_ISSUE_378.md @@ -0,0 +1,8 @@ +## Summary +- add explicit phase context to initialize/list/call MCP failures so logs show where the request died +- retry the Streamable HTTP initialize handshake once on obvious transient network errors before surfacing failure +- cover the new helpers with unit tests for phase labeling and retry gating + +## Testing +- cargo test -p code-rmcp-client +- ./build-fast.sh diff --git a/code-rs/rmcp-client/src/rmcp_client.rs b/code-rs/rmcp-client/src/rmcp_client.rs index 505dedee751..45e70ec5c1c 100644 --- a/code-rs/rmcp-client/src/rmcp_client.rs +++ b/code-rs/rmcp-client/src/rmcp_client.rs @@ -1,5 +1,8 @@ use std::collections::HashMap; +use std::error::Error as StdError; use std::ffi::OsString; +use std::fmt; +use std::future::Future; use std::io; use std::process::Stdio; use std::sync::Arc; @@ -24,9 +27,14 @@ use rmcp::service::{self}; use rmcp::transport::StreamableHttpClientTransport; use rmcp::transport::child_process::TokioChildProcess; use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig; +use reqwest::Error as ReqwestError; use tokio::io::AsyncBufReadExt; use tokio::io::BufReader; use tokio::process::Command; + +const INITIALIZE_RETRY_BASE_DELAY_MS: u64 = 200; +const INITIALIZE_RETRY_MAX_DELAY_MS: u64 = 1_600; +const INITIALIZE_MAX_RETRIES: usize = 3; use tokio::sync::Mutex; use tokio::time; use tracing::info; @@ -41,7 +49,11 @@ use crate::utils::run_with_timeout; enum PendingTransport { ChildProcess(TokioChildProcess), - StreamableHttp(StreamableHttpClientTransport), + StreamableHttp { + transport: StreamableHttpClientTransport, + url: String, + bearer_token: Option, + }, } enum ClientState { @@ -53,6 +65,23 @@ enum ClientState { }, } +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +enum Phase { + Initialize, + ListTools, + CallTool, +} + +impl Phase { + fn as_str(self) -> &'static str { + match self { + Phase::Initialize => "initialize", + Phase::ListTools => "list_tools", + Phase::CallTool => "call_tool", + } + } +} + /// MCP client implemented on top of the official `rmcp` SDK. /// https://github.com/modelcontextprotocol/rust-sdk pub struct RmcpClient { @@ -105,16 +134,15 @@ impl RmcpClient { } pub fn new_streamable_http_client(url: String, bearer_token: Option) -> Result { - let mut config = StreamableHttpClientTransportConfig::with_uri(url); - if let Some(token) = bearer_token { - config = config.auth_header(format!("Bearer {token}")); - } - - let transport = StreamableHttpClientTransport::from_config(config); + let transport = build_streamable_http_transport(&url, bearer_token.as_deref()); Ok(Self { state: Mutex::new(ClientState::Connecting { - transport: Some(PendingTransport::StreamableHttp(transport)), + transport: Some(PendingTransport::StreamableHttp { + transport, + url, + bearer_token, + }), }), }) } @@ -126,52 +154,66 @@ impl RmcpClient { params: InitializeRequestParams, timeout: Option, ) -> Result { - let transport = { + let pending_transport = { let mut guard = self.state.lock().await; match &mut *guard { ClientState::Connecting { transport } => transport .take() .ok_or_else(|| anyhow!("client already initializing"))?, - ClientState::Ready { .. } => { - return Err(anyhow!("client already initialized")); - } + ClientState::Ready { .. } => return Err(anyhow!("client already initialized")), } }; - let client_info = convert_to_rmcp::<_, InitializeRequestParam>(params.clone())?; - let client_handler = LoggingClientHandler::new(client_info); - let service_future = match transport { + let service = match pending_transport { PendingTransport::ChildProcess(transport) => { - service::serve_client(client_handler.clone(), transport).boxed() + let client_info = convert_to_rmcp::<_, InitializeRequestParam>(params.clone())?; + let client_handler = LoggingClientHandler::new(client_info); + let service_future = service::serve_client(client_handler.clone(), transport).boxed(); + await_handshake(service_future, timeout) + .await + .map_err(|err| annotate_phase_error(Phase::Initialize, err))? } - PendingTransport::StreamableHttp(transport) => { - service::serve_client(client_handler, transport).boxed() + PendingTransport::StreamableHttp { + mut transport, + url, + bearer_token, + } => { + let mut attempt = 0; + loop { + let client_info = convert_to_rmcp::<_, InitializeRequestParam>(params.clone())?; + let client_handler = LoggingClientHandler::new(client_info); + let service_future = service::serve_client(client_handler.clone(), transport).boxed(); + match await_handshake(service_future, timeout).await { + Ok(service) => break service, + Err(err) => { + let err = annotate_phase_error(Phase::Initialize, err); + if let Some(delay) = retry_delay_for_initialize(&err, attempt) { + attempt += 1; + time::sleep(delay).await; + transport = build_streamable_http_transport(&url, bearer_token.as_deref()); + continue; + } + return Err(err); + } + } + } } }; - let service = match timeout { - Some(duration) => match time::timeout(duration, service_future).await { - Ok(Ok(service)) => service, - Ok(Err(err)) => return Err(handshake_failed_error(err)), - Err(_) => return Err(handshake_timeout_error(duration)), - }, - None => match service_future.await { - Ok(service) => service, - Err(err) => return Err(handshake_failed_error(err)), - }, - }; - let initialize_result_rmcp = service .peer() .peer_info() - .ok_or_else(|| anyhow!("handshake succeeded but server info was missing"))?; + .ok_or_else(|| annotate_phase_error(Phase::Initialize, anyhow!("handshake succeeded but server info was missing")))?; let initialize_result: InitializeResult = convert_to_mcp(initialize_result_rmcp)?; if initialize_result.protocol_version != MCP_SCHEMA_VERSION { let reported_version = initialize_result.protocol_version.clone(); - return Err(anyhow!( - "MCP server reported protocol version {reported_version}, but this client expects {}. Update either side so both speak the same schema.", - MCP_SCHEMA_VERSION + return Err(annotate_phase_error( + Phase::Initialize, + anyhow!( + "MCP server reported protocol version {reported_version}, but this client expects {}. Update either side so both speak the same schema.", + MCP_SCHEMA_VERSION + ), )); } @@ -196,7 +238,9 @@ impl RmcpClient { .transpose()?; let fut = service.list_tools(rmcp_params); - let result = run_with_timeout(fut, timeout, "tools/list").await?; + let result = run_with_timeout(fut, timeout, "tools/list") + .await + .map_err(|err| annotate_phase_error(Phase::ListTools, err))?; convert_to_mcp(result) } @@ -210,7 +254,9 @@ impl RmcpClient { let params = CallToolRequestParams { arguments, name }; let rmcp_params: CallToolRequestParam = convert_to_rmcp(params)?; let fut = service.call_tool(rmcp_params); - let rmcp_result = run_with_timeout(fut, timeout, "tools/call").await?; + let rmcp_result = run_with_timeout(fut, timeout, "tools/call") + .await + .map_err(|err| annotate_phase_error(Phase::CallTool, err))?; convert_call_tool_result(rmcp_result) } @@ -229,6 +275,88 @@ impl RmcpClient { } } +async fn await_handshake( + future: F, + timeout: Option, +) -> Result> +where + F: Future< + Output = Result< + RunningService, + E, + >, + >, + E: Into, +{ + if let Some(duration) = timeout { + match time::timeout(duration, future).await { + Ok(Ok(service)) => Ok(service), + Ok(Err(err)) => Err(handshake_failed_error(err)), + Err(_) => Err(handshake_timeout_error(duration)), + } + } else { + future.await.map_err(handshake_failed_error) + } +} + +fn annotate_phase_error(phase: Phase, err: anyhow::Error) -> anyhow::Error { + err.context(format!("phase={}", phase.as_str())) +} + +fn retry_delay_for_initialize(err: &anyhow::Error, attempt: usize) -> Option { + if attempt >= INITIALIZE_MAX_RETRIES { + return None; + } + + let retryable = err.chain().any(|source| { + if let Some(reqwest_err) = source.downcast_ref::() { + if reqwest_err.is_timeout() || reqwest_err.is_connect() { + return true; + } + } + + if let Some(io_err) = source.downcast_ref::() { + if matches!( + io_err.kind(), + io::ErrorKind::TimedOut + | io::ErrorKind::ConnectionRefused + | io::ErrorKind::ConnectionReset + | io::ErrorKind::BrokenPipe + | io::ErrorKind::NotConnected + | io::ErrorKind::WouldBlock, + ) { + return true; + } + } + + source.downcast_ref::().is_some() + }); + + if retryable { + Some(initialize_retry_delay(attempt)) + } else { + None + } +} + +fn initialize_retry_delay(attempt: usize) -> Duration { + let capped_attempt = attempt.min(4); + let multiplier = 1u64 << capped_attempt; + let delay = INITIALIZE_RETRY_BASE_DELAY_MS.saturating_mul(multiplier); + Duration::from_millis(delay.min(INITIALIZE_RETRY_MAX_DELAY_MS)) +} + +fn build_streamable_http_transport( + url: &str, + bearer_token: Option<&str>, +) -> StreamableHttpClientTransport { + let mut config = StreamableHttpClientTransportConfig::with_uri(url.to_string()); + if let Some(token) = bearer_token { + config = config.auth_header(format!("Bearer {token}")); + } + StreamableHttpClientTransport::from_config(config) +} + fn handshake_failed_error(err: impl Into) -> anyhow::Error { let err = err.into(); anyhow!( @@ -237,14 +365,29 @@ fn handshake_failed_error(err: impl Into) -> anyhow::Error { } fn handshake_timeout_error(duration: Duration) -> anyhow::Error { - anyhow!( - "timed out handshaking with MCP server after {duration:?} (expected MCP schema version {MCP_SCHEMA_VERSION})" - ) + anyhow!(HandshakeTimeoutError(duration)) } +#[derive(Debug)] +struct HandshakeTimeoutError(Duration); + +impl fmt::Display for HandshakeTimeoutError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "timed out awaiting MCP handshake after {:?}", + self.0 + ) + } +} + +impl StdError for HandshakeTimeoutError {} + #[cfg(test)] mod tests { use super::*; + use anyhow::anyhow; + use std::time::Duration; #[test] fn mcp_schema_version_is_well_formed() { @@ -257,4 +400,52 @@ mod tests { ); assert!(parts.iter().all(|segment| !segment.trim().is_empty())); } + + #[test] + fn annotate_phase_error_adds_phase_label() { + let err = annotate_phase_error(Phase::ListTools, anyhow!("boom")); + let message = err.to_string(); + assert_eq!(message, "phase=list_tools"); + let sources: Vec = err.chain().map(|source| source.to_string()).collect(); + assert!(sources.iter().any(|s| s.contains("boom")), "sources: {sources:?}"); + } + + #[test] + fn retry_delay_for_initialize_detects_transient_errors() { + let timeout_err = annotate_phase_error( + Phase::Initialize, + anyhow!(io::Error::new(io::ErrorKind::TimedOut, "timed out")), + ); + assert_eq!( + retry_delay_for_initialize(&timeout_err, 0), + Some(Duration::from_millis(INITIALIZE_RETRY_BASE_DELAY_MS)) + ); + assert_eq!(retry_delay_for_initialize(&timeout_err, INITIALIZE_MAX_RETRIES), None); + + let mismatch_err = annotate_phase_error(Phase::Initialize, anyhow!("protocol mismatch")); + assert_eq!(retry_delay_for_initialize(&mismatch_err, 0), None); + } + + #[test] + fn retry_delay_handles_handshake_timeout() { + let err = annotate_phase_error( + Phase::Initialize, + handshake_timeout_error(Duration::from_secs(1)), + ); + assert!(retry_delay_for_initialize(&err, 0).is_some()); + } + + #[test] + fn initialize_retry_delay_exponential_and_capped() { + let first = initialize_retry_delay(0); + let second = initialize_retry_delay(1); + let capped = initialize_retry_delay(10); + + assert_eq!(first, Duration::from_millis(INITIALIZE_RETRY_BASE_DELAY_MS)); + assert_eq!(second, Duration::from_millis(INITIALIZE_RETRY_BASE_DELAY_MS * 2)); + assert_eq!( + capped, + Duration::from_millis(INITIALIZE_RETRY_MAX_DELAY_MS) + ); + } }