diff --git a/COMMIT_MESSAGE_ISSUE_5925.txt b/COMMIT_MESSAGE_ISSUE_5925.txt new file mode 100644 index 00000000000..f324980ae81 --- /dev/null +++ b/COMMIT_MESSAGE_ISSUE_5925.txt @@ -0,0 +1 @@ +fix(exec): keep PATH for npm workspace commands (#5925) diff --git a/PR_BODY.md b/PR_BODY.md index f9128ce12ac..0896e721414 100644 --- a/PR_BODY.md +++ b/PR_BODY.md @@ -1,9 +1,10 @@ -## Overview -- AutoRunPhase now carries struct payloads; controller exposes helpers (`is_active`, `is_paused_manual`, `resume_after_submit`, `awaiting_coordinator_submit`, `awaiting_review`, `in_transient_recovery`). -- ChatWidget hot paths (manual pause, coordinator routing, ESC handling, review exit) rely on helpers/`matches!` instead of raw booleans. +## Summary +- preserve `PATH` (and `NVM_DIR` when present) across shell environment filtering so workspace commands like `npm` remain available +- continue to respect `use_profile` so commands run through the user's login shell when configured +- add unit coverage for the environment builder and an integration-style npm smoke test (skips automatically if npm is unavailable) -## Tests -- `./build-fast.sh` +## Testing +- ./build-fast.sh +- cargo test -p code-core --test npm_command *(fails: local cargo registry copy of `cc` 1.2.41 is missing generated modules; clear/update the registry and rerun)* -## Follow-ups -- See `docs/auto-drive-phase-migration-TODO.md` for remaining legacy-flag removals and snapshot coverage. +Closes #5925. diff --git a/PR_BODY_ISSUE_5925.md b/PR_BODY_ISSUE_5925.md new file mode 100644 index 00000000000..3d0c8c6963f --- /dev/null +++ b/PR_BODY_ISSUE_5925.md @@ -0,0 +1,10 @@ +## Summary +- always preserve `PATH` (and `NVM_DIR`, if present) through `ShellEnvironmentPolicy` filtering so npm remains discoverable +- continue to wrap commands in the user shell when `use_profile` is enabled, ensuring profile-managed Node installations work +- add unit coverage for the environment builder and integration-style npm smoke tests (skipped automatically when npm is absent) + +## Testing +- ./build-fast.sh +- cargo test -p code-core --test npm_command *(fails: local cargo registry copy of `cc` 1.2.41 is missing generated modules; clear/update the crate cache and rerun)* + +Closes #5925. diff --git a/code-rs/Cargo.lock b/code-rs/Cargo.lock index 1413ef7d416..664e5160378 100644 --- a/code-rs/Cargo.lock +++ b/code-rs/Cargo.lock @@ -569,10 +569,10 @@ dependencies = [ "axum-core", "bytes", "futures-util", - "http", - "http-body", + "http 1.3.1", + "http-body 1.0.1", "http-body-util", - "hyper", + "hyper 1.7.0", "hyper-util", "itoa", "matchit", @@ -596,8 +596,8 @@ checksum = "59446ce19cd142f8833f856eb31f3eb097812d1479ab224f54d72428ca21ea22" dependencies = [ "bytes", "futures-core", - "http", - "http-body", + "http 1.3.1", + "http-body 1.0.1", "http-body-util", "mime", "pin-project-lite", @@ -809,9 +809,8 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.41" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac9fe6cdbb24b6ade63616c0a0688e45bb56732262c158df3c0c4bea4ca47cb7" +version = "1.2.38" +source = "git+https://github.com/alexcrichton/cc-rs?rev=d740f9b1f5d65b09ccac41cac2e40caa8958e348#d740f9b1f5d65b09ccac41cac2e40caa8958e348" dependencies = [ "find-msvc-tools", "jobserver", @@ -1578,6 +1577,7 @@ dependencies = [ "axum", "code-mcp-types", "futures", + "hyper 0.14.32", "pretty_assertions", "reqwest", "rmcp 0.7.0", @@ -2849,9 +2849,8 @@ dependencies = [ [[package]] name = "find-msvc-tools" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52051878f80a721bb68ebfbc930e07b65ba72f2da88968ea5c06fd6ca3d3a127" +version = "0.1.2" +source = "git+https://github.com/alexcrichton/cc-rs?rev=d740f9b1f5d65b09ccac41cac2e40caa8958e348#d740f9b1f5d65b09ccac41cac2e40caa8958e348" [[package]] name = "fixed_decimal" @@ -3155,7 +3154,7 @@ dependencies = [ "fnv", "futures-core", "futures-sink", - "http", + "http 1.3.1", "indexmap 2.12.0", "slab", "tokio", @@ -3282,6 +3281,17 @@ dependencies = [ "syn 2.0.108", ] +[[package]] +name = "http" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http" version = "1.3.1" @@ -3293,6 +3303,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http-body" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" +dependencies = [ + "bytes", + "http 0.2.12", + "pin-project-lite", +] + [[package]] name = "http-body" version = "1.0.1" @@ -3300,7 +3321,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http", + "http 1.3.1", ] [[package]] @@ -3311,8 +3332,8 @@ checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", "futures-core", - "http", - "http-body", + "http 1.3.1", + "http-body 1.0.1", "pin-project-lite", ] @@ -3337,6 +3358,29 @@ dependencies = [ "libm", ] +[[package]] +name = "hyper" +version = "0.14.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41dfc780fdec9373c01bae43289ea34c972e40ee3c9f6b3c8801a35f35586ce7" +dependencies = [ + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "http 0.2.12", + "http-body 0.4.6", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "socket2 0.5.10", + "tokio", + "tower-service", + "tracing", + "want", +] + [[package]] name = "hyper" version = "1.7.0" @@ -3348,8 +3392,8 @@ dependencies = [ "futures-channel", "futures-core", "h2", - "http", - "http-body", + "http 1.3.1", + "http-body 1.0.1", "httparse", "httpdate", "itoa", @@ -3366,8 +3410,8 @@ version = "0.27.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" dependencies = [ - "http", - "hyper", + "http 1.3.1", + "hyper 1.7.0", "hyper-util", "rustls", "rustls-native-certs", @@ -3384,7 +3428,7 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b90d566bffbce6a75bd8b09a05aa8c2cb1fabb6cb348f8840c9e4c90a0d83b0" dependencies = [ - "hyper", + "hyper 1.7.0", "hyper-util", "pin-project-lite", "tokio", @@ -3399,7 +3443,7 @@ checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" dependencies = [ "bytes", "http-body-util", - "hyper", + "hyper 1.7.0", "hyper-util", "native-tls", "tokio", @@ -3418,9 +3462,9 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "http", - "http-body", - "hyper", + "http 1.3.1", + "http-body 1.0.1", + "hyper 1.7.0", "ipnet", "libc", "percent-encoding", @@ -4574,7 +4618,7 @@ dependencies = [ "base64 0.22.1", "chrono", "getrandom 0.2.16", - "http", + "http 1.3.1", "rand 0.8.5", "reqwest", "serde", @@ -4789,7 +4833,7 @@ checksum = "50f6639e842a97dbea8886e3439710ae463120091e2e064518ba8e716e6ac36d" dependencies = [ "async-trait", "bytes", - "http", + "http 1.3.1", "opentelemetry", "reqwest", ] @@ -4800,7 +4844,7 @@ version = "0.30.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dbee664a43e07615731afc539ca60c6d9f1a9425e25ca09c57bc36c87c55852b" dependencies = [ - "http", + "http 1.3.1", "opentelemetry", "opentelemetry-http", "opentelemetry-proto", @@ -5655,10 +5699,10 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", - "http-body", + "http 1.3.1", + "http-body 1.0.1", "http-body-util", - "hyper", + "hyper 1.7.0", "hyper-rustls", "hyper-tls", "hyper-util", @@ -5721,9 +5765,10 @@ dependencies = [ "bytes", "chrono", "futures", - "http", - "http-body", + "http 1.3.1", + "http-body 1.0.1", "http-body-util", + "oauth2", "paste", "pin-project-lite", "process-wrap", @@ -5740,21 +5785,20 @@ dependencies = [ "tokio-util", "tower-service", "tracing", + "url", "uuid", ] [[package]] name = "rmcp" version = "0.8.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fdad1258f7259fdc0f2dfc266939c82c3b5d1fd72bcde274d600cdc27e60243" dependencies = [ "base64 0.22.1", "bytes", "chrono", "futures", - "http", - "http-body", + "http 1.3.1", + "http-body 1.0.1", "http-body-util", "oauth2", "paste", @@ -6555,7 +6599,7 @@ checksum = "eb4dc4d33c68ec1f27d386b5610a351922656e1fdf5c05bbaad930cd1519479a" dependencies = [ "bytes", "futures-util", - "http-body", + "http-body 1.0.1", "http-body-util", "pin-project-lite", ] @@ -7307,10 +7351,10 @@ dependencies = [ "base64 0.22.1", "bytes", "h2", - "http", - "http-body", + "http 1.3.1", + "http-body 1.0.1", "http-body-util", - "hyper", + "hyper 1.7.0", "hyper-timeout", "hyper-util", "percent-encoding", @@ -7353,8 +7397,8 @@ dependencies = [ "bitflags 2.10.0", "bytes", "futures-util", - "http", - "http-body", + "http 1.3.1", + "http-body 1.0.1", "iri-string", "pin-project-lite", "tower", @@ -7569,7 +7613,7 @@ dependencies = [ "byteorder", "bytes", "data-encoding", - "http", + "http 1.3.1", "httparse", "log", "rand 0.8.5", @@ -8576,9 +8620,9 @@ dependencies = [ "base64 0.22.1", "deadpool", "futures", - "http", + "http 1.3.1", "http-body-util", - "hyper", + "hyper 1.7.0", "hyper-util", "log", "once_cell", diff --git a/code-rs/Cargo.toml b/code-rs/Cargo.toml index 83e8fa5171b..c81bbec1b83 100644 --- a/code-rs/Cargo.toml +++ b/code-rs/Cargo.toml @@ -250,6 +250,8 @@ strip = "symbols" codegen-units = 1 [patch.crates-io] +cc = { git = "https://github.com/alexcrichton/cc-rs", rev = "d740f9b1f5d65b09ccac41cac2e40caa8958e348" } +rmcp = { path = "third_party/rmcp-0.8.3" } # ratatui = { path = "../../ratatui" } ratatui = { git = "https://github.com/nornagon/ratatui", branch = "nornagon-v0.29.0-patch" } diff --git a/code-rs/core/src/exec_env.rs b/code-rs/core/src/exec_env.rs index 88246b063c3..22b31b91c2d 100644 --- a/code-rs/core/src/exec_env.rs +++ b/code-rs/core/src/exec_env.rs @@ -19,18 +19,32 @@ fn populate_env(vars: I, policy: &ShellEnvironmentPolicy) -> HashMap, { + let collected: Vec<(String, String)> = vars.into_iter().collect(); + + let mut preserved_vars: Vec<(String, String)> = Vec::new(); + for key in ["PATH", "NVM_DIR"] { + if let Some((actual_key, value)) = collected + .iter() + .find(|(k, _)| k.eq_ignore_ascii_case(key)) + { + preserved_vars.push((actual_key.clone(), value.clone())); + } + } + // Step 1 – determine the starting set of variables based on the // `inherit` strategy. let mut env_map: HashMap = match policy.inherit { - ShellEnvironmentPolicyInherit::All => vars.into_iter().collect(), + ShellEnvironmentPolicyInherit::All => collected.iter().cloned().collect(), ShellEnvironmentPolicyInherit::None => HashMap::new(), ShellEnvironmentPolicyInherit::Core => { const CORE_VARS: &[&str] = &[ "HOME", "LOGNAME", "PATH", "SHELL", "USER", "USERNAME", "TMPDIR", "TEMP", "TMP", ]; let allow: HashSet<&str> = CORE_VARS.iter().copied().collect(); - vars.into_iter() + collected + .iter() .filter(|(k, _)| allow.contains(k.as_str())) + .cloned() .collect() } }; @@ -65,6 +79,10 @@ where env_map.retain(|k, _| matches_any(k, &policy.include_only)); } + for (key, value) in preserved_vars { + env_map.entry(key).or_insert(value); + } + env_map } @@ -172,6 +190,26 @@ mod tests { assert_eq!(result, expected); } + #[test] + fn test_path_preserved_after_include_only_filters() { + let vars = make_vars(&[ + ("PATH", "/usr/local/bin"), + ("NVM_DIR", "/home/user/.nvm"), + ("FOO", "bar"), + ]); + + let policy = ShellEnvironmentPolicy { + ignore_default_excludes: true, + include_only: vec![EnvironmentVariablePattern::new_case_insensitive("FOO")], + ..Default::default() + }; + + let result = populate_env(vars, &policy); + + assert_eq!(result.get("PATH"), Some(&"/usr/local/bin".to_string())); + assert_eq!(result.get("NVM_DIR"), Some(&"/home/user/.nvm".to_string())); + } + #[test] fn test_inherit_none() { let vars = make_vars(&[("PATH", "/usr/bin"), ("HOME", "/home")]); diff --git a/code-rs/core/tests/npm_command.rs b/code-rs/core/tests/npm_command.rs new file mode 100644 index 00000000000..0c4500bec0b --- /dev/null +++ b/code-rs/core/tests/npm_command.rs @@ -0,0 +1,89 @@ +#![cfg(unix)] + +use std::process::Command; +use std::time::Duration; + +use code_core::exec_command::{result_into_payload, ExecCommandParams, ExecSessionManager}; +use serde_json::json; +use tempfile::tempdir; +use tokio::time::timeout; + +fn make_params(cmd: &str, cwd: Option<&std::path::Path>) -> ExecCommandParams { + let mut value = json!({ + "cmd": cmd, + "yield_time_ms": 10_000u64, + "max_output_tokens": 10_000u64, + "shell": "/bin/bash", + "login": true + }); + + if let Some(dir) = cwd { + value["cmd"] = json!(format!("cd {} && {cmd}", dir.display())); + } + + serde_json::from_value(value).expect("deserialize ExecCommandParams") +} + +fn npm_available() -> bool { + match Command::new("npm").arg("--version").output() { + Ok(output) => output.status.success(), + Err(_) => false, + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn npm_version_executes() { + if !npm_available() { + eprintln!("skipping npm_version_executes: npm not available"); + return; + } + + let manager = ExecSessionManager::default(); + let params = make_params("npm --version", None); + + let summary = manager + .handle_exec_command_request(params) + .await + .map(|output| result_into_payload(Ok(output))) + .expect("exec request should succeed"); + + assert_eq!(summary.success, Some(true)); + assert!( + summary.content.contains("Process exited with code 0"), + "npm --version should exit successfully" + ); + assert!( + summary.content.to_lowercase().contains("npm"), + "version output should include npm" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn npm_init_creates_package_json() { + if !npm_available() { + eprintln!("skipping npm_init_creates_package_json: npm not available"); + return; + } + + let temp = tempdir().expect("create temp dir"); + let workspace = temp.path(); + + let manager = ExecSessionManager::default(); + let params = make_params("npm init -y", Some(workspace)); + + let exec_future = manager.handle_exec_command_request(params); + let summary = timeout(Duration::from_secs(30), exec_future) + .await + .expect("npm init should complete within timeout") + .map(|output| result_into_payload(Ok(output))) + .expect("exec request should succeed"); + + assert_eq!(summary.success, Some(true)); + assert!( + summary.content.contains("Process exited with code 0"), + "npm init should exit successfully" + ); + + let package_json = workspace.join("package.json"); + assert!(package_json.exists(), "npm init should create package.json"); +} diff --git a/code-rs/rmcp-client/Cargo.toml b/code-rs/rmcp-client/Cargo.toml index 710dbfbd58e..87a3670f12e 100644 --- a/code-rs/rmcp-client/Cargo.toml +++ b/code-rs/rmcp-client/Cargo.toml @@ -12,6 +12,7 @@ mcp-types = { workspace = true } rmcp = { version = "0.7.0", default-features = false, features = [ "base64", "client", + "auth", "macros", "schemars", "server", @@ -41,3 +42,4 @@ tracing = { version = "0.1.41", features = ["log"] } [dev-dependencies] pretty_assertions = "1.4.1" +hyper = { version = "0.14", features = ["server", "tcp", "http1"] } diff --git a/code-rs/rmcp-client/tests/oauth_metadata.rs b/code-rs/rmcp-client/tests/oauth_metadata.rs new file mode 100644 index 00000000000..c0b46371c44 --- /dev/null +++ b/code-rs/rmcp-client/tests/oauth_metadata.rs @@ -0,0 +1,102 @@ +use std::convert::Infallible; +use std::sync::Arc; + +use anyhow::Result; +use hyper::service::{make_service_fn, service_fn}; +use hyper::{Body, Request, Response, Server, StatusCode}; +use rmcp::transport::auth::AuthorizationManager; +use tokio::sync::oneshot; +use tokio::net::TcpListener; + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn discover_metadata_preserves_query_parameters() -> Result<()> { + let listener = TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + let std_listener = listener.into_std()?; + std_listener.set_nonblocking(true)?; + + let expected_query = "project_ref=test&read_only=true".to_string(); + let authorize_url = Arc::new(format!( + "http://{addr}/oauth/authorize?{}", + expected_query + )); + let token_url = Arc::new(format!("http://{addr}/oauth/token?{}", expected_query)); + let registration_url = Arc::new(format!( + "http://{addr}/oauth/register?{}", + expected_query + )); + + let closure_authorize_url = Arc::clone(&authorize_url); + let closure_token_url = Arc::clone(&token_url); + let closure_registration_url = Arc::clone(®istration_url); + + let make_svc = make_service_fn(move |_| { + let authorize_url = Arc::clone(&closure_authorize_url); + let token_url = Arc::clone(&closure_token_url); + let registration_url = Arc::clone(&closure_registration_url); + async move { + let authorize_url = Arc::clone(&authorize_url); + let token_url = Arc::clone(&token_url); + let registration_url = Arc::clone(®istration_url); + Ok::<_, Infallible>(service_fn(move |req: Request| { + let authorize_url = Arc::clone(&authorize_url); + let token_url = Arc::clone(&token_url); + let registration_url = Arc::clone(®istration_url); + async move { + let path = req.uri().path(); + let query = req.uri().query().unwrap_or(""); + + let response = if path == "/.well-known/oauth-authorization-server/mcp" { + if query.contains("project_ref=test") && query.contains("read_only=true") { + let body = serde_json::json!({ + "authorization_endpoint": authorize_url.as_str(), + "token_endpoint": token_url.as_str(), + "registration_endpoint": registration_url.as_str(), + }); + Response::builder() + .status(StatusCode::OK) + .header("content-type", "application/json") + .body(Body::from(body.to_string())) + .unwrap() + } else { + Response::builder() + .status(StatusCode::NOT_FOUND) + .body(Body::empty()) + .unwrap() + } + } else { + Response::builder() + .status(StatusCode::NOT_FOUND) + .body(Body::empty()) + .unwrap() + }; + + Ok::<_, Infallible>(response) + } + })) + } + }); + + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let server = Server::from_tcp(std_listener)? + .http1_only(true) + .serve(make_svc) + .with_graceful_shutdown(async move { + let _ = shutdown_rx.await; + }); + + let server_handle = tokio::spawn(server); + + let base_url = format!("http://{addr}/mcp?{}", expected_query); + let manager = AuthorizationManager::new(&base_url).await?; + let metadata = manager.discover_metadata().await?; + + assert_eq!(metadata.authorization_endpoint, authorize_url.as_str()); + assert_eq!(metadata.token_endpoint, token_url.as_str()); + assert_eq!(metadata.registration_endpoint, registration_url.as_str()); + + let _ = shutdown_tx.send(()); + server_handle.await??; + + Ok(()) +} diff --git a/code-rs/third_party/rmcp-0.8.3/.cargo-ok b/code-rs/third_party/rmcp-0.8.3/.cargo-ok new file mode 100644 index 00000000000..5f8b795830a --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/.cargo-ok @@ -0,0 +1 @@ +{"v":1} \ No newline at end of file diff --git a/code-rs/third_party/rmcp-0.8.3/.cargo_vcs_info.json b/code-rs/third_party/rmcp-0.8.3/.cargo_vcs_info.json new file mode 100644 index 00000000000..32251420571 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/.cargo_vcs_info.json @@ -0,0 +1,6 @@ +{ + "git": { + "sha1": "9012709079b6b0431863f7d8c4ed9c8c49926c09" + }, + "path_in_vcs": "crates/rmcp" +} \ No newline at end of file diff --git a/code-rs/third_party/rmcp-0.8.3/CHANGELOG.md b/code-rs/third_party/rmcp-0.8.3/CHANGELOG.md new file mode 100644 index 00000000000..b78a2d54934 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/CHANGELOG.md @@ -0,0 +1,314 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +## [0.8.3](https://github.com/modelcontextprotocol/rust-sdk/compare/rmcp-v0.8.2...rmcp-v0.8.3) - 2025-10-22 + +### Fixed + +- accept 204 in addition to 202 on initialize ([#497](https://github.com/modelcontextprotocol/rust-sdk/pull/497)) + +## [0.8.2](https://github.com/modelcontextprotocol/rust-sdk/compare/rmcp-v0.8.1...rmcp-v0.8.2) - 2025-10-21 + +### Added + +- add type-safe elicitation schema support ([#465](https://github.com/modelcontextprotocol/rust-sdk/pull/465)) ([#466](https://github.com/modelcontextprotocol/rust-sdk/pull/466)) +- *(SEP-973)* following change Icon.sizes from string to string array ([#479](https://github.com/modelcontextprotocol/rust-sdk/pull/479)) + +### Fixed + +- *(oauth)* three oauth discovery and registration issues ([#489](https://github.com/modelcontextprotocol/rust-sdk/pull/489)) +- *(oauth)* dynamic client registration should be optional ([#463](https://github.com/modelcontextprotocol/rust-sdk/pull/463)) +- *(sse-client)* consume control frames; refresh message endpoint ([#448](https://github.com/modelcontextprotocol/rust-sdk/pull/448)) + +### Other + +- Streamable HTTP: drain SSE frames until the initialize response, ignoring early notifications to prevent handshake timeouts ([#467](https://github.com/modelcontextprotocol/rust-sdk/pull/467)) +- bump crate version in README.md ([#471](https://github.com/modelcontextprotocol/rust-sdk/pull/471)) + +## [0.8.1](https://github.com/modelcontextprotocol/rust-sdk/compare/rmcp-v0.8.0...rmcp-v0.8.1) - 2025-10-07 + +### Fixed + +- *(oauth)* pass bearer token to all streamable http requests ([#476](https://github.com/modelcontextprotocol/rust-sdk/pull/476)) +- fix spellcheck on intentional typo in CHANGELOG ([#470](https://github.com/modelcontextprotocol/rust-sdk/pull/470)) + +## [0.8.0](https://github.com/modelcontextprotocol/rust-sdk/compare/rmcp-v0.7.0...rmcp-v0.8.0) - 2025-10-04 + +### Added + +- allow clients to override client_name ([#469](https://github.com/modelcontextprotocol/rust-sdk/pull/469)) + +### Fixed + +- *(oauth)* support suffixed and prefixed well-known paths ([#459](https://github.com/modelcontextprotocol/rust-sdk/pull/459)) +- generate default schema for tools with no params ([#446](https://github.com/modelcontextprotocol/rust-sdk/pull/446)) + +### Other + +- bump to rust 1.90.0 ([#453](https://github.com/modelcontextprotocol/rust-sdk/pull/453)) + +## [0.7.0](https://github.com/modelcontextprotocol/rust-sdk/compare/rmcp-v0.6.4...rmcp-v0.7.0) - 2025-09-24 + +### Fixed + +- return auth errors ([#451](https://github.com/modelcontextprotocol/rust-sdk/pull/451)) +- *(oauth)* do not treat empty secret as valid for public clients ([#443](https://github.com/modelcontextprotocol/rust-sdk/pull/443)) +- *(clippy)* add doc comment for generated tool attr fn ([#439](https://github.com/modelcontextprotocol/rust-sdk/pull/439)) +- *(oauth)* require CSRF token as part of the OAuth authorization flow. ([#435](https://github.com/modelcontextprotocol/rust-sdk/pull/435)) + +### Other + +- *(root)* Add Terminator to Built with rmcp section ([#437](https://github.com/modelcontextprotocol/rust-sdk/pull/437)) +- Non-empty paths in OAuth2 Authorization Server Metadata URLs ([#441](https://github.com/modelcontextprotocol/rust-sdk/pull/441)) + +## [0.6.4](https://github.com/modelcontextprotocol/rust-sdk/compare/rmcp-v0.6.3...rmcp-v0.6.4) - 2025-09-11 + +### Added + +- *(SEP-973)* add support for icons and websiteUrl across relevant types ([#432](https://github.com/modelcontextprotocol/rust-sdk/pull/432)) +- implement context-aware completion (MCP 2025-06-18) ([#396](https://github.com/modelcontextprotocol/rust-sdk/pull/396)) +- add `title` field for data types ([#410](https://github.com/modelcontextprotocol/rust-sdk/pull/410)) + +### Fixed + +- crates/rmcp/src/handler/client/progress.rs XXXXXX -> dispatcher ([#429](https://github.com/modelcontextprotocol/rust-sdk/pull/429)) +- build issue due to missing struct field ([#427](https://github.com/modelcontextprotocol/rust-sdk/pull/427)) +- generate simple {} schema for tools with no parameters ([#425](https://github.com/modelcontextprotocol/rust-sdk/pull/425)) + +### Other + +- Skip notification in initialization handshake ([#421](https://github.com/modelcontextprotocol/rust-sdk/pull/421)) +- add nvim-mcp project built by rmcp ([#422](https://github.com/modelcontextprotocol/rust-sdk/pull/422)) + +## [0.6.3](https://github.com/modelcontextprotocol/rust-sdk/compare/rmcp-v0.6.2...rmcp-v0.6.3) - 2025-09-04 + +### Fixed + +- change JSON-RPC request ID type from u32 to i64 ([#416](https://github.com/modelcontextprotocol/rust-sdk/pull/416)) + +## [0.6.2](https://github.com/modelcontextprotocol/rust-sdk/compare/rmcp-v0.6.1...rmcp-v0.6.2) - 2025-09-04 + +### Added + +- *(rmcp)* add optional _meta to CallToolResult, EmbeddedResource, and ResourceContents ([#386](https://github.com/modelcontextprotocol/rust-sdk/pull/386)) + +### Fixed + +- resolve compatibility issues with servers sending LSP notifications ([#413](https://github.com/modelcontextprotocol/rust-sdk/pull/413)) +- remove batched json rpc support ([#408](https://github.com/modelcontextprotocol/rust-sdk/pull/408)) +- transport-streamable-http-server depends on transport-worker ([#405](https://github.com/modelcontextprotocol/rust-sdk/pull/405)) +- *(typo)* correct typo in error message for transport cancellation and field. ([#404](https://github.com/modelcontextprotocol/rust-sdk/pull/404)) + +### Other + +- Spec conformance: meta support and spec updates ([#415](https://github.com/modelcontextprotocol/rust-sdk/pull/415)) +- add the rmcp-openapi and rmcp-actix-web related projects ([#406](https://github.com/modelcontextprotocol/rust-sdk/pull/406)) + +## [0.6.1](https://github.com/modelcontextprotocol/rust-sdk/compare/rmcp-v0.6.0...rmcp-v0.6.1) - 2025-08-29 + +### Added + +- *(rmcp)* add authorization header support for the streamable http client ([#390](https://github.com/modelcontextprotocol/rust-sdk/pull/390)) +- *(model)* add helpers to build enum from concrete values ([#393](https://github.com/modelcontextprotocol/rust-sdk/pull/393)) +- *(model)* expose client method name ([#391](https://github.com/modelcontextprotocol/rust-sdk/pull/391)) +- add resource_link support to tools and prompts ([#381](https://github.com/modelcontextprotocol/rust-sdk/pull/381)) +- Add prompt support ([#351](https://github.com/modelcontextprotocol/rust-sdk/pull/351)) +- include reqwest in transport-streamble-http-client feature ([#376](https://github.com/modelcontextprotocol/rust-sdk/pull/376)) + +### Fixed + +- *(auth)* url parse is not correct ([#402](https://github.com/modelcontextprotocol/rust-sdk/pull/402)) +- *(readme)* missing use declarations, more accurate server instructions ([#399](https://github.com/modelcontextprotocol/rust-sdk/pull/399)) +- enhance transport graceful shutdown with proper writer closure ([#392](https://github.com/modelcontextprotocol/rust-sdk/pull/392)) + +### Other + +- simplify remove_route method signature ([#401](https://github.com/modelcontextprotocol/rust-sdk/pull/401)) + +## [0.6.0](https://github.com/modelcontextprotocol/rust-sdk/compare/rmcp-v0.5.0...rmcp-v0.6.0) - 2025-08-19 + +### Added + +- Add MCP Elicitation support ([#332](https://github.com/modelcontextprotocol/rust-sdk/pull/332)) +- keep internal error in worker's quit reason ([#372](https://github.com/modelcontextprotocol/rust-sdk/pull/372)) + +### Fixed + +- match shape of the calltoolresult schema ([#377](https://github.com/modelcontextprotocol/rust-sdk/pull/377)) +- make stdio shutdown more graceful ([#364](https://github.com/modelcontextprotocol/rust-sdk/pull/364)) +- *(tool)* remove unnecessary schema validation ([#375](https://github.com/modelcontextprotocol/rust-sdk/pull/375)) +- *(rmcp)* return serialized json with structured content ([#368](https://github.com/modelcontextprotocol/rust-sdk/pull/368)) + +### Other + +- add related project rustfs-mcp ([#378](https://github.com/modelcontextprotocol/rust-sdk/pull/378)) +- *(streamable)* add document for extracting http info ([#373](https://github.com/modelcontextprotocol/rust-sdk/pull/373)) + +## [0.5.0](https://github.com/modelcontextprotocol/rust-sdk/compare/rmcp-v0.4.1...rmcp-v0.5.0) - 2025-08-07 + +### Fixed + +- correct numeric types in progress notifications ([#361](https://github.com/modelcontextprotocol/rust-sdk/pull/361)) + +## [0.4.1](https://github.com/modelcontextprotocol/rust-sdk/compare/rmcp-v0.4.0...rmcp-v0.4.1) - 2025-08-07 + +### Fixed + +- *(rmcp)* allow both content and structured content ([#359](https://github.com/modelcontextprotocol/rust-sdk/pull/359)) + +## [0.4.0](https://github.com/modelcontextprotocol/rust-sdk/compare/rmcp-v0.3.2...rmcp-v0.4.0) - 2025-08-05 + +### Added + +- [**breaking**] Add support for `Tool.outputSchema` and `CallToolResult.structuredContent` ([#316](https://github.com/modelcontextprotocol/rust-sdk/pull/316)) + +### Fixed + +- don't wrap errors in streamable http auth client ([#353](https://github.com/modelcontextprotocol/rust-sdk/pull/353)) +- *(prompt)* remove unused code ([#343](https://github.com/modelcontextprotocol/rust-sdk/pull/343)) + +## [0.3.2](https://github.com/modelcontextprotocol/rust-sdk/compare/rmcp-v0.3.1...rmcp-v0.3.2) - 2025-07-30 + +### Fixed + +- *(capabilities)* do not serialize None as null for `list_changed` ([#341](https://github.com/modelcontextprotocol/rust-sdk/pull/341)) +- *(Transport)* close oneshot transport on error ([#340](https://github.com/modelcontextprotocol/rust-sdk/pull/340)) +- *(oauth)* expose OAuthTokenResponse publicly ([#335](https://github.com/modelcontextprotocol/rust-sdk/pull/335)) + +## [0.3.1](https://github.com/modelcontextprotocol/rust-sdk/compare/rmcp-v0.3.0...rmcp-v0.3.1) - 2025-07-29 + +### Fixed + +- use mimeType instead of mime_type for MCP specification compliance ([#339](https://github.com/modelcontextprotocol/rust-sdk/pull/339)) +- return a 405 for GET and DELETE if stateful_mode=false ([#331](https://github.com/modelcontextprotocol/rust-sdk/pull/331)) +- propagate tracing spans when spawning new tokio tasks ([#334](https://github.com/modelcontextprotocol/rust-sdk/pull/334)) +- Explicitly added client_id as an extra parameter causes bad token requests ([#322](https://github.com/modelcontextprotocol/rust-sdk/pull/322)) + +### Other + +- Fix formatting in crate descriptions in README.md ([#333](https://github.com/modelcontextprotocol/rust-sdk/pull/333)) + +## [0.3.0](https://github.com/modelcontextprotocol/rust-sdk/compare/rmcp-v0.2.1...rmcp-v0.3.0) - 2025-07-15 + +### Added + +- unified error type ([#308](https://github.com/modelcontextprotocol/rust-sdk/pull/308)) +- *(transport)* add builder & expose child stderr ([#305](https://github.com/modelcontextprotocol/rust-sdk/pull/305)) + +### Other + +- *(deps)* update schemars requirement from 0.8 to 1.0 ([#258](https://github.com/modelcontextprotocol/rust-sdk/pull/258)) +- *(rmcp)* bump rmcp-macros version to match ([#311](https://github.com/modelcontextprotocol/rust-sdk/pull/311)) +- fix packages used for server example ([#309](https://github.com/modelcontextprotocol/rust-sdk/pull/309)) + +## [0.2.1](https://github.com/modelcontextprotocol/rust-sdk/compare/rmcp-v0.2.0...rmcp-v0.2.1) - 2025-07-03 + +### Other + +- *(docs)* Minor README updates ([#301](https://github.com/modelcontextprotocol/rust-sdk/pull/301)) + +## [0.2.0](https://github.com/modelcontextprotocol/rust-sdk/compare/rmcp-v0.1.5...rmcp-v0.2.0) - 2025-07-02 + +### Added + +- mark boxed http body as sync ([#291](https://github.com/modelcontextprotocol/rust-sdk/pull/291)) +- add progress notification handling and related structures ([#282](https://github.com/modelcontextprotocol/rust-sdk/pull/282)) +- allow failable service creation in streamable HTTP tower service ([#244](https://github.com/modelcontextprotocol/rust-sdk/pull/244)) +- provide more context information ([#236](https://github.com/modelcontextprotocol/rust-sdk/pull/236)) +- stateless mode of streamable http client ([#233](https://github.com/modelcontextprotocol/rust-sdk/pull/233)) +- add cancellation_token method to `RunningService` ([#218](https://github.com/modelcontextprotocol/rust-sdk/pull/218)) +- better http server support ([#199](https://github.com/modelcontextprotocol/rust-sdk/pull/199)) +- throw initialize error detail ([#192](https://github.com/modelcontextprotocol/rust-sdk/pull/192)) +- sse client optionally skip the endpoint event ([#187](https://github.com/modelcontextprotocol/rust-sdk/pull/187)) +- *(server)* add annotation to tool macro ([#184](https://github.com/modelcontextprotocol/rust-sdk/pull/184)) +- *(model)* add json schema generation support for all model types ([#176](https://github.com/modelcontextprotocol/rust-sdk/pull/176)) +- *(openapi)* add OpenAPI v3 compatibility and test for nullable field schema workaround ([#135](https://github.com/modelcontextprotocol/rust-sdk/pull/135)) ([#137](https://github.com/modelcontextprotocol/rust-sdk/pull/137)) +- *(extension)* extract http request part into rmcp extension ([#163](https://github.com/modelcontextprotocol/rust-sdk/pull/163)) +- *(transport)* support streamable http server ([#152](https://github.com/modelcontextprotocol/rust-sdk/pull/152)) +- *(oauth)* fixes + cache client credentials ([#157](https://github.com/modelcontextprotocol/rust-sdk/pull/157)) +- allow use of reqwest without ring provider ([#155](https://github.com/modelcontextprotocol/rust-sdk/pull/155)) +- extensions to context ([#102](https://github.com/modelcontextprotocol/rust-sdk/pull/102)) +- revision-2025-03-26 without streamable http ([#84](https://github.com/modelcontextprotocol/rust-sdk/pull/84)) +- *(tool)* allow tool call return a serializable value in json format ([#75](https://github.com/modelcontextprotocol/rust-sdk/pull/75)) ([#78](https://github.com/modelcontextprotocol/rust-sdk/pull/78)) +- Sse server auto ping ([#74](https://github.com/modelcontextprotocol/rust-sdk/pull/74)) +- *(transport)* Sse client transport trait ([#67](https://github.com/modelcontextprotocol/rust-sdk/pull/67)) + +### Fixed + +- let users decide what to wrap in child process command ([#279](https://github.com/modelcontextprotocol/rust-sdk/pull/279)) +- cancellable initialization process ([#280](https://github.com/modelcontextprotocol/rust-sdk/pull/280)) +- inject part into extension when handing init req ([#275](https://github.com/modelcontextprotocol/rust-sdk/pull/275)) +- streamable http server close request channel on response([#266](https://github.com/modelcontextprotocol/rust-sdk/pull/266)) ([#270](https://github.com/modelcontextprotocol/rust-sdk/pull/270)) +- streamable http client close on response ([#268](https://github.com/modelcontextprotocol/rust-sdk/pull/268)) +- expose TokioChildWrapper::id() in TokioChildProcess and TokioChildProcessOut ([#254](https://github.com/modelcontextprotocol/rust-sdk/pull/254)) +- add compatibility handling for non-standard notifications in async_rw ([#247](https://github.com/modelcontextprotocol/rust-sdk/pull/247)) +- allow SSE server router to be nested ([#240](https://github.com/modelcontextprotocol/rust-sdk/pull/240)) +- error for status in post method of streamable http client ([#238](https://github.com/modelcontextprotocol/rust-sdk/pull/238)) +- disable wasmbind in chrono for wasm32-unknown-unknown ([#234](https://github.com/modelcontextprotocol/rust-sdk/pull/234)) +- *(examples)* add clients in examples's readme ([#225](https://github.com/modelcontextprotocol/rust-sdk/pull/225)) +- generic ServerHandler ([#223](https://github.com/modelcontextprotocol/rust-sdk/pull/223)) +- comment error ([#215](https://github.com/modelcontextprotocol/rust-sdk/pull/215)) +- resolve the server 406 error in API calls ([#203](https://github.com/modelcontextprotocol/rust-sdk/pull/203)) +- sse endpoint build follow js's `new URL(url, base)` ([#197](https://github.com/modelcontextprotocol/rust-sdk/pull/197)) +- more friendly interface to get service error ([#190](https://github.com/modelcontextprotocol/rust-sdk/pull/190)) +- cleanup zombie processes for child process client ([#156](https://github.com/modelcontextprotocol/rust-sdk/pull/156)) +- *(schemar)* use self-defined settings ([#180](https://github.com/modelcontextprotocol/rust-sdk/pull/180)) +- *(transport-sse-server)* cleanup on connection drop ([#165](https://github.com/modelcontextprotocol/rust-sdk/pull/165)) +- *(test)* skip serialize tool's annotation if empty ([#160](https://github.com/modelcontextprotocol/rust-sdk/pull/160)) +- fix resource leak ([#136](https://github.com/modelcontextprotocol/rust-sdk/pull/136)) +- *(handler)* do call handler methods when initialize server ([#118](https://github.com/modelcontextprotocol/rust-sdk/pull/118)) +- *(server)* schemars compilation errors ([#104](https://github.com/modelcontextprotocol/rust-sdk/pull/104)) +- *(test)* fix test introduced by #97 ([#101](https://github.com/modelcontextprotocol/rust-sdk/pull/101)) +- *(macro)* add generics marco types support ([#98](https://github.com/modelcontextprotocol/rust-sdk/pull/98)) +- *(typo)* nit language corrections ([#90](https://github.com/modelcontextprotocol/rust-sdk/pull/90)) +- *(typo)* s/marcos/macros/ ([#85](https://github.com/modelcontextprotocol/rust-sdk/pull/85)) +- *(client)* add error enum while deal client info ([#76](https://github.com/modelcontextprotocol/rust-sdk/pull/76)) +- *(notification)* fix wrongly error report in notification ([#70](https://github.com/modelcontextprotocol/rust-sdk/pull/70)) +- *(test)* fix tool deserialization error ([#68](https://github.com/modelcontextprotocol/rust-sdk/pull/68)) +- *(server)* add error enum while deal server info ([#51](https://github.com/modelcontextprotocol/rust-sdk/pull/51)) + +### Other + +- add simpling example and test ([#289](https://github.com/modelcontextprotocol/rust-sdk/pull/289)) +- add update for test_message_schema ([#286](https://github.com/modelcontextprotocol/rust-sdk/pull/286)) +- add notion clear in model.rs ([#284](https://github.com/modelcontextprotocol/rust-sdk/pull/284)) +- cov settings, and fix several building warnings ([#281](https://github.com/modelcontextprotocol/rust-sdk/pull/281)) +- refactor tool macros and router implementation ([#261](https://github.com/modelcontextprotocol/rust-sdk/pull/261)) +- fix regression in URL joining ([#265](https://github.com/modelcontextprotocol/rust-sdk/pull/265)) +- remove erroneous definitions_path ([#264](https://github.com/modelcontextprotocol/rust-sdk/pull/264)) +- allow using a TokioCommandWrap for TokioChildProcess::new closes #243 ([#245](https://github.com/modelcontextprotocol/rust-sdk/pull/245)) +- Fix typo ([#249](https://github.com/modelcontextprotocol/rust-sdk/pull/249)) +- provide http server as tower service ([#228](https://github.com/modelcontextprotocol/rust-sdk/pull/228)) +- *(deps)* update sse-stream requirement from 0.1.4 to 0.2.0 ([#230](https://github.com/modelcontextprotocol/rust-sdk/pull/230)) +- Server info is only retrieved once during initialization ([#214](https://github.com/modelcontextprotocol/rust-sdk/pull/214)) +- *(deps)* update base64 requirement from 0.21 to 0.22 ([#209](https://github.com/modelcontextprotocol/rust-sdk/pull/209)) +- revert badge ([#202](https://github.com/modelcontextprotocol/rust-sdk/pull/202)) +- use hierarchical readme for publishing ([#198](https://github.com/modelcontextprotocol/rust-sdk/pull/198)) +- Ci/coverage badge ([#191](https://github.com/modelcontextprotocol/rust-sdk/pull/191)) +- fix error introduced by merge, and reorganize feature ([#185](https://github.com/modelcontextprotocol/rust-sdk/pull/185)) +- Transport trait and worker transport, and streamable http client with those new features. ([#167](https://github.com/modelcontextprotocol/rust-sdk/pull/167)) +- add oauth2 support ([#130](https://github.com/modelcontextprotocol/rust-sdk/pull/130)) +- remove un-used tower.rs ([#125](https://github.com/modelcontextprotocol/rust-sdk/pull/125)) +- update calculator example description ([#115](https://github.com/modelcontextprotocol/rust-sdk/pull/115)) +- fix the url ([#120](https://github.com/modelcontextprotocol/rust-sdk/pull/120)) +- add a simple chat client for example ([#119](https://github.com/modelcontextprotocol/rust-sdk/pull/119)) +- add an overview to `rmcp/src/lib.rs` ([#116](https://github.com/modelcontextprotocol/rust-sdk/pull/116)) +- *(context)* test context request handling and refactor for reusable client-server tests ([#97](https://github.com/modelcontextprotocol/rust-sdk/pull/97)) +- *(logging)* Add tests for logging ([#96](https://github.com/modelcontextprotocol/rust-sdk/pull/96)) +- Adopt Devcontainer for Development Environment ([#81](https://github.com/modelcontextprotocol/rust-sdk/pull/81)) +- fix typos ([#79](https://github.com/modelcontextprotocol/rust-sdk/pull/79)) +- format and fix typo ([#72](https://github.com/modelcontextprotocol/rust-sdk/pull/72)) +- add documentation generation job ([#59](https://github.com/modelcontextprotocol/rust-sdk/pull/59)) +- add test with js server ([#65](https://github.com/modelcontextprotocol/rust-sdk/pull/65)) +- fmt the project ([#54](https://github.com/modelcontextprotocol/rust-sdk/pull/54)) +- *(sse_server)* separate router and server startup ([#52](https://github.com/modelcontextprotocol/rust-sdk/pull/52)) +- fix broken link ([#53](https://github.com/modelcontextprotocol/rust-sdk/pull/53)) +- fix the branch name for git dependency ([#46](https://github.com/modelcontextprotocol/rust-sdk/pull/46)) +- Move whole rmcp crate to official rust sdk ([#44](https://github.com/modelcontextprotocol/rust-sdk/pull/44)) +- Initial commit diff --git a/code-rs/third_party/rmcp-0.8.3/Cargo.lock b/code-rs/third_party/rmcp-0.8.3/Cargo.lock new file mode 100644 index 00000000000..e05fb56c83b --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/Cargo.lock @@ -0,0 +1,2353 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + +[[package]] +name = "anyhow" +version = "1.0.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" + +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "axum" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a18ed336352031311f4e0b4dd2ff392d4fbb370777c9d18d7fc9d7359f73871" +dependencies = [ + "axum-core", + "bytes", + "form_urlencoded", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "serde_core", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59446ce19cd142f8833f856eb31f3eb097812d1479ab224f54d72428ca21ea22" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "bitflags" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "bumpalo" +version = "3.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" + +[[package]] +name = "bytes" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" + +[[package]] +name = "cc" +version = "1.2.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac9fe6cdbb24b6ade63616c0a0688e45bb56732262c158df3c0c4bea4ca47cb7" +dependencies = [ + "find-msvc-tools", + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + +[[package]] +name = "chrono" +version = "0.4.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "145052bdd345b87320e369255277e3fb5152762ad123a901ef5c262dd38fe8d2" +dependencies = [ + "iana-time-zone", + "js-sys", + "num-traits", + "serde", + "wasm-bindgen", + "windows-link 0.2.1", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "darling" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1247195ecd7e3c85f83c8d2a366e4210d588e802133e1e355180a9870b517ea4" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn", +] + +[[package]] +name = "darling_macro" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" +dependencies = [ + "darling_core", + "quote", + "syn", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "dyn-clone" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "find-msvc-tools" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52051878f80a721bb68ebfbc930e07b65ba72f2da88968ea5c06fd6ca3d3a127" + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "form_urlencoded" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" + +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" + +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" + +[[package]] +name = "futures-task" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" + +[[package]] +name = "futures-util" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + +[[package]] +name = "generic-array" +version = "0.14.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bb6743198531e02858aeaea5398fcc883e71851fcbcb5a2f773e2fb6cb1edf2" +dependencies = [ + "typenum", + "version_check", +] + +[[package]] +name = "getrandom" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" +dependencies = [ + "cfg-if", + "js-sys", + "libc", + "wasi", + "wasm-bindgen", +] + +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "js-sys", + "libc", + "r-efi", + "wasip2", + "wasm-bindgen", +] + +[[package]] +name = "hashbrown" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5419bdc4f6a9207fbeba6d11b604d481addf78ecd10c11ad51e76c2f6482748d" + +[[package]] +name = "http" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + +[[package]] +name = "hyper" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb3aa54a13a0dfe7fbe3a59e0c76093041720fdc77b110cc0fc260fafb4dc51e" +dependencies = [ + "atomic-waker", + "bytes", + "futures-channel", + "futures-core", + "http", + "http-body", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "pin-utils", + "smallvec", + "tokio", + "want", +] + +[[package]] +name = "hyper-rustls" +version = "0.27.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +dependencies = [ + "http", + "hyper", + "hyper-util", + "rustls", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", + "webpki-roots", +] + +[[package]] +name = "hyper-util" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c6995591a8f1380fcb4ba966a252a4b29188d51d2b89e3a252f5305be65aea8" +dependencies = [ + "base64", + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "http", + "http-body", + "hyper", + "ipnet", + "libc", + "percent-encoding", + "pin-project-lite", + "socket2", + "tokio", + "tower-service", + "tracing", +] + +[[package]] +name = "iana-time-zone" +version = "0.1.64" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33e57f83510bb73707521ebaffa789ec8caf86f9657cad665b092b581d40e9fb" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "log", + "wasm-bindgen", + "windows-core 0.62.2", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "icu_collections" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "200072f5d0e3614556f94a9930d5dc3e0662a652823904c3a75dc3b0af7fee47" +dependencies = [ + "displaydoc", + "potential_utf", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locale_core" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cde2700ccaed3872079a65fb1a78f6c0a36c91570f28755dda67bc8f7d9f00a" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_normalizer" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "436880e8e18df4d7bbc06d58432329d6458cc84531f7ac5f024e93deadb37979" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00210d6893afc98edb752b664b8890f0ef174c8adbb8d0be9710fa66fbbf72d3" + +[[package]] +name = "icu_properties" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "016c619c1eeb94efb86809b015c58f479963de65bdb6253345c1a1276f22e32b" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_locale_core", + "icu_properties_data", + "icu_provider", + "potential_utf", + "zerotrie", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "298459143998310acd25ffe6810ed544932242d3f07083eee1084d83a71bd632" + +[[package]] +name = "icu_provider" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03c80da27b5f4187909049ee2d72f276f0d9f99a42c306bd0131ecfe04d8e5af" +dependencies = [ + "displaydoc", + "icu_locale_core", + "stable_deref_trait", + "tinystr", + "writeable", + "yoke", + "zerofrom", + "zerotrie", + "zerovec", +] + +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + +[[package]] +name = "idna" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" +dependencies = [ + "icu_normalizer", + "icu_properties", +] + +[[package]] +name = "indexmap" +version = "2.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6717a8d2a5a929a1a2eb43a12812498ed141a0bcfb7e8f7844fbdbe4303bba9f" +dependencies = [ + "equivalent", + "hashbrown", +] + +[[package]] +name = "ipnet" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" + +[[package]] +name = "iri-string" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbc5ebe9c3a1a7a5127f920a418f7585e9e758e911d0466ed004f393b0e380b2" +dependencies = [ + "memchr", + "serde", +] + +[[package]] +name = "itoa" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" + +[[package]] +name = "js-sys" +version = "0.3.81" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec48937a97411dcb524a265206ccd4c90bb711fca92b2792c407f268825b9305" +dependencies = [ + "once_cell", + "wasm-bindgen", +] + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + +[[package]] +name = "libc" +version = "0.2.177" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976" + +[[package]] +name = "litemap" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" + +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432" + +[[package]] +name = "lru-slab" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" + +[[package]] +name = "matchers" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" +dependencies = [ + "regex-automata", +] + +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + +[[package]] +name = "memchr" +version = "2.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" + +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + +[[package]] +name = "mio" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69d83b0086dc8ecf3ce9ae2874b2d1290252e2a30720bea58a5c6639b0092873" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.61.2", +] + +[[package]] +name = "nix" +version = "0.30.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74523f3a35e05aba87a1d978330aef40f67b0304ac79c1c00b294c9830543db6" +dependencies = [ + "bitflags", + "cfg-if", + "cfg_aliases", + "libc", +] + +[[package]] +name = "nu-ansi-term" +version = "0.50.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "oauth2" +version = "5.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51e219e79014df21a225b1860a479e2dcd7cbd9130f4defd4bd0e191ea31d67d" +dependencies = [ + "base64", + "chrono", + "getrandom 0.2.16", + "http", + "rand 0.8.5", + "reqwest", + "serde", + "serde_json", + "serde_path_to_error", + "sha2", + "thiserror 1.0.69", + "url", +] + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link 0.2.1", +] + +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "potential_utf" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84df19adbe5b5a0782edcab45899906947ab039ccf4573713735ee7de1e6b08a" +dependencies = [ + "zerovec", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "proc-macro2" +version = "1.0.101" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "process-wrap" +version = "8.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3ef4f2f0422f23a82ec9f628ea2acd12871c81a9362b02c43c1aa86acfc3ba1" +dependencies = [ + "futures", + "indexmap", + "nix", + "tokio", + "tracing", + "windows", +] + +[[package]] +name = "quinn" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" +dependencies = [ + "bytes", + "cfg_aliases", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash", + "rustls", + "socket2", + "thiserror 2.0.17", + "tokio", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-proto" +version = "0.11.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" +dependencies = [ + "bytes", + "getrandom 0.3.4", + "lru-slab", + "rand 0.9.2", + "ring", + "rustc-hash", + "rustls", + "rustls-pki-types", + "slab", + "thiserror 2.0.17", + "tinyvec", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-udp" +version = "0.5.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" +dependencies = [ + "cfg_aliases", + "libc", + "once_cell", + "socket2", + "tracing", + "windows-sys 0.60.2", +] + +[[package]] +name = "quote" +version = "1.0.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce25767e7b499d1b604768e7cde645d14cc8584231ea6b295e9c9eb22c02e1d1" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.3", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.3", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.16", +] + +[[package]] +name = "rand_core" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" +dependencies = [ + "getrandom 0.3.4", +] + +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + +[[package]] +name = "ref-cast" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f354300ae66f76f1c85c5f84693f0ce81d747e2c3f21a45fef496d89c960bf7d" +dependencies = [ + "ref-cast-impl", +] + +[[package]] +name = "ref-cast-impl" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7186006dcb21920990093f30e3dea63b7d6e977bf1256be20c3563a5db070da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "regex-automata" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5276caf25ac86c8d810222b3dbb938e512c55c6831a10f3e6ed1c93b84041f1c" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" + +[[package]] +name = "reqwest" +version = "0.12.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d0946410b9f7b082a427e4ef5c8ff541a88b357bc6c637c40db3a68ac70a36f" +dependencies = [ + "base64", + "bytes", + "futures-core", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-rustls", + "hyper-util", + "js-sys", + "log", + "percent-encoding", + "pin-project-lite", + "quinn", + "rustls", + "rustls-pki-types", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tokio-rustls", + "tokio-util", + "tower", + "tower-http", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "wasm-streams", + "web-sys", + "webpki-roots", +] + +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.16", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + +[[package]] +name = "rmcp" +version = "0.8.3" +dependencies = [ + "anyhow", + "async-trait", + "axum", + "base64", + "bytes", + "chrono", + "futures", + "http", + "http-body", + "http-body-util", + "oauth2", + "paste", + "pin-project-lite", + "process-wrap", + "rand 0.9.2", + "reqwest", + "rmcp-macros", + "schemars", + "serde", + "serde_json", + "sse-stream", + "thiserror 2.0.17", + "tokio", + "tokio-stream", + "tokio-util", + "tower-service", + "tracing", + "tracing-subscriber", + "url", + "uuid", +] + +[[package]] +name = "rmcp-macros" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ede0589a208cc7ce81d1be68aa7e74b917fcd03c81528408bab0457e187dcd9b" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "serde_json", + "syn", +] + +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + +[[package]] +name = "rustls" +version = "0.23.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a9586e9ee2b4f8fab52a0048ca7334d7024eef48e2cb9407e3497bb7cab7fa7" +dependencies = [ + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pki-types" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79" +dependencies = [ + "web-time", + "zeroize", +] + +[[package]] +name = "rustls-webpki" +version = "0.103.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e10b3f4191e8a80e6b43eebabfac91e5dcecebb27a71f04e820c47ec41d314bf" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "ryu" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" + +[[package]] +name = "schemars" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82d20c4491bc164fa2f6c5d44565947a52ad80b9505d8e36f8d54c27c739fcd0" +dependencies = [ + "chrono", + "dyn-clone", + "ref-cast", + "schemars_derive", + "serde", + "serde_json", +] + +[[package]] +name = "schemars_derive" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33d020396d1d138dc19f1165df7545479dcd58d93810dc5d646a16e55abefa80" +dependencies = [ + "proc-macro2", + "quote", + "serde_derive_internals", + "syn", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_derive_internals" +version = "0.29.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.145" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", + "serde_core", +] + +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "signal-hook-registry" +version = "1.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a4719bff48cee6b39d12c020eeb490953ad2443b7055bd0b21fca26bd8c28b" +dependencies = [ + "libc", +] + +[[package]] +name = "slab" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "socket2" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17129e116933cf371d018bb80ae557e889637989d8638274fb25622827b03881" +dependencies = [ + "libc", + "windows-sys 0.60.2", +] + +[[package]] +name = "sse-stream" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb4dc4d33c68ec1f27d386b5610a351922656e1fdf5c05bbaad930cd1519479a" +dependencies = [ + "bytes", + "futures-util", + "http-body", + "http-body-util", + "pin-project-lite", +] + +[[package]] +name = "stable_deref_trait" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" + +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + +[[package]] +name = "syn" +version = "2.0.107" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a26dbd934e5451d21ef060c018dae56fc073894c5a7896f882928a76e6d081b" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" +dependencies = [ + "futures-core", +] + +[[package]] +name = "synstructure" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8" +dependencies = [ + "thiserror-impl 2.0.17", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "thread_local" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "tinystr" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d4f6d1145dcb577acf783d4e601bc1d76a13337bb54e6233add580b07344c8b" +dependencies = [ + "displaydoc", + "zerovec", +] + +[[package]] +name = "tinyvec" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + +[[package]] +name = "tokio" +version = "1.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff360e02eab121e0bc37a2d3b4d4dc622e6eda3a8e5253d5435ecf5bd4c68408" +dependencies = [ + "bytes", + "libc", + "mio", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys 0.61.2", +] + +[[package]] +name = "tokio-macros" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + +[[package]] +name = "tokio-stream" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tokio-util" +version = "0.7.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14307c986784f72ef81c89db7d9e28d6ac26d16213b109ea501696195e6e3ce5" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tower" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-http" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adc82fd73de2a9722ac5da747f12383d2bfdb93591ee6c58486e0097890f05f2" +dependencies = [ + "bitflags", + "bytes", + "futures-util", + "http", + "http-body", + "iri-string", + "pin-project-lite", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + +[[package]] +name = "tracing" +version = "0.1.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +dependencies = [ + "log", + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tracing-core" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" +dependencies = [ + "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex-automata", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", +] + +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + +[[package]] +name = "unicode-ident" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "462eeb75aeb73aea900253ce739c8e18a67423fadf006037cd3ff27e82748a06" + +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "url" +version = "2.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08bc136a29a3d1758e07a9cca267be308aeebf5cfd5a10f3f67ab2097683ef5b" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", + "serde", +] + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + +[[package]] +name = "uuid" +version = "1.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f87b8aa10b915a06587d0dec516c282ff295b475d94abf425d62b57710070a2" +dependencies = [ + "getrandom 0.3.4", + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "wasip2" +version = "1.0.1+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasm-bindgen" +version = "0.2.104" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1da10c01ae9f1ae40cbfac0bac3b1e724b320abfcf52229f80b547c0d250e2d" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.104" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "671c9a5a66f49d8a47345ab942e2cb93c7d1d0339065d4f8139c486121b43b19" +dependencies = [ + "bumpalo", + "log", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e038d41e478cc73bae0ff9b36c60cff1c98b8f38f8d7e8061e79ee63608ac5c" +dependencies = [ + "cfg-if", + "js-sys", + "once_cell", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.104" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ca60477e4c59f5f2986c50191cd972e3a50d8a95603bc9434501cf156a9a119" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.104" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f07d2f20d4da7b26400c9f4a0511e6e0345b040694e8a75bd41d578fa4421d7" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.104" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bad67dc8b2a1a6e5448428adec4c3e84c43e561d8c9ee8a9e5aabeb193ec41d1" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + +[[package]] +name = "web-sys" +version = "0.3.81" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9367c417a924a74cae129e6a2ae3b47fabb1f8995595ab474029da749a8be120" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "webpki-roots" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32b130c0d2d49f8b6889abc456e795e82525204f27c42cf767cf0d7734e089b8" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "windows" +version = "0.61.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9babd3a767a4c1aef6900409f85f5d53ce2544ccdfaa86dad48c91782c6d6893" +dependencies = [ + "windows-collections", + "windows-core 0.61.2", + "windows-future", + "windows-link 0.1.3", + "windows-numerics", +] + +[[package]] +name = "windows-collections" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3beeceb5e5cfd9eb1d76b381630e82c4241ccd0d27f1a39ed41b2760b255c5e8" +dependencies = [ + "windows-core 0.61.2", +] + +[[package]] +name = "windows-core" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link 0.1.3", + "windows-result 0.3.4", + "windows-strings 0.4.2", +] + +[[package]] +name = "windows-core" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link 0.2.1", + "windows-result 0.4.1", + "windows-strings 0.5.1", +] + +[[package]] +name = "windows-future" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc6a41e98427b19fe4b73c550f060b59fa592d7d686537eebf9385621bfbad8e" +dependencies = [ + "windows-core 0.61.2", + "windows-link 0.1.3", + "windows-threading", +] + +[[package]] +name = "windows-implement" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-interface" +version = "0.59.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-link" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-numerics" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9150af68066c4c5c07ddc0ce30421554771e528bde427614c61038bc2c92c2b1" +dependencies = [ + "windows-core 0.61.2", + "windows-link 0.1.3", +] + +[[package]] +name = "windows-result" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" +dependencies = [ + "windows-link 0.1.3", +] + +[[package]] +name = "windows-result" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" +dependencies = [ + "windows-link 0.2.1", +] + +[[package]] +name = "windows-strings" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" +dependencies = [ + "windows-link 0.1.3", +] + +[[package]] +name = "windows-strings" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" +dependencies = [ + "windows-link 0.2.1", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets 0.53.5", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link 0.2.1", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm 0.52.6", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows-targets" +version = "0.53.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" +dependencies = [ + "windows-link 0.2.1", + "windows_aarch64_gnullvm 0.53.1", + "windows_aarch64_msvc 0.53.1", + "windows_i686_gnu 0.53.1", + "windows_i686_gnullvm 0.53.1", + "windows_i686_msvc 0.53.1", + "windows_x86_64_gnu 0.53.1", + "windows_x86_64_gnullvm 0.53.1", + "windows_x86_64_msvc 0.53.1", +] + +[[package]] +name = "windows-threading" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b66463ad2e0ea3bbf808b7f1d371311c80e115c0b71d60efc142cafbcfb057a6" +dependencies = [ + "windows-link 0.1.3", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_i686_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" + +[[package]] +name = "wit-bindgen" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" + +[[package]] +name = "writeable" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" + +[[package]] +name = "yoke" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f41bb01b8226ef4bfd589436a297c53d118f65921786300e427be8d487695cc" +dependencies = [ + "serde", + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zerocopy" +version = "0.8.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0894878a5fa3edfd6da3f88c4805f4c8558e2b996227a3d864f47fe11e38282c" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zerofrom" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + +[[package]] +name = "zerotrie" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36f0bbd478583f79edad978b407914f61b2972f5af6fa089686016be8f9af595" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] + +[[package]] +name = "zerovec" +version = "0.11.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7aa2bd55086f1ab526693ecbe444205da57e25f4489879da80635a46d90e73b" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/code-rs/third_party/rmcp-0.8.3/Cargo.toml b/code-rs/third_party/rmcp-0.8.3/Cargo.toml new file mode 100644 index 00000000000..445f1d7edad --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/Cargo.toml @@ -0,0 +1,436 @@ +# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO +# +# When uploading crates to the registry Cargo will automatically +# "normalize" Cargo.toml files for maximal compatibility +# with all versions of Cargo and also rewrite `path` dependencies +# to registry (e.g., crates.io) dependencies. +# +# If you are reading this file be aware that the original Cargo.toml +# will likely look very different (and much more reasonable). +# See Cargo.toml.orig for the original contents. + +[package] +edition = "2024" +name = "rmcp" +version = "0.8.3" +build = false +autolib = false +autobins = false +autoexamples = false +autotests = false +autobenches = false +description = "Rust SDK for Model Context Protocol" +homepage = "https://github.com/modelcontextprotocol/rust-sdk" +documentation = "https://docs.rs/rmcp" +readme = "README.md" +license = "MIT" +repository = "https://github.com/modelcontextprotocol/rust-sdk/" +resolver = "2" + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = [ + "--cfg", + "docsrs", +] + +[features] +__reqwest = ["dep:reqwest"] +auth = [ + "dep:oauth2", + "__reqwest", + "dep:url", +] +client = ["dep:tokio-stream"] +client-side-sse = [ + "dep:sse-stream", + "dep:http", +] +default = [ + "base64", + "macros", + "server", +] +elicitation = [] +macros = [ + "dep:rmcp-macros", + "dep:paste", +] +reqwest = [ + "__reqwest", + "reqwest?/rustls-tls", +] +reqwest-tls-no-provider = [ + "__reqwest", + "reqwest?/rustls-tls-no-provider", +] +schemars = ["dep:schemars"] +server = [ + "transport-async-rw", + "dep:schemars", +] +server-side-http = [ + "uuid", + "dep:rand", + "dep:tokio-stream", + "dep:http", + "dep:http-body", + "dep:http-body-util", + "dep:bytes", + "dep:sse-stream", + "tower", +] +tower = ["dep:tower-service"] +transport-async-rw = [ + "tokio/io-util", + "tokio-util/codec", +] +transport-child-process = [ + "transport-async-rw", + "tokio/process", + "dep:process-wrap", +] +transport-io = [ + "transport-async-rw", + "tokio/io-std", +] +transport-sse-client = [ + "client-side-sse", + "transport-worker", +] +transport-sse-client-reqwest = [ + "transport-sse-client", + "reqwest", +] +transport-sse-server = [ + "transport-async-rw", + "transport-worker", + "server-side-http", + "dep:axum", +] +transport-streamable-http-client = [ + "client-side-sse", + "transport-worker", +] +transport-streamable-http-client-reqwest = [ + "transport-streamable-http-client", + "reqwest", +] +transport-streamable-http-server = [ + "transport-streamable-http-server-session", + "server-side-http", + "transport-worker", +] +transport-streamable-http-server-session = [ + "transport-async-rw", + "dep:tokio-stream", +] +transport-worker = ["dep:tokio-stream"] + +[lib] +name = "rmcp" +path = "src/lib.rs" + +[[test]] +name = "test_completion" +path = "tests/test_completion.rs" + +[[test]] +name = "test_complex_schema" +path = "tests/test_complex_schema.rs" + +[[test]] +name = "test_deserialization" +path = "tests/test_deserialization.rs" + +[[test]] +name = "test_elicitation" +path = "tests/test_elicitation.rs" +required-features = [ + "elicitation", + "client", + "server", +] + +[[test]] +name = "test_embedded_resource_meta" +path = "tests/test_embedded_resource_meta.rs" + +[[test]] +name = "test_json_schema_detection" +path = "tests/test_json_schema_detection.rs" + +[[test]] +name = "test_logging" +path = "tests/test_logging.rs" +required-features = [ + "server", + "client", +] + +[[test]] +name = "test_message_protocol" +path = "tests/test_message_protocol.rs" +required-features = ["client"] + +[[test]] +name = "test_message_schema" +path = "tests/test_message_schema.rs" +required-features = [ + "server", + "client", + "schemars", +] + +[[test]] +name = "test_notification" +path = "tests/test_notification.rs" +required-features = [ + "server", + "client", +] + +[[test]] +name = "test_progress_subscriber" +path = "tests/test_progress_subscriber.rs" +required-features = [ + "server", + "client", + "macros", +] + +[[test]] +name = "test_prompt_handler" +path = "tests/test_prompt_handler.rs" + +[[test]] +name = "test_prompt_macro_annotations" +path = "tests/test_prompt_macro_annotations.rs" + +[[test]] +name = "test_prompt_macros" +path = "tests/test_prompt_macros.rs" + +[[test]] +name = "test_prompt_routers" +path = "tests/test_prompt_routers.rs" + +[[test]] +name = "test_resource_link" +path = "tests/test_resource_link.rs" + +[[test]] +name = "test_resource_link_integration" +path = "tests/test_resource_link_integration.rs" + +[[test]] +name = "test_sampling" +path = "tests/test_sampling.rs" + +[[test]] +name = "test_structured_output" +path = "tests/test_structured_output.rs" + +[[test]] +name = "test_tool_builder_methods" +path = "tests/test_tool_builder_methods.rs" + +[[test]] +name = "test_tool_handler" +path = "tests/test_tool_handler.rs" + +[[test]] +name = "test_tool_macro_annotations" +path = "tests/test_tool_macro_annotations.rs" + +[[test]] +name = "test_tool_macros" +path = "tests/test_tool_macros.rs" +required-features = [ + "server", + "client", +] + +[[test]] +name = "test_tool_result_meta" +path = "tests/test_tool_result_meta.rs" + +[[test]] +name = "test_tool_routers" +path = "tests/test_tool_routers.rs" + +[[test]] +name = "test_with_js" +path = "tests/test_with_js.rs" +required-features = [ + "server", + "client", + "transport-sse-server", + "transport-child-process", + "transport-streamable-http-server", + "transport-streamable-http-client", + "__reqwest", +] + +[[test]] +name = "test_with_python" +path = "tests/test_with_python.rs" +required-features = [ + "reqwest", + "server", + "client", + "transport-sse-server", + "transport-sse-client", + "transport-child-process", +] + +[dependencies.axum] +version = "0.8" +features = [] +optional = true + +[dependencies.base64] +version = "0.22" +optional = true + +[dependencies.bytes] +version = "1" +optional = true + +[dependencies.futures] +version = "0.3" + +[dependencies.http] +version = "1" +optional = true + +[dependencies.http-body] +version = "1" +optional = true + +[dependencies.http-body-util] +version = "0.1" +optional = true + +[dependencies.oauth2] +version = "5.0" +optional = true + +[dependencies.paste] +version = "1" +optional = true + +[dependencies.pin-project-lite] +version = "0.2" + +[dependencies.process-wrap] +version = "8.2" +features = ["tokio1"] +optional = true + +[dependencies.rand] +version = "0.9" +optional = true + +[dependencies.reqwest] +version = "0.12" +features = [ + "json", + "stream", +] +optional = true +default-features = false + +[dependencies.rmcp-macros] +version = "0.8.3" +optional = true + +[dependencies.schemars] +version = "1.0" +features = ["chrono04"] +optional = true + +[dependencies.serde] +version = "1.0" +features = [ + "derive", + "rc", +] + +[dependencies.serde_json] +version = "1.0" + +[dependencies.sse-stream] +version = "0.2" +optional = true + +[dependencies.thiserror] +version = "2" + +[dependencies.tokio] +version = "1" +features = [ + "sync", + "macros", + "rt", + "time", +] + +[dependencies.tokio-stream] +version = "0.1" +optional = true + +[dependencies.tokio-util] +version = "0.7" + +[dependencies.tower-service] +version = "0.3" +optional = true + +[dependencies.tracing] +version = "0.1" + +[dependencies.url] +version = "2.4" +optional = true + +[dependencies.uuid] +version = "1" +features = ["v4"] +optional = true + +[dev-dependencies.anyhow] +version = "1.0" + +[dev-dependencies.async-trait] +version = "0.1" + +[dev-dependencies.schemars] +version = "1.0" +features = ["chrono04"] + +[dev-dependencies.tokio] +version = "1" +features = ["full"] + +[dev-dependencies.tracing-subscriber] +version = "0.3" +features = [ + "env-filter", + "std", + "fmt", +] + +[target.'cfg(all(target_family = "wasm", target_os = "unknown"))'.dependencies.chrono] +version = "0.4.38" +features = [ + "serde", + "clock", + "std", + "oldtime", +] +default-features = false + +[target.'cfg(not(all(target_family = "wasm", target_os = "unknown")))'.dependencies.chrono] +version = "0.4.38" +features = ["serde"] diff --git a/code-rs/third_party/rmcp-0.8.3/Cargo.toml.orig b/code-rs/third_party/rmcp-0.8.3/Cargo.toml.orig new file mode 100644 index 00000000000..514f4a08a4b --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/Cargo.toml.orig @@ -0,0 +1,212 @@ +[package] +name = "rmcp" +license = { workspace = true } +version = { workspace = true } +edition = { workspace = true } +repository = { workspace = true } +homepage = { workspace = true } +readme = { workspace = true } +description = "Rust SDK for Model Context Protocol" +documentation = "https://docs.rs/rmcp" + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] + +[dependencies] +serde = { version = "1.0", features = ["derive", "rc"] } +serde_json = "1.0" +thiserror = "2" +tokio = { version = "1", features = ["sync", "macros", "rt", "time"] } +futures = "0.3" +tracing = { version = "0.1" } +tokio-util = { version = "0.7" } +pin-project-lite = "0.2" +paste = { version = "1", optional = true } + +# oauth2 support +oauth2 = { version = "5.0", optional = true } + +# for auto generate schema +schemars = { version = "1.0", optional = true, features = ["chrono04"] } + +# for image encoding +base64 = { version = "0.22", optional = true } + +# for SSE client +reqwest = { version = "0.12", default-features = false, features = [ + "json", + "stream", +], optional = true } + +sse-stream = { version = "0.2", optional = true } + +http = { version = "1", optional = true } +url = { version = "2.4", optional = true } + +# For tower compatibility +tower-service = { version = "0.3", optional = true } + +# for child process transport +process-wrap = { version = "8.2", features = ["tokio1"], optional = true } + +# for ws transport +# tokio-tungstenite ={ version = "0.26", optional = true } + +# for http-server transport +axum = { version = "0.8", features = [], optional = true } +rand = { version = "0.9", optional = true } +tokio-stream = { version = "0.1", optional = true } +uuid = { version = "1", features = ["v4"], optional = true } +http-body = { version = "1", optional = true } +http-body-util = { version = "0.1", optional = true } +bytes = { version = "1", optional = true } +# macro +rmcp-macros = { workspace = true, optional = true } +[target.'cfg(not(all(target_family = "wasm", target_os = "unknown")))'.dependencies] +chrono = { version = "0.4.38", features = ["serde"] } + +[target.'cfg(all(target_family = "wasm", target_os = "unknown"))'.dependencies] +chrono = { version = "0.4.38", default-features = false, features = [ + "serde", + "clock", + "std", + "oldtime", +] } + +[features] +default = ["base64", "macros", "server"] +client = ["dep:tokio-stream"] +server = ["transport-async-rw", "dep:schemars"] +macros = ["dep:rmcp-macros", "dep:paste"] +elicitation = [] + +# reqwest http client +__reqwest = ["dep:reqwest"] + +reqwest = ["__reqwest", "reqwest?/rustls-tls"] + +reqwest-tls-no-provider = ["__reqwest", "reqwest?/rustls-tls-no-provider"] + +server-side-http = [ + "uuid", + "dep:rand", + "dep:tokio-stream", + "dep:http", + "dep:http-body", + "dep:http-body-util", + "dep:bytes", + "dep:sse-stream", + "tower", +] +# SSE client +client-side-sse = ["dep:sse-stream", "dep:http"] + +transport-sse-client = ["client-side-sse", "transport-worker"] +transport-sse-client-reqwest = ["transport-sse-client", "reqwest"] + +transport-worker = ["dep:tokio-stream"] + + +# Streamable HTTP client +transport-streamable-http-client = ["client-side-sse", "transport-worker"] +transport-streamable-http-client-reqwest = ["transport-streamable-http-client", "reqwest"] + + +transport-async-rw = ["tokio/io-util", "tokio-util/codec"] +transport-io = ["transport-async-rw", "tokio/io-std"] +transport-child-process = [ + "transport-async-rw", + "tokio/process", + "dep:process-wrap", +] +transport-sse-server = [ + "transport-async-rw", + "transport-worker", + "server-side-http", + "dep:axum", +] +transport-streamable-http-server = [ + "transport-streamable-http-server-session", + "server-side-http", + "transport-worker", +] +transport-streamable-http-server-session = [ + "transport-async-rw", + "dep:tokio-stream", +] +# transport-ws = ["transport-io", "dep:tokio-tungstenite"] +tower = ["dep:tower-service"] +auth = ["dep:oauth2", "__reqwest", "dep:url"] +schemars = ["dep:schemars"] + +[dev-dependencies] +tokio = { version = "1", features = ["full"] } +schemars = { version = "1.0", features = ["chrono04"] } + +anyhow = "1.0" +tracing-subscriber = { version = "0.3", features = [ + "env-filter", + "std", + "fmt", +] } +async-trait = "0.1" +[[test]] +name = "test_tool_macros" +required-features = ["server", "client"] +path = "tests/test_tool_macros.rs" + +[[test]] +name = "test_with_python" +required-features = [ + "reqwest", + "server", + "client", + "transport-sse-server", + "transport-sse-client", + "transport-child-process", +] +path = "tests/test_with_python.rs" + +[[test]] +name = "test_with_js" +required-features = [ + "server", + "client", + "transport-sse-server", + "transport-child-process", + "transport-streamable-http-server", + "transport-streamable-http-client", + "__reqwest", +] +path = "tests/test_with_js.rs" + +[[test]] +name = "test_notification" +required-features = ["server", "client"] +path = "tests/test_notification.rs" + +[[test]] +name = "test_logging" +required-features = ["server", "client"] +path = "tests/test_logging.rs" + +[[test]] +name = "test_message_protocol" +required-features = ["client"] +path = "tests/test_message_protocol.rs" + +[[test]] +name = "test_message_schema" +required-features = ["server", "client", "schemars"] +path = "tests/test_message_schema.rs" + +[[test]] +name = "test_progress_subscriber" +required-features = ["server", "client", "macros"] +path = "tests/test_progress_subscriber.rs" + +[[test]] +name = "test_elicitation" +required-features = ["elicitation", "client", "server"] +path = "tests/test_elicitation.rs" diff --git a/code-rs/third_party/rmcp-0.8.3/README.md b/code-rs/third_party/rmcp-0.8.3/README.md new file mode 100644 index 00000000000..057295aa7db --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/README.md @@ -0,0 +1,247 @@ +# RMCP: Rust Model Context Protocol + +`rmcp` is the official Rust implementation of the Model Context Protocol (MCP), a protocol designed for AI assistants to communicate with other services. This library can be used to build both servers that expose capabilities to AI assistants and clients that interact with such servers. + +wait for the first release. + + + + +## Quick Start + +### Server Implementation + +Creating a server with tools is simple using the `#[tool]` macro: + +```rust, ignore +use rmcp::{ + handler::server::router::tool::ToolRouter, model::*, tool, tool_handler, tool_router, + transport::stdio, ErrorData as McpError, ServiceExt, +}; +use std::future::Future; +use std::sync::Arc; +use tokio::sync::Mutex; + +#[derive(Clone)] +pub struct Counter { + counter: Arc>, + tool_router: ToolRouter, +} + +#[tool_router] +impl Counter { + fn new() -> Self { + Self { + counter: Arc::new(Mutex::new(0)), + tool_router: Self::tool_router(), + } + } + + #[tool(description = "Increment the counter by 1")] + async fn increment(&self) -> Result { + let mut counter = self.counter.lock().await; + *counter += 1; + Ok(CallToolResult::success(vec![Content::text( + counter.to_string(), + )])) + } + + #[tool(description = "Get the current counter value")] + async fn get(&self) -> Result { + let counter = self.counter.lock().await; + Ok(CallToolResult::success(vec![Content::text( + counter.to_string(), + )])) + } +} + +// Implement the server handler +#[tool_handler] +impl rmcp::ServerHandler for Counter { + fn get_info(&self) -> ServerInfo { + ServerInfo { + instructions: Some("A simple counter that tallies the number of times the increment tool has been used".into()), + capabilities: ServerCapabilities::builder().enable_tools().build(), + ..Default::default() + } + } +} + +// Run the server +#[tokio::main] +async fn main() -> Result<(), Box> { + // Create and run the server with STDIO transport + let service = Counter::new().serve(stdio()).await.inspect_err(|e| { + println!("Error starting server: {}", e); + })?; + service.waiting().await?; + + Ok(()) +} +``` + +### Client Implementation + +Creating a client to interact with a server: + +```rust, ignore +use rmcp::{ + model::CallToolRequestParam, + service::ServiceExt, + transport::{TokioChildProcess, ConfigureCommandExt} +}; +use tokio::process::Command; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Connect to a server running as a child process + let service = () + .serve(TokioChildProcess::new(Command::new("uvx").configure( + |cmd| { + cmd.arg("mcp-server-git"); + }, + ))?) + .await?; + + // Get server information + let server_info = service.peer_info(); + println!("Connected to server: {server_info:#?}"); + + // List available tools + let tools = service.list_tools(Default::default()).await?; + println!("Available tools: {tools:#?}"); + + // Call a tool + let result = service + .call_tool(CallToolRequestParam { + name: "increment".into(), + arguments: None, + }) + .await?; + println!("Result: {result:#?}"); + + // Gracefully close the connection + service.cancel().await?; + + Ok(()) +} +``` + +## Transport Options + +RMCP supports multiple transport mechanisms, each suited for different use cases: + +### `transport-async-rw` +Low-level interface for asynchronous read/write operations. This is the foundation for many other transports. + +### `transport-io` +For working directly with I/O streams (`tokio::io::AsyncRead` and `tokio::io::AsyncWrite`). + +### `transport-child-process` +Run MCP servers as child processes and communicate via standard I/O. + +Example: +```rust +use rmcp::transport::TokioChildProcess; +use tokio::process::Command; + +let transport = TokioChildProcess::new(Command::new("mcp-server"))?; +let service = client.serve(transport).await?; +``` + + + +## Access with peer interface when handling message + +You can get the [`Peer`](crate::service::Peer) struct from [`NotificationContext`](crate::service::NotificationContext) and [`RequestContext`](crate::service::RequestContext). + +```rust, ignore +# use rmcp::{ +# ServerHandler, +# model::{LoggingLevel, LoggingMessageNotificationParam, ProgressNotificationParam}, +# service::{NotificationContext, RoleServer}, +# }; +# pub struct Handler; + +impl ServerHandler for Handler { + async fn on_progress( + &self, + notification: ProgressNotificationParam, + context: NotificationContext, + ) { + let peer = context.peer; + let _ = peer + .notify_logging_message(LoggingMessageNotificationParam { + level: LoggingLevel::Info, + logger: None, + data: serde_json::json!({ + "message": format!("Progress: {}", notification.progress), + }), + }) + .await; + } +} +``` + + +## Manage Multi Services + +For many cases you need to manage several service in a collection, you can call `into_dyn` to convert services into the same type. +```rust, ignore +let service = service.into_dyn(); +``` + +## Feature Flags + +RMCP uses feature flags to control which components are included: + +- `client`: Enable client functionality +- `server`: Enable server functionality and the tool system +- `macros`: Enable the `#[tool]` macro (enabled by default) +- Transport-specific features: + - `transport-async-rw`: Async read/write support + - `transport-io`: I/O stream support + - `transport-child-process`: Child process support + - `transport-sse-client` / `transport-sse-server`: SSE support (client agnostic) + - `transport-sse-client-reqwest`: a default `reqwest` implementation of the SSE client + - `transport-streamable-http-client` / `transport-streamable-http-server`: HTTP streaming (client agnostic, see [`StreamableHttpClientTransport`] for details) + - `transport-streamable-http-client-reqwest`: a default `reqwest` implementation of the streamable http client +- `auth`: OAuth2 authentication support +- `schemars`: JSON Schema generation (for tool definitions) + + +## Transports + +- `transport-io`: Server stdio transport +- `transport-sse-server`: Server SSE transport +- `transport-child-process`: Client stdio transport +- `transport-sse-client`: Client sse transport +- `transport-streamable-http-server` streamable http server transport +- `transport-streamable-http-client` streamable http client transport + +
+Transport +The transport type must implemented [`Transport`] trait, which allow it send message concurrently and receive message sequentially. +There are 3 pairs of standard transport types: + +| transport | client | server | +|:-: |:-: |:-: | +| std IO | [`child_process::TokioChildProcess`] | [`io::stdio`] | +| streamable http | [`streamable_http_client::StreamableHttpClientTransport`] | [`streamable_http_server::session::create_session`] | +| sse | [`sse_client::SseClientTransport`] | [`sse_server::SseServer`] | + +#### [IntoTransport](`IntoTransport`) trait +[`IntoTransport`] is a helper trait that implicitly convert a type into a transport type. + +These types is automatically implemented [`IntoTransport`] trait +1. A type that already implement both [`futures::Sink`] and [`futures::Stream`] trait, or a tuple `(Tx, Rx)` where `Tx` is [`futures::Sink`] and `Rx` is [`futures::Stream`]. +2. A type that implement both [`tokio::io::AsyncRead`] and [`tokio::io::AsyncWrite`] trait. or a tuple `(R, W)` where `R` is [`tokio::io::AsyncRead`] and `W` is [`tokio::io::AsyncWrite`]. +3. A type that implement [Worker](`worker::Worker`) trait. +4. A type that implement [`Transport`] trait. + +
+ +## License + +This project is licensed under the terms specified in the repository's LICENSE file. \ No newline at end of file diff --git a/code-rs/third_party/rmcp-0.8.3/src/error.rs b/code-rs/third_party/rmcp-0.8.3/src/error.rs new file mode 100644 index 00000000000..e0da2b3d4b8 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/error.rs @@ -0,0 +1,56 @@ +use std::{borrow::Cow, fmt::Display}; + +use crate::ServiceError; +pub use crate::model::ErrorData; +#[deprecated( + note = "Use `rmcp::ErrorData` instead, `rmcp::ErrorData` could become `RmcpError` in the future." +)] +pub type Error = ErrorData; +impl Display for ErrorData { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}: {}", self.code.0, self.message)?; + if let Some(data) = &self.data { + write!(f, "({})", data)?; + } + Ok(()) + } +} + +impl std::error::Error for ErrorData {} + +/// This is an unified error type for the errors could be returned by the service. +#[derive(Debug, thiserror::Error)] +pub enum RmcpError { + #[error("Service error: {0}")] + Service(#[from] ServiceError), + #[cfg(feature = "client")] + #[error("Client initialization error: {0}")] + ClientInitialize(#[from] crate::service::ClientInitializeError), + #[cfg(feature = "server")] + #[error("Server initialization error: {0}")] + ServerInitialize(#[from] crate::service::ServerInitializeError), + #[error("Runtime error: {0}")] + Runtime(#[from] tokio::task::JoinError), + #[error("Transport creation error: {error}")] + // TODO: Maybe we can introduce something like `TryIntoTransport` to auto wrap transport type, + // but it could be an breaking change, so we could do it in the future. + TransportCreation { + into_transport_type_name: Cow<'static, str>, + into_transport_type_id: std::any::TypeId, + #[source] + error: Box, + }, + // and cancellation shouldn't be an error? +} + +impl RmcpError { + pub fn transport_creation( + error: impl Into>, + ) -> Self { + RmcpError::TransportCreation { + into_transport_type_id: std::any::TypeId::of::(), + into_transport_type_name: std::any::type_name::().into(), + error: error.into(), + } + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/handler.rs b/code-rs/third_party/rmcp-0.8.3/src/handler.rs new file mode 100644 index 00000000000..c2b9737b8b1 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/handler.rs @@ -0,0 +1,6 @@ +#[cfg(feature = "client")] +#[cfg_attr(docsrs, doc(cfg(feature = "client")))] +pub mod client; +#[cfg(feature = "server")] +#[cfg_attr(docsrs, doc(cfg(feature = "server")))] +pub mod server; diff --git a/code-rs/third_party/rmcp-0.8.3/src/handler/client.rs b/code-rs/third_party/rmcp-0.8.3/src/handler/client.rs new file mode 100644 index 00000000000..147f2fc29d5 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/handler/client.rs @@ -0,0 +1,183 @@ +pub mod progress; +use crate::{ + error::ErrorData as McpError, + model::*, + service::{NotificationContext, RequestContext, RoleClient, Service, ServiceRole}, +}; + +impl Service for H { + async fn handle_request( + &self, + request: ::PeerReq, + context: RequestContext, + ) -> Result<::Resp, McpError> { + match request { + ServerRequest::PingRequest(_) => self.ping(context).await.map(ClientResult::empty), + ServerRequest::CreateMessageRequest(request) => self + .create_message(request.params, context) + .await + .map(Box::new) + .map(ClientResult::CreateMessageResult), + ServerRequest::ListRootsRequest(_) => self + .list_roots(context) + .await + .map(ClientResult::ListRootsResult), + ServerRequest::CreateElicitationRequest(request) => self + .create_elicitation(request.params, context) + .await + .map(ClientResult::CreateElicitationResult), + } + } + + async fn handle_notification( + &self, + notification: ::PeerNot, + context: NotificationContext, + ) -> Result<(), McpError> { + match notification { + ServerNotification::CancelledNotification(notification) => { + self.on_cancelled(notification.params, context).await + } + ServerNotification::ProgressNotification(notification) => { + self.on_progress(notification.params, context).await + } + ServerNotification::LoggingMessageNotification(notification) => { + self.on_logging_message(notification.params, context).await + } + ServerNotification::ResourceUpdatedNotification(notification) => { + self.on_resource_updated(notification.params, context).await + } + ServerNotification::ResourceListChangedNotification(_notification_no_param) => { + self.on_resource_list_changed(context).await + } + ServerNotification::ToolListChangedNotification(_notification_no_param) => { + self.on_tool_list_changed(context).await + } + ServerNotification::PromptListChangedNotification(_notification_no_param) => { + self.on_prompt_list_changed(context).await + } + }; + Ok(()) + } + + fn get_info(&self) -> ::Info { + self.get_info() + } +} + +#[allow(unused_variables)] +pub trait ClientHandler: Sized + Send + Sync + 'static { + fn ping( + &self, + context: RequestContext, + ) -> impl Future> + Send + '_ { + std::future::ready(Ok(())) + } + + fn create_message( + &self, + params: CreateMessageRequestParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + std::future::ready(Err( + McpError::method_not_found::(), + )) + } + + fn list_roots( + &self, + context: RequestContext, + ) -> impl Future> + Send + '_ { + std::future::ready(Ok(ListRootsResult::default())) + } + + /// Handle an elicitation request from a server asking for user input. + /// + /// This method is called when a server needs interactive input from the user + /// during tool execution. Implementations should present the message to the user, + /// collect their input according to the requested schema, and return the result. + /// + /// # Arguments + /// * `request` - The elicitation request with message and schema + /// * `context` - The request context + /// + /// # Returns + /// The user's response including action (accept/decline/cancel) and optional data + /// + /// # Default Behavior + /// The default implementation automatically declines all elicitation requests. + /// Real clients should override this to provide user interaction. + fn create_elicitation( + &self, + request: CreateElicitationRequestParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + // Default implementation declines all requests - real clients should override this + let _ = (request, context); + std::future::ready(Ok(CreateElicitationResult { + action: ElicitationAction::Decline, + content: None, + })) + } + + fn on_cancelled( + &self, + params: CancelledNotificationParam, + context: NotificationContext, + ) -> impl Future + Send + '_ { + std::future::ready(()) + } + fn on_progress( + &self, + params: ProgressNotificationParam, + context: NotificationContext, + ) -> impl Future + Send + '_ { + std::future::ready(()) + } + fn on_logging_message( + &self, + params: LoggingMessageNotificationParam, + context: NotificationContext, + ) -> impl Future + Send + '_ { + std::future::ready(()) + } + fn on_resource_updated( + &self, + params: ResourceUpdatedNotificationParam, + context: NotificationContext, + ) -> impl Future + Send + '_ { + std::future::ready(()) + } + fn on_resource_list_changed( + &self, + context: NotificationContext, + ) -> impl Future + Send + '_ { + std::future::ready(()) + } + fn on_tool_list_changed( + &self, + context: NotificationContext, + ) -> impl Future + Send + '_ { + std::future::ready(()) + } + fn on_prompt_list_changed( + &self, + context: NotificationContext, + ) -> impl Future + Send + '_ { + std::future::ready(()) + } + + fn get_info(&self) -> ClientInfo { + ClientInfo::default() + } +} + +/// Do nothing, with default client info. +impl ClientHandler for () {} + +/// Do nothing, with a specific client info. +impl ClientHandler for ClientInfo { + fn get_info(&self) -> ClientInfo { + self.clone() + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/handler/client/progress.rs b/code-rs/third_party/rmcp-0.8.3/src/handler/client/progress.rs new file mode 100644 index 00000000000..04f31610ff5 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/handler/client/progress.rs @@ -0,0 +1,100 @@ +use std::{collections::HashMap, sync::Arc}; + +use futures::{Stream, StreamExt}; +use tokio::sync::RwLock; +use tokio_stream::wrappers::ReceiverStream; + +use crate::model::{ProgressNotificationParam, ProgressToken}; +type Dispatcher = + Arc>>>; + +/// A dispatcher for progress notifications. +#[derive(Debug, Clone, Default)] +pub struct ProgressDispatcher { + pub(crate) dispatcher: Dispatcher, +} + +impl ProgressDispatcher { + const CHANNEL_SIZE: usize = 16; + pub fn new() -> Self { + Self::default() + } + + /// Handle a progress notification by sending it to the appropriate subscriber + pub async fn handle_notification(&self, notification: ProgressNotificationParam) { + let token = ¬ification.progress_token; + if let Some(sender) = self.dispatcher.read().await.get(token).cloned() { + let send_result = sender.send(notification).await; + if let Err(e) = send_result { + tracing::warn!("Failed to send progress notification: {e}"); + } + } + } + + /// Subscribe to progress notifications for a specific token. + /// + /// If you drop the returned `ProgressSubscriber`, it will automatically unsubscribe from notifications for that token. + pub async fn subscribe(&self, progress_token: ProgressToken) -> ProgressSubscriber { + let (sender, receiver) = tokio::sync::mpsc::channel(Self::CHANNEL_SIZE); + self.dispatcher + .write() + .await + .insert(progress_token.clone(), sender); + let receiver = ReceiverStream::new(receiver); + ProgressSubscriber { + progress_token, + receiver, + dispatcher: self.dispatcher.clone(), + } + } + + /// Unsubscribe from progress notifications for a specific token. + pub async fn unsubscribe(&self, token: &ProgressToken) { + self.dispatcher.write().await.remove(token); + } + + /// Clear all dispatcher. + pub async fn clear(&self) { + let mut dispatcher = self.dispatcher.write().await; + dispatcher.clear(); + } +} + +pub struct ProgressSubscriber { + pub(crate) progress_token: ProgressToken, + pub(crate) receiver: ReceiverStream, + pub(crate) dispatcher: Dispatcher, +} + +impl ProgressSubscriber { + pub fn progress_token(&self) -> &ProgressToken { + &self.progress_token + } +} + +impl Stream for ProgressSubscriber { + type Item = ProgressNotificationParam; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.receiver.poll_next_unpin(cx) + } + + fn size_hint(&self) -> (usize, Option) { + self.receiver.size_hint() + } +} + +impl Drop for ProgressSubscriber { + fn drop(&mut self) { + let token = self.progress_token.clone(); + self.receiver.close(); + let dispatcher = self.dispatcher.clone(); + tokio::spawn(async move { + let mut dispatcher = dispatcher.write_owned().await; + dispatcher.remove(&token); + }); + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/handler/server.rs b/code-rs/third_party/rmcp-0.8.3/src/handler/server.rs new file mode 100644 index 00000000000..4f9edbc0f98 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/handler/server.rs @@ -0,0 +1,231 @@ +use crate::{ + error::ErrorData as McpError, + model::*, + service::{NotificationContext, RequestContext, RoleServer, Service, ServiceRole}, +}; + +pub mod common; +pub mod prompt; +mod resource; +pub mod router; +pub mod tool; +pub mod wrapper; +impl Service for H { + async fn handle_request( + &self, + request: ::PeerReq, + context: RequestContext, + ) -> Result<::Resp, McpError> { + match request { + ClientRequest::InitializeRequest(request) => self + .initialize(request.params, context) + .await + .map(ServerResult::InitializeResult), + ClientRequest::PingRequest(_request) => { + self.ping(context).await.map(ServerResult::empty) + } + ClientRequest::CompleteRequest(request) => self + .complete(request.params, context) + .await + .map(ServerResult::CompleteResult), + ClientRequest::SetLevelRequest(request) => self + .set_level(request.params, context) + .await + .map(ServerResult::empty), + ClientRequest::GetPromptRequest(request) => self + .get_prompt(request.params, context) + .await + .map(ServerResult::GetPromptResult), + ClientRequest::ListPromptsRequest(request) => self + .list_prompts(request.params, context) + .await + .map(ServerResult::ListPromptsResult), + ClientRequest::ListResourcesRequest(request) => self + .list_resources(request.params, context) + .await + .map(ServerResult::ListResourcesResult), + ClientRequest::ListResourceTemplatesRequest(request) => self + .list_resource_templates(request.params, context) + .await + .map(ServerResult::ListResourceTemplatesResult), + ClientRequest::ReadResourceRequest(request) => self + .read_resource(request.params, context) + .await + .map(ServerResult::ReadResourceResult), + ClientRequest::SubscribeRequest(request) => self + .subscribe(request.params, context) + .await + .map(ServerResult::empty), + ClientRequest::UnsubscribeRequest(request) => self + .unsubscribe(request.params, context) + .await + .map(ServerResult::empty), + ClientRequest::CallToolRequest(request) => self + .call_tool(request.params, context) + .await + .map(ServerResult::CallToolResult), + ClientRequest::ListToolsRequest(request) => self + .list_tools(request.params, context) + .await + .map(ServerResult::ListToolsResult), + } + } + + async fn handle_notification( + &self, + notification: ::PeerNot, + context: NotificationContext, + ) -> Result<(), McpError> { + match notification { + ClientNotification::CancelledNotification(notification) => { + self.on_cancelled(notification.params, context).await + } + ClientNotification::ProgressNotification(notification) => { + self.on_progress(notification.params, context).await + } + ClientNotification::InitializedNotification(_notification) => { + self.on_initialized(context).await + } + ClientNotification::RootsListChangedNotification(_notification) => { + self.on_roots_list_changed(context).await + } + }; + Ok(()) + } + + fn get_info(&self) -> ::Info { + self.get_info() + } +} + +#[allow(unused_variables)] +pub trait ServerHandler: Sized + Send + Sync + 'static { + fn ping( + &self, + context: RequestContext, + ) -> impl Future> + Send + '_ { + std::future::ready(Ok(())) + } + // handle requests + fn initialize( + &self, + request: InitializeRequestParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + if context.peer.peer_info().is_none() { + context.peer.set_peer_info(request); + } + std::future::ready(Ok(self.get_info())) + } + fn complete( + &self, + request: CompleteRequestParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + std::future::ready(Ok(CompleteResult::default())) + } + fn set_level( + &self, + request: SetLevelRequestParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + std::future::ready(Err(McpError::method_not_found::())) + } + fn get_prompt( + &self, + request: GetPromptRequestParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + std::future::ready(Err(McpError::method_not_found::())) + } + fn list_prompts( + &self, + request: Option, + context: RequestContext, + ) -> impl Future> + Send + '_ { + std::future::ready(Ok(ListPromptsResult::default())) + } + fn list_resources( + &self, + request: Option, + context: RequestContext, + ) -> impl Future> + Send + '_ { + std::future::ready(Ok(ListResourcesResult::default())) + } + fn list_resource_templates( + &self, + request: Option, + context: RequestContext, + ) -> impl Future> + Send + '_ { + std::future::ready(Ok(ListResourceTemplatesResult::default())) + } + fn read_resource( + &self, + request: ReadResourceRequestParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + std::future::ready(Err( + McpError::method_not_found::(), + )) + } + fn subscribe( + &self, + request: SubscribeRequestParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + std::future::ready(Err(McpError::method_not_found::())) + } + fn unsubscribe( + &self, + request: UnsubscribeRequestParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + std::future::ready(Err(McpError::method_not_found::())) + } + fn call_tool( + &self, + request: CallToolRequestParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + std::future::ready(Err(McpError::method_not_found::())) + } + fn list_tools( + &self, + request: Option, + context: RequestContext, + ) -> impl Future> + Send + '_ { + std::future::ready(Ok(ListToolsResult::default())) + } + + fn on_cancelled( + &self, + notification: CancelledNotificationParam, + context: NotificationContext, + ) -> impl Future + Send + '_ { + std::future::ready(()) + } + fn on_progress( + &self, + notification: ProgressNotificationParam, + context: NotificationContext, + ) -> impl Future + Send + '_ { + std::future::ready(()) + } + fn on_initialized( + &self, + context: NotificationContext, + ) -> impl Future + Send + '_ { + tracing::info!("client initialized"); + std::future::ready(()) + } + fn on_roots_list_changed( + &self, + context: NotificationContext, + ) -> impl Future + Send + '_ { + std::future::ready(()) + } + + fn get_info(&self) -> ServerInfo { + ServerInfo::default() + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/handler/server/common.rs b/code-rs/third_party/rmcp-0.8.3/src/handler/server/common.rs new file mode 100644 index 00000000000..b36696cbfe9 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/handler/server/common.rs @@ -0,0 +1,146 @@ +//! Common utilities shared between tool and prompt handlers + +use std::{any::TypeId, collections::HashMap, sync::Arc}; + +use schemars::JsonSchema; + +use crate::{ + RoleServer, model::JsonObject, schemars::generate::SchemaSettings, service::RequestContext, +}; + +/// A shortcut for generating a JSON schema for a type. +pub fn schema_for_type() -> JsonObject { + // explicitly to align json schema version to official specifications. + // https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-03-26/schema.json + // TODO: update to 2020-12 waiting for the mcp spec update + let mut settings = SchemaSettings::draft07(); + settings.transforms = vec![Box::new(schemars::transform::AddNullable::default())]; + let generator = settings.into_generator(); + let schema = generator.into_root_schema_for::(); + let object = serde_json::to_value(schema).expect("failed to serialize schema"); + match object { + serde_json::Value::Object(object) => object, + _ => panic!( + "Schema serialization produced non-object value: expected JSON object but got {:?}", + object + ), + } +} + +/// Call [`schema_for_type`] with a cache +pub fn cached_schema_for_type() -> Arc { + thread_local! { + static CACHE_FOR_TYPE: std::sync::RwLock>> = Default::default(); + }; + CACHE_FOR_TYPE.with(|cache| { + if let Some(x) = cache + .read() + .expect("schema cache lock poisoned") + .get(&TypeId::of::()) + { + x.clone() + } else { + let schema = schema_for_type::(); + let schema = Arc::new(schema); + cache + .write() + .expect("schema cache lock poisoned") + .insert(TypeId::of::(), schema.clone()); + schema + } + }) +} + +/// Trait for extracting parts from a context, unifying tool and prompt extraction +pub trait FromContextPart: Sized { + fn from_context_part(context: &mut C) -> Result; +} + +/// Common extractors that can be used by both tool and prompt handlers +impl FromContextPart for RequestContext +where + C: AsRequestContext, +{ + fn from_context_part(context: &mut C) -> Result { + Ok(context.as_request_context().clone()) + } +} + +impl FromContextPart for tokio_util::sync::CancellationToken +where + C: AsRequestContext, +{ + fn from_context_part(context: &mut C) -> Result { + Ok(context.as_request_context().ct.clone()) + } +} + +impl FromContextPart for crate::model::Extensions +where + C: AsRequestContext, +{ + fn from_context_part(context: &mut C) -> Result { + Ok(context.as_request_context().extensions.clone()) + } +} + +pub struct Extension(pub T); + +impl FromContextPart for Extension +where + C: AsRequestContext, + T: Send + Sync + 'static + Clone, +{ + fn from_context_part(context: &mut C) -> Result { + let extension = context + .as_request_context() + .extensions + .get::() + .cloned() + .ok_or_else(|| { + crate::ErrorData::invalid_params( + format!("missing extension {}", std::any::type_name::()), + None, + ) + })?; + Ok(Extension(extension)) + } +} + +impl FromContextPart for crate::Peer +where + C: AsRequestContext, +{ + fn from_context_part(context: &mut C) -> Result { + Ok(context.as_request_context().peer.clone()) + } +} + +impl FromContextPart for crate::model::Meta +where + C: AsRequestContext, +{ + fn from_context_part(context: &mut C) -> Result { + let request_context = context.as_request_context_mut(); + let mut meta = crate::model::Meta::default(); + std::mem::swap(&mut meta, &mut request_context.meta); + Ok(meta) + } +} + +pub struct RequestId(pub crate::model::RequestId); + +impl FromContextPart for RequestId +where + C: AsRequestContext, +{ + fn from_context_part(context: &mut C) -> Result { + Ok(RequestId(context.as_request_context().id.clone())) + } +} + +/// Trait for types that can provide access to RequestContext +pub trait AsRequestContext { + fn as_request_context(&self) -> &RequestContext; + fn as_request_context_mut(&mut self) -> &mut RequestContext; +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/handler/server/prompt.rs b/code-rs/third_party/rmcp-0.8.3/src/handler/server/prompt.rs new file mode 100644 index 00000000000..5c262e2b755 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/handler/server/prompt.rs @@ -0,0 +1,367 @@ +//! Prompt handling infrastructure for MCP servers +//! +//! This module provides the core types and traits for implementing prompt handlers +//! in MCP servers. Prompts allow servers to provide reusable templates for LLM +//! interactions with customizable arguments. + +use std::{future::Future, marker::PhantomData}; + +use futures::future::{BoxFuture, FutureExt}; +use serde::de::DeserializeOwned; + +use super::common::{AsRequestContext, FromContextPart}; +pub use super::common::{Extension, RequestId}; +use crate::{ + RoleServer, + handler::server::wrapper::Parameters, + model::{GetPromptResult, PromptMessage}, + service::RequestContext, +}; + +/// Context for prompt retrieval operations +pub struct PromptContext<'a, S> { + pub server: &'a S, + pub name: String, + pub arguments: Option>, + pub context: RequestContext, +} + +impl<'a, S> PromptContext<'a, S> { + pub fn new( + server: &'a S, + name: String, + arguments: Option>, + context: RequestContext, + ) -> Self { + Self { + server, + name, + arguments, + context, + } + } +} + +impl AsRequestContext for PromptContext<'_, S> { + fn as_request_context(&self) -> &RequestContext { + &self.context + } + + fn as_request_context_mut(&mut self) -> &mut RequestContext { + &mut self.context + } +} + +/// Trait for handling prompt retrieval +pub trait GetPromptHandler { + fn handle( + self, + context: PromptContext<'_, S>, + ) -> BoxFuture<'_, Result>; +} + +/// Type alias for dynamic prompt handlers +pub type DynGetPromptHandler = dyn for<'a> Fn(PromptContext<'a, S>) -> BoxFuture<'a, Result> + + Send + + Sync; + +/// Adapter type for async methods that return `Vec` +pub struct AsyncMethodAdapter(PhantomData); + +/// Adapter type for async methods with parameters that return `Vec` +pub struct AsyncMethodWithArgsAdapter(PhantomData); + +/// Adapter types for macro-generated implementations +#[allow(clippy::type_complexity)] +pub struct AsyncPromptAdapter(PhantomData fn(Fut) -> R>); +pub struct SyncPromptAdapter(PhantomData R>); +pub struct AsyncPromptMethodAdapter(PhantomData R>); +pub struct SyncPromptMethodAdapter(PhantomData R>); + +/// Trait for types that can be converted into GetPromptResult +pub trait IntoGetPromptResult { + fn into_get_prompt_result(self) -> Result; +} + +impl IntoGetPromptResult for GetPromptResult { + fn into_get_prompt_result(self) -> Result { + Ok(self) + } +} + +impl IntoGetPromptResult for Vec { + fn into_get_prompt_result(self) -> Result { + Ok(GetPromptResult { + description: None, + messages: self, + }) + } +} + +impl IntoGetPromptResult for Result { + fn into_get_prompt_result(self) -> Result { + self.and_then(|v| v.into_get_prompt_result()) + } +} + +// Future wrapper that automatically handles IntoGetPromptResult conversion +pin_project_lite::pin_project! { + #[project = IntoGetPromptResultFutProj] + pub enum IntoGetPromptResultFut { + Pending { + #[pin] + fut: F, + _marker: PhantomData, + }, + Ready { + #[pin] + result: futures::future::Ready>, + } + } +} + +impl Future for IntoGetPromptResultFut +where + F: Future, + R: IntoGetPromptResult, +{ + type Output = Result; + + fn poll( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + match self.project() { + IntoGetPromptResultFutProj::Pending { fut, _marker } => fut + .poll(cx) + .map(IntoGetPromptResult::into_get_prompt_result), + IntoGetPromptResultFutProj::Ready { result } => result.poll(cx), + } + } +} + +// Prompt-specific extractor for prompt name +pub struct PromptName(pub String); + +impl FromContextPart> for PromptName { + fn from_context_part(context: &mut PromptContext) -> Result { + Ok(Self(context.name.clone())) + } +} + +// Special implementation for Parameters that handles prompt arguments +impl FromContextPart> for Parameters

+where + P: DeserializeOwned, +{ + fn from_context_part(context: &mut PromptContext) -> Result { + let params = if let Some(args_map) = context.arguments.take() { + let args_value = serde_json::Value::Object(args_map); + serde_json::from_value::

(args_value).map_err(|e| { + crate::ErrorData::invalid_params(format!("Failed to parse parameters: {}", e), None) + })? + } else { + // Try to deserialize from empty object for optional fields + serde_json::from_value::

(serde_json::json!({})).map_err(|e| { + crate::ErrorData::invalid_params( + format!("Missing required parameters: {}", e), + None, + ) + })? + }; + Ok(Parameters(params)) + } +} + +// Macro to generate GetPromptHandler implementations for various parameter combinations +macro_rules! impl_prompt_handler_for { + ($($T: ident)*) => { + impl_prompt_handler_for!([] [$($T)*]); + }; + // finished + ([$($Tn: ident)*] []) => { + impl_prompt_handler_for!(@impl $($Tn)*); + }; + ([$($Tn: ident)*] [$Tn_1: ident $($Rest: ident)*]) => { + impl_prompt_handler_for!(@impl $($Tn)*); + impl_prompt_handler_for!([$($Tn)* $Tn_1] [$($Rest)*]); + }; + (@impl $($Tn: ident)*) => { + // Implementation for async methods (transformed by #[prompt] macro) + impl<$($Tn,)* S, F, R> GetPromptHandler for F + where + $( + $Tn: for<'a> FromContextPart> + Send, + )* + F: FnOnce(&S, $($Tn,)*) -> BoxFuture<'_, R> + Send, + R: IntoGetPromptResult + Send + 'static, + S: Send + Sync + 'static, + { + #[allow(unused_variables, non_snake_case, unused_mut)] + fn handle( + self, + mut context: PromptContext<'_, S>, + ) -> BoxFuture<'_, Result> + { + $( + let result = $Tn::from_context_part(&mut context); + let $Tn = match result { + Ok(value) => value, + Err(e) => return std::future::ready(Err(e)).boxed(), + }; + )* + let service = context.server; + let fut = self(service, $($Tn,)*); + async move { + let result = fut.await; + result.into_get_prompt_result() + }.boxed() + } + } + + + // Implementation for sync methods + impl<$($Tn,)* S, F, R> GetPromptHandler> for F + where + $( + $Tn: for<'a> FromContextPart> + Send, + )* + F: FnOnce(&S, $($Tn,)*) -> R + Send, + R: IntoGetPromptResult + Send, + S: Send + Sync, + { + #[allow(unused_variables, non_snake_case, unused_mut)] + fn handle( + self, + mut context: PromptContext<'_, S>, + ) -> BoxFuture<'_, Result> + { + $( + let result = $Tn::from_context_part(&mut context); + let $Tn = match result { + Ok(value) => value, + Err(e) => return std::future::ready(Err(e)).boxed(), + }; + )* + let service = context.server; + let result = self(service, $($Tn,)*); + std::future::ready(result.into_get_prompt_result()).boxed() + } + } + + + // AsyncPromptAdapter - for standalone functions returning GetPromptResult + impl<$($Tn,)* S, F, Fut, R> GetPromptHandler> for F + where + $( + $Tn: for<'a> FromContextPart> + Send + 'static, + )* + F: FnOnce($($Tn,)*) -> Fut + Send + 'static, + Fut: Future> + Send + 'static, + R: IntoGetPromptResult + Send + 'static, + S: Send + Sync + 'static, + { + #[allow(unused_variables, non_snake_case, unused_mut)] + fn handle( + self, + mut context: PromptContext<'_, S>, + ) -> BoxFuture<'_, Result> + { + // Extract all parameters before moving into the async block + $( + let result = $Tn::from_context_part(&mut context); + let $Tn = match result { + Ok(value) => value, + Err(e) => return std::future::ready(Err(e)).boxed(), + }; + )* + + // Since we're dealing with standalone functions that don't take &S, + // we can return a 'static future + Box::pin(async move { + let result = self($($Tn,)*).await?; + result.into_get_prompt_result() + }) + } + } + + + // SyncPromptAdapter - for standalone sync functions returning Result + impl<$($Tn,)* S, F, R> GetPromptHandler> for F + where + $( + $Tn: for<'a> FromContextPart> + Send + 'static, + )* + F: FnOnce($($Tn,)*) -> Result + Send + 'static, + R: IntoGetPromptResult + Send + 'static, + S: Send + Sync, + { + #[allow(unused_variables, non_snake_case, unused_mut)] + fn handle( + self, + mut context: PromptContext<'_, S>, + ) -> BoxFuture<'_, Result> + { + $( + let result = $Tn::from_context_part(&mut context); + let $Tn = match result { + Ok(value) => value, + Err(e) => return std::future::ready(Err(e)).boxed(), + }; + )* + let result = self($($Tn,)*); + std::future::ready(result.and_then(|r| r.into_get_prompt_result())).boxed() + } + } + + }; +} + +// Invoke the macro to generate implementations for up to 16 parameters +impl_prompt_handler_for!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); + +/// Extract prompt arguments from a type's JSON schema +/// This function analyzes the schema of a type and extracts the properties +/// as PromptArgument entries with name, description, and required status +pub fn cached_arguments_from_schema() +-> Option> { + let schema = super::common::cached_schema_for_type::(); + let schema_value = serde_json::Value::Object((*schema).clone()); + + let properties = schema_value.get("properties").and_then(|p| p.as_object()); + + if let Some(props) = properties { + let required = schema_value + .get("required") + .and_then(|r| r.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_str()) + .collect::>() + }) + .unwrap_or_default(); + + let mut arguments = Vec::new(); + for (name, prop_schema) in props { + let description = prop_schema + .get("description") + .and_then(|d| d.as_str()) + .map(|s| s.to_string()); + + arguments.push(crate::model::PromptArgument { + name: name.clone(), + title: None, + description, + required: Some(required.contains(name.as_str())), + }); + } + + if arguments.is_empty() { + None + } else { + Some(arguments) + } + } else { + None + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/handler/server/resource.rs b/code-rs/third_party/rmcp-0.8.3/src/handler/server/resource.rs new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/handler/server/resource.rs @@ -0,0 +1 @@ + diff --git a/code-rs/third_party/rmcp-0.8.3/src/handler/server/router.rs b/code-rs/third_party/rmcp-0.8.3/src/handler/server/router.rs new file mode 100644 index 00000000000..23b15ddef42 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/handler/server/router.rs @@ -0,0 +1,138 @@ +use std::sync::Arc; + +use prompt::{IntoPromptRoute, PromptRoute}; +use tool::{IntoToolRoute, ToolRoute}; + +use super::ServerHandler; +use crate::{ + RoleServer, Service, + model::{ClientRequest, ListPromptsResult, ListToolsResult, ServerResult}, + service::NotificationContext, +}; + +pub mod prompt; +pub mod tool; + +pub struct Router { + pub tool_router: tool::ToolRouter, + pub prompt_router: prompt::PromptRouter, + pub service: Arc, +} + +impl Router +where + S: ServerHandler, +{ + pub fn new(service: S) -> Self { + Self { + tool_router: tool::ToolRouter::new(), + prompt_router: prompt::PromptRouter::new(), + service: Arc::new(service), + } + } + + pub fn with_tool(mut self, route: R) -> Self + where + R: IntoToolRoute, + { + self.tool_router.add_route(route.into_tool_route()); + self + } + + pub fn with_tools(mut self, routes: impl IntoIterator>) -> Self { + for route in routes { + self.tool_router.add_route(route); + } + self + } + + pub fn with_prompt(mut self, route: R) -> Self + where + R: IntoPromptRoute, + { + self.prompt_router.add_route(route.into_prompt_route()); + self + } + + pub fn with_prompts(mut self, routes: impl IntoIterator>) -> Self { + for route in routes { + self.prompt_router.add_route(route); + } + self + } +} + +impl Service for Router +where + S: ServerHandler, +{ + async fn handle_notification( + &self, + notification: ::PeerNot, + context: NotificationContext, + ) -> Result<(), crate::ErrorData> { + self.service + .handle_notification(notification, context) + .await + } + async fn handle_request( + &self, + request: ::PeerReq, + context: crate::service::RequestContext, + ) -> Result<::Resp, crate::ErrorData> { + match request { + ClientRequest::CallToolRequest(request) => { + if self.tool_router.has_route(request.params.name.as_ref()) + || !self.tool_router.transparent_when_not_found + { + let tool_call_context = crate::handler::server::tool::ToolCallContext::new( + self.service.as_ref(), + request.params, + context, + ); + let result = self.tool_router.call(tool_call_context).await?; + Ok(ServerResult::CallToolResult(result)) + } else { + self.service + .handle_request(ClientRequest::CallToolRequest(request), context) + .await + } + } + ClientRequest::ListToolsRequest(_) => { + let tools = self.tool_router.list_all(); + Ok(ServerResult::ListToolsResult(ListToolsResult { + tools, + next_cursor: None, + })) + } + ClientRequest::GetPromptRequest(request) => { + if self.prompt_router.has_route(request.params.name.as_ref()) { + let prompt_context = crate::handler::server::prompt::PromptContext::new( + self.service.as_ref(), + request.params.name, + request.params.arguments, + context, + ); + let result = self.prompt_router.get_prompt(prompt_context).await?; + Ok(ServerResult::GetPromptResult(result)) + } else { + self.service + .handle_request(ClientRequest::GetPromptRequest(request), context) + .await + } + } + ClientRequest::ListPromptsRequest(_) => { + let prompts = self.prompt_router.list_all(); + Ok(ServerResult::ListPromptsResult(ListPromptsResult { + prompts, + next_cursor: None, + })) + } + rest => self.service.handle_request(rest, context).await, + } + } + + fn get_info(&self) -> ::Info { + self.service.get_info() + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/handler/server/router/prompt.rs b/code-rs/third_party/rmcp-0.8.3/src/handler/server/router/prompt.rs new file mode 100644 index 00000000000..a48d2ad7db5 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/handler/server/router/prompt.rs @@ -0,0 +1,213 @@ +use std::{borrow::Cow, sync::Arc}; + +use futures::future::BoxFuture; + +use crate::{ + handler::server::prompt::{DynGetPromptHandler, GetPromptHandler, PromptContext}, + model::{GetPromptResult, Prompt}, +}; + +pub struct PromptRoute { + #[allow(clippy::type_complexity)] + pub get: Arc>, + pub attr: crate::model::Prompt, +} + +impl std::fmt::Debug for PromptRoute { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PromptRoute") + .field("name", &self.attr.name) + .field("description", &self.attr.description) + .field("arguments", &self.attr.arguments) + .finish() + } +} + +impl Clone for PromptRoute { + fn clone(&self) -> Self { + Self { + get: self.get.clone(), + attr: self.attr.clone(), + } + } +} + +impl PromptRoute { + pub fn new(attr: impl Into, handler: H) -> Self + where + H: GetPromptHandler + Send + Sync + Clone + 'static, + { + Self { + get: Arc::new(move |context: PromptContext| { + let handler = handler.clone(); + handler.handle(context) + }), + attr: attr.into(), + } + } + + pub fn new_dyn(attr: impl Into, handler: H) -> Self + where + H: for<'a> Fn( + PromptContext<'a, S>, + ) -> BoxFuture<'a, Result> + + Send + + Sync + + 'static, + { + Self { + get: Arc::new(handler), + attr: attr.into(), + } + } + + pub fn name(&self) -> &str { + &self.attr.name + } +} + +pub trait IntoPromptRoute { + fn into_prompt_route(self) -> PromptRoute; +} + +impl IntoPromptRoute for (P, H) +where + S: Send + Sync + 'static, + A: 'static, + H: GetPromptHandler + Send + Sync + Clone + 'static, + P: Into, +{ + fn into_prompt_route(self) -> PromptRoute { + PromptRoute::new(self.0.into(), self.1) + } +} + +impl IntoPromptRoute for PromptRoute +where + S: Send + Sync + 'static, +{ + fn into_prompt_route(self) -> PromptRoute { + self + } +} + +/// Adapter for functions generated by the #\[prompt\] macro +pub struct PromptAttrGenerateFunctionAdapter; + +impl IntoPromptRoute for F +where + S: Send + Sync + 'static, + F: Fn() -> PromptRoute, +{ + fn into_prompt_route(self) -> PromptRoute { + (self)() + } +} + +#[derive(Debug)] +pub struct PromptRouter { + #[allow(clippy::type_complexity)] + pub map: std::collections::HashMap, PromptRoute>, +} + +impl Default for PromptRouter { + fn default() -> Self { + Self { + map: std::collections::HashMap::new(), + } + } +} + +impl Clone for PromptRouter { + fn clone(&self) -> Self { + Self { + map: self.map.clone(), + } + } +} + +impl IntoIterator for PromptRouter { + type Item = PromptRoute; + type IntoIter = std::collections::hash_map::IntoValues, PromptRoute>; + + fn into_iter(self) -> Self::IntoIter { + self.map.into_values() + } +} + +impl PromptRouter +where + S: Send + Sync + 'static, +{ + pub fn new() -> Self { + Self { + map: std::collections::HashMap::new(), + } + } + + pub fn with_route(mut self, route: R) -> Self + where + R: IntoPromptRoute, + { + self.add_route(route.into_prompt_route()); + self + } + + pub fn add_route(&mut self, item: PromptRoute) { + self.map.insert(item.attr.name.clone().into(), item); + } + + pub fn merge(&mut self, other: PromptRouter) { + for item in other.map.into_values() { + self.add_route(item); + } + } + + pub fn remove_route(&mut self, name: &str) { + self.map.remove(name); + } + + pub fn has_route(&self, name: &str) -> bool { + self.map.contains_key(name) + } + + pub async fn get_prompt( + &self, + context: PromptContext<'_, S>, + ) -> Result { + let item = self.map.get(context.name.as_str()).ok_or_else(|| { + crate::ErrorData::invalid_params( + format!("prompt '{}' not found", context.name), + Some(serde_json::json!({ + "available_prompts": self.list_all().iter().map(|p| &p.name).collect::>() + })), + ) + })?; + (item.get)(context).await + } + + pub fn list_all(&self) -> Vec { + self.map.values().map(|item| item.attr.clone()).collect() + } +} + +impl std::ops::Add> for PromptRouter +where + S: Send + Sync + 'static, +{ + type Output = Self; + + fn add(mut self, other: PromptRouter) -> Self::Output { + self.merge(other); + self + } +} + +impl std::ops::AddAssign> for PromptRouter +where + S: Send + Sync + 'static, +{ + fn add_assign(&mut self, other: PromptRouter) { + self.merge(other); + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/handler/server/router/tool.rs b/code-rs/third_party/rmcp-0.8.3/src/handler/server/router/tool.rs new file mode 100644 index 00000000000..bdef13e0698 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/handler/server/router/tool.rs @@ -0,0 +1,275 @@ +use std::{borrow::Cow, sync::Arc}; + +use futures::{FutureExt, future::BoxFuture}; +use schemars::JsonSchema; + +use crate::{ + handler::server::tool::{ + CallToolHandler, DynCallToolHandler, ToolCallContext, schema_for_type, + }, + model::{CallToolResult, Tool, ToolAnnotations}, +}; + +pub struct ToolRoute { + #[allow(clippy::type_complexity)] + pub call: Arc>, + pub attr: crate::model::Tool, +} + +impl std::fmt::Debug for ToolRoute { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ToolRoute") + .field("name", &self.attr.name) + .field("description", &self.attr.description) + .field("input_schema", &self.attr.input_schema) + .finish() + } +} + +impl Clone for ToolRoute { + fn clone(&self) -> Self { + Self { + call: self.call.clone(), + attr: self.attr.clone(), + } + } +} + +impl ToolRoute { + pub fn new(attr: impl Into, call: C) -> Self + where + C: CallToolHandler + Send + Sync + Clone + 'static, + { + Self { + call: Arc::new(move |context: ToolCallContext| { + let call = call.clone(); + context.invoke(call).boxed() + }), + attr: attr.into(), + } + } + pub fn new_dyn(attr: impl Into, call: C) -> Self + where + C: for<'a> Fn( + ToolCallContext<'a, S>, + ) -> BoxFuture<'a, Result> + + Send + + Sync + + 'static, + { + Self { + call: Arc::new(call), + attr: attr.into(), + } + } + pub fn name(&self) -> &str { + &self.attr.name + } +} + +pub trait IntoToolRoute { + fn into_tool_route(self) -> ToolRoute; +} + +impl IntoToolRoute for (T, C) +where + S: Send + Sync + 'static, + C: CallToolHandler + Send + Sync + Clone + 'static, + T: Into, +{ + fn into_tool_route(self) -> ToolRoute { + ToolRoute::new(self.0.into(), self.1) + } +} + +impl IntoToolRoute for ToolRoute +where + S: Send + Sync + 'static, +{ + fn into_tool_route(self) -> ToolRoute { + self + } +} + +pub struct ToolAttrGenerateFunctionAdapter; +impl IntoToolRoute for F +where + S: Send + Sync + 'static, + F: Fn() -> ToolRoute, +{ + fn into_tool_route(self) -> ToolRoute { + (self)() + } +} + +pub trait CallToolHandlerExt: Sized +where + Self: CallToolHandler + Send + Sync + Clone + 'static, +{ + fn name(self, name: impl Into>) -> WithToolAttr; +} + +impl CallToolHandlerExt for C +where + C: CallToolHandler + Send + Sync + Clone + 'static, +{ + fn name(self, name: impl Into>) -> WithToolAttr { + WithToolAttr { + attr: Tool::new( + name.into(), + "", + schema_for_type::(), + ), + call: self, + _marker: std::marker::PhantomData, + } + } +} + +pub struct WithToolAttr +where + C: CallToolHandler + Send + Sync + Clone + 'static, +{ + pub attr: crate::model::Tool, + pub call: C, + pub _marker: std::marker::PhantomData, +} + +impl IntoToolRoute for WithToolAttr +where + C: CallToolHandler + Send + Sync + Clone + 'static, + S: Send + Sync + 'static, +{ + fn into_tool_route(self) -> ToolRoute { + ToolRoute::new(self.attr, self.call) + } +} + +impl WithToolAttr +where + C: CallToolHandler + Send + Sync + Clone + 'static, +{ + pub fn description(mut self, description: impl Into>) -> Self { + self.attr.description = Some(description.into()); + self + } + pub fn parameters(mut self) -> Self { + self.attr.input_schema = schema_for_type::().into(); + self + } + pub fn parameters_value(mut self, schema: serde_json::Value) -> Self { + self.attr.input_schema = crate::model::object(schema).into(); + self + } + pub fn annotation(mut self, annotation: impl Into) -> Self { + self.attr.annotations = Some(annotation.into()); + self + } +} +#[derive(Debug)] +pub struct ToolRouter { + #[allow(clippy::type_complexity)] + pub map: std::collections::HashMap, ToolRoute>, + + pub transparent_when_not_found: bool, +} + +impl Default for ToolRouter { + fn default() -> Self { + Self { + map: std::collections::HashMap::new(), + transparent_when_not_found: false, + } + } +} +impl Clone for ToolRouter { + fn clone(&self) -> Self { + Self { + map: self.map.clone(), + transparent_when_not_found: self.transparent_when_not_found, + } + } +} + +impl IntoIterator for ToolRouter { + type Item = ToolRoute; + type IntoIter = std::collections::hash_map::IntoValues, ToolRoute>; + + fn into_iter(self) -> Self::IntoIter { + self.map.into_values() + } +} + +impl ToolRouter +where + S: Send + Sync + 'static, +{ + pub fn new() -> Self { + Self { + map: std::collections::HashMap::new(), + transparent_when_not_found: false, + } + } + pub fn with_route(mut self, route: R) -> Self + where + R: IntoToolRoute, + { + self.add_route(route.into_tool_route()); + self + } + + pub fn add_route(&mut self, item: ToolRoute) { + self.map.insert(item.attr.name.clone(), item); + } + + pub fn merge(&mut self, other: ToolRouter) { + for item in other.map.into_values() { + self.add_route(item); + } + } + + pub fn remove_route(&mut self, name: &str) { + self.map.remove(name); + } + pub fn has_route(&self, name: &str) -> bool { + self.map.contains_key(name) + } + pub async fn call( + &self, + context: ToolCallContext<'_, S>, + ) -> Result { + let item = self + .map + .get(context.name()) + .ok_or_else(|| crate::ErrorData::invalid_params("tool not found", None))?; + + let result = (item.call)(context).await?; + + Ok(result) + } + + pub fn list_all(&self) -> Vec { + self.map.values().map(|item| item.attr.clone()).collect() + } +} + +impl std::ops::Add> for ToolRouter +where + S: Send + Sync + 'static, +{ + type Output = Self; + + fn add(mut self, other: ToolRouter) -> Self::Output { + self.merge(other); + self + } +} + +impl std::ops::AddAssign> for ToolRouter +where + S: Send + Sync + 'static, +{ + fn add_assign(&mut self, other: ToolRouter) { + self.merge(other); + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/handler/server/tool.rs b/code-rs/third_party/rmcp-0.8.3/src/handler/server/tool.rs new file mode 100644 index 00000000000..cf842679a3e --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/handler/server/tool.rs @@ -0,0 +1,326 @@ +use std::{ + borrow::Cow, + future::{Future, Ready}, + marker::PhantomData, +}; + +use futures::future::{BoxFuture, FutureExt}; +use serde::de::DeserializeOwned; + +use super::common::{AsRequestContext, FromContextPart}; +pub use super::{ + common::{Extension, RequestId, cached_schema_for_type, schema_for_type}, + router::tool::{ToolRoute, ToolRouter}, +}; +use crate::{ + RoleServer, + handler::server::wrapper::Parameters, + model::{CallToolRequestParam, CallToolResult, IntoContents, JsonObject}, + service::RequestContext, +}; + +/// Deserialize a JSON object into a type +pub fn parse_json_object(input: JsonObject) -> Result { + serde_json::from_value(serde_json::Value::Object(input)).map_err(|e| { + crate::ErrorData::invalid_params( + format!("failed to deserialize parameters: {error}", error = e), + None, + ) + }) +} +pub struct ToolCallContext<'s, S> { + pub request_context: RequestContext, + pub service: &'s S, + pub name: Cow<'static, str>, + pub arguments: Option, +} + +impl<'s, S> ToolCallContext<'s, S> { + pub fn new( + service: &'s S, + CallToolRequestParam { name, arguments }: CallToolRequestParam, + request_context: RequestContext, + ) -> Self { + Self { + request_context, + service, + name, + arguments, + } + } + pub fn name(&self) -> &str { + &self.name + } + pub fn request_context(&self) -> &RequestContext { + &self.request_context + } +} + +impl AsRequestContext for ToolCallContext<'_, S> { + fn as_request_context(&self) -> &RequestContext { + &self.request_context + } + + fn as_request_context_mut(&mut self) -> &mut RequestContext { + &mut self.request_context + } +} + +pub trait IntoCallToolResult { + fn into_call_tool_result(self) -> Result; +} + +impl IntoCallToolResult for T { + fn into_call_tool_result(self) -> Result { + Ok(CallToolResult::success(self.into_contents())) + } +} + +impl IntoCallToolResult for Result { + fn into_call_tool_result(self) -> Result { + match self { + Ok(value) => Ok(CallToolResult::success(value.into_contents())), + Err(error) => Ok(CallToolResult::error(error.into_contents())), + } + } +} + +impl IntoCallToolResult for Result { + fn into_call_tool_result(self) -> Result { + match self { + Ok(value) => value.into_call_tool_result(), + Err(error) => Err(error), + } + } +} + +pin_project_lite::pin_project! { + #[project = IntoCallToolResultFutProj] + pub enum IntoCallToolResultFut { + Pending { + #[pin] + fut: F, + _marker: PhantomData, + }, + Ready { + #[pin] + result: Ready>, + } + } +} + +impl Future for IntoCallToolResultFut +where + F: Future, + R: IntoCallToolResult, +{ + type Output = Result; + + fn poll( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + match self.project() { + IntoCallToolResultFutProj::Pending { fut, _marker } => { + fut.poll(cx).map(IntoCallToolResult::into_call_tool_result) + } + IntoCallToolResultFutProj::Ready { result } => result.poll(cx), + } + } +} + +impl IntoCallToolResult for Result { + fn into_call_tool_result(self) -> Result { + self + } +} + +pub trait CallToolHandler { + fn call( + self, + context: ToolCallContext<'_, S>, + ) -> BoxFuture<'_, Result>; +} + +pub type DynCallToolHandler = dyn for<'s> Fn(ToolCallContext<'s, S>) -> BoxFuture<'s, Result> + + Send + + Sync; + +// Tool-specific extractor for tool name +pub struct ToolName(pub Cow<'static, str>); + +impl FromContextPart> for ToolName { + fn from_context_part(context: &mut ToolCallContext) -> Result { + Ok(Self(context.name.clone())) + } +} + +// Special implementation for Parameters that handles tool arguments +impl FromContextPart> for Parameters

+where + P: DeserializeOwned, +{ + fn from_context_part(context: &mut ToolCallContext) -> Result { + let arguments = context.arguments.take().unwrap_or_default(); + let value: P = + serde_json::from_value(serde_json::Value::Object(arguments)).map_err(|e| { + crate::ErrorData::invalid_params( + format!("failed to deserialize parameters: {error}", error = e), + None, + ) + })?; + Ok(Parameters(value)) + } +} + +// Special implementation for JsonObject that takes tool arguments +impl FromContextPart> for JsonObject { + fn from_context_part(context: &mut ToolCallContext) -> Result { + let object = context.arguments.take().unwrap_or_default(); + Ok(object) + } +} + +impl<'s, S> ToolCallContext<'s, S> { + pub fn invoke(self, h: H) -> BoxFuture<'s, Result> + where + H: CallToolHandler, + { + h.call(self) + } +} +#[allow(clippy::type_complexity)] +pub struct AsyncAdapter(PhantomData fn(Fut) -> R>); +pub struct SyncAdapter(PhantomData R>); +// #[allow(clippy::type_complexity)] +pub struct AsyncMethodAdapter(PhantomData R>); +pub struct SyncMethodAdapter(PhantomData R>); + +macro_rules! impl_for { + ($($T: ident)*) => { + impl_for!([] [$($T)*]); + }; + // finished + ([$($Tn: ident)*] []) => { + impl_for!(@impl $($Tn)*); + }; + ([$($Tn: ident)*] [$Tn_1: ident $($Rest: ident)*]) => { + impl_for!(@impl $($Tn)*); + impl_for!([$($Tn)* $Tn_1] [$($Rest)*]); + }; + (@impl $($Tn: ident)*) => { + impl<$($Tn,)* S, F, R> CallToolHandler> for F + where + $( + $Tn: for<'a> FromContextPart> , + )* + F: FnOnce(&S, $($Tn,)*) -> BoxFuture<'_, R>, + + // Need RTN support here(I guess), https://github.com/rust-lang/rust/pull/138424 + // Fut: Future + Send + 'a, + R: IntoCallToolResult + Send + 'static, + S: Send + Sync + 'static, + { + #[allow(unused_variables, non_snake_case, unused_mut)] + fn call( + self, + mut context: ToolCallContext<'_, S>, + ) -> BoxFuture<'_, Result>{ + $( + let result = $Tn::from_context_part(&mut context); + let $Tn = match result { + Ok(value) => value, + Err(e) => return std::future::ready(Err(e)).boxed(), + }; + )* + let service = context.service; + let fut = self(service, $($Tn,)*); + async move { + let result = fut.await; + result.into_call_tool_result() + }.boxed() + } + } + + impl<$($Tn,)* S, F, Fut, R> CallToolHandler> for F + where + $( + $Tn: for<'a> FromContextPart> , + )* + F: FnOnce($($Tn,)*) -> Fut + Send + , + Fut: Future + Send + 'static, + R: IntoCallToolResult + Send + 'static, + S: Send + Sync, + { + #[allow(unused_variables, non_snake_case, unused_mut)] + fn call( + self, + mut context: ToolCallContext, + ) -> BoxFuture<'static, Result>{ + $( + let result = $Tn::from_context_part(&mut context); + let $Tn = match result { + Ok(value) => value, + Err(e) => return std::future::ready(Err(e)).boxed(), + }; + )* + let fut = self($($Tn,)*); + async move { + let result = fut.await; + result.into_call_tool_result() + }.boxed() + } + } + + impl<$($Tn,)* S, F, R> CallToolHandler> for F + where + $( + $Tn: for<'a> FromContextPart> + , + )* + F: FnOnce(&S, $($Tn,)*) -> R + Send + , + R: IntoCallToolResult + Send + , + S: Send + Sync, + { + #[allow(unused_variables, non_snake_case, unused_mut)] + fn call( + self, + mut context: ToolCallContext, + ) -> BoxFuture<'static, Result> { + $( + let result = $Tn::from_context_part(&mut context); + let $Tn = match result { + Ok(value) => value, + Err(e) => return std::future::ready(Err(e)).boxed(), + }; + )* + std::future::ready(self(context.service, $($Tn,)*).into_call_tool_result()).boxed() + } + } + + impl<$($Tn,)* S, F, R> CallToolHandler> for F + where + $( + $Tn: for<'a> FromContextPart> + , + )* + F: FnOnce($($Tn,)*) -> R + Send + , + R: IntoCallToolResult + Send + , + S: Send + Sync, + { + #[allow(unused_variables, non_snake_case, unused_mut)] + fn call( + self, + mut context: ToolCallContext, + ) -> BoxFuture<'static, Result> { + $( + let result = $Tn::from_context_part(&mut context); + let $Tn = match result { + Ok(value) => value, + Err(e) => return std::future::ready(Err(e)).boxed(), + }; + )* + std::future::ready(self($($Tn,)*).into_call_tool_result()).boxed() + } + } + }; +} +impl_for!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); diff --git a/code-rs/third_party/rmcp-0.8.3/src/handler/server/wrapper.rs b/code-rs/third_party/rmcp-0.8.3/src/handler/server/wrapper.rs new file mode 100644 index 00000000000..d9f2c86d97c --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/handler/server/wrapper.rs @@ -0,0 +1,4 @@ +mod json; +mod parameters; +pub use json::*; +pub use parameters::*; diff --git a/code-rs/third_party/rmcp-0.8.3/src/handler/server/wrapper/json.rs b/code-rs/third_party/rmcp-0.8.3/src/handler/server/wrapper/json.rs new file mode 100644 index 00000000000..8eae3026888 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/handler/server/wrapper/json.rs @@ -0,0 +1,54 @@ +use std::borrow::Cow; + +use schemars::JsonSchema; +use serde::Serialize; + +use crate::{ + handler::server::tool::IntoCallToolResult, + model::{CallToolResult, IntoContents}, +}; + +/// Json wrapper for structured output +/// +/// When used with tools, this wrapper indicates that the value should be +/// serialized as structured JSON content with an associated schema. +/// The framework will place the JSON in the `structured_content` field +/// of the tool result rather than the regular `content` field. +pub struct Json(pub T); + +// Implement JsonSchema for Json to delegate to T's schema +impl JsonSchema for Json { + fn schema_name() -> Cow<'static, str> { + T::schema_name() + } + + fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema { + T::json_schema(generator) + } +} + +// Implementation for Json to create structured content +impl IntoCallToolResult for Json { + fn into_call_tool_result(self) -> Result { + let value = serde_json::to_value(self.0).map_err(|e| { + crate::ErrorData::internal_error( + format!("Failed to serialize structured content: {}", e), + None, + ) + })?; + + Ok(CallToolResult::structured(value)) + } +} + +// Implementation for Result, E> +impl IntoCallToolResult + for Result, E> +{ + fn into_call_tool_result(self) -> Result { + match self { + Ok(value) => value.into_call_tool_result(), + Err(error) => Ok(CallToolResult::error(error.into_contents())), + } + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/handler/server/wrapper/parameters.rs b/code-rs/third_party/rmcp-0.8.3/src/handler/server/wrapper/parameters.rs new file mode 100644 index 00000000000..9de73dd6683 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/handler/server/wrapper/parameters.rs @@ -0,0 +1,55 @@ +use schemars::JsonSchema; + +/// Parameter extractor for tools and prompts +/// +/// When used in tool and prompt handlers, this wrapper extracts and deserializes +/// parameters from the incoming request. The framework will automatically parse +/// the JSON arguments from tool calls or prompt arguments and deserialize them +/// into the specified type `P`. +/// +/// The `#[serde(transparent)]` attribute ensures that the wrapper doesn't add +/// an extra layer in the JSON structure - it directly delegates serialization +/// and deserialization to the inner type `P`. +/// +/// # Usage +/// +/// Use `Parameters` as a parameter in your tool or prompt handler functions: +/// +/// ```rust +/// # use rmcp::handler::server::wrapper::Parameters; +/// # use schemars::JsonSchema; +/// # use serde::{Deserialize, Serialize}; +/// #[derive(Deserialize, JsonSchema)] +/// struct CalculationRequest { +/// operation: String, +/// a: f64, +/// b: f64, +/// } +/// +/// // In a tool handler +/// async fn calculate(params: Parameters) -> Result { +/// let request = params.0; // Extract the inner value +/// match request.operation.as_str() { +/// "add" => Ok((request.a + request.b).to_string()), +/// _ => Err("Unknown operation".to_string()), +/// } +/// } +/// ``` +/// +/// The framework handles the extraction automatically: +/// - For tools: Parses the `arguments` field from tool call requests +/// - For prompts: Parses the `arguments` field from prompt requests +/// - Returns appropriate error responses if deserialization fails +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(transparent)] +pub struct Parameters

(pub P); + +impl JsonSchema for Parameters

{ + fn schema_name() -> std::borrow::Cow<'static, str> { + P::schema_name() + } + + fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema { + P::json_schema(generator) + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/lib.rs b/code-rs/third_party/rmcp-0.8.3/src/lib.rs new file mode 100644 index 00000000000..3476390a58e --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/lib.rs @@ -0,0 +1,182 @@ +#![cfg_attr(docsrs, feature(doc_cfg))] +#![cfg_attr(docsrs, allow(unused_attributes))] +//! The official Rust SDK for the Model Context Protocol (MCP). +//! +//! The MCP is a protocol that allows AI assistants to communicate with other +//! services. `rmcp` is the official Rust implementation of this protocol. +//! +//! There are two ways in which the library can be used, namely to build a +//! server or to build a client. +//! +//! ## Server +//! +//! A server is a service that exposes capabilities. For example, a common +//! use-case is for the server to make multiple tools available to clients such +//! as Claude Desktop or the Cursor IDE. +//! +//! For example, to implement a server that has a tool that can count, you would +//! make an object for that tool and add an implementation with the `#[tool_router]` macro: +//! +//! ```rust +//! use std::sync::Arc; +//! use rmcp::{ErrorData as McpError, model::*, tool, tool_router, handler::server::tool::ToolRouter}; +//! use tokio::sync::Mutex; +//! +//! #[derive(Clone)] +//! pub struct Counter { +//! counter: Arc>, +//! tool_router: ToolRouter, +//! } +//! +//! #[tool_router] +//! impl Counter { +//! fn new() -> Self { +//! Self { +//! counter: Arc::new(Mutex::new(0)), +//! tool_router: Self::tool_router(), +//! } +//! } +//! +//! #[tool(description = "Increment the counter by 1")] +//! async fn increment(&self) -> Result { +//! let mut counter = self.counter.lock().await; +//! *counter += 1; +//! Ok(CallToolResult::success(vec![Content::text( +//! counter.to_string(), +//! )])) +//! } +//! } +//! ``` +//! +//! ### Structured Output +//! +//! Tools can also return structured JSON data with schemas. Use the [`Json`] wrapper: +//! +//! ```rust +//! # use rmcp::{tool, tool_router, handler::server::{tool::ToolRouter, wrapper::Parameters}, Json}; +//! # use schemars::JsonSchema; +//! # use serde::{Serialize, Deserialize}; +//! # +//! #[derive(Serialize, Deserialize, JsonSchema)] +//! struct CalculationRequest { +//! a: i32, +//! b: i32, +//! operation: String, +//! } +//! +//! #[derive(Serialize, Deserialize, JsonSchema)] +//! struct CalculationResult { +//! result: i32, +//! operation: String, +//! } +//! +//! # #[derive(Clone)] +//! # struct Calculator { +//! # tool_router: ToolRouter, +//! # } +//! # +//! # #[tool_router] +//! # impl Calculator { +//! #[tool(name = "calculate", description = "Perform a calculation")] +//! async fn calculate(&self, params: Parameters) -> Result, String> { +//! let result = match params.0.operation.as_str() { +//! "add" => params.0.a + params.0.b, +//! "multiply" => params.0.a * params.0.b, +//! _ => return Err("Unknown operation".to_string()), +//! }; +//! +//! Ok(Json(CalculationResult { result, operation: params.0.operation })) +//! } +//! # } +//! ``` +//! +//! The `#[tool]` macro automatically generates an output schema from the `CalculationResult` type. +//! +//! Next also implement [ServerHandler] for your server type and start the server inside +//! `main` by calling `.serve(...)`. See the examples directory in the repository for more information. +//! +//! ## Client +//! +//! A client can be used to interact with a server. Clients can be used to get a +//! list of the available tools and to call them. For example, we can `uv` to +//! start a MCP server in Python and then list the tools and call `git status` +//! as follows: +//! +//! ```rust +//! use anyhow::Result; +//! use rmcp::{model::CallToolRequestParam, service::ServiceExt, transport::{TokioChildProcess, ConfigureCommandExt}}; +//! use tokio::process::Command; +//! +//! async fn client() -> Result<()> { +//! let service = ().serve(TokioChildProcess::new(Command::new("uvx").configure(|cmd| { +//! cmd.arg("mcp-server-git"); +//! }))?).await?; +//! +//! // Initialize +//! let server_info = service.peer_info(); +//! println!("Connected to server: {server_info:#?}"); +//! +//! // List tools +//! let tools = service.list_tools(Default::default()).await?; +//! println!("Available tools: {tools:#?}"); +//! +//! // Call tool 'git_status' with arguments = {"repo_path": "."} +//! let tool_result = service +//! .call_tool(CallToolRequestParam { +//! name: "git_status".into(), +//! arguments: serde_json::json!({ "repo_path": "." }).as_object().cloned(), +//! }) +//! .await?; +//! println!("Tool result: {tool_result:#?}"); +//! +//! service.cancel().await?; +//! Ok(()) +//! } +//! ``` +mod error; +#[allow(deprecated)] +pub use error::{Error, ErrorData, RmcpError}; + +/// Basic data types in MCP specification +pub mod model; +#[cfg(any(feature = "client", feature = "server"))] +#[cfg_attr(docsrs, doc(cfg(any(feature = "client", feature = "server"))))] +pub mod service; +#[cfg(feature = "client")] +#[cfg_attr(docsrs, doc(cfg(feature = "client")))] +pub use handler::client::ClientHandler; +#[cfg(feature = "server")] +#[cfg_attr(docsrs, doc(cfg(feature = "server")))] +pub use handler::server::ServerHandler; +#[cfg(feature = "server")] +#[cfg_attr(docsrs, doc(cfg(feature = "server")))] +pub use handler::server::wrapper::Json; +#[cfg(any(feature = "client", feature = "server"))] +#[cfg_attr(docsrs, doc(cfg(any(feature = "client", feature = "server"))))] +pub use service::{Peer, Service, ServiceError, ServiceExt}; +#[cfg(feature = "client")] +#[cfg_attr(docsrs, doc(cfg(feature = "client")))] +pub use service::{RoleClient, serve_client}; +#[cfg(feature = "server")] +#[cfg_attr(docsrs, doc(cfg(feature = "server")))] +pub use service::{RoleServer, serve_server}; + +pub mod handler; +pub mod transport; + +// re-export +#[cfg(all(feature = "macros", feature = "server"))] +#[cfg_attr(docsrs, doc(cfg(all(feature = "macros", feature = "server"))))] +pub use paste::paste; +#[cfg(all(feature = "macros", feature = "server"))] +#[cfg_attr(docsrs, doc(cfg(all(feature = "macros", feature = "server"))))] +pub use rmcp_macros::*; +#[cfg(all(feature = "macros", feature = "server"))] +#[cfg_attr(docsrs, doc(cfg(all(feature = "macros", feature = "server"))))] +pub use schemars; +#[cfg(feature = "macros")] +#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] +pub use serde; +#[cfg(feature = "macros")] +#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] +pub use serde_json; diff --git a/code-rs/third_party/rmcp-0.8.3/src/model.rs b/code-rs/third_party/rmcp-0.8.3/src/model.rs new file mode 100644 index 00000000000..fb757c09524 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/model.rs @@ -0,0 +1,2180 @@ +use std::{borrow::Cow, sync::Arc}; +mod annotated; +mod capabilities; +mod content; +mod elicitation_schema; +mod extension; +mod meta; +mod prompt; +mod resource; +mod serde_impl; +mod tool; +pub use annotated::*; +pub use capabilities::*; +pub use content::*; +pub use elicitation_schema::*; +pub use extension::*; +pub use meta::*; +pub use prompt::*; +pub use resource::*; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; +use serde_json::Value; +pub use tool::*; + +/// A JSON object type alias for convenient handling of JSON data. +/// +/// You can use [`crate::object!`] or [`crate::model::object`] to create a json object quickly. +/// This is commonly used for storing arbitrary JSON data in MCP messages. +pub type JsonObject = serde_json::Map; + +/// unwrap the JsonObject under [`serde_json::Value`] +/// +/// # Panic +/// This will panic when the value is not a object in debug mode. +pub fn object(value: serde_json::Value) -> JsonObject { + debug_assert!(value.is_object()); + match value { + serde_json::Value::Object(map) => map, + _ => JsonObject::default(), + } +} + +/// Use this macro just like [`serde_json::json!`] +#[cfg(feature = "macros")] +#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] +#[macro_export] +macro_rules! object { + ({$($tt:tt)*}) => { + $crate::model::object(serde_json::json! { + {$($tt)*} + }) + }; +} + +/// This is commonly used for representing empty objects in MCP messages. +/// +/// without returning any specific data. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Copy, Eq)] +#[cfg_attr(feature = "server", derive(schemars::JsonSchema))] +pub struct EmptyObject {} + +pub trait ConstString: Default { + const VALUE: &str; + fn as_str(&self) -> &'static str { + Self::VALUE + } +} +#[macro_export] +macro_rules! const_string { + ($name:ident = $value:literal) => { + #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] + pub struct $name; + + impl ConstString for $name { + const VALUE: &str = $value; + } + + impl serde::Serialize for $name { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + $value.serialize(serializer) + } + } + + impl<'de> serde::Deserialize<'de> for $name { + fn deserialize(deserializer: D) -> Result<$name, D::Error> + where + D: serde::Deserializer<'de>, + { + let s: String = serde::Deserialize::deserialize(deserializer)?; + if s == $value { + Ok($name) + } else { + Err(serde::de::Error::custom(format!(concat!( + "expect const string value \"", + $value, + "\"" + )))) + } + } + } + + #[cfg(feature = "schemars")] + impl schemars::JsonSchema for $name { + fn schema_name() -> Cow<'static, str> { + Cow::Borrowed(stringify!($name)) + } + + fn json_schema(_: &mut schemars::SchemaGenerator) -> schemars::Schema { + use serde_json::{Map, json}; + + let mut schema_map = Map::new(); + schema_map.insert("type".to_string(), json!("string")); + schema_map.insert("format".to_string(), json!("const")); + schema_map.insert("const".to_string(), json!($value)); + + schemars::Schema::from(schema_map) + } + } + }; +} + +const_string!(JsonRpcVersion2_0 = "2.0"); + +// ============================================================================= +// CORE PROTOCOL TYPES +// ============================================================================= + +/// Represents the MCP protocol version used for communication. +/// +/// This ensures compatibility between clients and servers by specifying +/// which version of the Model Context Protocol is being used. +#[derive(Debug, Clone, Eq, PartialEq, Hash, PartialOrd)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct ProtocolVersion(Cow<'static, str>); + +impl Default for ProtocolVersion { + fn default() -> Self { + Self::LATEST + } +} + +impl std::fmt::Display for ProtocolVersion { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl ProtocolVersion { + pub const V_2025_06_18: Self = Self(Cow::Borrowed("2025-06-18")); + pub const V_2025_03_26: Self = Self(Cow::Borrowed("2025-03-26")); + pub const V_2024_11_05: Self = Self(Cow::Borrowed("2024-11-05")); + // Keep LATEST at 2025-03-26 until full 2025-06-18 compliance and automated testing are in place. + pub const LATEST: Self = Self::V_2025_03_26; +} + +impl Serialize for ProtocolVersion { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.0.serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for ProtocolVersion { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s: String = Deserialize::deserialize(deserializer)?; + #[allow(clippy::single_match)] + match s.as_str() { + "2024-11-05" => return Ok(ProtocolVersion::V_2024_11_05), + "2025-03-26" => return Ok(ProtocolVersion::V_2025_03_26), + "2025-06-18" => return Ok(ProtocolVersion::V_2025_06_18), + _ => {} + } + Ok(ProtocolVersion(Cow::Owned(s))) + } +} + +/// A flexible identifier type that can be either a number or a string. +/// +/// This is commonly used for request IDs and other identifiers in JSON-RPC +/// where the specification allows both numeric and string values. +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub enum NumberOrString { + /// A numeric identifier + Number(i64), + /// A string identifier + String(Arc), +} + +impl NumberOrString { + pub fn into_json_value(self) -> Value { + match self { + NumberOrString::Number(n) => Value::Number(serde_json::Number::from(n)), + NumberOrString::String(s) => Value::String(s.to_string()), + } + } +} + +impl std::fmt::Display for NumberOrString { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + NumberOrString::Number(n) => n.fmt(f), + NumberOrString::String(s) => s.fmt(f), + } + } +} + +impl Serialize for NumberOrString { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + NumberOrString::Number(n) => n.serialize(serializer), + NumberOrString::String(s) => s.serialize(serializer), + } + } +} + +impl<'de> Deserialize<'de> for NumberOrString { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let value: Value = Deserialize::deserialize(deserializer)?; + match value { + Value::Number(n) => { + if let Some(i) = n.as_i64() { + Ok(NumberOrString::Number(i)) + } else if let Some(u) = n.as_u64() { + // Handle large unsigned numbers that fit in i64 + if u <= i64::MAX as u64 { + Ok(NumberOrString::Number(u as i64)) + } else { + Err(serde::de::Error::custom("Number too large for i64")) + } + } else { + Err(serde::de::Error::custom("Expected an integer")) + } + } + Value::String(s) => Ok(NumberOrString::String(s.into())), + _ => Err(serde::de::Error::custom("Expect number or string")), + } + } +} + +#[cfg(feature = "schemars")] +impl schemars::JsonSchema for NumberOrString { + fn schema_name() -> Cow<'static, str> { + Cow::Borrowed("NumberOrString") + } + + fn json_schema(_: &mut schemars::SchemaGenerator) -> schemars::Schema { + use serde_json::{Map, json}; + + let mut number_schema = Map::new(); + number_schema.insert("type".to_string(), json!("number")); + + let mut string_schema = Map::new(); + string_schema.insert("type".to_string(), json!("string")); + + let mut schema_map = Map::new(); + schema_map.insert("oneOf".to_string(), json!([number_schema, string_schema])); + + schemars::Schema::from(schema_map) + } +} + +/// Type alias for request identifiers used in JSON-RPC communication. +pub type RequestId = NumberOrString; + +/// A token used to track the progress of long-running operations. +/// +/// Progress tokens allow clients and servers to associate progress notifications +/// with specific requests, enabling real-time updates on operation status. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Hash, Eq)] +#[serde(transparent)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct ProgressToken(pub NumberOrString); + +// ============================================================================= +// JSON-RPC MESSAGE STRUCTURES +// ============================================================================= + +/// Represents a JSON-RPC request with method, parameters, and extensions. +/// +/// This is the core structure for all MCP requests, containing: +/// - `method`: The name of the method being called +/// - `params`: The parameters for the method +/// - `extensions`: Additional context data (similar to HTTP headers) +#[derive(Debug, Clone, Default)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct Request { + pub method: M, + pub params: P, + /// extensions will carry anything possible in the context, including [`Meta`] + /// + /// this is similar with the Extensions in `http` crate + #[cfg_attr(feature = "schemars", schemars(skip))] + pub extensions: Extensions, +} + +impl Request { + pub fn new(params: P) -> Self { + Self { + method: Default::default(), + params, + extensions: Extensions::default(), + } + } +} + +impl GetExtensions for Request { + fn extensions(&self) -> &Extensions { + &self.extensions + } + fn extensions_mut(&mut self) -> &mut Extensions { + &mut self.extensions + } +} + +#[derive(Debug, Clone, Default)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct RequestOptionalParam { + pub method: M, + // #[serde(skip_serializing_if = "Option::is_none")] + pub params: Option

, + /// extensions will carry anything possible in the context, including [`Meta`] + /// + /// this is similar with the Extensions in `http` crate + #[cfg_attr(feature = "schemars", schemars(skip))] + pub extensions: Extensions, +} + +impl RequestOptionalParam { + pub fn with_param(params: P) -> Self { + Self { + method: Default::default(), + params: Some(params), + extensions: Extensions::default(), + } + } +} + +#[derive(Debug, Clone, Default)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct RequestNoParam { + pub method: M, + /// extensions will carry anything possible in the context, including [`Meta`] + /// + /// this is similar with the Extensions in `http` crate + #[cfg_attr(feature = "schemars", schemars(skip))] + pub extensions: Extensions, +} + +impl GetExtensions for RequestNoParam { + fn extensions(&self) -> &Extensions { + &self.extensions + } + fn extensions_mut(&mut self) -> &mut Extensions { + &mut self.extensions + } +} +#[derive(Debug, Clone, Default)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct Notification { + pub method: M, + pub params: P, + /// extensions will carry anything possible in the context, including [`Meta`] + /// + /// this is similar with the Extensions in `http` crate + #[cfg_attr(feature = "schemars", schemars(skip))] + pub extensions: Extensions, +} + +impl Notification { + pub fn new(params: P) -> Self { + Self { + method: Default::default(), + params, + extensions: Extensions::default(), + } + } +} + +#[derive(Debug, Clone, Default)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct NotificationNoParam { + pub method: M, + /// extensions will carry anything possible in the context, including [`Meta`] + /// + /// this is similar with the Extensions in `http` crate + #[cfg_attr(feature = "schemars", schemars(skip))] + pub extensions: Extensions, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct JsonRpcRequest { + pub jsonrpc: JsonRpcVersion2_0, + pub id: RequestId, + #[serde(flatten)] + pub request: R, +} + +type DefaultResponse = JsonObject; +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct JsonRpcResponse { + pub jsonrpc: JsonRpcVersion2_0, + pub id: RequestId, + pub result: R, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct JsonRpcError { + pub jsonrpc: JsonRpcVersion2_0, + pub id: RequestId, + pub error: ErrorData, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct JsonRpcNotification { + pub jsonrpc: JsonRpcVersion2_0, + #[serde(flatten)] + pub notification: N, +} + +/// Standard JSON-RPC error codes used throughout the MCP protocol. +/// +/// These codes follow the JSON-RPC 2.0 specification and provide +/// standardized error reporting across all MCP implementations. +#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)] +#[serde(transparent)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct ErrorCode(pub i32); + +impl ErrorCode { + pub const RESOURCE_NOT_FOUND: Self = Self(-32002); + pub const INVALID_REQUEST: Self = Self(-32600); + pub const METHOD_NOT_FOUND: Self = Self(-32601); + pub const INVALID_PARAMS: Self = Self(-32602); + pub const INTERNAL_ERROR: Self = Self(-32603); + pub const PARSE_ERROR: Self = Self(-32700); +} + +/// Error information for JSON-RPC error responses. +/// +/// This structure follows the JSON-RPC 2.0 specification for error reporting, +/// providing a standardized way to communicate errors between clients and servers. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct ErrorData { + /// The error type that occurred (using standard JSON-RPC error codes) + pub code: ErrorCode, + + /// A short description of the error. The message SHOULD be limited to a concise single sentence. + pub message: Cow<'static, str>, + + /// Additional information about the error. The value of this member is defined by the + /// sender (e.g. detailed error information, nested errors etc.). + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, +} + +impl ErrorData { + pub fn new( + code: ErrorCode, + message: impl Into>, + data: Option, + ) -> Self { + Self { + code, + message: message.into(), + data, + } + } + pub fn resource_not_found(message: impl Into>, data: Option) -> Self { + Self::new(ErrorCode::RESOURCE_NOT_FOUND, message, data) + } + pub fn parse_error(message: impl Into>, data: Option) -> Self { + Self::new(ErrorCode::PARSE_ERROR, message, data) + } + pub fn invalid_request(message: impl Into>, data: Option) -> Self { + Self::new(ErrorCode::INVALID_REQUEST, message, data) + } + pub fn method_not_found() -> Self { + Self::new(ErrorCode::METHOD_NOT_FOUND, M::VALUE, None) + } + pub fn invalid_params(message: impl Into>, data: Option) -> Self { + Self::new(ErrorCode::INVALID_PARAMS, message, data) + } + pub fn internal_error(message: impl Into>, data: Option) -> Self { + Self::new(ErrorCode::INTERNAL_ERROR, message, data) + } +} + +/// Represents any JSON-RPC message that can be sent or received. +/// +/// This enum covers all possible message types in the JSON-RPC protocol: +/// individual requests/responses, notifications, and errors. +/// It serves as the top-level message container for MCP communication. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(untagged)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub enum JsonRpcMessage { + /// A single request expecting a response + Request(JsonRpcRequest), + /// A response to a previous request + Response(JsonRpcResponse), + /// A one-way notification (no response expected) + Notification(JsonRpcNotification), + /// An error response + Error(JsonRpcError), +} + +impl JsonRpcMessage { + #[inline] + pub const fn request(request: Req, id: RequestId) -> Self { + JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: JsonRpcVersion2_0, + id, + request, + }) + } + #[inline] + pub const fn response(response: Resp, id: RequestId) -> Self { + JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: JsonRpcVersion2_0, + id, + result: response, + }) + } + #[inline] + pub const fn error(error: ErrorData, id: RequestId) -> Self { + JsonRpcMessage::Error(JsonRpcError { + jsonrpc: JsonRpcVersion2_0, + id, + error, + }) + } + #[inline] + pub const fn notification(notification: Not) -> Self { + JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: JsonRpcVersion2_0, + notification, + }) + } + pub fn into_request(self) -> Option<(Req, RequestId)> { + match self { + JsonRpcMessage::Request(r) => Some((r.request, r.id)), + _ => None, + } + } + pub fn into_response(self) -> Option<(Resp, RequestId)> { + match self { + JsonRpcMessage::Response(r) => Some((r.result, r.id)), + _ => None, + } + } + pub fn into_notification(self) -> Option { + match self { + JsonRpcMessage::Notification(n) => Some(n.notification), + _ => None, + } + } + pub fn into_error(self) -> Option<(ErrorData, RequestId)> { + match self { + JsonRpcMessage::Error(e) => Some((e.error, e.id)), + _ => None, + } + } + pub fn into_result(self) -> Option<(Result, RequestId)> { + match self { + JsonRpcMessage::Response(r) => Some((Ok(r.result), r.id)), + JsonRpcMessage::Error(e) => Some((Err(e.error), e.id)), + + _ => None, + } + } +} + +// ============================================================================= +// INITIALIZATION AND CONNECTION SETUP +// ============================================================================= + +/// # Empty result +/// A response that indicates success but carries no data. +pub type EmptyResult = EmptyObject; + +impl From<()> for EmptyResult { + fn from(_value: ()) -> Self { + EmptyResult {} + } +} + +impl From for () { + fn from(_value: EmptyResult) {} +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct CancelledNotificationParam { + pub request_id: RequestId, + pub reason: Option, +} + +const_string!(CancelledNotificationMethod = "notifications/cancelled"); + +/// # Cancellation +/// This notification can be sent by either side to indicate that it is cancelling a previously-issued request. +/// +/// The request SHOULD still be in-flight, but due to communication latency, it is always possible that this notification MAY arrive after the request has already finished. +/// +/// This notification indicates that the result will be unused, so any associated processing SHOULD cease. +/// +/// A client MUST NOT attempt to cancel its `initialize` request. +pub type CancelledNotification = + Notification; + +const_string!(InitializeResultMethod = "initialize"); +/// # Initialization +/// This request is sent from the client to the server when it first connects, asking it to begin initialization. +pub type InitializeRequest = Request; + +const_string!(InitializedNotificationMethod = "notifications/initialized"); +/// This notification is sent from the client to the server after initialization has finished. +pub type InitializedNotification = NotificationNoParam; + +/// Parameters sent by a client when initializing a connection to an MCP server. +/// +/// This contains the client's protocol version, capabilities, and implementation +/// information, allowing the server to understand what the client supports. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct InitializeRequestParam { + /// The MCP protocol version this client supports + pub protocol_version: ProtocolVersion, + /// The capabilities this client supports (sampling, roots, etc.) + pub capabilities: ClientCapabilities, + /// Information about the client implementation + pub client_info: Implementation, +} + +/// The server's response to an initialization request. +/// +/// Contains the server's protocol version, capabilities, and implementation +/// information, along with optional instructions for the client. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct InitializeResult { + /// The MCP protocol version this server supports + pub protocol_version: ProtocolVersion, + /// The capabilities this server provides (tools, resources, prompts, etc.) + pub capabilities: ServerCapabilities, + /// Information about the server implementation + pub server_info: Implementation, + /// Optional human-readable instructions about using this server + #[serde(skip_serializing_if = "Option::is_none")] + pub instructions: Option, +} + +pub type ServerInfo = InitializeResult; +pub type ClientInfo = InitializeRequestParam; + +#[allow(clippy::derivable_impls)] +impl Default for ServerInfo { + fn default() -> Self { + ServerInfo { + protocol_version: ProtocolVersion::default(), + capabilities: ServerCapabilities::default(), + server_info: Implementation::from_build_env(), + instructions: None, + } + } +} + +#[allow(clippy::derivable_impls)] +impl Default for ClientInfo { + fn default() -> Self { + ClientInfo { + protocol_version: ProtocolVersion::default(), + capabilities: ClientCapabilities::default(), + client_info: Implementation::from_build_env(), + } + } +} + +/// A URL pointing to an icon resource or a base64-encoded data URI. +/// +/// Clients that support rendering icons MUST support at least the following MIME types: +/// - image/png - PNG images (safe, universal compatibility) +/// - image/jpeg (and image/jpg) - JPEG images (safe, universal compatibility) +/// +/// Clients that support rendering icons SHOULD also support: +/// - image/svg+xml - SVG images (scalable but requires security precautions) +/// - image/webp - WebP images (modern, efficient format) +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct Icon { + /// A standard URI pointing to an icon resource + pub src: String, + /// Optional override if the server's MIME type is missing or generic + #[serde(skip_serializing_if = "Option::is_none")] + pub mime_type: Option, + /// Size specification, each string should be in WxH format (e.g., `\"48x48\"`, `\"96x96\"`) or `\"any\"` for scalable formats like SVG + #[serde(skip_serializing_if = "Option::is_none")] + pub sizes: Option>, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct Implementation { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + pub version: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub icons: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub website_url: Option, +} + +impl Default for Implementation { + fn default() -> Self { + Self::from_build_env() + } +} + +impl Implementation { + pub fn from_build_env() -> Self { + Implementation { + name: env!("CARGO_CRATE_NAME").to_owned(), + title: None, + version: env!("CARGO_PKG_VERSION").to_owned(), + icons: None, + website_url: None, + } + } +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct PaginatedRequestParam { + #[serde(skip_serializing_if = "Option::is_none")] + pub cursor: Option, +} +// ============================================================================= +// PROGRESS AND PAGINATION +// ============================================================================= + +const_string!(PingRequestMethod = "ping"); +pub type PingRequest = RequestNoParam; + +const_string!(ProgressNotificationMethod = "notifications/progress"); +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct ProgressNotificationParam { + pub progress_token: ProgressToken, + /// The progress thus far. This should increase every time progress is made, even if the total is unknown. + pub progress: f64, + /// Total number of items to process (or total progress required), if known + #[serde(skip_serializing_if = "Option::is_none")] + pub total: Option, + /// An optional message describing the current progress. + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, +} + +pub type ProgressNotification = Notification; + +pub type Cursor = String; + +macro_rules! paginated_result { + ($t:ident { + $i_item: ident: $t_item: ty + }) => { + #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] + #[serde(rename_all = "camelCase")] + #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] + pub struct $t { + #[serde(skip_serializing_if = "Option::is_none")] + pub next_cursor: Option, + pub $i_item: $t_item, + } + + impl $t { + pub fn with_all_items( + items: $t_item, + ) -> Self { + Self { + next_cursor: None, + $i_item: items, + } + } + } + }; +} + +// ============================================================================= +// RESOURCE MANAGEMENT +// ============================================================================= + +const_string!(ListResourcesRequestMethod = "resources/list"); +/// Request to list all available resources from a server +pub type ListResourcesRequest = + RequestOptionalParam; + +paginated_result!(ListResourcesResult { + resources: Vec +}); + +const_string!(ListResourceTemplatesRequestMethod = "resources/templates/list"); +/// Request to list all available resource templates from a server +pub type ListResourceTemplatesRequest = + RequestOptionalParam; + +paginated_result!(ListResourceTemplatesResult { + resource_templates: Vec +}); + +const_string!(ReadResourceRequestMethod = "resources/read"); +/// Parameters for reading a specific resource +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct ReadResourceRequestParam { + /// The URI of the resource to read + pub uri: String, +} + +/// Result containing the contents of a read resource +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct ReadResourceResult { + /// The actual content of the resource + pub contents: Vec, +} + +/// Request to read a specific resource +pub type ReadResourceRequest = Request; + +const_string!(ResourceListChangedNotificationMethod = "notifications/resources/list_changed"); +/// Notification sent when the list of available resources changes +pub type ResourceListChangedNotification = + NotificationNoParam; + +const_string!(SubscribeRequestMethod = "resources/subscribe"); +/// Parameters for subscribing to resource updates +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct SubscribeRequestParam { + /// The URI of the resource to subscribe to + pub uri: String, +} +/// Request to subscribe to resource updates +pub type SubscribeRequest = Request; + +const_string!(UnsubscribeRequestMethod = "resources/unsubscribe"); +/// Parameters for unsubscribing from resource updates +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct UnsubscribeRequestParam { + /// The URI of the resource to unsubscribe from + pub uri: String, +} +/// Request to unsubscribe from resource updates +pub type UnsubscribeRequest = Request; + +const_string!(ResourceUpdatedNotificationMethod = "notifications/resources/updated"); +/// Parameters for a resource update notification +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct ResourceUpdatedNotificationParam { + /// The URI of the resource that was updated + pub uri: String, +} +/// Notification sent when a subscribed resource is updated +pub type ResourceUpdatedNotification = + Notification; + +// ============================================================================= +// PROMPT MANAGEMENT +// ============================================================================= + +const_string!(ListPromptsRequestMethod = "prompts/list"); +/// Request to list all available prompts from a server +pub type ListPromptsRequest = RequestOptionalParam; + +paginated_result!(ListPromptsResult { + prompts: Vec +}); + +const_string!(GetPromptRequestMethod = "prompts/get"); +/// Parameters for retrieving a specific prompt +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct GetPromptRequestParam { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option, +} +/// Request to get a specific prompt +pub type GetPromptRequest = Request; + +const_string!(PromptListChangedNotificationMethod = "notifications/prompts/list_changed"); +/// Notification sent when the list of available prompts changes +pub type PromptListChangedNotification = NotificationNoParam; + +const_string!(ToolListChangedNotificationMethod = "notifications/tools/list_changed"); +/// Notification sent when the list of available tools changes +pub type ToolListChangedNotification = NotificationNoParam; + +// ============================================================================= +// LOGGING +// ============================================================================= + +/// Logging levels supported by the MCP protocol +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Copy)] +#[serde(rename_all = "lowercase")] //match spec +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub enum LoggingLevel { + Debug, + Info, + Notice, + Warning, + Error, + Critical, + Alert, + Emergency, +} + +const_string!(SetLevelRequestMethod = "logging/setLevel"); +/// Parameters for setting the logging level +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct SetLevelRequestParam { + /// The desired logging level + pub level: LoggingLevel, +} +/// Request to set the logging level +pub type SetLevelRequest = Request; + +const_string!(LoggingMessageNotificationMethod = "notifications/message"); +/// Parameters for a logging message notification +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct LoggingMessageNotificationParam { + /// The severity level of this log message + pub level: LoggingLevel, + /// Optional logger name that generated this message + #[serde(skip_serializing_if = "Option::is_none")] + pub logger: Option, + /// The actual log data + pub data: Value, +} +/// Notification containing a log message +pub type LoggingMessageNotification = + Notification; + +// ============================================================================= +// SAMPLING (LLM INTERACTION) +// ============================================================================= + +const_string!(CreateMessageRequestMethod = "sampling/createMessage"); +pub type CreateMessageRequest = Request; + +/// Represents the role of a participant in a conversation or message exchange. +/// +/// Used in sampling and chat contexts to distinguish between different +/// types of message senders in the conversation flow. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub enum Role { + /// A human user or client making a request + User, + /// An AI assistant or server providing a response + Assistant, +} + +/// A message in a sampling conversation, containing a role and content. +/// +/// This represents a single message in a conversation flow, used primarily +/// in LLM sampling requests where the conversation history is important +/// for generating appropriate responses. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct SamplingMessage { + /// The role of the message sender (User or Assistant) + pub role: Role, + /// The actual content of the message (text, image, etc.) + pub content: Content, +} + +/// Specifies how much context should be included in sampling requests. +/// +/// This allows clients to control what additional context information +/// should be provided to the LLM when processing sampling requests. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub enum ContextInclusion { + /// Include context from all connected MCP servers + #[serde(rename = "allServers")] + AllServers, + /// Include no additional context + #[serde(rename = "none")] + None, + /// Include context only from the requesting server + #[serde(rename = "thisServer")] + ThisServer, +} + +/// Parameters for creating a message through LLM sampling. +/// +/// This structure contains all the necessary information for a client to +/// generate an LLM response, including conversation history, model preferences, +/// and generation parameters. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct CreateMessageRequestParam { + /// The conversation history and current messages + pub messages: Vec, + /// Preferences for model selection and behavior + #[serde(skip_serializing_if = "Option::is_none")] + pub model_preferences: Option, + /// System prompt to guide the model's behavior + #[serde(skip_serializing_if = "Option::is_none")] + pub system_prompt: Option, + /// How much context to include from MCP servers + #[serde(skip_serializing_if = "Option::is_none")] + pub include_context: Option, + /// Temperature for controlling randomness (0.0 to 1.0) + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + /// Maximum number of tokens to generate + pub max_tokens: u32, + /// Sequences that should stop generation + #[serde(skip_serializing_if = "Option::is_none")] + pub stop_sequences: Option>, + /// Additional metadata for the request + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, +} + +/// Preferences for model selection and behavior in sampling requests. +/// +/// This allows servers to express their preferences for which model to use +/// and how to balance different priorities when the client has multiple +/// model options available. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct ModelPreferences { + /// Specific model names or families to prefer (e.g., "claude", "gpt") + #[serde(skip_serializing_if = "Option::is_none")] + pub hints: Option>, + /// Priority for cost optimization (0.0 to 1.0, higher = prefer cheaper models) + #[serde(skip_serializing_if = "Option::is_none")] + pub cost_priority: Option, + /// Priority for speed/latency (0.0 to 1.0, higher = prefer faster models) + #[serde(skip_serializing_if = "Option::is_none")] + pub speed_priority: Option, + /// Priority for intelligence/capability (0.0 to 1.0, higher = prefer more capable models) + #[serde(skip_serializing_if = "Option::is_none")] + pub intelligence_priority: Option, +} + +/// A hint suggesting a preferred model name or family. +/// +/// Model hints are advisory suggestions that help clients choose appropriate +/// models. They can be specific model names or general families like "claude" or "gpt". +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct ModelHint { + /// The suggested model name or family identifier + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, +} + +// ============================================================================= +// COMPLETION AND AUTOCOMPLETE +// ============================================================================= + +/// Context for completion requests providing previously resolved arguments. +/// +/// This enables context-aware completion where subsequent argument completions +/// can take into account the values of previously resolved arguments. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct CompletionContext { + /// Previously resolved argument values that can inform completion suggestions + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option>, +} + +impl CompletionContext { + /// Create a new empty completion context + pub fn new() -> Self { + Self::default() + } + + /// Create a completion context with the given arguments + pub fn with_arguments(arguments: std::collections::HashMap) -> Self { + Self { + arguments: Some(arguments), + } + } + + /// Get a specific argument value by name + pub fn get_argument(&self, name: &str) -> Option<&String> { + self.arguments.as_ref()?.get(name) + } + + /// Check if the context has any arguments + pub fn has_arguments(&self) -> bool { + self.arguments.as_ref().is_some_and(|args| !args.is_empty()) + } + + /// Get all argument names + pub fn argument_names(&self) -> impl Iterator { + self.arguments + .as_ref() + .into_iter() + .flat_map(|args| args.keys()) + .map(|k| k.as_str()) + } +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct CompleteRequestParam { + pub r#ref: Reference, + pub argument: ArgumentInfo, + /// Optional context containing previously resolved argument values + #[serde(skip_serializing_if = "Option::is_none")] + pub context: Option, +} + +pub type CompleteRequest = Request; + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct CompletionInfo { + pub values: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub total: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub has_more: Option, +} + +impl CompletionInfo { + /// Maximum number of completion values allowed per response according to MCP specification + pub const MAX_VALUES: usize = 100; + + /// Create a new CompletionInfo with validation for maximum values + pub fn new(values: Vec) -> Result { + if values.len() > Self::MAX_VALUES { + return Err(format!( + "Too many completion values: {} (max: {})", + values.len(), + Self::MAX_VALUES + )); + } + Ok(Self { + values, + total: None, + has_more: None, + }) + } + + /// Create CompletionInfo with all values and no pagination + pub fn with_all_values(values: Vec) -> Result { + let completion = Self::new(values)?; + Ok(Self { + total: Some(completion.values.len() as u32), + has_more: Some(false), + ..completion + }) + } + + /// Create CompletionInfo with pagination information + pub fn with_pagination( + values: Vec, + total: Option, + has_more: bool, + ) -> Result { + let completion = Self::new(values)?; + Ok(Self { + total, + has_more: Some(has_more), + ..completion + }) + } + + /// Check if this completion response indicates more results are available + pub fn has_more_results(&self) -> bool { + self.has_more.unwrap_or(false) + } + + /// Get the total number of available completions, if known + pub fn total_available(&self) -> Option { + self.total + } + + /// Validate that the completion info complies with MCP specification + pub fn validate(&self) -> Result<(), String> { + if self.values.len() > Self::MAX_VALUES { + return Err(format!( + "Too many completion values: {} (max: {})", + self.values.len(), + Self::MAX_VALUES + )); + } + Ok(()) + } +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct CompleteResult { + pub completion: CompletionInfo, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(tag = "type")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub enum Reference { + #[serde(rename = "ref/resource")] + Resource(ResourceReference), + #[serde(rename = "ref/prompt")] + Prompt(PromptReference), +} + +impl Reference { + /// Create a prompt reference + pub fn for_prompt(name: impl Into) -> Self { + // Not accepting `title` currently as it'll break the API + // Until further decision, keep it `None`, modify later + // if required, add `title` to the API + Self::Prompt(PromptReference { + name: name.into(), + title: None, + }) + } + + /// Create a resource reference + pub fn for_resource(uri: impl Into) -> Self { + Self::Resource(ResourceReference { uri: uri.into() }) + } + + /// Get the reference type as a string + pub fn reference_type(&self) -> &'static str { + match self { + Self::Prompt(_) => "ref/prompt", + Self::Resource(_) => "ref/resource", + } + } + + /// Extract prompt name if this is a prompt reference + pub fn as_prompt_name(&self) -> Option<&str> { + match self { + Self::Prompt(prompt_ref) => Some(&prompt_ref.name), + _ => None, + } + } + + /// Extract resource URI if this is a resource reference + pub fn as_resource_uri(&self) -> Option<&str> { + match self { + Self::Resource(resource_ref) => Some(&resource_ref.uri), + _ => None, + } + } +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct ResourceReference { + pub uri: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct PromptReference { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, +} + +const_string!(CompleteRequestMethod = "completion/complete"); +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct ArgumentInfo { + pub name: String, + pub value: String, +} + +// ============================================================================= +// ROOTS AND WORKSPACE MANAGEMENT +// ============================================================================= + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct Root { + pub uri: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, +} + +const_string!(ListRootsRequestMethod = "roots/list"); +pub type ListRootsRequest = RequestNoParam; + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct ListRootsResult { + pub roots: Vec, +} + +const_string!(RootsListChangedNotificationMethod = "notifications/roots/list_changed"); +pub type RootsListChangedNotification = NotificationNoParam; + +// ============================================================================= +// ELICITATION (INTERACTIVE USER INPUT) +// ============================================================================= + +// Method constants for elicitation operations. +// Elicitation allows servers to request interactive input from users during tool execution. +const_string!(ElicitationCreateRequestMethod = "elicitation/create"); +const_string!(ElicitationResponseNotificationMethod = "notifications/elicitation/response"); + +/// Represents the possible actions a user can take in response to an elicitation request. +/// +/// When a server requests user input through elicitation, the user can: +/// - Accept: Provide the requested information and continue +/// - Decline: Refuse to provide the information but continue the operation +/// - Cancel: Stop the entire operation +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub enum ElicitationAction { + /// User accepts the request and provides the requested information + Accept, + /// User declines to provide the information but allows the operation to continue + Decline, + /// User cancels the entire operation + Cancel, +} + +/// Parameters for creating an elicitation request to gather user input. +/// +/// This structure contains everything needed to request interactive input from a user: +/// - A human-readable message explaining what information is needed +/// - A type-safe schema defining the expected structure of the response +/// +/// # Example +/// +/// ```rust +/// use rmcp::model::*; +/// +/// let params = CreateElicitationRequestParam { +/// message: "Please provide your email".to_string(), +/// requested_schema: ElicitationSchema::builder() +/// .required_email("email") +/// .build() +/// .unwrap(), +/// }; +/// ``` +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct CreateElicitationRequestParam { + /// Human-readable message explaining what input is needed from the user. + /// This should be clear and provide sufficient context for the user to understand + /// what information they need to provide. + pub message: String, + + /// Type-safe schema defining the expected structure and validation rules for the user's response. + /// This enforces the MCP 2025-06-18 specification that elicitation schemas must be objects + /// with primitive-typed properties. + pub requested_schema: ElicitationSchema, +} + +/// The result returned by a client in response to an elicitation request. +/// +/// Contains the user's decision (accept/decline/cancel) and optionally their input data +/// if they chose to accept the request. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct CreateElicitationResult { + /// The user's decision on how to handle the elicitation request + pub action: ElicitationAction, + + /// The actual data provided by the user, if they accepted the request. + /// Must conform to the JSON schema specified in the original request. + /// Only present when action is Accept. + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, +} + +/// Request type for creating an elicitation to gather user input +pub type CreateElicitationRequest = + Request; + +// ============================================================================= +// TOOL EXECUTION RESULTS +// ============================================================================= + +/// The result of a tool call operation. +/// +/// Contains the content returned by the tool execution and an optional +/// flag indicating whether the operation resulted in an error. +#[derive(Debug, Serialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct CallToolResult { + /// The content returned by the tool (text, images, etc.) + pub content: Vec, + /// An optional JSON object that represents the structured result of the tool call + #[serde(skip_serializing_if = "Option::is_none")] + pub structured_content: Option, + /// Whether this result represents an error condition + #[serde(skip_serializing_if = "Option::is_none")] + pub is_error: Option, + /// Optional protocol-level metadata for this result + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option, +} + +impl CallToolResult { + /// Create a successful tool result with unstructured content + pub fn success(content: Vec) -> Self { + CallToolResult { + content, + structured_content: None, + is_error: Some(false), + meta: None, + } + } + /// Create an error tool result with unstructured content + pub fn error(content: Vec) -> Self { + CallToolResult { + content, + structured_content: None, + is_error: Some(true), + meta: None, + } + } + /// Create a successful tool result with structured content + /// + /// # Example + /// + /// ```rust,ignore + /// use rmcp::model::CallToolResult; + /// use serde_json::json; + /// + /// let result = CallToolResult::structured(json!({ + /// "temperature": 22.5, + /// "humidity": 65, + /// "description": "Partly cloudy" + /// })); + /// ``` + pub fn structured(value: Value) -> Self { + CallToolResult { + content: vec![Content::text(value.to_string())], + structured_content: Some(value), + is_error: Some(false), + meta: None, + } + } + /// Create an error tool result with structured content + /// + /// # Example + /// + /// ```rust,ignore + /// use rmcp::model::CallToolResult; + /// use serde_json::json; + /// + /// let result = CallToolResult::structured_error(json!({ + /// "error_code": "INVALID_INPUT", + /// "message": "Temperature value out of range", + /// "details": { + /// "min": -50, + /// "max": 50, + /// "provided": 100 + /// } + /// })); + /// ``` + pub fn structured_error(value: Value) -> Self { + CallToolResult { + content: vec![Content::text(value.to_string())], + structured_content: Some(value), + is_error: Some(true), + meta: None, + } + } + + /// Convert the `structured_content` part of response into a certain type. + /// + /// # About json schema validation + /// Since rust is a strong type language, we don't need to do json schema validation here. + /// + /// But if you do have to validate the response data, you can use [`jsonschema`](https://crates.io/crates/jsonschema) crate. + pub fn into_typed(self) -> Result + where + T: DeserializeOwned, + { + let raw_text = match (self.structured_content, &self.content.first()) { + (Some(value), _) => return serde_json::from_value(value), + (None, Some(contents)) => { + if let Some(text) = contents.as_text() { + let text = &text.text; + Some(text) + } else { + None + } + } + (None, None) => None, + }; + if let Some(text) = raw_text { + return serde_json::from_str(text); + } + serde_json::from_value(serde_json::Value::Null) + } +} + +// Custom deserialize implementation to validate mutual exclusivity +impl<'de> Deserialize<'de> for CallToolResult { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + #[serde(rename_all = "camelCase")] + struct CallToolResultHelper { + #[serde(skip_serializing_if = "Option::is_none")] + content: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + structured_content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + is_error: Option, + /// Accept `_meta` during deserialization + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + meta: Option, + } + + let helper = CallToolResultHelper::deserialize(deserializer)?; + let result = CallToolResult { + content: helper.content.unwrap_or_default(), + structured_content: helper.structured_content, + is_error: helper.is_error, + meta: helper.meta, + }; + + // Validate mutual exclusivity + if result.content.is_empty() && result.structured_content.is_none() { + return Err(serde::de::Error::custom( + "CallToolResult must have either content or structured_content", + )); + } + + Ok(result) + } +} + +const_string!(ListToolsRequestMethod = "tools/list"); +/// Request to list all available tools from a server +pub type ListToolsRequest = RequestOptionalParam; + +paginated_result!( + ListToolsResult { + tools: Vec + } +); + +const_string!(CallToolRequestMethod = "tools/call"); +/// Parameters for calling a tool provided by an MCP server. +/// +/// Contains the tool name and optional arguments needed to execute +/// the tool operation. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct CallToolRequestParam { + /// The name of the tool to call + pub name: Cow<'static, str>, + /// Arguments to pass to the tool (must match the tool's input schema) + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option, +} + +/// Request to call a specific tool +pub type CallToolRequest = Request; + +/// The result of a sampling/createMessage request containing the generated response. +/// +/// This structure contains the generated message along with metadata about +/// how the generation was performed and why it stopped. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct CreateMessageResult { + /// The identifier of the model that generated the response + pub model: String, + /// The reason why generation stopped (e.g., "endTurn", "maxTokens") + #[serde(skip_serializing_if = "Option::is_none")] + pub stop_reason: Option, + /// The generated message with role and content + #[serde(flatten)] + pub message: SamplingMessage, +} + +impl CreateMessageResult { + pub const STOP_REASON_END_TURN: &str = "endTurn"; + pub const STOP_REASON_END_SEQUENCE: &str = "stopSequence"; + pub const STOP_REASON_END_MAX_TOKEN: &str = "maxTokens"; +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct GetPromptResult { + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + pub messages: Vec, +} + +// ============================================================================= +// MESSAGE TYPE UNIONS +// ============================================================================= + +macro_rules! ts_union { + ( + export type $U:ident = + $($rest:tt)* + ) => { + ts_union!(@declare $U { $($rest)* }); + ts_union!(@impl_from $U { $($rest)* }); + }; + (@declare $U:ident { $($variant:tt)* }) => { + ts_union!(@declare_variant $U { } {$($variant)*} ); + }; + (@declare_variant $U:ident { $($declared:tt)* } {$(|)? box $V:ident $($rest:tt)*}) => { + ts_union!(@declare_variant $U { $($declared)* $V(Box<$V>), } {$($rest)*}); + }; + (@declare_variant $U:ident { $($declared:tt)* } {$(|)? $V:ident $($rest:tt)*}) => { + ts_union!(@declare_variant $U { $($declared)* $V($V), } {$($rest)*}); + }; + (@declare_variant $U:ident { $($declared:tt)* } { ; }) => { + ts_union!(@declare_end $U { $($declared)* } ); + }; + (@declare_end $U:ident { $($declared:tt)* }) => { + #[derive(Debug, Serialize, Deserialize, Clone)] + #[serde(untagged)] + #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] + pub enum $U { + $($declared)* + } + }; + (@impl_from $U: ident {$(|)? box $V:ident $($rest:tt)*}) => { + impl From<$V> for $U { + fn from(value: $V) -> Self { + $U::$V(Box::new(value)) + } + } + ts_union!(@impl_from $U {$($rest)*}); + }; + (@impl_from $U: ident {$(|)? $V:ident $($rest:tt)*}) => { + impl From<$V> for $U { + fn from(value: $V) -> Self { + $U::$V(value) + } + } + ts_union!(@impl_from $U {$($rest)*}); + }; + (@impl_from $U: ident { ; }) => {}; + (@impl_from $U: ident { }) => {}; +} + +ts_union!( + export type ClientRequest = + | PingRequest + | InitializeRequest + | CompleteRequest + | SetLevelRequest + | GetPromptRequest + | ListPromptsRequest + | ListResourcesRequest + | ListResourceTemplatesRequest + | ReadResourceRequest + | SubscribeRequest + | UnsubscribeRequest + | CallToolRequest + | ListToolsRequest; +); + +impl ClientRequest { + pub fn method(&self) -> &'static str { + match &self { + ClientRequest::PingRequest(r) => r.method.as_str(), + ClientRequest::InitializeRequest(r) => r.method.as_str(), + ClientRequest::CompleteRequest(r) => r.method.as_str(), + ClientRequest::SetLevelRequest(r) => r.method.as_str(), + ClientRequest::GetPromptRequest(r) => r.method.as_str(), + ClientRequest::ListPromptsRequest(r) => r.method.as_str(), + ClientRequest::ListResourcesRequest(r) => r.method.as_str(), + ClientRequest::ListResourceTemplatesRequest(r) => r.method.as_str(), + ClientRequest::ReadResourceRequest(r) => r.method.as_str(), + ClientRequest::SubscribeRequest(r) => r.method.as_str(), + ClientRequest::UnsubscribeRequest(r) => r.method.as_str(), + ClientRequest::CallToolRequest(r) => r.method.as_str(), + ClientRequest::ListToolsRequest(r) => r.method.as_str(), + } + } +} + +ts_union!( + export type ClientNotification = + | CancelledNotification + | ProgressNotification + | InitializedNotification + | RootsListChangedNotification; +); + +ts_union!( + export type ClientResult = box CreateMessageResult | ListRootsResult | CreateElicitationResult | EmptyResult; +); + +impl ClientResult { + pub fn empty(_: ()) -> ClientResult { + ClientResult::EmptyResult(EmptyResult {}) + } +} + +pub type ClientJsonRpcMessage = JsonRpcMessage; + +ts_union!( + export type ServerRequest = + | PingRequest + | CreateMessageRequest + | ListRootsRequest + | CreateElicitationRequest; +); + +ts_union!( + export type ServerNotification = + | CancelledNotification + | ProgressNotification + | LoggingMessageNotification + | ResourceUpdatedNotification + | ResourceListChangedNotification + | ToolListChangedNotification + | PromptListChangedNotification; +); + +ts_union!( + export type ServerResult = + | InitializeResult + | CompleteResult + | GetPromptResult + | ListPromptsResult + | ListResourcesResult + | ListResourceTemplatesResult + | ReadResourceResult + | CallToolResult + | ListToolsResult + | CreateElicitationResult + | EmptyResult + ; +); + +impl ServerResult { + pub fn empty(_: ()) -> ServerResult { + ServerResult::EmptyResult(EmptyResult {}) + } +} + +pub type ServerJsonRpcMessage = JsonRpcMessage; + +impl TryInto for ServerNotification { + type Error = ServerNotification; + fn try_into(self) -> Result { + if let ServerNotification::CancelledNotification(t) = self { + Ok(t) + } else { + Err(self) + } + } +} + +impl TryInto for ClientNotification { + type Error = ClientNotification; + fn try_into(self) -> Result { + if let ClientNotification::CancelledNotification(t) = self { + Ok(t) + } else { + Err(self) + } + } +} + +// ============================================================================= +// TESTS +// ============================================================================= + +#[cfg(test)] +mod tests { + use serde_json::json; + + use super::*; + + #[test] + fn test_notification_serde() { + let raw = json!( { + "jsonrpc": JsonRpcVersion2_0, + "method": InitializedNotificationMethod, + }); + let message: ClientJsonRpcMessage = + serde_json::from_value(raw.clone()).expect("invalid notification"); + match &message { + ClientJsonRpcMessage::Notification(JsonRpcNotification { + notification: ClientNotification::InitializedNotification(_n), + .. + }) => {} + _ => panic!("Expected Notification"), + } + let json = serde_json::to_value(message).expect("valid json"); + assert_eq!(json, raw); + } + + #[test] + fn test_request_conversion() { + let raw = json!( { + "jsonrpc": JsonRpcVersion2_0, + "id": 1, + "method": "request", + "params": {"key": "value"}, + }); + let message: JsonRpcMessage = serde_json::from_value(raw.clone()).expect("invalid request"); + + match &message { + JsonRpcMessage::Request(r) => { + assert_eq!(r.id, RequestId::Number(1)); + assert_eq!(r.request.method, "request"); + assert_eq!( + &r.request.params, + json!({"key": "value"}) + .as_object() + .expect("should be an object") + ); + } + _ => panic!("Expected Request"), + } + let json = serde_json::to_value(&message).expect("valid json"); + assert_eq!(json, raw); + } + + #[test] + fn test_initial_request_response_serde() { + let request = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": { + "roots": { + "listChanged": true + }, + "sampling": {} + }, + "clientInfo": { + "name": "ExampleClient", + "version": "1.0.0" + } + } + }); + let raw_response_json = json!({ + "jsonrpc": "2.0", + "id": 1, + "result": { + "protocolVersion": "2024-11-05", + "capabilities": { + "logging": {}, + "prompts": { + "listChanged": true + }, + "resources": { + "subscribe": true, + "listChanged": true + }, + "tools": { + "listChanged": true + } + }, + "serverInfo": { + "name": "ExampleServer", + "version": "1.0.0" + } + } + }); + let request: ClientJsonRpcMessage = + serde_json::from_value(request.clone()).expect("invalid request"); + let (request, id) = request.into_request().expect("should be a request"); + assert_eq!(id, RequestId::Number(1)); + match request { + ClientRequest::InitializeRequest(Request { + method: _, + params: + InitializeRequestParam { + protocol_version: _, + capabilities, + client_info, + }, + .. + }) => { + assert_eq!(capabilities.roots.unwrap().list_changed, Some(true)); + assert_eq!(capabilities.sampling.unwrap().len(), 0); + assert_eq!(client_info.name, "ExampleClient"); + assert_eq!(client_info.version, "1.0.0"); + } + _ => panic!("Expected InitializeRequest"), + } + let server_response: ServerJsonRpcMessage = + serde_json::from_value(raw_response_json.clone()).expect("invalid response"); + let (response, id) = server_response + .clone() + .into_response() + .expect("expect response"); + assert_eq!(id, RequestId::Number(1)); + match response { + ServerResult::InitializeResult(InitializeResult { + protocol_version: _, + capabilities, + server_info, + instructions, + }) => { + assert_eq!(capabilities.logging.unwrap().len(), 0); + assert_eq!(capabilities.prompts.unwrap().list_changed, Some(true)); + assert_eq!( + capabilities.resources.as_ref().unwrap().subscribe, + Some(true) + ); + assert_eq!(capabilities.resources.unwrap().list_changed, Some(true)); + assert_eq!(capabilities.tools.unwrap().list_changed, Some(true)); + assert_eq!(server_info.name, "ExampleServer"); + assert_eq!(server_info.version, "1.0.0"); + assert_eq!(server_info.icons, None); + assert_eq!(instructions, None); + } + other => panic!("Expected InitializeResult, got {other:?}"), + } + + let server_response_json: Value = serde_json::to_value(&server_response).expect("msg"); + + assert_eq!(server_response_json, raw_response_json); + } + + #[test] + fn test_negative_and_large_request_ids() { + // Test negative ID + let negative_id_json = json!({ + "jsonrpc": "2.0", + "id": -1, + "method": "test", + "params": {} + }); + + let message: JsonRpcMessage = + serde_json::from_value(negative_id_json.clone()).expect("Should parse negative ID"); + + match &message { + JsonRpcMessage::Request(r) => { + assert_eq!(r.id, RequestId::Number(-1)); + } + _ => panic!("Expected Request"), + } + + // Test roundtrip serialization + let serialized = serde_json::to_value(&message).expect("Should serialize"); + assert_eq!(serialized, negative_id_json); + + // Test large negative ID + let large_negative_json = json!({ + "jsonrpc": "2.0", + "id": -9007199254740991i64, // JavaScript's MIN_SAFE_INTEGER + "method": "test", + "params": {} + }); + + let message: JsonRpcMessage = serde_json::from_value(large_negative_json.clone()) + .expect("Should parse large negative ID"); + + match &message { + JsonRpcMessage::Request(r) => { + assert_eq!(r.id, RequestId::Number(-9007199254740991i64)); + } + _ => panic!("Expected Request"), + } + + // Test large positive ID (JavaScript's MAX_SAFE_INTEGER) + let large_positive_json = json!({ + "jsonrpc": "2.0", + "id": 9007199254740991i64, + "method": "test", + "params": {} + }); + + let message: JsonRpcMessage = serde_json::from_value(large_positive_json.clone()) + .expect("Should parse large positive ID"); + + match &message { + JsonRpcMessage::Request(r) => { + assert_eq!(r.id, RequestId::Number(9007199254740991i64)); + } + _ => panic!("Expected Request"), + } + + // Test zero ID + let zero_id_json = json!({ + "jsonrpc": "2.0", + "id": 0, + "method": "test", + "params": {} + }); + + let message: JsonRpcMessage = + serde_json::from_value(zero_id_json.clone()).expect("Should parse zero ID"); + + match &message { + JsonRpcMessage::Request(r) => { + assert_eq!(r.id, RequestId::Number(0)); + } + _ => panic!("Expected Request"), + } + } + + #[test] + fn test_protocol_version_order() { + let v1 = ProtocolVersion::V_2024_11_05; + let v2 = ProtocolVersion::V_2025_03_26; + assert!(v1 < v2); + } + + #[test] + fn test_icon_serialization() { + let icon = Icon { + src: "https://example.com/icon.png".to_string(), + mime_type: Some("image/png".to_string()), + sizes: Some(vec!["48x48".to_string()]), + }; + + let json = serde_json::to_value(&icon).unwrap(); + assert_eq!(json["src"], "https://example.com/icon.png"); + assert_eq!(json["mimeType"], "image/png"); + assert_eq!(json["sizes"][0], "48x48"); + + // Test deserialization + let deserialized: Icon = serde_json::from_value(json).unwrap(); + assert_eq!(deserialized, icon); + } + + #[test] + fn test_icon_minimal() { + let icon = Icon { + src: "data:image/svg+xml;base64,PHN2Zy8+".to_string(), + mime_type: None, + sizes: None, + }; + + let json = serde_json::to_value(&icon).unwrap(); + assert_eq!(json["src"], "data:image/svg+xml;base64,PHN2Zy8+"); + assert!(json.get("mimeType").is_none()); + assert!(json.get("sizes").is_none()); + } + + #[test] + fn test_implementation_with_icons() { + let implementation = Implementation { + name: "test-server".to_string(), + title: Some("Test Server".to_string()), + version: "1.0.0".to_string(), + icons: Some(vec![ + Icon { + src: "https://example.com/icon.png".to_string(), + mime_type: Some("image/png".to_string()), + sizes: Some(vec!["48x48".to_string()]), + }, + Icon { + src: "https://example.com/icon.svg".to_string(), + mime_type: Some("image/svg+xml".to_string()), + sizes: Some(vec!["any".to_string()]), + }, + ]), + website_url: Some("https://example.com".to_string()), + }; + + let json = serde_json::to_value(&implementation).unwrap(); + assert_eq!(json["name"], "test-server"); + assert_eq!(json["websiteUrl"], "https://example.com"); + assert!(json["icons"].is_array()); + assert_eq!(json["icons"][0]["src"], "https://example.com/icon.png"); + assert_eq!(json["icons"][0]["sizes"][0], "48x48"); + assert_eq!(json["icons"][1]["mimeType"], "image/svg+xml"); + assert_eq!(json["icons"][1]["sizes"][0], "any"); + } + + #[test] + fn test_backward_compatibility() { + // Test that old JSON without icons still deserializes correctly + let old_json = json!({ + "name": "legacy-server", + "version": "0.9.0" + }); + + let implementation: Implementation = serde_json::from_value(old_json).unwrap(); + assert_eq!(implementation.name, "legacy-server"); + assert_eq!(implementation.version, "0.9.0"); + assert_eq!(implementation.icons, None); + assert_eq!(implementation.website_url, None); + } + + #[test] + fn test_initialize_with_icons() { + let init_result = InitializeResult { + protocol_version: ProtocolVersion::default(), + capabilities: ServerCapabilities::default(), + server_info: Implementation { + name: "icon-server".to_string(), + title: None, + version: "2.0.0".to_string(), + icons: Some(vec![Icon { + src: "https://example.com/server.png".to_string(), + mime_type: Some("image/png".to_string()), + sizes: Some(vec!["48x48".to_string()]), + }]), + website_url: Some("https://docs.example.com".to_string()), + }, + instructions: None, + }; + + let json = serde_json::to_value(&init_result).unwrap(); + assert!(json["serverInfo"]["icons"].is_array()); + assert_eq!( + json["serverInfo"]["icons"][0]["src"], + "https://example.com/server.png" + ); + assert_eq!(json["serverInfo"]["icons"][0]["sizes"][0], "48x48"); + assert_eq!(json["serverInfo"]["websiteUrl"], "https://docs.example.com"); + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/model/annotated.rs b/code-rs/third_party/rmcp-0.8.3/src/model/annotated.rs new file mode 100644 index 00000000000..f9921146a46 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/model/annotated.rs @@ -0,0 +1,224 @@ +use std::ops::{Deref, DerefMut}; + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +use super::{ + RawAudioContent, RawContent, RawEmbeddedResource, RawImageContent, RawResource, + RawResourceTemplate, RawTextContent, Role, +}; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct Annotations { + #[serde(skip_serializing_if = "Option::is_none")] + pub audience: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub priority: Option, + #[serde(skip_serializing_if = "Option::is_none", rename = "lastModified")] + pub last_modified: Option>, +} + +impl Annotations { + /// Creates a new Annotations instance specifically for resources + /// optional priority, and a timestamp (defaults to now if None) + pub fn for_resource(priority: f32, timestamp: DateTime) -> Self { + assert!( + (0.0..=1.0).contains(&priority), + "Priority {priority} must be between 0.0 and 1.0" + ); + Annotations { + priority: Some(priority), + last_modified: Some(timestamp), + audience: None, + } + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct Annotated { + #[serde(flatten)] + pub raw: T, + #[serde(skip_serializing_if = "Option::is_none")] + pub annotations: Option, +} + +impl Deref for Annotated { + type Target = T; + fn deref(&self) -> &Self::Target { + &self.raw + } +} + +impl DerefMut for Annotated { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.raw + } +} + +impl Annotated { + pub fn new(raw: T, annotations: Option) -> Self { + Self { raw, annotations } + } + pub fn remove_annotation(&mut self) -> Option { + self.annotations.take() + } + pub fn audience(&self) -> Option<&Vec> { + self.annotations.as_ref().and_then(|a| a.audience.as_ref()) + } + pub fn priority(&self) -> Option { + self.annotations.as_ref().and_then(|a| a.priority) + } + pub fn timestamp(&self) -> Option> { + self.annotations.as_ref().and_then(|a| a.last_modified) + } + pub fn with_audience(self, audience: Vec) -> Annotated + where + Self: Sized, + { + if let Some(annotations) = self.annotations { + Annotated { + raw: self.raw, + annotations: Some(Annotations { + audience: Some(audience), + ..annotations + }), + } + } else { + Annotated { + raw: self.raw, + annotations: Some(Annotations { + audience: Some(audience), + priority: None, + last_modified: None, + }), + } + } + } + pub fn with_priority(self, priority: f32) -> Annotated + where + Self: Sized, + { + if let Some(annotations) = self.annotations { + Annotated { + raw: self.raw, + annotations: Some(Annotations { + priority: Some(priority), + ..annotations + }), + } + } else { + Annotated { + raw: self.raw, + annotations: Some(Annotations { + priority: Some(priority), + last_modified: None, + audience: None, + }), + } + } + } + pub fn with_timestamp(self, timestamp: DateTime) -> Annotated + where + Self: Sized, + { + if let Some(annotations) = self.annotations { + Annotated { + raw: self.raw, + annotations: Some(Annotations { + last_modified: Some(timestamp), + ..annotations + }), + } + } else { + Annotated { + raw: self.raw, + annotations: Some(Annotations { + last_modified: Some(timestamp), + priority: None, + audience: None, + }), + } + } + } + pub fn with_timestamp_now(self) -> Annotated + where + Self: Sized, + { + self.with_timestamp(Utc::now()) + } +} + +mod sealed { + pub trait Sealed {} +} +macro_rules! annotate { + ($T: ident) => { + impl sealed::Sealed for $T {} + impl AnnotateAble for $T {} + }; +} + +annotate!(RawContent); +annotate!(RawTextContent); +annotate!(RawImageContent); +annotate!(RawAudioContent); +annotate!(RawEmbeddedResource); +annotate!(RawResource); +annotate!(RawResourceTemplate); + +pub trait AnnotateAble: sealed::Sealed { + fn optional_annotate(self, annotations: Option) -> Annotated + where + Self: Sized, + { + Annotated::new(self, annotations) + } + fn annotate(self, annotations: Annotations) -> Annotated + where + Self: Sized, + { + Annotated::new(self, Some(annotations)) + } + fn no_annotation(self) -> Annotated + where + Self: Sized, + { + Annotated::new(self, None) + } + fn with_audience(self, audience: Vec) -> Annotated + where + Self: Sized, + { + self.annotate(Annotations { + audience: Some(audience), + ..Default::default() + }) + } + fn with_priority(self, priority: f32) -> Annotated + where + Self: Sized, + { + self.annotate(Annotations { + priority: Some(priority), + ..Default::default() + }) + } + fn with_timestamp(self, timestamp: DateTime) -> Annotated + where + Self: Sized, + { + self.annotate(Annotations { + last_modified: Some(timestamp), + ..Default::default() + }) + } + fn with_timestamp_now(self) -> Annotated + where + Self: Sized, + { + self.with_timestamp(Utc::now()) + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/model/capabilities.rs b/code-rs/third_party/rmcp-0.8.3/src/model/capabilities.rs new file mode 100644 index 00000000000..7399141f990 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/model/capabilities.rs @@ -0,0 +1,346 @@ +use std::{collections::BTreeMap, marker::PhantomData}; + +use paste::paste; +use serde::{Deserialize, Serialize}; + +use super::JsonObject; +pub type ExperimentalCapabilities = BTreeMap; + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct PromptsCapability { + #[serde(skip_serializing_if = "Option::is_none")] + pub list_changed: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct ResourcesCapability { + #[serde(skip_serializing_if = "Option::is_none")] + pub subscribe: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub list_changed: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct ToolsCapability { + #[serde(skip_serializing_if = "Option::is_none")] + pub list_changed: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct RootsCapabilities { + #[serde(skip_serializing_if = "Option::is_none")] + pub list_changed: Option, +} + +/// Capability for handling elicitation requests from servers. +/// +/// Elicitation allows servers to request interactive input from users during tool execution. +/// This capability indicates that a client can handle elicitation requests and present +/// appropriate UI to users for collecting the requested information. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct ElicitationCapability { + /// Whether the client supports JSON Schema validation for elicitation responses. + /// When true, the client will validate user input against the requested_schema + /// before sending the response back to the server. + #[serde(skip_serializing_if = "Option::is_none")] + pub schema_validation: Option, +} + +/// +/// # Builder +/// ```rust +/// # use rmcp::model::ClientCapabilities; +/// let cap = ClientCapabilities::builder() +/// .enable_experimental() +/// .enable_roots() +/// .enable_roots_list_changed() +/// .build(); +/// ``` +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct ClientCapabilities { + #[serde(skip_serializing_if = "Option::is_none")] + pub experimental: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub roots: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub sampling: Option, + /// Capability to handle elicitation requests from servers for interactive user input + #[serde(skip_serializing_if = "Option::is_none")] + pub elicitation: Option, +} + +/// +/// ## Builder +/// ```rust +/// # use rmcp::model::ServerCapabilities; +/// let cap = ServerCapabilities::builder() +/// .enable_logging() +/// .enable_experimental() +/// .enable_prompts() +/// .enable_resources() +/// .enable_tools() +/// .enable_tool_list_changed() +/// .build(); +/// ``` +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct ServerCapabilities { + #[serde(skip_serializing_if = "Option::is_none")] + pub experimental: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub logging: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub completions: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompts: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub resources: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option, +} + +macro_rules! builder { + ($Target: ident {$($f: ident: $T: ty),* $(,)?}) => { + paste! { + #[derive(Default, Clone, Copy, Debug)] + pub struct [<$Target BuilderState>]< + $(const [<$f:upper>]: bool = false,)* + >; + #[derive(Debug, Default)] + pub struct [<$Target Builder>]]> { + $(pub $f: Option<$T>,)* + pub state: PhantomData + } + impl $Target { + #[doc = "Create a new [`" $Target "`] builder."] + pub fn builder() -> [<$Target Builder>] { + <[<$Target Builder>]>::default() + } + } + impl [<$Target Builder>] { + pub fn build(self) -> $Target { + $Target { + $( $f: self.$f, )* + } + } + } + impl From<[<$Target Builder>]> for $Target { + fn from(builder: [<$Target Builder>]) -> Self { + builder.build() + } + } + } + builder!($Target @toggle $($f: $T,) *); + + }; + ($Target: ident @toggle $f0: ident: $T0: ty, $($f: ident: $T: ty,)*) => { + builder!($Target @toggle [][$f0: $T0][$($f: $T,)*]); + }; + ($Target: ident @toggle [$($ff: ident: $Tf: ty,)*][$fn: ident: $TN: ty][$fn_1: ident: $Tn_1: ty, $($ft: ident: $Tt: ty,)*]) => { + builder!($Target @impl_toggle [$($ff: $Tf,)*][$fn: $TN][$fn_1: $Tn_1, $($ft:$Tt,)*]); + builder!($Target @toggle [$($ff: $Tf,)* $fn: $TN,][$fn_1: $Tn_1][$($ft:$Tt,)*]); + }; + ($Target: ident @toggle [$($ff: ident: $Tf: ty,)*][$fn: ident: $TN: ty][]) => { + builder!($Target @impl_toggle [$($ff: $Tf,)*][$fn: $TN][]); + }; + ($Target: ident @impl_toggle [$($ff: ident: $Tf: ty,)*][$fn: ident: $TN: ty][$($ft: ident: $Tt: ty,)*]) => { + paste! { + impl< + $(const [<$ff:upper>]: bool,)* + $(const [<$ft:upper>]: bool,)* + > [<$Target Builder>]<[<$Target BuilderState>]< + $([<$ff:upper>],)* + false, + $([<$ft:upper>],)* + >> { + pub fn [](self) -> [<$Target Builder>]<[<$Target BuilderState>]< + $([<$ff:upper>],)* + true, + $([<$ft:upper>],)* + >> { + [<$Target Builder>] { + $( $ff: self.$ff, )* + $fn: Some($TN::default()), + $( $ft: self.$ft, )* + state: PhantomData + } + } + pub fn [](self, $fn: $TN) -> [<$Target Builder>]<[<$Target BuilderState>]< + $([<$ff:upper>],)* + true, + $([<$ft:upper>],)* + >> { + [<$Target Builder>] { + $( $ff: self.$ff, )* + $fn: Some($fn), + $( $ft: self.$ft, )* + state: PhantomData + } + } + } + // do we really need to disable some thing in builder? + // impl< + // $(const [<$ff:upper>]: bool,)* + // $(const [<$ft:upper>]: bool,)* + // > [<$Target Builder>]<[<$Target BuilderState>]< + // $([<$ff:upper>],)* + // true, + // $([<$ft:upper>],)* + // >> { + // pub fn [](self) -> [<$Target Builder>]<[<$Target BuilderState>]< + // $([<$ff:upper>],)* + // false, + // $([<$ft:upper>],)* + // >> { + // [<$Target Builder>] { + // $( $ff: self.$ff, )* + // $fn: None, + // $( $ft: self.$ft, )* + // state: PhantomData + // } + // } + // } + } + } +} + +builder! { + ServerCapabilities { + experimental: ExperimentalCapabilities, + logging: JsonObject, + completions: JsonObject, + prompts: PromptsCapability, + resources: ResourcesCapability, + tools: ToolsCapability + } +} + +impl + ServerCapabilitiesBuilder> +{ + pub fn enable_tool_list_changed(mut self) -> Self { + if let Some(c) = self.tools.as_mut() { + c.list_changed = Some(true); + } + self + } +} + +impl + ServerCapabilitiesBuilder> +{ + pub fn enable_prompts_list_changed(mut self) -> Self { + if let Some(c) = self.prompts.as_mut() { + c.list_changed = Some(true); + } + self + } +} + +impl + ServerCapabilitiesBuilder> +{ + pub fn enable_resources_list_changed(mut self) -> Self { + if let Some(c) = self.resources.as_mut() { + c.list_changed = Some(true); + } + self + } + + pub fn enable_resources_subscribe(mut self) -> Self { + if let Some(c) = self.resources.as_mut() { + c.subscribe = Some(true); + } + self + } +} + +builder! { + ClientCapabilities{ + experimental: ExperimentalCapabilities, + roots: RootsCapabilities, + sampling: JsonObject, + elicitation: ElicitationCapability, + } +} + +impl + ClientCapabilitiesBuilder> +{ + pub fn enable_roots_list_changed(mut self) -> Self { + if let Some(c) = self.roots.as_mut() { + c.list_changed = Some(true); + } + self + } +} + +#[cfg(feature = "elicitation")] +impl + ClientCapabilitiesBuilder> +{ + /// Enable JSON Schema validation for elicitation responses. + /// When enabled, the client will validate user input against the requested_schema + /// before sending responses back to the server. + pub fn enable_elicitation_schema_validation(mut self) -> Self { + if let Some(c) = self.elicitation.as_mut() { + c.schema_validation = Some(true); + } + self + } +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn test_builder() { + let builder = ::default() + .enable_logging() + .enable_experimental() + .enable_prompts() + .enable_resources() + .enable_tools() + .enable_tool_list_changed(); + assert_eq!(builder.logging, Some(JsonObject::default())); + assert_eq!(builder.prompts, Some(PromptsCapability::default())); + assert_eq!(builder.resources, Some(ResourcesCapability::default())); + assert_eq!( + builder.tools, + Some(ToolsCapability { + list_changed: Some(true), + }) + ); + assert_eq!( + builder.experimental, + Some(ExperimentalCapabilities::default()) + ); + let client_builder = ::default() + .enable_experimental() + .enable_roots() + .enable_roots_list_changed() + .enable_sampling(); + assert_eq!( + client_builder.experimental, + Some(ExperimentalCapabilities::default()) + ); + assert_eq!( + client_builder.roots, + Some(RootsCapabilities { + list_changed: Some(true), + }) + ); + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/model/content.rs b/code-rs/third_party/rmcp-0.8.3/src/model/content.rs new file mode 100644 index 00000000000..7c60fafd953 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/model/content.rs @@ -0,0 +1,293 @@ +//! Content sent around agents, extensions, and LLMs +//! The various content types can be display to humans but also understood by models +//! They include optional annotations used to help inform agent usage +use serde::{Deserialize, Serialize}; +use serde_json::json; + +use super::{AnnotateAble, Annotated, resource::ResourceContents}; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct RawTextContent { + pub text: String, + /// Optional protocol-level metadata for this content block + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option, +} +pub type TextContent = Annotated; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct RawImageContent { + /// The base64-encoded image + pub data: String, + pub mime_type: String, + /// Optional protocol-level metadata for this content block + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option, +} + +pub type ImageContent = Annotated; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct RawEmbeddedResource { + /// Optional protocol-level metadata for this content block + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option, + pub resource: ResourceContents, +} +pub type EmbeddedResource = Annotated; + +impl EmbeddedResource { + pub fn get_text(&self) -> String { + match &self.resource { + ResourceContents::TextResourceContents { text, .. } => text.clone(), + _ => String::new(), + } + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct RawAudioContent { + pub data: String, + pub mime_type: String, +} + +pub type AudioContent = Annotated; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub enum RawContent { + Text(RawTextContent), + Image(RawImageContent), + Resource(RawEmbeddedResource), + Audio(RawAudioContent), + ResourceLink(super::resource::RawResource), +} + +pub type Content = Annotated; + +impl RawContent { + pub fn json(json: S) -> Result { + let json = serde_json::to_string(&json).map_err(|e| { + crate::ErrorData::internal_error( + "fail to serialize response to json", + Some(json!( + {"reason": e.to_string()} + )), + ) + })?; + Ok(RawContent::text(json)) + } + + pub fn text>(text: S) -> Self { + RawContent::Text(RawTextContent { + text: text.into(), + meta: None, + }) + } + + pub fn image, T: Into>(data: S, mime_type: T) -> Self { + RawContent::Image(RawImageContent { + data: data.into(), + mime_type: mime_type.into(), + meta: None, + }) + } + + pub fn resource(resource: ResourceContents) -> Self { + RawContent::Resource(RawEmbeddedResource { + meta: None, + resource, + }) + } + + pub fn embedded_text, T: Into>(uri: S, content: T) -> Self { + RawContent::Resource(RawEmbeddedResource { + meta: None, + resource: ResourceContents::TextResourceContents { + uri: uri.into(), + mime_type: Some("text".to_string()), + text: content.into(), + meta: None, + }, + }) + } + + /// Get the text content if this is a TextContent variant + pub fn as_text(&self) -> Option<&RawTextContent> { + match self { + RawContent::Text(text) => Some(text), + _ => None, + } + } + + /// Get the image content if this is an ImageContent variant + pub fn as_image(&self) -> Option<&RawImageContent> { + match self { + RawContent::Image(image) => Some(image), + _ => None, + } + } + + /// Get the resource content if this is an ImageContent variant + pub fn as_resource(&self) -> Option<&RawEmbeddedResource> { + match self { + RawContent::Resource(resource) => Some(resource), + _ => None, + } + } + + /// Get the resource link if this is a ResourceLink variant + pub fn as_resource_link(&self) -> Option<&super::resource::RawResource> { + match self { + RawContent::ResourceLink(link) => Some(link), + _ => None, + } + } + + /// Create a resource link content + pub fn resource_link(resource: super::resource::RawResource) -> Self { + RawContent::ResourceLink(resource) + } +} + +impl Content { + pub fn text>(text: S) -> Self { + RawContent::text(text).no_annotation() + } + + pub fn image, T: Into>(data: S, mime_type: T) -> Self { + RawContent::image(data, mime_type).no_annotation() + } + + pub fn resource(resource: ResourceContents) -> Self { + RawContent::resource(resource).no_annotation() + } + + pub fn embedded_text, T: Into>(uri: S, content: T) -> Self { + RawContent::embedded_text(uri, content).no_annotation() + } + + pub fn json(json: S) -> Result { + RawContent::json(json).map(|c| c.no_annotation()) + } + + /// Create a resource link content + pub fn resource_link(resource: super::resource::RawResource) -> Self { + RawContent::resource_link(resource).no_annotation() + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct JsonContent(S); +/// Types that can be converted into a list of contents +pub trait IntoContents { + fn into_contents(self) -> Vec; +} + +impl IntoContents for Content { + fn into_contents(self) -> Vec { + vec![self] + } +} + +impl IntoContents for String { + fn into_contents(self) -> Vec { + vec![Content::text(self)] + } +} + +impl IntoContents for () { + fn into_contents(self) -> Vec { + vec![] + } +} + +#[cfg(test)] +mod tests { + use serde_json; + + use super::*; + + #[test] + fn test_image_content_serialization() { + let image_content = RawImageContent { + data: "base64data".to_string(), + mime_type: "image/png".to_string(), + meta: None, + }; + + let json = serde_json::to_string(&image_content).unwrap(); + println!("ImageContent JSON: {}", json); + + // Verify it contains mimeType (camelCase) not mime_type (snake_case) + assert!(json.contains("mimeType")); + assert!(!json.contains("mime_type")); + } + + #[test] + fn test_audio_content_serialization() { + let audio_content = RawAudioContent { + data: "base64audiodata".to_string(), + mime_type: "audio/wav".to_string(), + }; + + let json = serde_json::to_string(&audio_content).unwrap(); + println!("AudioContent JSON: {}", json); + + // Verify it contains mimeType (camelCase) not mime_type (snake_case) + assert!(json.contains("mimeType")); + assert!(!json.contains("mime_type")); + } + + #[test] + fn test_resource_link_serialization() { + use super::super::resource::RawResource; + + let resource_link = RawContent::ResourceLink(RawResource { + uri: "file:///test.txt".to_string(), + name: "test.txt".to_string(), + title: None, + description: Some("A test file".to_string()), + mime_type: Some("text/plain".to_string()), + size: Some(100), + icons: None, + }); + + let json = serde_json::to_string(&resource_link).unwrap(); + println!("ResourceLink JSON: {}", json); + + // Verify it contains the correct type tag + assert!(json.contains("\"type\":\"resource_link\"")); + assert!(json.contains("\"uri\":\"file:///test.txt\"")); + assert!(json.contains("\"name\":\"test.txt\"")); + } + + #[test] + fn test_resource_link_deserialization() { + let json = r#"{ + "type": "resource_link", + "uri": "file:///example.txt", + "name": "example.txt", + "description": "Example file", + "mimeType": "text/plain" + }"#; + + let content: RawContent = serde_json::from_str(json).unwrap(); + + if let RawContent::ResourceLink(resource) = content { + assert_eq!(resource.uri, "file:///example.txt"); + assert_eq!(resource.name, "example.txt"); + assert_eq!(resource.description, Some("Example file".to_string())); + assert_eq!(resource.mime_type, Some("text/plain".to_string())); + } else { + panic!("Expected ResourceLink variant"); + } + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/model/elicitation_schema.rs b/code-rs/third_party/rmcp-0.8.3/src/model/elicitation_schema.rs new file mode 100644 index 00000000000..095e28e9883 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/model/elicitation_schema.rs @@ -0,0 +1,1180 @@ +//! Type-safe schema definitions for MCP elicitation requests. +//! +//! This module provides strongly-typed schema definitions for elicitation requests +//! that comply with the MCP 2025-06-18 specification. Elicitation schemas must be +//! objects with primitive-typed properties. +//! +//! # Example +//! +//! ```rust +//! use rmcp::model::*; +//! +//! let schema = ElicitationSchema::builder() +//! .required_email("email") +//! .required_integer("age", 0, 150) +//! .optional_bool("newsletter", false) +//! .build(); +//! ``` + +use std::{borrow::Cow, collections::BTreeMap}; + +use serde::{Deserialize, Serialize}; + +use crate::{const_string, model::ConstString}; + +// ============================================================================= +// CONST TYPES FOR JSON SCHEMA TYPE FIELD +// ============================================================================= + +const_string!(ObjectTypeConst = "object"); +const_string!(StringTypeConst = "string"); +const_string!(NumberTypeConst = "number"); +const_string!(IntegerTypeConst = "integer"); +const_string!(BooleanTypeConst = "boolean"); +const_string!(EnumTypeConst = "string"); + +// ============================================================================= +// PRIMITIVE SCHEMA DEFINITIONS +// ============================================================================= + +/// Primitive schema definition for elicitation properties. +/// +/// According to MCP 2025-06-18 specification, elicitation schemas must have +/// properties of primitive types only (string, number, integer, boolean, enum). +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[serde(untagged)] +pub enum PrimitiveSchema { + /// String property (with optional enum constraint) + String(StringSchema), + /// Number property (with optional enum constraint) + Number(NumberSchema), + /// Integer property (with optional enum constraint) + Integer(IntegerSchema), + /// Boolean property + Boolean(BooleanSchema), + /// Enum property (explicit enum schema) + Enum(EnumSchema), +} + +// ============================================================================= +// STRING SCHEMA +// ============================================================================= + +/// String format types allowed by the MCP specification. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[serde(rename_all = "kebab-case")] +pub enum StringFormat { + /// Email address format + Email, + /// URI format + Uri, + /// Date format (YYYY-MM-DD) + Date, + /// Date-time format (ISO 8601) + DateTime, +} + +/// Schema definition for string properties. +/// +/// Compliant with MCP 2025-06-18 specification for elicitation schemas. +/// Supports only the fields allowed by the MCP spec: +/// - format limited to: "email", "uri", "date", "date-time" +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[serde(rename_all = "camelCase")] +pub struct StringSchema { + /// Type discriminator + #[serde(rename = "type")] + pub type_: StringTypeConst, + + /// Optional title for the schema + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option>, + + /// Human-readable description + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option>, + + /// Minimum string length + #[serde(skip_serializing_if = "Option::is_none")] + pub min_length: Option, + + /// Maximum string length + #[serde(skip_serializing_if = "Option::is_none")] + pub max_length: Option, + + /// String format - limited to: "email", "uri", "date", "date-time" + #[serde(skip_serializing_if = "Option::is_none")] + pub format: Option, +} + +impl Default for StringSchema { + fn default() -> Self { + Self { + type_: StringTypeConst, + title: None, + description: None, + min_length: None, + max_length: None, + format: None, + } + } +} + +impl StringSchema { + /// Create a new string schema + pub fn new() -> Self { + Self::default() + } + + /// Create an email string schema + pub fn email() -> Self { + Self { + format: Some(StringFormat::Email), + ..Default::default() + } + } + + /// Create a URI string schema + pub fn uri() -> Self { + Self { + format: Some(StringFormat::Uri), + ..Default::default() + } + } + + /// Create a date string schema + pub fn date() -> Self { + Self { + format: Some(StringFormat::Date), + ..Default::default() + } + } + + /// Create a date-time string schema + pub fn date_time() -> Self { + Self { + format: Some(StringFormat::DateTime), + ..Default::default() + } + } + + /// Set title + pub fn title(mut self, title: impl Into>) -> Self { + self.title = Some(title.into()); + self + } + + /// Set description + pub fn description(mut self, description: impl Into>) -> Self { + self.description = Some(description.into()); + self + } + + /// Set minimum and maximum length + pub fn with_length(mut self, min: u32, max: u32) -> Result { + if min > max { + return Err("min_length must be <= max_length"); + } + self.min_length = Some(min); + self.max_length = Some(max); + Ok(self) + } + + /// Set minimum and maximum length (panics on invalid input) + pub fn length(mut self, min: u32, max: u32) -> Self { + assert!(min <= max, "min_length must be <= max_length"); + self.min_length = Some(min); + self.max_length = Some(max); + self + } + + /// Set minimum length + pub fn min_length(mut self, min: u32) -> Self { + self.min_length = Some(min); + self + } + + /// Set maximum length + pub fn max_length(mut self, max: u32) -> Self { + self.max_length = Some(max); + self + } + + /// Set format (limited to: "email", "uri", "date", "date-time") + pub fn format(mut self, format: StringFormat) -> Self { + self.format = Some(format); + self + } +} + +// ============================================================================= +// NUMBER SCHEMA +// ============================================================================= + +/// Schema definition for number properties (floating-point). +/// +/// Compliant with MCP 2025-06-18 specification for elicitation schemas. +/// Supports only the fields allowed by the MCP spec. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[serde(rename_all = "camelCase")] +pub struct NumberSchema { + /// Type discriminator + #[serde(rename = "type")] + pub type_: NumberTypeConst, + + /// Optional title for the schema + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option>, + + /// Human-readable description + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option>, + + /// Minimum value (inclusive) + #[serde(skip_serializing_if = "Option::is_none")] + pub minimum: Option, + + /// Maximum value (inclusive) + #[serde(skip_serializing_if = "Option::is_none")] + pub maximum: Option, +} + +impl Default for NumberSchema { + fn default() -> Self { + Self { + type_: NumberTypeConst, + title: None, + description: None, + minimum: None, + maximum: None, + } + } +} + +impl NumberSchema { + /// Create a new number schema + pub fn new() -> Self { + Self::default() + } + + /// Set minimum and maximum (inclusive) + pub fn with_range(mut self, min: f64, max: f64) -> Result { + if min > max { + return Err("minimum must be <= maximum"); + } + self.minimum = Some(min); + self.maximum = Some(max); + Ok(self) + } + + /// Set minimum and maximum (panics on invalid input) + pub fn range(mut self, min: f64, max: f64) -> Self { + assert!(min <= max, "minimum must be <= maximum"); + self.minimum = Some(min); + self.maximum = Some(max); + self + } + + /// Set minimum (inclusive) + pub fn minimum(mut self, min: f64) -> Self { + self.minimum = Some(min); + self + } + + /// Set maximum (inclusive) + pub fn maximum(mut self, max: f64) -> Self { + self.maximum = Some(max); + self + } + + /// Set title + pub fn title(mut self, title: impl Into>) -> Self { + self.title = Some(title.into()); + self + } + + /// Set description + pub fn description(mut self, description: impl Into>) -> Self { + self.description = Some(description.into()); + self + } +} + +// ============================================================================= +// INTEGER SCHEMA +// ============================================================================= + +/// Schema definition for integer properties. +/// +/// Compliant with MCP 2025-06-18 specification for elicitation schemas. +/// Supports only the fields allowed by the MCP spec. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[serde(rename_all = "camelCase")] +pub struct IntegerSchema { + /// Type discriminator + #[serde(rename = "type")] + pub type_: IntegerTypeConst, + + /// Optional title for the schema + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option>, + + /// Human-readable description + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option>, + + /// Minimum value (inclusive) + #[serde(skip_serializing_if = "Option::is_none")] + pub minimum: Option, + + /// Maximum value (inclusive) + #[serde(skip_serializing_if = "Option::is_none")] + pub maximum: Option, +} + +impl Default for IntegerSchema { + fn default() -> Self { + Self { + type_: IntegerTypeConst, + title: None, + description: None, + minimum: None, + maximum: None, + } + } +} + +impl IntegerSchema { + /// Create a new integer schema + pub fn new() -> Self { + Self::default() + } + + /// Set minimum and maximum (inclusive) + pub fn with_range(mut self, min: i64, max: i64) -> Result { + if min > max { + return Err("minimum must be <= maximum"); + } + self.minimum = Some(min); + self.maximum = Some(max); + Ok(self) + } + + /// Set minimum and maximum (panics on invalid input) + pub fn range(mut self, min: i64, max: i64) -> Self { + assert!(min <= max, "minimum must be <= maximum"); + self.minimum = Some(min); + self.maximum = Some(max); + self + } + + /// Set minimum (inclusive) + pub fn minimum(mut self, min: i64) -> Self { + self.minimum = Some(min); + self + } + + /// Set maximum (inclusive) + pub fn maximum(mut self, max: i64) -> Self { + self.maximum = Some(max); + self + } + + /// Set title + pub fn title(mut self, title: impl Into>) -> Self { + self.title = Some(title.into()); + self + } + + /// Set description + pub fn description(mut self, description: impl Into>) -> Self { + self.description = Some(description.into()); + self + } +} + +// ============================================================================= +// BOOLEAN SCHEMA +// ============================================================================= + +/// Schema definition for boolean properties. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[serde(rename_all = "camelCase")] +pub struct BooleanSchema { + /// Type discriminator + #[serde(rename = "type")] + pub type_: BooleanTypeConst, + + /// Optional title for the schema + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option>, + + /// Human-readable description + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option>, + + /// Default value + #[serde(skip_serializing_if = "Option::is_none")] + pub default: Option, +} + +impl Default for BooleanSchema { + fn default() -> Self { + Self { + type_: BooleanTypeConst, + title: None, + description: None, + default: None, + } + } +} + +impl BooleanSchema { + /// Create a new boolean schema + pub fn new() -> Self { + Self::default() + } + + /// Set title + pub fn title(mut self, title: impl Into>) -> Self { + self.title = Some(title.into()); + self + } + + /// Set description + pub fn description(mut self, description: impl Into>) -> Self { + self.description = Some(description.into()); + self + } + + /// Set default value + pub fn with_default(mut self, default: bool) -> Self { + self.default = Some(default); + self + } +} + +// ============================================================================= +// ENUM SCHEMA +// ============================================================================= + +/// Schema definition for enum properties. +/// +/// Compliant with MCP 2025-06-18 specification for elicitation schemas. +/// Enums must have string type and can optionally include human-readable names. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[serde(rename_all = "camelCase")] +pub struct EnumSchema { + /// Type discriminator (always "string" for enums) + #[serde(rename = "type")] + pub type_: StringTypeConst, + + /// Allowed enum values (string values only per MCP spec) + #[serde(rename = "enum")] + pub enum_values: Vec, + + /// Optional human-readable names for each enum value + #[serde(skip_serializing_if = "Option::is_none")] + pub enum_names: Option>, + + /// Optional title for the schema + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option>, + + /// Human-readable description + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option>, +} + +impl EnumSchema { + /// Create a new enum schema with string values + pub fn new(values: Vec) -> Self { + Self { + type_: StringTypeConst, + enum_values: values, + enum_names: None, + title: None, + description: None, + } + } + + /// Set enum names (human-readable names for each enum value) + pub fn enum_names(mut self, names: Vec) -> Self { + self.enum_names = Some(names); + self + } + + /// Set title + pub fn title(mut self, title: impl Into>) -> Self { + self.title = Some(title.into()); + self + } + + /// Set description + pub fn description(mut self, description: impl Into>) -> Self { + self.description = Some(description.into()); + self + } +} + +// ============================================================================= +// ELICITATION SCHEMA +// ============================================================================= + +/// Type-safe elicitation schema for requesting structured user input. +/// +/// This enforces the MCP 2025-06-18 specification that elicitation schemas +/// must be objects with primitive-typed properties. +/// +/// # Example +/// +/// ```rust +/// use rmcp::model::*; +/// +/// let schema = ElicitationSchema::builder() +/// .required_email("email") +/// .required_integer("age", 0, 150) +/// .optional_bool("newsletter", false) +/// .build(); +/// ``` +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[serde(rename_all = "camelCase")] +pub struct ElicitationSchema { + /// Always "object" for elicitation schemas + #[serde(rename = "type")] + pub type_: ObjectTypeConst, + + /// Optional title for the schema + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option>, + + /// Property definitions (must be primitive types) + pub properties: BTreeMap, + + /// List of required property names + #[serde(skip_serializing_if = "Option::is_none")] + pub required: Option>, + + /// Optional description of what this schema represents + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option>, +} + +impl ElicitationSchema { + /// Create a new elicitation schema with the given properties + pub fn new(properties: BTreeMap) -> Self { + Self { + type_: ObjectTypeConst, + title: None, + properties, + required: None, + description: None, + } + } + + /// Convert from a JSON Schema object (typically generated by schemars) + /// + /// This allows converting from JsonObject to ElicitationSchema, which is useful + /// when working with automatically generated schemas from types. + /// + /// # Example + /// + /// ```rust,ignore + /// use rmcp::model::*; + /// + /// let json_schema = schema_for_type::(); + /// let elicitation_schema = ElicitationSchema::from_json_schema(json_schema)?; + /// ``` + /// + /// # Errors + /// + /// Returns a [`serde_json::Error`] if the JSON object cannot be deserialized + /// into a valid ElicitationSchema. + pub fn from_json_schema(schema: crate::model::JsonObject) -> Result { + serde_json::from_value(serde_json::Value::Object(schema)) + } + + /// Generate an ElicitationSchema from a Rust type that implements JsonSchema + /// + /// This is a convenience method that combines schema generation and conversion. + /// It uses the same schema generation settings as the rest of the MCP SDK. + /// + /// # Example + /// + /// ```rust,ignore + /// use rmcp::model::*; + /// use schemars::JsonSchema; + /// use serde::{Deserialize, Serialize}; + /// + /// #[derive(JsonSchema, Serialize, Deserialize)] + /// struct UserInput { + /// name: String, + /// age: u32, + /// } + /// + /// let schema = ElicitationSchema::from_type::()?; + /// ``` + /// + /// # Errors + /// + /// Returns a [`serde_json::Error`] if the generated schema cannot be converted + /// to a valid ElicitationSchema. + #[cfg(feature = "schemars")] + pub fn from_type() -> Result + where + T: schemars::JsonSchema, + { + use crate::schemars::generate::SchemaSettings; + + let mut settings = SchemaSettings::draft07(); + settings.transforms = vec![Box::new(schemars::transform::AddNullable::default())]; + let generator = settings.into_generator(); + let schema = generator.into_root_schema_for::(); + let object = serde_json::to_value(schema).expect("failed to serialize schema"); + match object { + serde_json::Value::Object(object) => Self::from_json_schema(object), + _ => panic!( + "Schema serialization produced non-object value: expected JSON object but got {:?}", + object + ), + } + } + + /// Set the required fields + pub fn with_required(mut self, required: Vec) -> Self { + self.required = Some(required); + self + } + + /// Set the title + pub fn with_title(mut self, title: impl Into>) -> Self { + self.title = Some(title.into()); + self + } + + /// Set the description + pub fn with_description(mut self, description: impl Into>) -> Self { + self.description = Some(description.into()); + self + } + + /// Create a builder for constructing elicitation schemas fluently + pub fn builder() -> ElicitationSchemaBuilder { + ElicitationSchemaBuilder::new() + } +} + +// ============================================================================= +// BUILDER +// ============================================================================= + +/// Fluent builder for constructing elicitation schemas. +/// +/// # Example +/// +/// ```rust +/// use rmcp::model::*; +/// +/// let schema = ElicitationSchema::builder() +/// .required_email("email") +/// .required_integer("age", 0, 150) +/// .optional_bool("newsletter", false) +/// .description("User registration") +/// .build(); +/// ``` +#[derive(Debug, Default)] +pub struct ElicitationSchemaBuilder { + pub properties: BTreeMap, + pub required: Vec, + pub title: Option>, + pub description: Option>, +} + +impl ElicitationSchemaBuilder { + /// Create a new builder + pub fn new() -> Self { + Self::default() + } + + /// Add a property to the schema + pub fn property(mut self, name: impl Into, schema: PrimitiveSchema) -> Self { + self.properties.insert(name.into(), schema); + self + } + + /// Add a required property to the schema + pub fn required_property(mut self, name: impl Into, schema: PrimitiveSchema) -> Self { + let name_str = name.into(); + self.required.push(name_str.clone()); + self.properties.insert(name_str, schema); + self + } + + // =========================================================================== + // TYPED PROPERTY METHODS - Cleaner API without PrimitiveSchema wrapper + // =========================================================================== + + /// Add a string property with custom builder (required) + pub fn string_property( + mut self, + name: impl Into, + f: impl FnOnce(StringSchema) -> StringSchema, + ) -> Self { + self.properties + .insert(name.into(), PrimitiveSchema::String(f(StringSchema::new()))); + self + } + + /// Add a required string property with custom builder + pub fn required_string_property( + mut self, + name: impl Into, + f: impl FnOnce(StringSchema) -> StringSchema, + ) -> Self { + let name_str = name.into(); + self.required.push(name_str.clone()); + self.properties + .insert(name_str, PrimitiveSchema::String(f(StringSchema::new()))); + self + } + + /// Add a number property with custom builder + pub fn number_property( + mut self, + name: impl Into, + f: impl FnOnce(NumberSchema) -> NumberSchema, + ) -> Self { + self.properties + .insert(name.into(), PrimitiveSchema::Number(f(NumberSchema::new()))); + self + } + + /// Add a required number property with custom builder + pub fn required_number_property( + mut self, + name: impl Into, + f: impl FnOnce(NumberSchema) -> NumberSchema, + ) -> Self { + let name_str = name.into(); + self.required.push(name_str.clone()); + self.properties + .insert(name_str, PrimitiveSchema::Number(f(NumberSchema::new()))); + self + } + + /// Add an integer property with custom builder + pub fn integer_property( + mut self, + name: impl Into, + f: impl FnOnce(IntegerSchema) -> IntegerSchema, + ) -> Self { + self.properties.insert( + name.into(), + PrimitiveSchema::Integer(f(IntegerSchema::new())), + ); + self + } + + /// Add a required integer property with custom builder + pub fn required_integer_property( + mut self, + name: impl Into, + f: impl FnOnce(IntegerSchema) -> IntegerSchema, + ) -> Self { + let name_str = name.into(); + self.required.push(name_str.clone()); + self.properties + .insert(name_str, PrimitiveSchema::Integer(f(IntegerSchema::new()))); + self + } + + /// Add a boolean property with custom builder + pub fn bool_property( + mut self, + name: impl Into, + f: impl FnOnce(BooleanSchema) -> BooleanSchema, + ) -> Self { + self.properties.insert( + name.into(), + PrimitiveSchema::Boolean(f(BooleanSchema::new())), + ); + self + } + + /// Add a required boolean property with custom builder + pub fn required_bool_property( + mut self, + name: impl Into, + f: impl FnOnce(BooleanSchema) -> BooleanSchema, + ) -> Self { + let name_str = name.into(); + self.required.push(name_str.clone()); + self.properties + .insert(name_str, PrimitiveSchema::Boolean(f(BooleanSchema::new()))); + self + } + + // =========================================================================== + // CONVENIENCE METHODS - Simple common cases + // =========================================================================== + + /// Add a required string property + pub fn required_string(self, name: impl Into) -> Self { + self.required_property(name, PrimitiveSchema::String(StringSchema::new())) + } + + /// Add an optional string property + pub fn optional_string(self, name: impl Into) -> Self { + self.property(name, PrimitiveSchema::String(StringSchema::new())) + } + + /// Add a required email property + pub fn required_email(self, name: impl Into) -> Self { + self.required_property(name, PrimitiveSchema::String(StringSchema::email())) + } + + /// Add an optional email property + pub fn optional_email(self, name: impl Into) -> Self { + self.property(name, PrimitiveSchema::String(StringSchema::email())) + } + + /// Add a required string property with custom builder + pub fn required_string_with( + self, + name: impl Into, + f: impl FnOnce(StringSchema) -> StringSchema, + ) -> Self { + self.required_property(name, PrimitiveSchema::String(f(StringSchema::new()))) + } + + /// Add an optional string property with custom builder + pub fn optional_string_with( + self, + name: impl Into, + f: impl FnOnce(StringSchema) -> StringSchema, + ) -> Self { + self.property(name, PrimitiveSchema::String(f(StringSchema::new()))) + } + + // Convenience methods for numbers + + /// Add a required number property with range + pub fn required_number(self, name: impl Into, min: f64, max: f64) -> Self { + self.required_property( + name, + PrimitiveSchema::Number(NumberSchema::new().range(min, max)), + ) + } + + /// Add an optional number property with range + pub fn optional_number(self, name: impl Into, min: f64, max: f64) -> Self { + self.property( + name, + PrimitiveSchema::Number(NumberSchema::new().range(min, max)), + ) + } + + /// Add a required number property with custom builder + pub fn required_number_with( + self, + name: impl Into, + f: impl FnOnce(NumberSchema) -> NumberSchema, + ) -> Self { + self.required_property(name, PrimitiveSchema::Number(f(NumberSchema::new()))) + } + + /// Add an optional number property with custom builder + pub fn optional_number_with( + self, + name: impl Into, + f: impl FnOnce(NumberSchema) -> NumberSchema, + ) -> Self { + self.property(name, PrimitiveSchema::Number(f(NumberSchema::new()))) + } + + // Convenience methods for integers + + /// Add a required integer property with range + pub fn required_integer(self, name: impl Into, min: i64, max: i64) -> Self { + self.required_property( + name, + PrimitiveSchema::Integer(IntegerSchema::new().range(min, max)), + ) + } + + /// Add an optional integer property with range + pub fn optional_integer(self, name: impl Into, min: i64, max: i64) -> Self { + self.property( + name, + PrimitiveSchema::Integer(IntegerSchema::new().range(min, max)), + ) + } + + /// Add a required integer property with custom builder + pub fn required_integer_with( + self, + name: impl Into, + f: impl FnOnce(IntegerSchema) -> IntegerSchema, + ) -> Self { + self.required_property(name, PrimitiveSchema::Integer(f(IntegerSchema::new()))) + } + + /// Add an optional integer property with custom builder + pub fn optional_integer_with( + self, + name: impl Into, + f: impl FnOnce(IntegerSchema) -> IntegerSchema, + ) -> Self { + self.property(name, PrimitiveSchema::Integer(f(IntegerSchema::new()))) + } + + // Convenience methods for booleans + + /// Add a required boolean property + pub fn required_bool(self, name: impl Into) -> Self { + self.required_property(name, PrimitiveSchema::Boolean(BooleanSchema::new())) + } + + /// Add an optional boolean property with default value + pub fn optional_bool(self, name: impl Into, default: bool) -> Self { + self.property( + name, + PrimitiveSchema::Boolean(BooleanSchema::new().with_default(default)), + ) + } + + /// Add a required boolean property with custom builder + pub fn required_bool_with( + self, + name: impl Into, + f: impl FnOnce(BooleanSchema) -> BooleanSchema, + ) -> Self { + self.required_property(name, PrimitiveSchema::Boolean(f(BooleanSchema::new()))) + } + + /// Add an optional boolean property with custom builder + pub fn optional_bool_with( + self, + name: impl Into, + f: impl FnOnce(BooleanSchema) -> BooleanSchema, + ) -> Self { + self.property(name, PrimitiveSchema::Boolean(f(BooleanSchema::new()))) + } + + // Enum convenience methods + + /// Add a required enum property + pub fn required_enum(self, name: impl Into, values: Vec) -> Self { + self.required_property(name, PrimitiveSchema::Enum(EnumSchema::new(values))) + } + + /// Add an optional enum property + pub fn optional_enum(self, name: impl Into, values: Vec) -> Self { + self.property(name, PrimitiveSchema::Enum(EnumSchema::new(values))) + } + + /// Mark an existing property as required + pub fn mark_required(mut self, name: impl Into) -> Self { + self.required.push(name.into()); + self + } + + /// Set the schema title + pub fn title(mut self, title: impl Into>) -> Self { + self.title = Some(title.into()); + self + } + + /// Set the schema description + pub fn description(mut self, description: impl Into>) -> Self { + self.description = Some(description.into()); + self + } + + /// Build the elicitation schema with validation + /// + /// # Errors + /// + /// Returns an error if: + /// - Required fields reference non-existent properties + /// - No properties are defined (empty schema) + pub fn build(self) -> Result { + // Validate that all required fields exist in properties + if !self.required.is_empty() { + for field_name in &self.required { + if !self.properties.contains_key(field_name) { + return Err("Required field does not exist in properties"); + } + } + } + + Ok(ElicitationSchema { + type_: ObjectTypeConst, + title: self.title, + properties: self.properties, + required: if self.required.is_empty() { + None + } else { + Some(self.required) + }, + description: self.description, + }) + } + + /// Build the elicitation schema without validation (panics on invalid schema) + /// + /// # Panics + /// + /// Panics if required fields reference non-existent properties + pub fn build_unchecked(self) -> ElicitationSchema { + self.build().expect("Invalid elicitation schema") + } +} + +#[cfg(test)] +mod tests { + use serde_json::json; + + use super::*; + + #[test] + fn test_string_schema_serialization() { + let schema = StringSchema::email().description("Email address"); + let json = serde_json::to_value(&schema).unwrap(); + + assert_eq!(json["type"], "string"); + assert_eq!(json["format"], "email"); + assert_eq!(json["description"], "Email address"); + } + + #[test] + fn test_number_schema_serialization() { + let schema = NumberSchema::new() + .range(0.0, 100.0) + .description("Percentage"); + let json = serde_json::to_value(&schema).unwrap(); + + assert_eq!(json["type"], "number"); + assert_eq!(json["minimum"], 0.0); + assert_eq!(json["maximum"], 100.0); + } + + #[test] + fn test_integer_schema_serialization() { + let schema = IntegerSchema::new().range(0, 150); + let json = serde_json::to_value(&schema).unwrap(); + + assert_eq!(json["type"], "integer"); + assert_eq!(json["minimum"], 0); + assert_eq!(json["maximum"], 150); + } + + #[test] + fn test_boolean_schema_serialization() { + let schema = BooleanSchema::new().with_default(true); + let json = serde_json::to_value(&schema).unwrap(); + + assert_eq!(json["type"], "boolean"); + assert_eq!(json["default"], true); + } + + #[test] + fn test_enum_schema_serialization() { + let schema = EnumSchema::new(vec!["US".to_string(), "UK".to_string()]) + .enum_names(vec![ + "United States".to_string(), + "United Kingdom".to_string(), + ]) + .description("Country code"); + let json = serde_json::to_value(&schema).unwrap(); + + assert_eq!(json["type"], "string"); + assert_eq!(json["enum"], json!(["US", "UK"])); + assert_eq!( + json["enumNames"], + json!(["United States", "United Kingdom"]) + ); + assert_eq!(json["description"], "Country code"); + } + + #[test] + fn test_elicitation_schema_builder_simple() { + let schema = ElicitationSchema::builder() + .required_email("email") + .optional_bool("newsletter", false) + .build() + .unwrap(); + + assert_eq!(schema.properties.len(), 2); + assert!(schema.properties.contains_key("email")); + assert!(schema.properties.contains_key("newsletter")); + assert_eq!(schema.required, Some(vec!["email".to_string()])); + } + + #[test] + fn test_elicitation_schema_builder_complex() { + let schema = ElicitationSchema::builder() + .required_string_with("name", |s| s.length(1, 100)) + .required_integer("age", 0, 150) + .optional_bool("newsletter", false) + .required_enum( + "country", + vec!["US".to_string(), "UK".to_string(), "CA".to_string()], + ) + .description("User registration") + .build() + .unwrap(); + + assert_eq!(schema.properties.len(), 4); + assert_eq!( + schema.required, + Some(vec![ + "name".to_string(), + "age".to_string(), + "country".to_string() + ]) + ); + assert_eq!(schema.description.as_deref(), Some("User registration")); + } + + #[test] + fn test_elicitation_schema_serialization() { + let schema = ElicitationSchema::builder() + .required_string_with("name", |s| s.length(1, 100)) + .build() + .unwrap(); + + let json = serde_json::to_value(&schema).unwrap(); + + assert_eq!(json["type"], "object"); + assert!(json["properties"]["name"].is_object()); + assert_eq!(json["required"], json!(["name"])); + } + + #[test] + #[should_panic(expected = "minimum must be <= maximum")] + fn test_integer_range_validation() { + IntegerSchema::new().range(10, 5); // Should panic + } + + #[test] + #[should_panic(expected = "min_length must be <= max_length")] + fn test_string_length_validation() { + StringSchema::new().length(10, 5); // Should panic + } + + #[test] + fn test_integer_range_validation_with_result() { + let result = IntegerSchema::new().with_range(10, 5); + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), "minimum must be <= maximum"); + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/model/extension.rs b/code-rs/third_party/rmcp-0.8.3/src/model/extension.rs new file mode 100644 index 00000000000..039fdf2eb7c --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/model/extension.rs @@ -0,0 +1,337 @@ +//! A container for those extra data could be carried on request or notification. +//! +//! This file is copied and modified from crate [http](https://github.com/hyperium/http). +//! +//! - Original code license: +//! - Original code: +use std::{ + any::{Any, TypeId}, + collections::HashMap, + fmt, + hash::{BuildHasherDefault, Hasher}, +}; + +type AnyMap = HashMap, BuildHasherDefault>; + +// With TypeIds as keys, there's no need to hash them. They are already hashes +// themselves, coming from the compiler. The IdHasher just holds the u64 of +// the TypeId, and then returns it, instead of doing any bit fiddling. +#[derive(Default)] +struct IdHasher(u64); + +impl Hasher for IdHasher { + fn write(&mut self, _: &[u8]) { + unreachable!("TypeId calls write_u64"); + } + + #[inline] + fn write_u64(&mut self, id: u64) { + self.0 = id; + } + + #[inline] + fn finish(&self) -> u64 { + self.0 + } +} + +/// A type map of protocol extensions. +/// +/// `Extensions` can be used by `Request` `Notification` and `Response` to store +/// extra data derived from the underlying protocol. +#[derive(Clone, Default)] +pub struct Extensions { + // If extensions are never used, no need to carry around an empty HashMap. + // That's 3 words. Instead, this is only 1 word. + map: Option>, +} + +impl Extensions { + /// Create an empty `Extensions`. + #[inline] + pub const fn new() -> Extensions { + Extensions { map: None } + } + + /// Insert a type into this `Extensions`. + /// + /// If a extension of this type already existed, it will + /// be returned and replaced with the new one. + /// + /// # Example + /// + /// ``` + /// # use rmcp::model::Extensions; + /// let mut ext = Extensions::new(); + /// assert!(ext.insert(5i32).is_none()); + /// assert!(ext.insert(4u8).is_none()); + /// assert_eq!(ext.insert(9i32), Some(5i32)); + /// ``` + pub fn insert(&mut self, val: T) -> Option { + self.map + .get_or_insert_with(Box::default) + .insert(TypeId::of::(), Box::new(val)) + .and_then(|boxed| boxed.into_any().downcast().ok().map(|boxed| *boxed)) + } + + /// Get a reference to a type previously inserted on this `Extensions`. + /// + /// # Example + /// + /// ``` + /// # use rmcp::model::Extensions; + /// let mut ext = Extensions::new(); + /// assert!(ext.get::().is_none()); + /// ext.insert(5i32); + /// + /// assert_eq!(ext.get::(), Some(&5i32)); + /// ``` + pub fn get(&self) -> Option<&T> { + self.map + .as_ref() + .and_then(|map| map.get(&TypeId::of::())) + .and_then(|boxed| (**boxed).as_any().downcast_ref()) + } + + /// Get a mutable reference to a type previously inserted on this `Extensions`. + /// + /// # Example + /// + /// ``` + /// # use rmcp::model::Extensions; + /// let mut ext = Extensions::new(); + /// ext.insert(String::from("Hello")); + /// ext.get_mut::().unwrap().push_str(" World"); + /// + /// assert_eq!(ext.get::().unwrap(), "Hello World"); + /// ``` + pub fn get_mut(&mut self) -> Option<&mut T> { + self.map + .as_mut() + .and_then(|map| map.get_mut(&TypeId::of::())) + .and_then(|boxed| (**boxed).as_any_mut().downcast_mut()) + } + + /// Get a mutable reference to a type, inserting `value` if not already present on this + /// `Extensions`. + /// + /// # Example + /// + /// ``` + /// # use rmcp::model::Extensions; + /// let mut ext = Extensions::new(); + /// *ext.get_or_insert(1i32) += 2; + /// + /// assert_eq!(*ext.get::().unwrap(), 3); + /// ``` + pub fn get_or_insert(&mut self, value: T) -> &mut T { + self.get_or_insert_with(|| value) + } + + /// Get a mutable reference to a type, inserting the value created by `f` if not already present + /// on this `Extensions`. + /// + /// # Example + /// + /// ``` + /// # use rmcp::model::Extensions; + /// let mut ext = Extensions::new(); + /// *ext.get_or_insert_with(|| 1i32) += 2; + /// + /// assert_eq!(*ext.get::().unwrap(), 3); + /// ``` + pub fn get_or_insert_with T>( + &mut self, + f: F, + ) -> &mut T { + let out = self + .map + .get_or_insert_with(Box::default) + .entry(TypeId::of::()) + .or_insert_with(|| Box::new(f())); + (**out).as_any_mut().downcast_mut().unwrap() + } + + /// Get a mutable reference to a type, inserting the type's default value if not already present + /// on this `Extensions`. + /// + /// # Example + /// + /// ``` + /// # use rmcp::model::Extensions; + /// let mut ext = Extensions::new(); + /// *ext.get_or_insert_default::() += 2; + /// + /// assert_eq!(*ext.get::().unwrap(), 2); + /// ``` + pub fn get_or_insert_default(&mut self) -> &mut T { + self.get_or_insert_with(T::default) + } + + /// Remove a type from this `Extensions`. + /// + /// If a extension of this type existed, it will be returned. + /// + /// # Example + /// + /// ``` + /// # use rmcp::model::Extensions; + /// let mut ext = Extensions::new(); + /// ext.insert(5i32); + /// assert_eq!(ext.remove::(), Some(5i32)); + /// assert!(ext.get::().is_none()); + /// ``` + pub fn remove(&mut self) -> Option { + self.map + .as_mut() + .and_then(|map| map.remove(&TypeId::of::())) + .and_then(|boxed| boxed.into_any().downcast().ok().map(|boxed| *boxed)) + } + + /// Clear the `Extensions` of all inserted extensions. + /// + /// # Example + /// + /// ``` + /// # use rmcp::model::Extensions; + /// let mut ext = Extensions::new(); + /// ext.insert(5i32); + /// ext.clear(); + /// + /// assert!(ext.get::().is_none()); + /// ``` + #[inline] + pub fn clear(&mut self) { + if let Some(ref mut map) = self.map { + map.clear(); + } + } + + /// Check whether the extension set is empty or not. + /// + /// # Example + /// + /// ``` + /// # use rmcp::model::Extensions; + /// let mut ext = Extensions::new(); + /// assert!(ext.is_empty()); + /// ext.insert(5i32); + /// assert!(!ext.is_empty()); + /// ``` + #[inline] + pub fn is_empty(&self) -> bool { + self.map.as_ref().is_none_or(|map| map.is_empty()) + } + + /// Get the number of extensions available. + /// + /// # Example + /// + /// ``` + /// # use rmcp::model::Extensions; + /// let mut ext = Extensions::new(); + /// assert_eq!(ext.len(), 0); + /// ext.insert(5i32); + /// assert_eq!(ext.len(), 1); + /// ``` + #[inline] + pub fn len(&self) -> usize { + self.map.as_ref().map_or(0, |map| map.len()) + } + + /// Extends `self` with another `Extensions`. + /// + /// If an instance of a specific type exists in both, the one in `self` is overwritten with the + /// one from `other`. + /// + /// # Example + /// + /// ``` + /// # use rmcp::model::Extensions; + /// let mut ext_a = Extensions::new(); + /// ext_a.insert(8u8); + /// ext_a.insert(16u16); + /// + /// let mut ext_b = Extensions::new(); + /// ext_b.insert(4u8); + /// ext_b.insert("hello"); + /// + /// ext_a.extend(ext_b); + /// assert_eq!(ext_a.len(), 3); + /// assert_eq!(ext_a.get::(), Some(&4u8)); + /// assert_eq!(ext_a.get::(), Some(&16u16)); + /// assert_eq!(ext_a.get::<&'static str>().copied(), Some("hello")); + /// ``` + pub fn extend(&mut self, other: Self) { + if let Some(other) = other.map { + if let Some(map) = &mut self.map { + map.extend(*other); + } else { + self.map = Some(other); + } + } + } +} + +impl fmt::Debug for Extensions { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Extensions").finish() + } +} + +trait AnyClone: Any { + fn clone_box(&self) -> Box; + fn as_any(&self) -> &dyn Any; + fn as_any_mut(&mut self) -> &mut dyn Any; + fn into_any(self: Box) -> Box; +} + +impl AnyClone for T { + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + fn into_any(self: Box) -> Box { + self + } +} + +impl Clone for Box { + fn clone(&self) -> Self { + (**self).clone_box() + } +} + +#[test] +fn test_extensions() { + #[derive(Clone, Debug, PartialEq)] + struct MyType(i32); + + let mut extensions = Extensions::new(); + + extensions.insert(5i32); + extensions.insert(MyType(10)); + + assert_eq!(extensions.get(), Some(&5i32)); + assert_eq!(extensions.get_mut(), Some(&mut 5i32)); + + let ext2 = extensions.clone(); + + assert_eq!(extensions.remove::(), Some(5i32)); + assert!(extensions.get::().is_none()); + + // clone still has it + assert_eq!(ext2.get(), Some(&5i32)); + assert_eq!(ext2.get(), Some(&MyType(10))); + + assert_eq!(extensions.get::(), None); + assert_eq!(extensions.get(), Some(&MyType(10))); +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/model/meta.rs b/code-rs/third_party/rmcp-0.8.3/src/model/meta.rs new file mode 100644 index 00000000000..fd93362b7c9 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/model/meta.rs @@ -0,0 +1,188 @@ +use std::ops::{Deref, DerefMut}; + +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use super::{ + ClientNotification, ClientRequest, Extensions, JsonObject, JsonRpcMessage, NumberOrString, + ProgressToken, ServerNotification, ServerRequest, +}; + +pub trait GetMeta { + fn get_meta_mut(&mut self) -> &mut Meta; + fn get_meta(&self) -> &Meta; +} + +pub trait GetExtensions { + fn extensions(&self) -> &Extensions; + fn extensions_mut(&mut self) -> &mut Extensions; +} + +macro_rules! variant_extension { + ( + $Enum: ident { + $($variant: ident)* + } + ) => { + impl GetExtensions for $Enum { + fn extensions(&self) -> &Extensions { + match self { + $( + $Enum::$variant(v) => &v.extensions, + )* + } + } + fn extensions_mut(&mut self) -> &mut Extensions { + match self { + $( + $Enum::$variant(v) => &mut v.extensions, + )* + } + } + } + impl GetMeta for $Enum { + fn get_meta_mut(&mut self) -> &mut Meta { + self.extensions_mut().get_or_insert_default() + } + fn get_meta(&self) -> &Meta { + self.extensions().get::().unwrap_or(Meta::static_empty()) + } + } + }; +} + +variant_extension! { + ClientRequest { + PingRequest + InitializeRequest + CompleteRequest + SetLevelRequest + GetPromptRequest + ListPromptsRequest + ListResourcesRequest + ListResourceTemplatesRequest + ReadResourceRequest + SubscribeRequest + UnsubscribeRequest + CallToolRequest + ListToolsRequest + } +} + +variant_extension! { + ServerRequest { + PingRequest + CreateMessageRequest + ListRootsRequest + CreateElicitationRequest + } +} + +variant_extension! { + ClientNotification { + CancelledNotification + ProgressNotification + InitializedNotification + RootsListChangedNotification + } +} + +variant_extension! { + ServerNotification { + CancelledNotification + ProgressNotification + LoggingMessageNotification + ResourceUpdatedNotification + ResourceListChangedNotification + ToolListChangedNotification + PromptListChangedNotification + } +} +#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[serde(transparent)] +pub struct Meta(pub JsonObject); +const PROGRESS_TOKEN_FIELD: &str = "progressToken"; +impl Meta { + pub fn new() -> Self { + Self(JsonObject::new()) + } + + pub(crate) fn static_empty() -> &'static Self { + static EMPTY: std::sync::OnceLock = std::sync::OnceLock::new(); + EMPTY.get_or_init(Default::default) + } + + pub fn get_progress_token(&self) -> Option { + self.0.get(PROGRESS_TOKEN_FIELD).and_then(|v| match v { + Value::String(s) => Some(ProgressToken(NumberOrString::String(s.to_string().into()))), + Value::Number(n) => { + if let Some(i) = n.as_i64() { + Some(ProgressToken(NumberOrString::Number(i))) + } else if let Some(u) = n.as_u64() { + if u <= i64::MAX as u64 { + Some(ProgressToken(NumberOrString::Number(u as i64))) + } else { + None + } + } else { + None + } + } + _ => None, + }) + } + + pub fn set_progress_token(&mut self, token: ProgressToken) { + match token.0 { + NumberOrString::String(ref s) => self.0.insert( + PROGRESS_TOKEN_FIELD.to_string(), + Value::String(s.to_string()), + ), + NumberOrString::Number(n) => self + .0 + .insert(PROGRESS_TOKEN_FIELD.to_string(), Value::Number(n.into())), + }; + } + + pub fn extend(&mut self, other: Meta) { + for (k, v) in other.0.into_iter() { + self.0.insert(k, v); + } + } +} + +impl Deref for Meta { + type Target = JsonObject; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for Meta { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl JsonRpcMessage +where + Req: GetExtensions, + Noti: GetExtensions, +{ + pub fn insert_extension(&mut self, value: T) { + match self { + JsonRpcMessage::Request(json_rpc_request) => { + json_rpc_request.request.extensions_mut().insert(value); + } + JsonRpcMessage::Notification(json_rpc_notification) => { + json_rpc_notification + .notification + .extensions_mut() + .insert(value); + } + _ => {} + } + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/model/prompt.rs b/code-rs/third_party/rmcp-0.8.3/src/model/prompt.rs new file mode 100644 index 00000000000..3c69758ccd1 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/model/prompt.rs @@ -0,0 +1,266 @@ +use base64::engine::{Engine, general_purpose::STANDARD as BASE64_STANDARD}; +use serde::{Deserialize, Serialize}; + +use super::{ + AnnotateAble, Annotations, Icon, RawEmbeddedResource, RawImageContent, + content::{EmbeddedResource, ImageContent}, + resource::ResourceContents, +}; + +/// A prompt that can be used to generate text from a model +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct Prompt { + /// The name of the prompt + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + /// Optional description of what the prompt does + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// Optional arguments that can be passed to customize the prompt + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option>, + /// Optional list of icons for the prompt + #[serde(skip_serializing_if = "Option::is_none")] + pub icons: Option>, +} + +impl Prompt { + /// Create a new prompt with the given name, description and arguments + pub fn new( + name: N, + description: Option, + arguments: Option>, + ) -> Self + where + N: Into, + D: Into, + { + Prompt { + name: name.into(), + title: None, + description: description.map(Into::into), + arguments, + icons: None, + } + } +} + +/// Represents a prompt argument that can be passed to customize the prompt +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct PromptArgument { + /// The name of the argument + pub name: String, + /// A human-readable title for the argument + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + /// A description of what the argument is used for + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// Whether this argument is required + #[serde(skip_serializing_if = "Option::is_none")] + pub required: Option, +} + +/// Represents the role of a message sender in a prompt conversation +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub enum PromptMessageRole { + User, + Assistant, +} + +/// Content types that can be included in prompt messages +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub enum PromptMessageContent { + /// Plain text content + Text { text: String }, + /// Image content with base64-encoded data + Image { + #[serde(flatten)] + image: ImageContent, + }, + /// Embedded server-side resource + Resource { resource: EmbeddedResource }, + /// A link to a resource that can be fetched separately + ResourceLink { + #[serde(flatten)] + link: super::resource::Resource, + }, +} + +impl PromptMessageContent { + pub fn text(text: impl Into) -> Self { + Self::Text { text: text.into() } + } + + /// Create a resource link content + pub fn resource_link(resource: super::resource::Resource) -> Self { + Self::ResourceLink { link: resource } + } +} + +/// A message in a prompt conversation +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct PromptMessage { + /// The role of the message sender + pub role: PromptMessageRole, + /// The content of the message + pub content: PromptMessageContent, +} + +impl PromptMessage { + /// Create a new text message with the given role and text content + pub fn new_text>(role: PromptMessageRole, text: S) -> Self { + Self { + role, + content: PromptMessageContent::Text { text: text.into() }, + } + } + + /// Create a new image message. `meta` and `annotations` are optional. + #[cfg(feature = "base64")] + pub fn new_image( + role: PromptMessageRole, + data: &[u8], + mime_type: &str, + meta: Option, + annotations: Option, + ) -> Self { + let base64 = BASE64_STANDARD.encode(data); + Self { + role, + content: PromptMessageContent::Image { + image: RawImageContent { + data: base64, + mime_type: mime_type.into(), + meta, + } + .optional_annotate(annotations), + }, + } + } + + /// Create a new resource message. `resource_meta`, `resource_content_meta`, and `annotations` are optional. + pub fn new_resource( + role: PromptMessageRole, + uri: String, + mime_type: Option, + text: Option, + resource_meta: Option, + resource_content_meta: Option, + annotations: Option, + ) -> Self { + let resource_contents = match text { + Some(t) => ResourceContents::TextResourceContents { + uri, + mime_type, + text: t, + meta: resource_content_meta, + }, + None => ResourceContents::BlobResourceContents { + uri, + mime_type, + blob: String::new(), + meta: resource_content_meta, + }, + }; + Self { + role, + content: PromptMessageContent::Resource { + resource: RawEmbeddedResource { + meta: resource_meta, + resource: resource_contents, + } + .optional_annotate(annotations), + }, + } + } + + /// Note: PromptMessage text content does not carry protocol-level _meta per current schema. + /// This function exists for API symmetry but ignores the meta parameter. + pub fn new_text_with_meta>( + role: PromptMessageRole, + text: S, + _meta: Option, + ) -> Self { + Self::new_text(role, text) + } + + /// Create a new resource link message + pub fn new_resource_link(role: PromptMessageRole, resource: super::resource::Resource) -> Self { + Self { + role, + content: PromptMessageContent::ResourceLink { link: resource }, + } + } +} + +#[cfg(test)] +mod tests { + use serde_json; + + use super::*; + + #[test] + fn test_prompt_message_image_serialization() { + let image_content = RawImageContent { + data: "base64data".to_string(), + mime_type: "image/png".to_string(), + meta: None, + }; + + let json = serde_json::to_string(&image_content).unwrap(); + println!("PromptMessage ImageContent JSON: {}", json); + + // Verify it contains mimeType (camelCase) not mime_type (snake_case) + assert!(json.contains("mimeType")); + assert!(!json.contains("mime_type")); + } + + #[test] + fn test_prompt_message_resource_link_serialization() { + use super::super::resource::RawResource; + + let resource = RawResource::new("file:///test.txt", "test.txt"); + let message = + PromptMessage::new_resource_link(PromptMessageRole::User, resource.no_annotation()); + + let json = serde_json::to_string(&message).unwrap(); + println!("PromptMessage with ResourceLink JSON: {}", json); + + // Verify it contains the correct type tag + assert!(json.contains("\"type\":\"resource_link\"")); + assert!(json.contains("\"uri\":\"file:///test.txt\"")); + assert!(json.contains("\"name\":\"test.txt\"")); + } + + #[test] + fn test_prompt_message_content_resource_link_deserialization() { + let json = r#"{ + "type": "resource_link", + "uri": "file:///example.txt", + "name": "example.txt", + "description": "Example file", + "mimeType": "text/plain" + }"#; + + let content: PromptMessageContent = serde_json::from_str(json).unwrap(); + + if let PromptMessageContent::ResourceLink { link } = content { + assert_eq!(link.uri, "file:///example.txt"); + assert_eq!(link.name, "example.txt"); + assert_eq!(link.description, Some("Example file".to_string())); + assert_eq!(link.mime_type, Some("text/plain".to_string())); + } else { + panic!("Expected ResourceLink variant"); + } + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/model/resource.rs b/code-rs/third_party/rmcp-0.8.3/src/model/resource.rs new file mode 100644 index 00000000000..a342ad4e752 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/model/resource.rs @@ -0,0 +1,144 @@ +use serde::{Deserialize, Serialize}; + +use super::{Annotated, Icon, Meta}; + +/// Represents a resource in the extension with metadata +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct RawResource { + /// URI representing the resource location (e.g., "file:///path/to/file" or "str:///content") + pub uri: String, + /// Name of the resource + pub name: String, + /// Human-readable title of the resource + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + /// Optional description of the resource + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// MIME type of the resource content ("text" or "blob") + #[serde(skip_serializing_if = "Option::is_none")] + pub mime_type: Option, + + /// The size of the raw resource content, in bytes (i.e., before base64 encoding or any tokenization), if known. + /// + /// This can be used by Hosts to display file sizes and estimate context window us + #[serde(skip_serializing_if = "Option::is_none")] + pub size: Option, + /// Optional list of icons for the resource + #[serde(skip_serializing_if = "Option::is_none")] + pub icons: Option>, +} + +pub type Resource = Annotated; + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct RawResourceTemplate { + pub uri_template: String, + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub mime_type: Option, +} + +pub type ResourceTemplate = Annotated; + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(untagged)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub enum ResourceContents { + #[serde(rename_all = "camelCase")] + TextResourceContents { + uri: String, + #[serde(skip_serializing_if = "Option::is_none")] + mime_type: Option, + text: String, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + meta: Option, + }, + #[serde(rename_all = "camelCase")] + BlobResourceContents { + uri: String, + #[serde(skip_serializing_if = "Option::is_none")] + mime_type: Option, + blob: String, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + meta: Option, + }, +} + +impl ResourceContents { + pub fn text(text: impl Into, uri: impl Into) -> Self { + Self::TextResourceContents { + uri: uri.into(), + mime_type: Some("text".into()), + text: text.into(), + meta: None, + } + } +} + +impl RawResource { + /// Creates a new Resource from a URI with explicit mime type + pub fn new(uri: impl Into, name: impl Into) -> Self { + Self { + uri: uri.into(), + name: name.into(), + title: None, + description: None, + mime_type: None, + size: None, + icons: None, + } + } +} + +#[cfg(test)] +mod tests { + use serde_json; + + use super::*; + + #[test] + fn test_resource_serialization() { + let resource = RawResource { + uri: "file:///test.txt".to_string(), + title: None, + name: "test".to_string(), + description: Some("Test resource".to_string()), + mime_type: Some("text/plain".to_string()), + size: Some(100), + icons: None, + }; + + let json = serde_json::to_string(&resource).unwrap(); + println!("Serialized JSON: {}", json); + + // Verify it contains mimeType (camelCase) not mime_type (snake_case) + assert!(json.contains("mimeType")); + assert!(!json.contains("mime_type")); + } + + #[test] + fn test_resource_contents_serialization() { + let text_contents = ResourceContents::TextResourceContents { + uri: "file:///test.txt".to_string(), + mime_type: Some("text/plain".to_string()), + text: "Hello world".to_string(), + meta: None, + }; + + let json = serde_json::to_string(&text_contents).unwrap(); + println!("ResourceContents JSON: {}", json); + + // Verify it contains mimeType (camelCase) not mime_type (snake_case) + assert!(json.contains("mimeType")); + assert!(!json.contains("mime_type")); + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/model/serde_impl.rs b/code-rs/third_party/rmcp-0.8.3/src/model/serde_impl.rs new file mode 100644 index 00000000000..09222d52188 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/model/serde_impl.rs @@ -0,0 +1,267 @@ +use std::borrow::Cow; + +use serde::{Deserialize, Serialize}; + +use super::{ + Extensions, Meta, Notification, NotificationNoParam, Request, RequestNoParam, + RequestOptionalParam, +}; +#[derive(Serialize, Deserialize)] +struct WithMeta<'a, P> { + #[serde(skip_serializing_if = "Option::is_none")] + _meta: Option>, + #[serde(flatten)] + _rest: P, +} + +#[derive(Serialize, Deserialize)] +struct Proxy<'a, M, P> { + method: M, + params: WithMeta<'a, P>, +} + +#[derive(Serialize, Deserialize)] +struct ProxyOptionalParam<'a, M, P> { + method: M, + params: Option>, +} + +#[derive(Serialize, Deserialize)] +struct ProxyNoParam { + method: M, +} + +impl Serialize for Request +where + M: Serialize, + R: Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let extensions = &self.extensions; + let _meta = extensions.get::().map(Cow::Borrowed); + Proxy::serialize( + &Proxy { + method: &self.method, + params: WithMeta { + _rest: &self.params, + _meta, + }, + }, + serializer, + ) + } +} + +impl<'de, M, R> Deserialize<'de> for Request +where + M: Deserialize<'de>, + R: Deserialize<'de>, +{ + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let body = Proxy::deserialize(deserializer)?; + let _meta = body.params._meta.map(|m| m.into_owned()); + let mut extensions = Extensions::new(); + if let Some(meta) = _meta { + extensions.insert(meta); + } + Ok(Request { + extensions, + method: body.method, + params: body.params._rest, + }) + } +} + +impl Serialize for RequestOptionalParam +where + M: Serialize, + R: Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let extensions = &self.extensions; + let _meta = extensions.get::().map(Cow::Borrowed); + Proxy::serialize( + &Proxy { + method: &self.method, + params: WithMeta { + _rest: &self.params, + _meta, + }, + }, + serializer, + ) + } +} + +impl<'de, M, R> Deserialize<'de> for RequestOptionalParam +where + M: Deserialize<'de>, + R: Deserialize<'de>, +{ + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let body = ProxyOptionalParam::<'_, _, Option>::deserialize(deserializer)?; + let mut params = None; + let mut _meta = None; + if let Some(body_params) = body.params { + params = body_params._rest; + _meta = body_params._meta.map(|m| m.into_owned()); + } + let mut extensions = Extensions::new(); + if let Some(meta) = _meta { + extensions.insert(meta); + } + Ok(RequestOptionalParam { + extensions, + method: body.method, + params, + }) + } +} + +impl Serialize for RequestNoParam +where + M: Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let extensions = &self.extensions; + let _meta = extensions.get::().map(Cow::Borrowed); + ProxyNoParam::serialize( + &ProxyNoParam { + method: &self.method, + }, + serializer, + ) + } +} + +impl<'de, M> Deserialize<'de> for RequestNoParam +where + M: Deserialize<'de>, +{ + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let body = ProxyNoParam::<_>::deserialize(deserializer)?; + let extensions = Extensions::new(); + Ok(RequestNoParam { + extensions, + method: body.method, + }) + } +} + +impl Serialize for Notification +where + M: Serialize, + R: Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let extensions = &self.extensions; + let _meta = extensions.get::().map(Cow::Borrowed); + Proxy::serialize( + &Proxy { + method: &self.method, + params: WithMeta { + _rest: &self.params, + _meta, + }, + }, + serializer, + ) + } +} + +impl<'de, M, R> Deserialize<'de> for Notification +where + M: Deserialize<'de>, + R: Deserialize<'de>, +{ + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let body = Proxy::deserialize(deserializer)?; + let _meta = body.params._meta.map(|m| m.into_owned()); + let mut extensions = Extensions::new(); + if let Some(meta) = _meta { + extensions.insert(meta); + } + Ok(Notification { + extensions, + method: body.method, + params: body.params._rest, + }) + } +} + +impl Serialize for NotificationNoParam +where + M: Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let extensions = &self.extensions; + let _meta = extensions.get::().map(Cow::Borrowed); + ProxyNoParam::serialize( + &ProxyNoParam { + method: &self.method, + }, + serializer, + ) + } +} + +impl<'de, M> Deserialize<'de> for NotificationNoParam +where + M: Deserialize<'de>, +{ + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let body = ProxyNoParam::<_>::deserialize(deserializer)?; + let extensions = Extensions::new(); + Ok(NotificationNoParam { + extensions, + method: body.method, + }) + } +} + +#[cfg(test)] +mod test { + use serde_json::json; + + use crate::model::ListToolsRequest; + + #[test] + fn test_deserialize_lost_tools_request() { + let _req: ListToolsRequest = serde_json::from_value(json!( + { + "method": "tools/list", + } + )) + .unwrap(); + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/model/tool.rs b/code-rs/third_party/rmcp-0.8.3/src/model/tool.rs new file mode 100644 index 00000000000..282d8aa3c8a --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/model/tool.rs @@ -0,0 +1,179 @@ +use std::{borrow::Cow, sync::Arc}; + +use schemars::JsonSchema; +/// Tools represent a routine that a server can execute +/// Tool calls represent requests from the client to execute one +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use super::{Icon, JsonObject}; + +/// A tool that can be used by a model. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct Tool { + /// The name of the tool + pub name: Cow<'static, str>, + /// A human-readable title for the tool + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + /// A description of what the tool does + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option>, + /// A JSON Schema object defining the expected parameters for the tool + pub input_schema: Arc, + /// An optional JSON Schema object defining the structure of the tool's output + #[serde(skip_serializing_if = "Option::is_none")] + pub output_schema: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + /// Optional additional tool information. + pub annotations: Option, + /// Optional list of icons for the tool + #[serde(skip_serializing_if = "Option::is_none")] + pub icons: Option>, +} + +/// Additional properties describing a Tool to clients. +/// +/// NOTE: all properties in ToolAnnotations are **hints**. +/// They are not guaranteed to provide a faithful description of +/// tool behavior (including descriptive properties like `title`). +/// +/// Clients should never make tool use decisions based on ToolAnnotations +/// received from untrusted servers. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct ToolAnnotations { + /// A human-readable title for the tool. + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + + /// If true, the tool does not modify its environment. + /// + /// Default: false + #[serde(skip_serializing_if = "Option::is_none")] + pub read_only_hint: Option, + + /// If true, the tool may perform destructive updates to its environment. + /// If false, the tool performs only additive updates. + /// + /// (This property is meaningful only when `readOnlyHint == false`) + /// + /// Default: true + /// A human-readable description of the tool's purpose. + #[serde(skip_serializing_if = "Option::is_none")] + pub destructive_hint: Option, + + /// If true, calling the tool repeatedly with the same arguments + /// will have no additional effect on the its environment. + /// + /// (This property is meaningful only when `readOnlyHint == false`) + /// + /// Default: false. + #[serde(skip_serializing_if = "Option::is_none")] + pub idempotent_hint: Option, + + /// If true, this tool may interact with an "open world" of external + /// entities. If false, the tool's domain of interaction is closed. + /// For example, the world of a web search tool is open, whereas that + /// of a memory tool is not. + /// + /// Default: true + #[serde(skip_serializing_if = "Option::is_none")] + pub open_world_hint: Option, +} + +impl ToolAnnotations { + pub fn new() -> Self { + Self::default() + } + pub fn with_title(title: T) -> Self + where + T: Into, + { + ToolAnnotations { + title: Some(title.into()), + ..Self::default() + } + } + pub fn read_only(self, read_only: bool) -> Self { + ToolAnnotations { + read_only_hint: Some(read_only), + ..self + } + } + pub fn destructive(self, destructive: bool) -> Self { + ToolAnnotations { + destructive_hint: Some(destructive), + ..self + } + } + pub fn idempotent(self, idempotent: bool) -> Self { + ToolAnnotations { + idempotent_hint: Some(idempotent), + ..self + } + } + pub fn open_world(self, open_world: bool) -> Self { + ToolAnnotations { + open_world_hint: Some(open_world), + ..self + } + } + + /// If not set, defaults to true. + pub fn is_destructive(&self) -> bool { + self.destructive_hint.unwrap_or(true) + } + + /// If not set, defaults to false. + pub fn is_idempotent(&self) -> bool { + self.idempotent_hint.unwrap_or(false) + } +} + +impl Tool { + /// Create a new tool with the given name and description + pub fn new(name: N, description: D, input_schema: S) -> Self + where + N: Into>, + D: Into>, + S: Into>, + { + Tool { + name: name.into(), + title: None, + description: Some(description.into()), + input_schema: input_schema.into(), + output_schema: None, + annotations: None, + icons: None, + } + } + + pub fn annotate(self, annotations: ToolAnnotations) -> Self { + Tool { + annotations: Some(annotations), + ..self + } + } + + /// Set the output schema using a type that implements JsonSchema + pub fn with_output_schema(mut self) -> Self { + self.output_schema = Some(crate::handler::server::tool::cached_schema_for_type::()); + self + } + + /// Set the input schema using a type that implements JsonSchema + pub fn with_input_schema(mut self) -> Self { + self.input_schema = crate::handler::server::tool::cached_schema_for_type::(); + self + } + + /// Get the schema as json value + pub fn schema_as_json_value(&self) -> Value { + Value::Object(self.input_schema.as_ref().clone()) + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/service.rs b/code-rs/third_party/rmcp-0.8.3/src/service.rs new file mode 100644 index 00000000000..5fc8934fa8e --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/service.rs @@ -0,0 +1,854 @@ +use futures::{FutureExt, future::BoxFuture}; +use thiserror::Error; + +use crate::{ + error::ErrorData as McpError, + model::{ + CancelledNotification, CancelledNotificationParam, Extensions, GetExtensions, GetMeta, + JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, Meta, + NumberOrString, ProgressToken, RequestId, ServerJsonRpcMessage, + }, + transport::{DynamicTransportError, IntoTransport, Transport}, +}; +#[cfg(feature = "client")] +#[cfg_attr(docsrs, doc(cfg(feature = "client")))] +mod client; +#[cfg(feature = "client")] +#[cfg_attr(docsrs, doc(cfg(feature = "client")))] +pub use client::*; +#[cfg(feature = "server")] +#[cfg_attr(docsrs, doc(cfg(feature = "server")))] +mod server; +#[cfg(feature = "server")] +#[cfg_attr(docsrs, doc(cfg(feature = "server")))] +pub use server::*; +#[cfg(feature = "tower")] +#[cfg_attr(docsrs, doc(cfg(feature = "tower")))] +mod tower; +use tokio_util::sync::{CancellationToken, DropGuard}; +#[cfg(feature = "tower")] +#[cfg_attr(docsrs, doc(cfg(feature = "tower")))] +pub use tower::*; +use tracing::{Instrument as _, instrument}; +#[derive(Error, Debug)] +#[non_exhaustive] +pub enum ServiceError { + #[error("Mcp error: {0}")] + McpError(McpError), + #[error("Transport send error: {0}")] + TransportSend(DynamicTransportError), + #[error("Transport closed")] + TransportClosed, + #[error("Unexpected response type")] + UnexpectedResponse, + #[error("task cancelled for reason {}", reason.as_deref().unwrap_or(""))] + Cancelled { reason: Option }, + #[error("request timeout after {}", chrono::Duration::from_std(*timeout).unwrap_or_default())] + Timeout { timeout: Duration }, +} + +trait TransferObject: + std::fmt::Debug + Clone + serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static +{ +} + +impl TransferObject for T where + T: std::fmt::Debug + + serde::Serialize + + serde::de::DeserializeOwned + + Send + + Sync + + 'static + + Clone +{ +} + +#[allow(private_bounds, reason = "there's no the third implementation")] +pub trait ServiceRole: std::fmt::Debug + Send + Sync + 'static + Copy + Clone { + type Req: TransferObject + GetMeta + GetExtensions; + type Resp: TransferObject; + type Not: TryInto + + From + + TransferObject; + type PeerReq: TransferObject + GetMeta + GetExtensions; + type PeerResp: TransferObject; + type PeerNot: TryInto + + From + + TransferObject + + GetMeta + + GetExtensions; + type InitializeError; + const IS_CLIENT: bool; + type Info: TransferObject; + type PeerInfo: TransferObject; +} + +pub type TxJsonRpcMessage = + JsonRpcMessage<::Req, ::Resp, ::Not>; +pub type RxJsonRpcMessage = JsonRpcMessage< + ::PeerReq, + ::PeerResp, + ::PeerNot, +>; + +pub trait Service: Send + Sync + 'static { + fn handle_request( + &self, + request: R::PeerReq, + context: RequestContext, + ) -> impl Future> + Send + '_; + fn handle_notification( + &self, + notification: R::PeerNot, + context: NotificationContext, + ) -> impl Future> + Send + '_; + fn get_info(&self) -> R::Info; +} + +pub trait ServiceExt: Service + Sized { + /// Convert this service to a dynamic boxed service + /// + /// This could be very helpful when you want to store the services in a collection + fn into_dyn(self) -> Box> { + Box::new(self) + } + fn serve( + self, + transport: T, + ) -> impl Future, R::InitializeError>> + Send + where + T: IntoTransport, + E: std::error::Error + Send + Sync + 'static, + Self: Sized, + { + Self::serve_with_ct(self, transport, Default::default()) + } + fn serve_with_ct( + self, + transport: T, + ct: CancellationToken, + ) -> impl Future, R::InitializeError>> + Send + where + T: IntoTransport, + E: std::error::Error + Send + Sync + 'static, + Self: Sized; +} + +impl Service for Box> { + fn handle_request( + &self, + request: R::PeerReq, + context: RequestContext, + ) -> impl Future> + Send + '_ { + DynService::handle_request(self.as_ref(), request, context) + } + + fn handle_notification( + &self, + notification: R::PeerNot, + context: NotificationContext, + ) -> impl Future> + Send + '_ { + DynService::handle_notification(self.as_ref(), notification, context) + } + + fn get_info(&self) -> R::Info { + DynService::get_info(self.as_ref()) + } +} + +pub trait DynService: Send + Sync { + fn handle_request( + &self, + request: R::PeerReq, + context: RequestContext, + ) -> BoxFuture<'_, Result>; + fn handle_notification( + &self, + notification: R::PeerNot, + context: NotificationContext, + ) -> BoxFuture<'_, Result<(), McpError>>; + fn get_info(&self) -> R::Info; +} + +impl> DynService for S { + fn handle_request( + &self, + request: R::PeerReq, + context: RequestContext, + ) -> BoxFuture<'_, Result> { + Box::pin(self.handle_request(request, context)) + } + fn handle_notification( + &self, + notification: R::PeerNot, + context: NotificationContext, + ) -> BoxFuture<'_, Result<(), McpError>> { + Box::pin(self.handle_notification(notification, context)) + } + fn get_info(&self) -> R::Info { + self.get_info() + } +} + +use std::{ + collections::{HashMap, VecDeque}, + ops::Deref, + sync::{Arc, atomic::AtomicU64}, + time::Duration, +}; + +use tokio::sync::mpsc; + +pub trait RequestIdProvider: Send + Sync + 'static { + fn next_request_id(&self) -> RequestId; +} + +pub trait ProgressTokenProvider: Send + Sync + 'static { + fn next_progress_token(&self) -> ProgressToken; +} + +pub type AtomicU32RequestIdProvider = AtomicU32Provider; +pub type AtomicU32ProgressTokenProvider = AtomicU32Provider; + +#[derive(Debug, Default)] +pub struct AtomicU32Provider { + id: AtomicU64, +} + +impl RequestIdProvider for AtomicU32Provider { + fn next_request_id(&self) -> RequestId { + let id = self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + // Safe conversion: we start at 0 and increment by 1, so we won't overflow i64::MAX in practice + RequestId::Number(id as i64) + } +} + +impl ProgressTokenProvider for AtomicU32Provider { + fn next_progress_token(&self) -> ProgressToken { + let id = self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + ProgressToken(NumberOrString::Number(id as i64)) + } +} + +type Responder = tokio::sync::oneshot::Sender; + +/// A handle to a remote request +/// +/// You can cancel it by call [`RequestHandle::cancel`] with a reason, +/// +/// or wait for response by call [`RequestHandle::await_response`] +#[derive(Debug)] +pub struct RequestHandle { + pub rx: tokio::sync::oneshot::Receiver>, + pub options: PeerRequestOptions, + pub peer: Peer, + pub id: RequestId, + pub progress_token: ProgressToken, +} + +impl RequestHandle { + pub const REQUEST_TIMEOUT_REASON: &str = "request timeout"; + pub async fn await_response(self) -> Result { + if let Some(timeout) = self.options.timeout { + let timeout_result = tokio::time::timeout(timeout, async move { + self.rx.await.map_err(|_e| ServiceError::TransportClosed)? + }) + .await; + match timeout_result { + Ok(response) => response, + Err(_) => { + let error = Err(ServiceError::Timeout { timeout }); + // cancel this request + let notification = CancelledNotification { + params: CancelledNotificationParam { + request_id: self.id, + reason: Some(Self::REQUEST_TIMEOUT_REASON.to_owned()), + }, + method: crate::model::CancelledNotificationMethod, + extensions: Default::default(), + }; + let _ = self.peer.send_notification(notification.into()).await; + error + } + } + } else { + self.rx.await.map_err(|_e| ServiceError::TransportClosed)? + } + } + + /// Cancel this request + pub async fn cancel(self, reason: Option) -> Result<(), ServiceError> { + let notification = CancelledNotification { + params: CancelledNotificationParam { + request_id: self.id, + reason, + }, + method: crate::model::CancelledNotificationMethod, + extensions: Default::default(), + }; + self.peer.send_notification(notification.into()).await?; + Ok(()) + } +} + +#[derive(Debug)] +pub(crate) enum PeerSinkMessage { + Request { + request: R::Req, + id: RequestId, + responder: Responder>, + }, + Notification { + notification: R::Not, + responder: Responder>, + }, +} + +/// An interface to fetch the remote client or server +/// +/// For general purpose, call [`Peer::send_request`] or [`Peer::send_notification`] to send message to remote peer. +/// +/// To create a cancellable request, call [`Peer::send_request_with_option`]. +#[derive(Clone)] +pub struct Peer { + tx: mpsc::Sender>, + request_id_provider: Arc, + progress_token_provider: Arc, + info: Arc>, +} + +impl std::fmt::Debug for Peer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PeerSink") + .field("tx", &self.tx) + .field("is_client", &R::IS_CLIENT) + .finish() + } +} + +type ProxyOutbound = mpsc::Receiver>; + +#[derive(Debug, Default)] +pub struct PeerRequestOptions { + pub timeout: Option, + pub meta: Option, +} + +impl PeerRequestOptions { + pub fn no_options() -> Self { + Self::default() + } +} + +impl Peer { + const CLIENT_CHANNEL_BUFFER_SIZE: usize = 1024; + pub(crate) fn new( + request_id_provider: Arc, + peer_info: Option, + ) -> (Peer, ProxyOutbound) { + let (tx, rx) = mpsc::channel(Self::CLIENT_CHANNEL_BUFFER_SIZE); + ( + Self { + tx, + request_id_provider, + progress_token_provider: Arc::new(AtomicU32ProgressTokenProvider::default()), + info: Arc::new(tokio::sync::OnceCell::new_with(peer_info)), + }, + rx, + ) + } + pub async fn send_notification(&self, notification: R::Not) -> Result<(), ServiceError> { + let (responder, receiver) = tokio::sync::oneshot::channel(); + self.tx + .send(PeerSinkMessage::Notification { + notification, + responder, + }) + .await + .map_err(|_m| ServiceError::TransportClosed)?; + receiver.await.map_err(|_e| ServiceError::TransportClosed)? + } + pub async fn send_request(&self, request: R::Req) -> Result { + self.send_request_with_option(request, PeerRequestOptions::no_options()) + .await? + .await_response() + .await + } + + pub async fn send_cancellable_request( + &self, + request: R::Req, + options: PeerRequestOptions, + ) -> Result, ServiceError> { + self.send_request_with_option(request, options).await + } + + pub async fn send_request_with_option( + &self, + mut request: R::Req, + options: PeerRequestOptions, + ) -> Result, ServiceError> { + let id = self.request_id_provider.next_request_id(); + let progress_token = self.progress_token_provider.next_progress_token(); + request + .get_meta_mut() + .set_progress_token(progress_token.clone()); + if let Some(meta) = options.meta.clone() { + request.get_meta_mut().extend(meta); + } + let (responder, receiver) = tokio::sync::oneshot::channel(); + self.tx + .send(PeerSinkMessage::Request { + request, + id: id.clone(), + responder, + }) + .await + .map_err(|_m| ServiceError::TransportClosed)?; + Ok(RequestHandle { + id, + rx: receiver, + progress_token, + options, + peer: self.clone(), + }) + } + pub fn peer_info(&self) -> Option<&R::PeerInfo> { + self.info.get() + } + + pub fn set_peer_info(&self, info: R::PeerInfo) { + if self.info.initialized() { + tracing::warn!("trying to set peer info, which is already initialized"); + } else { + let _ = self.info.set(info); + } + } + + pub fn is_transport_closed(&self) -> bool { + self.tx.is_closed() + } +} + +#[derive(Debug)] +pub struct RunningService> { + service: Arc, + peer: Peer, + handle: tokio::task::JoinHandle, + cancellation_token: CancellationToken, + dg: DropGuard, +} +impl> Deref for RunningService { + type Target = Peer; + + fn deref(&self) -> &Self::Target { + &self.peer + } +} + +impl> RunningService { + #[inline] + pub fn peer(&self) -> &Peer { + &self.peer + } + #[inline] + pub fn service(&self) -> &S { + self.service.as_ref() + } + #[inline] + pub fn cancellation_token(&self) -> RunningServiceCancellationToken { + RunningServiceCancellationToken(self.cancellation_token.clone()) + } + #[inline] + pub async fn waiting(self) -> Result { + self.handle.await + } + pub async fn cancel(self) -> Result { + let RunningService { dg, handle, .. } = self; + dg.disarm().cancel(); + handle.await + } +} + +// use a wrapper type so we can tweak the implementation if needed +pub struct RunningServiceCancellationToken(CancellationToken); + +impl RunningServiceCancellationToken { + pub fn cancel(self) { + self.0.cancel(); + } +} + +#[derive(Debug)] +pub enum QuitReason { + Cancelled, + Closed, + JoinError(tokio::task::JoinError), +} + +/// Request execution context +#[derive(Debug, Clone)] +pub struct RequestContext { + /// this token will be cancelled when the [`CancelledNotification`] is received. + pub ct: CancellationToken, + pub id: RequestId, + pub meta: Meta, + pub extensions: Extensions, + /// An interface to fetch the remote client or server + pub peer: Peer, +} + +/// Request execution context +#[derive(Debug, Clone)] +pub struct NotificationContext { + pub meta: Meta, + pub extensions: Extensions, + /// An interface to fetch the remote client or server + pub peer: Peer, +} + +/// Use this function to skip initialization process +pub fn serve_directly( + service: S, + transport: T, + peer_info: Option, +) -> RunningService +where + R: ServiceRole, + S: Service, + T: IntoTransport, + E: std::error::Error + Send + Sync + 'static, +{ + serve_directly_with_ct(service, transport, peer_info, Default::default()) +} + +/// Use this function to skip initialization process +pub fn serve_directly_with_ct( + service: S, + transport: T, + peer_info: Option, + ct: CancellationToken, +) -> RunningService +where + R: ServiceRole, + S: Service, + T: IntoTransport, + E: std::error::Error + Send + Sync + 'static, +{ + let (peer, peer_rx) = Peer::new(Arc::new(AtomicU32RequestIdProvider::default()), peer_info); + serve_inner(service, transport.into_transport(), peer, peer_rx, ct) +} + +#[instrument(skip_all)] +fn serve_inner( + service: S, + transport: T, + peer: Peer, + mut peer_rx: tokio::sync::mpsc::Receiver>, + ct: CancellationToken, +) -> RunningService +where + R: ServiceRole, + S: Service, + T: Transport + 'static, +{ + const SINK_PROXY_BUFFER_SIZE: usize = 64; + let (sink_proxy_tx, mut sink_proxy_rx) = + tokio::sync::mpsc::channel::>(SINK_PROXY_BUFFER_SIZE); + let peer_info = peer.peer_info(); + if R::IS_CLIENT { + tracing::info!(?peer_info, "Service initialized as client"); + } else { + tracing::info!(?peer_info, "Service initialized as server"); + } + + let mut local_responder_pool = + HashMap::>>::new(); + let mut local_ct_pool = HashMap::::new(); + let shared_service = Arc::new(service); + // for return + let service = shared_service.clone(); + + // let message_sink = tokio::sync:: + // let mut stream = std::pin::pin!(stream); + let serve_loop_ct = ct.child_token(); + let peer_return: Peer = peer.clone(); + let current_span = tracing::Span::current(); + let handle = tokio::spawn(async move { + let mut transport = transport.into_transport(); + let mut batch_messages = VecDeque::>::new(); + let mut send_task_set = tokio::task::JoinSet::::new(); + #[derive(Debug)] + enum SendTaskResult { + Request { + id: RequestId, + result: Result<(), DynamicTransportError>, + }, + Notification { + responder: Responder>, + cancellation_param: Option, + result: Result<(), DynamicTransportError>, + }, + } + #[derive(Debug)] + enum Event { + ProxyMessage(PeerSinkMessage), + PeerMessage(RxJsonRpcMessage), + ToSink(TxJsonRpcMessage), + SendTaskResult(SendTaskResult), + } + + let quit_reason = loop { + let evt = if let Some(m) = batch_messages.pop_front() { + Event::PeerMessage(m) + } else { + tokio::select! { + m = sink_proxy_rx.recv(), if !sink_proxy_rx.is_closed() => { + if let Some(m) = m { + Event::ToSink(m) + } else { + continue + } + } + m = transport.receive() => { + if let Some(m) = m { + Event::PeerMessage(m) + } else { + // input stream closed + tracing::info!("input stream terminated"); + break QuitReason::Closed + } + } + m = peer_rx.recv(), if !peer_rx.is_closed() => { + if let Some(m) = m { + Event::ProxyMessage(m) + } else { + continue + } + } + m = send_task_set.join_next(), if !send_task_set.is_empty() => { + let Some(result) = m else { + continue + }; + match result { + Err(e) => { + // join error, which is serious, we should quit. + tracing::error!(%e, "send request task encounter a tokio join error"); + break QuitReason::JoinError(e) + } + Ok(result) => { + Event::SendTaskResult(result) + } + } + } + _ = serve_loop_ct.cancelled() => { + tracing::info!("task cancelled"); + break QuitReason::Cancelled + } + } + }; + + tracing::trace!(?evt, "new event"); + match evt { + Event::SendTaskResult(SendTaskResult::Request { id, result }) => { + if let Err(e) = result { + if let Some(responder) = local_responder_pool.remove(&id) { + let _ = responder.send(Err(ServiceError::TransportSend(e))); + } + } + } + Event::SendTaskResult(SendTaskResult::Notification { + responder, + result, + cancellation_param, + }) => { + let response = if let Err(e) = result { + Err(ServiceError::TransportSend(e)) + } else { + Ok(()) + }; + let _ = responder.send(response); + if let Some(param) = cancellation_param { + if let Some(responder) = local_responder_pool.remove(¶m.request_id) { + tracing::info!(id = %param.request_id, reason = param.reason, "cancelled"); + let _response_result = responder.send(Err(ServiceError::Cancelled { + reason: param.reason.clone(), + })); + } + } + } + // response and error + Event::ToSink(m) => { + if let Some(id) = match &m { + JsonRpcMessage::Response(response) => Some(&response.id), + JsonRpcMessage::Error(error) => Some(&error.id), + _ => None, + } { + if let Some(ct) = local_ct_pool.remove(id) { + ct.cancel(); + } + let send = transport.send(m); + let current_span = tracing::Span::current(); + tokio::spawn(async move { + let send_result = send.await; + if let Err(error) = send_result { + tracing::error!(%error, "fail to response message"); + } + }.instrument(current_span)); + } + } + Event::ProxyMessage(PeerSinkMessage::Request { + request, + id, + responder, + }) => { + local_responder_pool.insert(id.clone(), responder); + let send = transport.send(JsonRpcMessage::request(request, id.clone())); + { + let id = id.clone(); + let current_span = tracing::Span::current(); + send_task_set.spawn(send.map(move |r| SendTaskResult::Request { + id, + result: r.map_err(DynamicTransportError::new::), + }).instrument(current_span)); + } + } + Event::ProxyMessage(PeerSinkMessage::Notification { + notification, + responder, + }) => { + // catch cancellation notification + let mut cancellation_param = None; + let notification = match notification.try_into() { + Ok::(cancelled) => { + cancellation_param.replace(cancelled.params.clone()); + cancelled.into() + } + Err(notification) => notification, + }; + let send = transport.send(JsonRpcMessage::notification(notification)); + let current_span = tracing::Span::current(); + send_task_set.spawn(send.map(move |result| SendTaskResult::Notification { + responder, + cancellation_param, + result: result.map_err(DynamicTransportError::new::), + }).instrument(current_span)); + } + Event::PeerMessage(JsonRpcMessage::Request(JsonRpcRequest { + id, + mut request, + .. + })) => { + tracing::debug!(%id, ?request, "received request"); + { + let service = shared_service.clone(); + let sink = sink_proxy_tx.clone(); + let request_ct = serve_loop_ct.child_token(); + let context_ct = request_ct.child_token(); + local_ct_pool.insert(id.clone(), request_ct); + let mut extensions = Extensions::new(); + let mut meta = Meta::new(); + // avoid clone + // swap meta firstly, otherwise progress token will be lost + std::mem::swap(&mut meta, request.get_meta_mut()); + std::mem::swap(&mut extensions, request.extensions_mut()); + let context = RequestContext { + ct: context_ct, + id: id.clone(), + peer: peer.clone(), + meta, + extensions, + }; + let current_span = tracing::Span::current(); + tokio::spawn(async move { + let result = service + .handle_request(request, context) + .await; + let response = match result { + Ok(result) => { + tracing::debug!(%id, ?result, "response message"); + JsonRpcMessage::response(result, id) + } + Err(error) => { + tracing::warn!(%id, ?error, "response error"); + JsonRpcMessage::error(error, id) + } + }; + let _send_result = sink.send(response).await; + }.instrument(current_span)); + } + } + Event::PeerMessage(JsonRpcMessage::Notification(JsonRpcNotification { + notification, + .. + })) => { + tracing::info!(?notification, "received notification"); + // catch cancelled notification + let mut notification = match notification.try_into() { + Ok::(cancelled) => { + if let Some(ct) = local_ct_pool.remove(&cancelled.params.request_id) { + tracing::info!(id = %cancelled.params.request_id, reason = cancelled.params.reason, "cancelled"); + ct.cancel(); + } + cancelled.into() + } + Err(notification) => notification, + }; + { + let service = shared_service.clone(); + let mut extensions = Extensions::new(); + let mut meta = Meta::new(); + // avoid clone + std::mem::swap(&mut extensions, notification.extensions_mut()); + std::mem::swap(&mut meta, notification.get_meta_mut()); + let context = NotificationContext { + peer: peer.clone(), + meta, + extensions, + }; + let current_span = tracing::Span::current(); + tokio::spawn(async move { + let result = service.handle_notification(notification, context).await; + if let Err(error) = result { + tracing::warn!(%error, "Error sending notification"); + } + }.instrument(current_span)); + } + } + Event::PeerMessage(JsonRpcMessage::Response(JsonRpcResponse { + result, + id, + .. + })) => { + if let Some(responder) = local_responder_pool.remove(&id) { + let response_result = responder.send(Ok(result)); + if let Err(_error) = response_result { + tracing::warn!(%id, "Error sending response"); + } + } + } + Event::PeerMessage(JsonRpcMessage::Error(JsonRpcError { error, id, .. })) => { + if let Some(responder) = local_responder_pool.remove(&id) { + let _response_result = responder.send(Err(ServiceError::McpError(error))); + if let Err(_error) = _response_result { + tracing::warn!(%id, "Error sending response"); + } + } + } + } + }; + let sink_close_result = transport.close().await; + if let Err(e) = sink_close_result { + tracing::error!(%e, "fail to close sink"); + } + tracing::info!(?quit_reason, "serve finished"); + quit_reason + }.instrument(current_span)); + RunningService { + service, + peer: peer_return, + handle, + cancellation_token: ct.clone(), + dg: ct.drop_guard(), + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/service/client.rs b/code-rs/third_party/rmcp-0.8.3/src/service/client.rs new file mode 100644 index 00000000000..75cdc8fa615 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/service/client.rs @@ -0,0 +1,534 @@ +use std::borrow::Cow; + +use thiserror::Error; + +use super::*; +use crate::{ + model::{ + ArgumentInfo, CallToolRequest, CallToolRequestParam, CallToolResult, CancelledNotification, + CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage, ClientNotification, + ClientRequest, ClientResult, CompleteRequest, CompleteRequestParam, CompleteResult, + CompletionContext, CompletionInfo, GetPromptRequest, GetPromptRequestParam, + GetPromptResult, InitializeRequest, InitializedNotification, JsonRpcResponse, + ListPromptsRequest, ListPromptsResult, ListResourceTemplatesRequest, + ListResourceTemplatesResult, ListResourcesRequest, ListResourcesResult, ListToolsRequest, + ListToolsResult, PaginatedRequestParam, ProgressNotification, ProgressNotificationParam, + ReadResourceRequest, ReadResourceRequestParam, ReadResourceResult, Reference, RequestId, + RootsListChangedNotification, ServerInfo, ServerJsonRpcMessage, ServerNotification, + ServerRequest, ServerResult, SetLevelRequest, SetLevelRequestParam, SubscribeRequest, + SubscribeRequestParam, UnsubscribeRequest, UnsubscribeRequestParam, + }, + transport::DynamicTransportError, +}; + +/// It represents the error that may occur when serving the client. +/// +/// if you want to handle the error, you can use `serve_client_with_ct` or `serve_client` with `Result, ClientError>` +#[derive(Error, Debug)] +pub enum ClientInitializeError { + #[error("expect initialized response, but received: {0:?}")] + ExpectedInitResponse(Option), + + #[error("expect initialized result, but received: {0:?}")] + ExpectedInitResult(Option), + + #[error("conflict initialized response id: expected {0}, got {1}")] + ConflictInitResponseId(RequestId, RequestId), + + #[error("connection closed: {0}")] + ConnectionClosed(String), + + #[error("Send message error {error}, when {context}")] + TransportError { + error: DynamicTransportError, + context: Cow<'static, str>, + }, + + #[error("Cancelled")] + Cancelled, +} + +impl ClientInitializeError { + pub fn transport + 'static>( + error: T::Error, + context: impl Into>, + ) -> Self { + Self::TransportError { + error: DynamicTransportError::new::(error), + context: context.into(), + } + } +} + +/// Helper function to get the next message from the stream +async fn expect_next_message( + transport: &mut T, + context: &str, +) -> Result +where + T: Transport, +{ + transport + .receive() + .await + .ok_or_else(|| ClientInitializeError::ConnectionClosed(context.to_string())) +} + +/// Helper function to expect a response from the stream +async fn expect_response( + transport: &mut T, + context: &str, + service: &S, + peer: Peer, +) -> Result<(ServerResult, RequestId), ClientInitializeError> +where + T: Transport, + S: Service, +{ + loop { + let message = expect_next_message(transport, context).await?; + match message { + // Expected message to complete the initialization + ServerJsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => { + break Ok((result, id)); + } + // Server could send logging messages before handshake + ServerJsonRpcMessage::Notification(mut notification) => { + let ServerNotification::LoggingMessageNotification(logging) = + &mut notification.notification + else { + tracing::warn!(?notification, "Received unexpected message"); + continue; + }; + + let mut context = NotificationContext { + peer: peer.clone(), + meta: Meta::default(), + extensions: Extensions::default(), + }; + + if let Some(meta) = logging.extensions.get_mut::() { + std::mem::swap(&mut context.meta, meta); + } + std::mem::swap(&mut context.extensions, &mut logging.extensions); + + if let Err(error) = service + .handle_notification(notification.notification, context) + .await + { + tracing::warn!(?error, "Handle logging before handshake failed."); + } + } + // Server could send pings before handshake + ServerJsonRpcMessage::Request(ref request) + if matches!(request.request, ServerRequest::PingRequest(_)) => + { + tracing::trace!("Received ping request. Ignored.") + } + // Server SHOULD NOT send any other messages before handshake. We ignore them anyway + _ => tracing::warn!(?message, "Received unexpected message"), + } + } +} + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub struct RoleClient; + +impl ServiceRole for RoleClient { + type Req = ClientRequest; + type Resp = ClientResult; + type Not = ClientNotification; + type PeerReq = ServerRequest; + type PeerResp = ServerResult; + type PeerNot = ServerNotification; + type Info = ClientInfo; + type PeerInfo = ServerInfo; + type InitializeError = ClientInitializeError; + const IS_CLIENT: bool = true; +} + +pub type ServerSink = Peer; + +impl> ServiceExt for S { + fn serve_with_ct( + self, + transport: T, + ct: CancellationToken, + ) -> impl Future, ClientInitializeError>> + Send + where + T: IntoTransport, + E: std::error::Error + Send + Sync + 'static, + Self: Sized, + { + serve_client_with_ct(self, transport, ct) + } +} + +pub async fn serve_client( + service: S, + transport: T, +) -> Result, ClientInitializeError> +where + S: Service, + T: IntoTransport, + E: std::error::Error + Send + Sync + 'static, +{ + serve_client_with_ct(service, transport, Default::default()).await +} + +pub async fn serve_client_with_ct( + service: S, + transport: T, + ct: CancellationToken, +) -> Result, ClientInitializeError> +where + S: Service, + T: IntoTransport, + E: std::error::Error + Send + Sync + 'static, +{ + tokio::select! { + result = serve_client_with_ct_inner(service, transport.into_transport(), ct.clone()) => { result } + _ = ct.cancelled() => { + Err(ClientInitializeError::Cancelled) + } + } +} + +async fn serve_client_with_ct_inner( + service: S, + transport: T, + ct: CancellationToken, +) -> Result, ClientInitializeError> +where + S: Service, + T: Transport + 'static, +{ + let mut transport = transport.into_transport(); + let id_provider = >::default(); + + // service + let id = id_provider.next_request_id(); + let init_request = InitializeRequest { + method: Default::default(), + params: service.get_info(), + extensions: Default::default(), + }; + transport + .send(ClientJsonRpcMessage::request( + ClientRequest::InitializeRequest(init_request), + id.clone(), + )) + .await + .map_err(|error| ClientInitializeError::TransportError { + error: DynamicTransportError::new::(error), + context: "send initialize request".into(), + })?; + + let (peer, peer_rx) = Peer::new(id_provider, None); + + let (response, response_id) = expect_response( + &mut transport, + "initialize response", + &service, + peer.clone(), + ) + .await?; + + if id != response_id { + return Err(ClientInitializeError::ConflictInitResponseId( + id, + response_id, + )); + } + + let ServerResult::InitializeResult(initialize_result) = response else { + return Err(ClientInitializeError::ExpectedInitResult(Some(response))); + }; + peer.set_peer_info(initialize_result); + + // send notification + let notification = ClientJsonRpcMessage::notification( + ClientNotification::InitializedNotification(InitializedNotification { + method: Default::default(), + extensions: Default::default(), + }), + ); + transport.send(notification).await.map_err(|error| { + ClientInitializeError::transport::(error, "send initialized notification") + })?; + Ok(serve_inner(service, transport, peer, peer_rx, ct)) +} + +macro_rules! method { + (peer_req $method:ident $Req:ident() => $Resp: ident ) => { + pub async fn $method(&self) -> Result<$Resp, ServiceError> { + let result = self + .send_request(ClientRequest::$Req($Req { + method: Default::default(), + })) + .await?; + match result { + ServerResult::$Resp(result) => Ok(result), + _ => Err(ServiceError::UnexpectedResponse), + } + } + }; + (peer_req $method:ident $Req:ident($Param: ident) => $Resp: ident ) => { + pub async fn $method(&self, params: $Param) -> Result<$Resp, ServiceError> { + let result = self + .send_request(ClientRequest::$Req($Req { + method: Default::default(), + params, + extensions: Default::default(), + })) + .await?; + match result { + ServerResult::$Resp(result) => Ok(result), + _ => Err(ServiceError::UnexpectedResponse), + } + } + }; + (peer_req $method:ident $Req:ident($Param: ident)? => $Resp: ident ) => { + pub async fn $method(&self, params: Option<$Param>) -> Result<$Resp, ServiceError> { + let result = self + .send_request(ClientRequest::$Req($Req { + method: Default::default(), + params, + extensions: Default::default(), + })) + .await?; + match result { + ServerResult::$Resp(result) => Ok(result), + _ => Err(ServiceError::UnexpectedResponse), + } + } + }; + (peer_req $method:ident $Req:ident($Param: ident)) => { + pub async fn $method(&self, params: $Param) -> Result<(), ServiceError> { + let result = self + .send_request(ClientRequest::$Req($Req { + method: Default::default(), + params, + extensions: Default::default(), + })) + .await?; + match result { + ServerResult::EmptyResult(_) => Ok(()), + _ => Err(ServiceError::UnexpectedResponse), + } + } + }; + + (peer_not $method:ident $Not:ident($Param: ident)) => { + pub async fn $method(&self, params: $Param) -> Result<(), ServiceError> { + self.send_notification(ClientNotification::$Not($Not { + method: Default::default(), + params, + extensions: Default::default(), + })) + .await?; + Ok(()) + } + }; + (peer_not $method:ident $Not:ident) => { + pub async fn $method(&self) -> Result<(), ServiceError> { + self.send_notification(ClientNotification::$Not($Not { + method: Default::default(), + extensions: Default::default(), + })) + .await?; + Ok(()) + } + }; +} + +impl Peer { + method!(peer_req complete CompleteRequest(CompleteRequestParam) => CompleteResult); + method!(peer_req set_level SetLevelRequest(SetLevelRequestParam)); + method!(peer_req get_prompt GetPromptRequest(GetPromptRequestParam) => GetPromptResult); + method!(peer_req list_prompts ListPromptsRequest(PaginatedRequestParam)? => ListPromptsResult); + method!(peer_req list_resources ListResourcesRequest(PaginatedRequestParam)? => ListResourcesResult); + method!(peer_req list_resource_templates ListResourceTemplatesRequest(PaginatedRequestParam)? => ListResourceTemplatesResult); + method!(peer_req read_resource ReadResourceRequest(ReadResourceRequestParam) => ReadResourceResult); + method!(peer_req subscribe SubscribeRequest(SubscribeRequestParam) ); + method!(peer_req unsubscribe UnsubscribeRequest(UnsubscribeRequestParam)); + method!(peer_req call_tool CallToolRequest(CallToolRequestParam) => CallToolResult); + method!(peer_req list_tools ListToolsRequest(PaginatedRequestParam)? => ListToolsResult); + + method!(peer_not notify_cancelled CancelledNotification(CancelledNotificationParam)); + method!(peer_not notify_progress ProgressNotification(ProgressNotificationParam)); + method!(peer_not notify_initialized InitializedNotification); + method!(peer_not notify_roots_list_changed RootsListChangedNotification); +} + +impl Peer { + /// A wrapper method for [`Peer::list_tools`]. + /// + /// This function will call [`Peer::list_tools`] multiple times until all tools are listed. + pub async fn list_all_tools(&self) -> Result, ServiceError> { + let mut tools = Vec::new(); + let mut cursor = None; + loop { + let result = self + .list_tools(Some(PaginatedRequestParam { cursor })) + .await?; + tools.extend(result.tools); + cursor = result.next_cursor; + if cursor.is_none() { + break; + } + } + Ok(tools) + } + + /// A wrapper method for [`Peer::list_prompts`]. + /// + /// This function will call [`Peer::list_prompts`] multiple times until all prompts are listed. + pub async fn list_all_prompts(&self) -> Result, ServiceError> { + let mut prompts = Vec::new(); + let mut cursor = None; + loop { + let result = self + .list_prompts(Some(PaginatedRequestParam { cursor })) + .await?; + prompts.extend(result.prompts); + cursor = result.next_cursor; + if cursor.is_none() { + break; + } + } + Ok(prompts) + } + + /// A wrapper method for [`Peer::list_resources`]. + /// + /// This function will call [`Peer::list_resources`] multiple times until all resources are listed. + pub async fn list_all_resources(&self) -> Result, ServiceError> { + let mut resources = Vec::new(); + let mut cursor = None; + loop { + let result = self + .list_resources(Some(PaginatedRequestParam { cursor })) + .await?; + resources.extend(result.resources); + cursor = result.next_cursor; + if cursor.is_none() { + break; + } + } + Ok(resources) + } + + /// A wrapper method for [`Peer::list_resource_templates`]. + /// + /// This function will call [`Peer::list_resource_templates`] multiple times until all resource templates are listed. + pub async fn list_all_resource_templates( + &self, + ) -> Result, ServiceError> { + let mut resource_templates = Vec::new(); + let mut cursor = None; + loop { + let result = self + .list_resource_templates(Some(PaginatedRequestParam { cursor })) + .await?; + resource_templates.extend(result.resource_templates); + cursor = result.next_cursor; + if cursor.is_none() { + break; + } + } + Ok(resource_templates) + } + + /// Convenient method to get completion suggestions for a prompt argument + /// + /// # Arguments + /// * `prompt_name` - Name of the prompt being completed + /// * `argument_name` - Name of the argument being completed + /// * `current_value` - Current partial value of the argument + /// * `context` - Optional context with previously resolved arguments + /// + /// # Returns + /// CompletionInfo with suggestions for the specified prompt argument + pub async fn complete_prompt_argument( + &self, + prompt_name: impl Into, + argument_name: impl Into, + current_value: impl Into, + context: Option, + ) -> Result { + let request = CompleteRequestParam { + r#ref: Reference::for_prompt(prompt_name), + argument: ArgumentInfo { + name: argument_name.into(), + value: current_value.into(), + }, + context, + }; + + let result = self.complete(request).await?; + Ok(result.completion) + } + + /// Convenient method to get completion suggestions for a resource URI argument + /// + /// # Arguments + /// * `uri_template` - URI template pattern being completed + /// * `argument_name` - Name of the URI parameter being completed + /// * `current_value` - Current partial value of the parameter + /// * `context` - Optional context with previously resolved arguments + /// + /// # Returns + /// CompletionInfo with suggestions for the specified resource URI argument + pub async fn complete_resource_argument( + &self, + uri_template: impl Into, + argument_name: impl Into, + current_value: impl Into, + context: Option, + ) -> Result { + let request = CompleteRequestParam { + r#ref: Reference::for_resource(uri_template), + argument: ArgumentInfo { + name: argument_name.into(), + value: current_value.into(), + }, + context, + }; + + let result = self.complete(request).await?; + Ok(result.completion) + } + + /// Simple completion for a prompt argument without context + /// + /// This is a convenience wrapper around `complete_prompt_argument` for + /// simple completion scenarios that don't require context awareness. + pub async fn complete_prompt_simple( + &self, + prompt_name: impl Into, + argument_name: impl Into, + current_value: impl Into, + ) -> Result, ServiceError> { + let completion = self + .complete_prompt_argument(prompt_name, argument_name, current_value, None) + .await?; + Ok(completion.values) + } + + /// Simple completion for a resource URI argument without context + /// + /// This is a convenience wrapper around `complete_resource_argument` for + /// simple completion scenarios that don't require context awareness. + pub async fn complete_resource_simple( + &self, + uri_template: impl Into, + argument_name: impl Into, + current_value: impl Into, + ) -> Result, ServiceError> { + let completion = self + .complete_resource_argument(uri_template, argument_name, current_value, None) + .await?; + Ok(completion.values) + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/service/server.rs b/code-rs/third_party/rmcp-0.8.3/src/service/server.rs new file mode 100644 index 00000000000..82a7e7d82e3 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/service/server.rs @@ -0,0 +1,716 @@ +use std::borrow::Cow; + +use thiserror::Error; + +use super::*; +#[cfg(feature = "elicitation")] +use crate::model::{ + CreateElicitationRequest, CreateElicitationRequestParam, CreateElicitationResult, +}; +use crate::{ + model::{ + CancelledNotification, CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage, + ClientNotification, ClientRequest, ClientResult, CreateMessageRequest, + CreateMessageRequestParam, CreateMessageResult, ErrorData, ListRootsRequest, + ListRootsResult, LoggingMessageNotification, LoggingMessageNotificationParam, + ProgressNotification, ProgressNotificationParam, PromptListChangedNotification, + ProtocolVersion, ResourceListChangedNotification, ResourceUpdatedNotification, + ResourceUpdatedNotificationParam, ServerInfo, ServerNotification, ServerRequest, + ServerResult, ToolListChangedNotification, + }, + transport::DynamicTransportError, +}; + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub struct RoleServer; + +impl ServiceRole for RoleServer { + type Req = ServerRequest; + type Resp = ServerResult; + type Not = ServerNotification; + type PeerReq = ClientRequest; + type PeerResp = ClientResult; + type PeerNot = ClientNotification; + type Info = ServerInfo; + type PeerInfo = ClientInfo; + + type InitializeError = ServerInitializeError; + const IS_CLIENT: bool = false; +} + +/// It represents the error that may occur when serving the server. +/// +/// if you want to handle the error, you can use `serve_server_with_ct` or `serve_server` with `Result, ServerError>` +#[derive(Error, Debug)] +pub enum ServerInitializeError { + #[error("expect initialized request, but received: {0:?}")] + ExpectedInitializeRequest(Option), + + #[error("expect initialized notification, but received: {0:?}")] + ExpectedInitializedNotification(Option), + + #[error("connection closed: {0}")] + ConnectionClosed(String), + + #[error("unexpected initialize result: {0:?}")] + UnexpectedInitializeResponse(ServerResult), + + #[error("initialize failed: {0}")] + InitializeFailed(ErrorData), + + #[error("unsupported protocol version: {0}")] + UnsupportedProtocolVersion(ProtocolVersion), + + #[error("Send message error {error}, when {context}")] + TransportError { + error: DynamicTransportError, + context: Cow<'static, str>, + }, + + #[error("Cancelled")] + Cancelled, +} + +impl ServerInitializeError { + pub fn transport + 'static>( + error: T::Error, + context: impl Into>, + ) -> Self { + Self::TransportError { + error: DynamicTransportError::new::(error), + context: context.into(), + } + } +} +pub type ClientSink = Peer; + +impl> ServiceExt for S { + fn serve_with_ct( + self, + transport: T, + ct: CancellationToken, + ) -> impl Future, ServerInitializeError>> + Send + where + T: IntoTransport, + E: std::error::Error + Send + Sync + 'static, + Self: Sized, + { + serve_server_with_ct(self, transport, ct) + } +} + +pub async fn serve_server( + service: S, + transport: T, +) -> Result, ServerInitializeError> +where + S: Service, + T: IntoTransport, + E: std::error::Error + Send + Sync + 'static, +{ + serve_server_with_ct(service, transport, CancellationToken::new()).await +} + +/// Helper function to get the next message from the stream +async fn expect_next_message( + transport: &mut T, + context: &str, +) -> Result +where + T: Transport, +{ + transport + .receive() + .await + .ok_or_else(|| ServerInitializeError::ConnectionClosed(context.to_string())) +} + +/// Helper function to expect a request from the stream +async fn expect_request( + transport: &mut T, + context: &str, +) -> Result<(ClientRequest, RequestId), ServerInitializeError> +where + T: Transport, +{ + let msg = expect_next_message(transport, context).await?; + let msg_clone = msg.clone(); + msg.into_request() + .ok_or(ServerInitializeError::ExpectedInitializeRequest(Some( + msg_clone, + ))) +} + +/// Helper function to expect a notification from the stream +async fn expect_notification( + transport: &mut T, + context: &str, +) -> Result +where + T: Transport, +{ + let msg = expect_next_message(transport, context).await?; + let msg_clone = msg.clone(); + msg.into_notification() + .ok_or(ServerInitializeError::ExpectedInitializedNotification( + Some(msg_clone), + )) +} + +pub async fn serve_server_with_ct( + service: S, + transport: T, + ct: CancellationToken, +) -> Result, ServerInitializeError> +where + S: Service, + T: IntoTransport, + E: std::error::Error + Send + Sync + 'static, +{ + tokio::select! { + result = serve_server_with_ct_inner(service, transport.into_transport(), ct.clone()) => { result } + _ = ct.cancelled() => { + Err(ServerInitializeError::Cancelled) + } + } +} + +async fn serve_server_with_ct_inner( + service: S, + transport: T, + ct: CancellationToken, +) -> Result, ServerInitializeError> +where + S: Service, + T: Transport + 'static, +{ + let mut transport = transport.into_transport(); + let id_provider = >::default(); + + // Get initialize request + let (request, id) = expect_request(&mut transport, "initialized request").await?; + + let ClientRequest::InitializeRequest(peer_info) = &request else { + return Err(ServerInitializeError::ExpectedInitializeRequest(Some( + ClientJsonRpcMessage::request(request, id), + ))); + }; + let (peer, peer_rx) = Peer::new(id_provider, Some(peer_info.params.clone())); + let context = RequestContext { + ct: ct.child_token(), + id: id.clone(), + meta: request.get_meta().clone(), + extensions: request.extensions().clone(), + peer: peer.clone(), + }; + // Send initialize response + let init_response = service.handle_request(request.clone(), context).await; + let mut init_response = match init_response { + Ok(ServerResult::InitializeResult(init_response)) => init_response, + Ok(result) => { + return Err(ServerInitializeError::UnexpectedInitializeResponse(result)); + } + Err(e) => { + transport + .send(ServerJsonRpcMessage::error(e.clone(), id)) + .await + .map_err(|error| { + ServerInitializeError::transport::(error, "sending error response") + })?; + return Err(ServerInitializeError::InitializeFailed(e)); + } + }; + let peer_protocol_version = peer_info.params.protocol_version.clone(); + let protocol_version = match peer_protocol_version + .partial_cmp(&init_response.protocol_version) + .ok_or(ServerInitializeError::UnsupportedProtocolVersion( + peer_protocol_version, + ))? { + std::cmp::Ordering::Less => peer_info.params.protocol_version.clone(), + _ => init_response.protocol_version, + }; + init_response.protocol_version = protocol_version; + transport + .send(ServerJsonRpcMessage::response( + ServerResult::InitializeResult(init_response), + id, + )) + .await + .map_err(|error| { + ServerInitializeError::transport::(error, "sending initialize response") + })?; + + // Wait for initialize notification + let notification = expect_notification(&mut transport, "initialize notification").await?; + let ClientNotification::InitializedNotification(_) = notification else { + return Err(ServerInitializeError::ExpectedInitializedNotification( + Some(ClientJsonRpcMessage::notification(notification)), + )); + }; + let context = NotificationContext { + meta: notification.get_meta().clone(), + extensions: notification.extensions().clone(), + peer: peer.clone(), + }; + let _ = service.handle_notification(notification, context).await; + // Continue processing service + Ok(serve_inner(service, transport, peer, peer_rx, ct)) +} + +macro_rules! method { + (peer_req $method:ident $Req:ident() => $Resp: ident ) => { + pub async fn $method(&self) -> Result<$Resp, ServiceError> { + let result = self + .send_request(ServerRequest::$Req($Req { + method: Default::default(), + extensions: Default::default(), + })) + .await?; + match result { + ClientResult::$Resp(result) => Ok(result), + _ => Err(ServiceError::UnexpectedResponse), + } + } + }; + (peer_req $method:ident $Req:ident($Param: ident) => $Resp: ident ) => { + pub async fn $method(&self, params: $Param) -> Result<$Resp, ServiceError> { + let result = self + .send_request(ServerRequest::$Req($Req { + method: Default::default(), + params, + extensions: Default::default(), + })) + .await?; + match result { + ClientResult::$Resp(result) => Ok(result), + _ => Err(ServiceError::UnexpectedResponse), + } + } + }; + (peer_req $method:ident $Req:ident($Param: ident)) => { + pub fn $method( + &self, + params: $Param, + ) -> impl Future> + Send + '_ { + async move { + let result = self + .send_request(ServerRequest::$Req($Req { + method: Default::default(), + params, + })) + .await?; + match result { + ClientResult::EmptyResult(_) => Ok(()), + _ => Err(ServiceError::UnexpectedResponse), + } + } + } + }; + + (peer_not $method:ident $Not:ident($Param: ident)) => { + pub async fn $method(&self, params: $Param) -> Result<(), ServiceError> { + self.send_notification(ServerNotification::$Not($Not { + method: Default::default(), + params, + extensions: Default::default(), + })) + .await?; + Ok(()) + } + }; + (peer_not $method:ident $Not:ident) => { + pub async fn $method(&self) -> Result<(), ServiceError> { + self.send_notification(ServerNotification::$Not($Not { + method: Default::default(), + extensions: Default::default(), + })) + .await?; + Ok(()) + } + }; + + // Timeout-only variants (base method should be created separately with peer_req) + (peer_req_with_timeout $method_with_timeout:ident $Req:ident() => $Resp: ident) => { + pub async fn $method_with_timeout( + &self, + timeout: Option, + ) -> Result<$Resp, ServiceError> { + let request = ServerRequest::$Req($Req { + method: Default::default(), + extensions: Default::default(), + }); + let options = crate::service::PeerRequestOptions { + timeout, + meta: None, + }; + let result = self + .send_request_with_option(request, options) + .await? + .await_response() + .await?; + match result { + ClientResult::$Resp(result) => Ok(result), + _ => Err(ServiceError::UnexpectedResponse), + } + } + }; + + (peer_req_with_timeout $method_with_timeout:ident $Req:ident($Param: ident) => $Resp: ident) => { + pub async fn $method_with_timeout( + &self, + params: $Param, + timeout: Option, + ) -> Result<$Resp, ServiceError> { + let request = ServerRequest::$Req($Req { + method: Default::default(), + params, + extensions: Default::default(), + }); + let options = crate::service::PeerRequestOptions { + timeout, + meta: None, + }; + let result = self + .send_request_with_option(request, options) + .await? + .await_response() + .await?; + match result { + ClientResult::$Resp(result) => Ok(result), + _ => Err(ServiceError::UnexpectedResponse), + } + } + }; +} + +impl Peer { + pub async fn create_message( + &self, + params: CreateMessageRequestParam, + ) -> Result { + let result = self + .send_request(ServerRequest::CreateMessageRequest(CreateMessageRequest { + method: Default::default(), + params, + extensions: Default::default(), + })) + .await?; + match result { + ClientResult::CreateMessageResult(result) => Ok(*result), + _ => Err(ServiceError::UnexpectedResponse), + } + } + method!(peer_req list_roots ListRootsRequest() => ListRootsResult); + #[cfg(feature = "elicitation")] + method!(peer_req create_elicitation CreateElicitationRequest(CreateElicitationRequestParam) => CreateElicitationResult); + #[cfg(feature = "elicitation")] + method!(peer_req_with_timeout create_elicitation_with_timeout CreateElicitationRequest(CreateElicitationRequestParam) => CreateElicitationResult); + + method!(peer_not notify_cancelled CancelledNotification(CancelledNotificationParam)); + method!(peer_not notify_progress ProgressNotification(ProgressNotificationParam)); + method!(peer_not notify_logging_message LoggingMessageNotification(LoggingMessageNotificationParam)); + method!(peer_not notify_resource_updated ResourceUpdatedNotification(ResourceUpdatedNotificationParam)); + method!(peer_not notify_resource_list_changed ResourceListChangedNotification); + method!(peer_not notify_tool_list_changed ToolListChangedNotification); + method!(peer_not notify_prompt_list_changed PromptListChangedNotification); +} + +// ============================================================================= +// ELICITATION CONVENIENCE METHODS +// These methods are specific to server role and provide typed elicitation functionality +// ============================================================================= + +/// Errors that can occur during typed elicitation operations +#[cfg(feature = "elicitation")] +#[derive(Error, Debug)] +pub enum ElicitationError { + /// The elicitation request failed at the service level + #[error("Service error: {0}")] + Service(#[from] ServiceError), + + /// User explicitly declined to provide the requested information + /// This indicates a conscious decision by the user to reject the request + /// (e.g., clicked "Reject", "Decline", "No", etc.) + #[error("User explicitly declined the request")] + UserDeclined, + + /// User dismissed the request without making an explicit choice + /// This indicates the user cancelled without explicitly declining + /// (e.g., closed dialog, clicked outside, pressed Escape, etc.) + #[error("User cancelled/dismissed the request")] + UserCancelled, + + /// The response data could not be parsed into the requested type + #[error("Failed to parse response data: {error}\nReceived data: {data}")] + ParseError { + error: serde_json::Error, + data: serde_json::Value, + }, + + /// No response content was provided by the user + #[error("No response content provided")] + NoContent, + + /// Client does not support elicitation capability + #[error("Client does not support elicitation - capability not declared during initialization")] + CapabilityNotSupported, +} + +/// Marker trait to ensure that elicitation types generate object-type JSON schemas. +/// +/// This trait provides compile-time safety to ensure that types used with +/// `elicit()` methods will generate JSON schemas of type "object", which +/// aligns with MCP client expectations for structured data input. +/// +/// # Type Safety Rationale +/// +/// MCP clients typically expect JSON objects for elicitation schemas to +/// provide structured forms and validation. This trait prevents common +/// mistakes like: +/// +/// ```compile_fail +/// // These would not compile due to missing ElicitationSafe bound: +/// let name: String = server.elicit("Enter name").await?; // Primitive +/// let items: Vec = server.elicit("Enter items").await?; // Array +/// ``` +#[cfg(feature = "elicitation")] +pub trait ElicitationSafe: schemars::JsonSchema {} + +/// Macro to mark types as safe for elicitation by verifying they generate object schemas. +/// +/// This macro automatically implements the `ElicitationSafe` trait for struct types +/// that should be used with `elicit()` methods. +/// +/// # Example +/// +/// ```rust +/// use rmcp::elicit_safe; +/// use schemars::JsonSchema; +/// use serde::{Deserialize, Serialize}; +/// +/// #[derive(Serialize, Deserialize, JsonSchema)] +/// struct UserProfile { +/// name: String, +/// email: String, +/// } +/// +/// elicit_safe!(UserProfile); +/// +/// // Now safe to use in async context: +/// // let profile: UserProfile = server.elicit("Enter profile").await?; +/// ``` +#[cfg(feature = "elicitation")] +#[macro_export] +macro_rules! elicit_safe { + ($($t:ty),* $(,)?) => { + $( + impl $crate::service::ElicitationSafe for $t {} + )* + }; +} + +#[cfg(feature = "elicitation")] +impl Peer { + /// Check if the client supports elicitation capability + /// + /// Returns true if the client declared elicitation capability during initialization, + /// false otherwise. According to MCP 2025-06-18 specification, clients that support + /// elicitation MUST declare the capability during initialization. + pub fn supports_elicitation(&self) -> bool { + if let Some(client_info) = self.peer_info() { + client_info.capabilities.elicitation.is_some() + } else { + false + } + } + + /// Request typed data from the user with automatic schema generation. + /// + /// This method automatically generates the JSON schema from the Rust type using `schemars`, + /// eliminating the need to manually create schemas. The response is automatically parsed + /// into the requested type. + /// + /// **Requires the `elicitation` feature to be enabled.** + /// + /// # Type Requirements + /// The type `T` must implement: + /// - `schemars::JsonSchema` - for automatic schema generation + /// - `serde::Deserialize` - for parsing the response + /// + /// # Arguments + /// * `message` - The prompt message for the user + /// + /// # Returns + /// * `Ok(Some(data))` if user provided valid data that matches type T + /// * `Err(ElicitationError::UserDeclined)` if user explicitly declined the request + /// * `Err(ElicitationError::UserCancelled)` if user cancelled/dismissed the request + /// * `Err(ElicitationError::ParseError { .. })` if response data couldn't be parsed into type T + /// * `Err(ElicitationError::NoContent)` if no response content was provided + /// * `Err(ElicitationError::Service(_))` if the underlying service call failed + /// + /// # Example + /// + /// Add to your `Cargo.toml`: + /// ```toml + /// [dependencies] + /// rmcp = { version = "0.3", features = ["elicitation"] } + /// serde = { version = "1.0", features = ["derive"] } + /// schemars = "1.0" + /// ``` + /// + /// ```rust,no_run + /// # use rmcp::*; + /// # use rmcp::service::ElicitationError; + /// # use serde::{Deserialize, Serialize}; + /// # use schemars::JsonSchema; + /// # + /// #[derive(Debug, Serialize, Deserialize, JsonSchema)] + /// struct UserProfile { + /// #[schemars(description = "Full name")] + /// name: String, + /// #[schemars(description = "Email address")] + /// email: String, + /// #[schemars(description = "Age")] + /// age: u8, + /// } + /// + /// // Mark as safe for elicitation (generates object schema) + /// rmcp::elicit_safe!(UserProfile); + /// + /// # async fn example(peer: Peer) -> Result<(), Box> { + /// match peer.elicit::("Please enter your profile information").await { + /// Ok(Some(profile)) => { + /// println!("Name: {}, Email: {}, Age: {}", profile.name, profile.email, profile.age); + /// } + /// Ok(None) => { + /// println!("User provided no content"); + /// } + /// Err(ElicitationError::UserDeclined) => { + /// println!("User explicitly declined to provide information"); + /// // Handle explicit decline - perhaps offer alternatives + /// } + /// Err(ElicitationError::UserCancelled) => { + /// println!("User cancelled the request"); + /// // Handle cancellation - perhaps prompt again later + /// } + /// Err(ElicitationError::ParseError { error, data }) => { + /// println!("Failed to parse response: {}\nData: {}", error, data); + /// } + /// Err(e) => return Err(e.into()), + /// } + /// # Ok(()) + /// # } + /// ``` + #[cfg(all(feature = "schemars", feature = "elicitation"))] + pub async fn elicit(&self, message: impl Into) -> Result, ElicitationError> + where + T: ElicitationSafe + for<'de> serde::Deserialize<'de>, + { + self.elicit_with_timeout(message, None).await + } + + /// Request typed data from the user with custom timeout. + /// + /// Same as `elicit()` but allows specifying a custom timeout for the request. + /// If the user doesn't respond within the timeout, the request will be cancelled. + /// + /// # Arguments + /// * `message` - The prompt message for the user + /// * `timeout` - Optional timeout duration. If None, uses default timeout behavior + /// + /// # Returns + /// Same as `elicit()` but may also return `ServiceError::Timeout` if timeout expires + /// + /// # Example + /// ```rust,no_run + /// # use rmcp::*; + /// # use rmcp::service::ElicitationError; + /// # use serde::{Deserialize, Serialize}; + /// # use schemars::JsonSchema; + /// # use std::time::Duration; + /// # + /// #[derive(Debug, Serialize, Deserialize, JsonSchema)] + /// struct QuickResponse { + /// answer: String, + /// } + /// + /// // Mark as safe for elicitation + /// rmcp::elicit_safe!(QuickResponse); + /// + /// # async fn example(peer: Peer) -> Result<(), Box> { + /// // Give user 30 seconds to respond + /// let timeout = Some(Duration::from_secs(30)); + /// match peer.elicit_with_timeout::( + /// "Quick question - what's your answer?", + /// timeout + /// ).await { + /// Ok(Some(response)) => println!("Got answer: {}", response.answer), + /// Ok(None) => println!("User provided no content"), + /// Err(ElicitationError::UserDeclined) => { + /// println!("User explicitly declined"); + /// // Handle explicit decline + /// } + /// Err(ElicitationError::UserCancelled) => { + /// println!("User cancelled/dismissed"); + /// // Handle cancellation + /// } + /// Err(ElicitationError::Service(ServiceError::Timeout { .. })) => { + /// println!("User didn't respond in time"); + /// } + /// Err(e) => return Err(e.into()), + /// } + /// # Ok(()) + /// # } + /// ``` + #[cfg(all(feature = "schemars", feature = "elicitation"))] + pub async fn elicit_with_timeout( + &self, + message: impl Into, + timeout: Option, + ) -> Result, ElicitationError> + where + T: ElicitationSafe + for<'de> serde::Deserialize<'de>, + { + // Check if client supports elicitation capability + if !self.supports_elicitation() { + return Err(ElicitationError::CapabilityNotSupported); + } + + // Generate schema automatically from type + let schema = crate::model::ElicitationSchema::from_type::().map_err(|e| { + ElicitationError::Service(ServiceError::McpError(crate::ErrorData::invalid_params( + format!( + "Invalid schema for type {}: {}", + std::any::type_name::(), + e + ), + None, + ))) + })?; + + let response = self + .create_elicitation_with_timeout( + CreateElicitationRequestParam { + message: message.into(), + requested_schema: schema, + }, + timeout, + ) + .await?; + + match response.action { + crate::model::ElicitationAction::Accept => { + if let Some(value) = response.content { + match serde_json::from_value::(value.clone()) { + Ok(parsed) => Ok(Some(parsed)), + Err(error) => Err(ElicitationError::ParseError { error, data: value }), + } + } else { + Err(ElicitationError::NoContent) + } + } + crate::model::ElicitationAction::Decline => Err(ElicitationError::UserDeclined), + crate::model::ElicitationAction::Cancel => Err(ElicitationError::UserCancelled), + } + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/service/tower.rs b/code-rs/third_party/rmcp-0.8.3/src/service/tower.rs new file mode 100644 index 00000000000..ac4a66f00da --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/service/tower.rs @@ -0,0 +1,54 @@ +use std::{future::poll_fn, marker::PhantomData}; + +use tower_service::Service as TowerService; + +use super::NotificationContext; +use crate::service::{RequestContext, Service, ServiceRole}; + +pub struct TowerHandler { + pub service: S, + pub info: R::Info, + role: PhantomData, +} + +impl TowerHandler { + pub fn new(service: S, info: R::Info) -> Self { + Self { + service, + role: PhantomData, + info, + } + } +} + +impl Service for TowerHandler +where + S: TowerService + Sync + Send + Clone + 'static, + S::Error: Into, + S::Future: Send, +{ + async fn handle_request( + &self, + request: R::PeerReq, + _context: RequestContext, + ) -> Result { + let mut service = self.service.clone(); + poll_fn(|cx| service.poll_ready(cx)) + .await + .map_err(Into::into)?; + let resp = service.call(request).await.map_err(Into::into)?; + Ok(resp) + } + + fn handle_notification( + &self, + _notification: R::PeerNot, + _context: NotificationContext, + ) -> impl Future> + Send + '_ { + std::future::ready(Ok(())) + } + + fn get_info(&self) -> R::Info { + self.info.clone() + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/transport.rs b/code-rs/third_party/rmcp-0.8.3/src/transport.rs new file mode 100644 index 00000000000..81286fede97 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/transport.rs @@ -0,0 +1,280 @@ +//! # Transport +//! The transport type must implemented [`Transport`] trait, which allow it send message concurrently and receive message sequentially. +//! +//! ## Standard Transport Types +//! There are 3 pairs of standard transport types: +//! +//! | transport | client | server | +//! |:-: |:-: |:-: | +//! | std IO | [`child_process::TokioChildProcess`] | [`io::stdio`] | +//! | streamable http | [`streamable_http_client::StreamableHttpClientTransport`] | [`streamable_http_server::StreamableHttpService`] | +//! | sse | [`sse_client::SseClientTransport`] | [`sse_server::SseServer`] | +//! +//!## Helper Transport Types +//! Thers are several helper transport types that can help you to create transport quickly. +//! +//! ### [Worker Transport](`worker::WorkerTransport`) +//! Which allows you to run a worker and process messages in another tokio task. +//! +//! ### [Async Read/Write Transport](`async_rw::AsyncRwTransport`) +//! You need to enable `transport-async-rw` feature to use this transport. +//! +//! This transport is used to create a transport from a byte stream which implemented [`tokio::io::AsyncRead`] and [`tokio::io::AsyncWrite`]. +//! +//! This could be very helpful when you want to create a transport from a byte stream, such as a file or a tcp connection. +//! +//! ### [Sink/Stream Transport](`sink_stream::SinkStreamTransport`) +//! This transport is used to create a transport from a sink and a stream. +//! +//! This could be very helpful when you want to create a transport from a duplex object stream, such as a websocket connection. +//! +//! ## [IntoTransport](`IntoTransport`) trait +//! [`IntoTransport`] is a helper trait that implicitly convert a type into a transport type. +//! +//! ### These types is automatically implemented [`IntoTransport`] trait +//! 1. A type that already implement both [`futures::Sink`] and [`futures::Stream`] trait, or a tuple `(Tx, Rx)` where `Tx` is [`futures::Sink`] and `Rx` is [`futures::Stream`]. +//! 2. A type that implement both [`tokio::io::AsyncRead`] and [`tokio::io::AsyncWrite`] trait. or a tuple `(R, W)` where `R` is [`tokio::io::AsyncRead`] and `W` is [`tokio::io::AsyncWrite`]. +//! 3. A type that implement [Worker](`worker::Worker`) trait. +//! 4. A type that implement [`Transport`] trait. +//! +//! ## Examples +//! +//! ```rust +//! # use rmcp::{ +//! # ServiceExt, serve_client, serve_server, +//! # }; +//! +//! // create transport from tcp stream +//! async fn client() -> Result<(), Box> { +//! let stream = tokio::net::TcpSocket::new_v4()? +//! .connect("127.0.0.1:8001".parse()?) +//! .await?; +//! let client = ().serve(stream).await?; +//! let tools = client.peer().list_tools(Default::default()).await?; +//! println!("{:?}", tools); +//! Ok(()) +//! } +//! +//! // create transport from std io +//! async fn io() -> Result<(), Box> { +//! let client = ().serve((tokio::io::stdin(), tokio::io::stdout())).await?; +//! let tools = client.peer().list_tools(Default::default()).await?; +//! println!("{:?}", tools); +//! Ok(()) +//! } +//! ``` + +use std::{borrow::Cow, sync::Arc}; + +use crate::service::{RxJsonRpcMessage, ServiceRole, TxJsonRpcMessage}; + +pub mod sink_stream; + +#[cfg(feature = "transport-async-rw")] +#[cfg_attr(docsrs, doc(cfg(feature = "transport-async-rw")))] +pub mod async_rw; + +#[cfg(feature = "transport-worker")] +#[cfg_attr(docsrs, doc(cfg(feature = "transport-worker")))] +pub mod worker; +#[cfg(feature = "transport-worker")] +#[cfg_attr(docsrs, doc(cfg(feature = "transport-worker")))] +pub use worker::WorkerTransport; + +#[cfg(feature = "transport-child-process")] +#[cfg_attr(docsrs, doc(cfg(feature = "transport-child-process")))] +pub mod child_process; +#[cfg(feature = "transport-child-process")] +#[cfg_attr(docsrs, doc(cfg(feature = "transport-child-process")))] +pub use child_process::{ConfigureCommandExt, TokioChildProcess}; + +#[cfg(feature = "transport-io")] +#[cfg_attr(docsrs, doc(cfg(feature = "transport-io")))] +pub mod io; +#[cfg(feature = "transport-io")] +#[cfg_attr(docsrs, doc(cfg(feature = "transport-io")))] +pub use io::stdio; + +#[cfg(feature = "transport-sse-client")] +#[cfg_attr(docsrs, doc(cfg(feature = "transport-sse-client")))] +pub mod sse_client; +#[cfg(feature = "transport-sse-client")] +#[cfg_attr(docsrs, doc(cfg(feature = "transport-sse-client")))] +pub use sse_client::SseClientTransport; + +#[cfg(feature = "transport-sse-server")] +#[cfg_attr(docsrs, doc(cfg(feature = "transport-sse-server")))] +pub mod sse_server; +#[cfg(feature = "transport-sse-server")] +#[cfg_attr(docsrs, doc(cfg(feature = "transport-sse-server")))] +pub use sse_server::SseServer; + +#[cfg(feature = "auth")] +#[cfg_attr(docsrs, doc(cfg(feature = "auth")))] +pub mod auth; +#[cfg(feature = "auth")] +#[cfg_attr(docsrs, doc(cfg(feature = "auth")))] +pub use auth::{AuthError, AuthorizationManager, AuthorizationSession, AuthorizedHttpClient}; + +// #[cfg(feature = "transport-ws")] +// #[cfg_attr(docsrs, doc(cfg(feature = "transport-ws")))] +// pub mod ws; +#[cfg(feature = "transport-streamable-http-server-session")] +#[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-server-session")))] +pub mod streamable_http_server; +#[cfg(feature = "transport-streamable-http-server")] +#[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-server")))] +pub use streamable_http_server::tower::{StreamableHttpServerConfig, StreamableHttpService}; + +#[cfg(feature = "transport-streamable-http-client")] +#[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-client")))] +pub mod streamable_http_client; +#[cfg(feature = "transport-streamable-http-client")] +#[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-client")))] +pub use streamable_http_client::StreamableHttpClientTransport; + +/// Common use codes +pub mod common; + +pub trait Transport: Send +where + R: ServiceRole, +{ + type Error: std::error::Error + Send + Sync + 'static; + fn name() -> Cow<'static, str> { + std::any::type_name::().into() + } + /// Send a message to the transport + /// + /// Notice that the future returned by this function should be `Send` and `'static`. + /// It's because the sending message could be executed concurrently. + /// + fn send( + &mut self, + item: TxJsonRpcMessage, + ) -> impl Future> + Send + 'static; + + /// Receive a message from the transport, this operation is sequential. + fn receive(&mut self) -> impl Future>> + Send; + + /// Close the transport + fn close(&mut self) -> impl Future> + Send; +} + +pub trait IntoTransport: Send + 'static +where + R: ServiceRole, + E: std::error::Error + Send + 'static, +{ + fn into_transport(self) -> impl Transport + 'static; +} + +pub enum TransportAdapterIdentity {} +impl IntoTransport for T +where + T: Transport + Send + 'static, + R: ServiceRole, + E: std::error::Error + Send + Sync + 'static, +{ + fn into_transport(self) -> impl Transport + 'static { + self + } +} + +/// A transport that can send a single message and then close itself +pub struct OneshotTransport +where + R: ServiceRole, +{ + message: Option>, + sender: tokio::sync::mpsc::Sender>, + finished_signal: Arc, +} + +impl OneshotTransport +where + R: ServiceRole, +{ + pub fn new( + message: RxJsonRpcMessage, + ) -> (Self, tokio::sync::mpsc::Receiver>) { + let (sender, receiver) = tokio::sync::mpsc::channel(16); + ( + Self { + message: Some(message), + sender, + finished_signal: Arc::new(tokio::sync::Notify::new()), + }, + receiver, + ) + } +} + +impl Transport for OneshotTransport +where + R: ServiceRole, +{ + type Error = tokio::sync::mpsc::error::SendError>; + + fn send( + &mut self, + item: TxJsonRpcMessage, + ) -> impl Future> + Send + 'static { + let sender = self.sender.clone(); + let terminate = matches!(item, TxJsonRpcMessage::::Response(_)) + || matches!(item, TxJsonRpcMessage::::Error(_)); + let signal = self.finished_signal.clone(); + async move { + sender.send(item).await?; + if terminate { + signal.notify_waiters(); + } + Ok(()) + } + } + + async fn receive(&mut self) -> Option> { + if self.message.is_none() { + self.finished_signal.notified().await; + } + self.message.take() + } + + fn close(&mut self) -> impl Future> + Send { + self.message.take(); + std::future::ready(Ok(())) + } +} + +#[derive(Debug, thiserror::Error)] +#[error("Transport [{transport_name}] error: {error}")] +pub struct DynamicTransportError { + pub transport_name: Cow<'static, str>, + pub transport_type_id: std::any::TypeId, + #[source] + pub error: Box, +} + +impl DynamicTransportError { + pub fn new + 'static, R: ServiceRole>(e: T::Error) -> Self { + Self { + transport_name: T::name(), + transport_type_id: std::any::TypeId::of::(), + error: Box::new(e), + } + } + pub fn downcast + 'static, R: ServiceRole>(self) -> Result { + if !self.is::() { + Err(self) + } else { + Ok(self + .error + .downcast::() + .map(|e| *e) + .expect("type is checked")) + } + } + pub fn is + 'static, R: ServiceRole>(&self) -> bool { + self.error.is::() && self.transport_type_id == std::any::TypeId::of::() + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/transport/async_rw.rs b/code-rs/third_party/rmcp-0.8.3/src/transport/async_rw.rs new file mode 100644 index 00000000000..acd1b4f65b5 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/transport/async_rw.rs @@ -0,0 +1,557 @@ +use std::{marker::PhantomData, sync::Arc}; + +// use crate::schema::*; +use futures::{SinkExt, StreamExt}; +use serde::{Serialize, de::DeserializeOwned}; +use thiserror::Error; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::Mutex, +}; +use tokio_util::{ + bytes::{Buf, BufMut, BytesMut}, + codec::{Decoder, Encoder, FramedRead, FramedWrite}, +}; + +use super::{IntoTransport, Transport}; +use crate::service::{RxJsonRpcMessage, ServiceRole, TxJsonRpcMessage}; + +pub enum TransportAdapterAsyncRW {} + +impl IntoTransport for (R, W) +where + Role: ServiceRole, + R: AsyncRead + Send + 'static + Unpin, + W: AsyncWrite + Send + 'static + Unpin, +{ + fn into_transport(self) -> impl Transport + 'static { + AsyncRwTransport::new(self.0, self.1) + } +} + +pub enum TransportAdapterAsyncCombinedRW {} +impl IntoTransport for S +where + Role: ServiceRole, + S: AsyncRead + AsyncWrite + Send + 'static, +{ + fn into_transport(self) -> impl Transport + 'static { + IntoTransport::::into_transport( + tokio::io::split(self), + ) + } +} + +pub type TransportWriter = FramedWrite>>; + +pub struct AsyncRwTransport { + read: FramedRead>>, + write: Arc>>>, +} + +impl AsyncRwTransport +where + R: Send + AsyncRead + Unpin, + W: Send + AsyncWrite + Unpin + 'static, +{ + pub fn new(read: R, write: W) -> Self { + let read = FramedRead::new( + read, + JsonRpcMessageCodec::>::default(), + ); + let write = Arc::new(Mutex::new(Some(FramedWrite::new( + write, + JsonRpcMessageCodec::>::default(), + )))); + Self { read, write } + } +} + +#[cfg(feature = "client")] +#[cfg_attr(docsrs, doc(cfg(feature = "client")))] +impl AsyncRwTransport +where + R: Send + AsyncRead + Unpin, + W: Send + AsyncWrite + Unpin + 'static, +{ + pub fn new_client(read: R, write: W) -> Self { + Self::new(read, write) + } +} + +#[cfg(feature = "server")] +#[cfg_attr(docsrs, doc(cfg(feature = "server")))] +impl AsyncRwTransport +where + R: Send + AsyncRead + Unpin, + W: Send + AsyncWrite + Unpin + 'static, +{ + pub fn new_server(read: R, write: W) -> Self { + Self::new(read, write) + } +} + +impl Transport for AsyncRwTransport +where + R: Send + AsyncRead + Unpin, + W: Send + AsyncWrite + Unpin + 'static, +{ + type Error = std::io::Error; + + fn send( + &mut self, + item: TxJsonRpcMessage, + ) -> impl Future> + Send + 'static { + let lock = self.write.clone(); + async move { + let mut write = lock.lock().await; + if let Some(ref mut write) = *write { + write.send(item).await.map_err(Into::into) + } else { + Err(std::io::Error::new( + std::io::ErrorKind::NotConnected, + "Transport is closed", + )) + } + } + } + + fn receive(&mut self) -> impl Future>> { + let next = self.read.next(); + async { + next.await.and_then(|e| { + e.inspect_err(|e| { + tracing::error!("Error reading from stream: {}", e); + }) + .ok() + }) + } + } + + async fn close(&mut self) -> Result<(), Self::Error> { + let mut write = self.write.lock().await; + drop(write.take()); + Ok(()) + } +} + +#[derive(Debug, Clone)] +pub struct JsonRpcMessageCodec { + _marker: PhantomData T>, + next_index: usize, + max_length: usize, + is_discarding: bool, +} + +impl Default for JsonRpcMessageCodec { + fn default() -> Self { + Self::new() + } +} + +impl JsonRpcMessageCodec { + pub fn new() -> Self { + Self { + _marker: PhantomData, + next_index: 0, + max_length: usize::MAX, + is_discarding: false, + } + } + + pub fn new_with_max_length(max_length: usize) -> Self { + Self { + max_length, + ..Self::new() + } + } + + pub fn max_length(&self) -> usize { + self.max_length + } +} + +fn without_carriage_return(s: &[u8]) -> &[u8] { + if let Some(&b'\r') = s.last() { + &s[..s.len() - 1] + } else { + s + } +} + +/// Check if a method is a standard MCP method (request, response, or notification). +/// This includes both requests and notifications defined in the MCP specification. +/// +/// Based on MCP specification 2025-06-18: https://modelcontextprotocol.io/specification/2025-06-18 +fn is_standard_method(method: &str) -> bool { + matches!( + method, + "initialize" + | "ping" + | "prompts/get" + | "prompts/list" + | "resources/list" + | "resources/read" + | "resources/subscribe" + | "resources/unsubscribe" + | "resources/templates/list" + | "tools/call" + | "tools/list" + | "completion/complete" + | "logging/setLevel" + | "roots/list" + | "sampling/createMessage" + ) || is_standard_notification(method) +} + +fn is_standard_notification(method: &str) -> bool { + matches!( + method, + "notifications/cancelled" + | "notifications/initialized" + | "notifications/message" + | "notifications/progress" + | "notifications/prompts/list_changed" + | "notifications/resources/list_changed" + | "notifications/resources/updated" + | "notifications/roots/list_changed" + | "notifications/tools/list_changed" + ) +} + +/// Determines if a notification should be ignored for compatibility. +fn should_ignore_notification(json_value: &serde_json::Value, method: &str) -> bool { + let is_notification = json_value.get("id").is_none(); + + // Ignore non-MCP notifications (like LSP messages) for compatibility + if is_notification && !is_standard_method(method) { + tracing::trace!( + "Ignoring non-MCP notification '{}' for compatibility", + method + ); + return true; + } + + // Ignore non-standard MCP notifications + matches!( + ( + method.starts_with("notifications/"), + is_standard_notification(method) + ), + (true, false) + ) +} + +/// Try to parse a message with compatibility handling for non-standard notifications +fn try_parse_with_compatibility( + line: &[u8], + context: &str, +) -> Result, JsonRpcMessageCodecError> { + if let Ok(line_str) = std::str::from_utf8(line) { + match serde_json::from_slice(line) { + Ok(item) => Ok(Some(item)), + Err(e) => { + // Check if this is a notification that should be ignored for compatibility + if let Ok(json_value) = serde_json::from_str::(line_str) { + if let Some(method) = + json_value.get("method").and_then(serde_json::Value::as_str) + { + if should_ignore_notification(&json_value, method) { + return Ok(None); + } + } + } + + tracing::debug!( + "Failed to parse message {}: {} | Error: {}", + context, + line_str, + e + ); + Err(JsonRpcMessageCodecError::Serde(e)) + } + } + } else { + serde_json::from_slice(line) + .map(Some) + .map_err(JsonRpcMessageCodecError::Serde) + } +} + +#[derive(Debug, Error)] +pub enum JsonRpcMessageCodecError { + #[error("max line length exceeded")] + MaxLineLengthExceeded, + #[error("serde error {0}")] + Serde(#[from] serde_json::Error), + #[error("io error {0}")] + Io(#[from] std::io::Error), +} + +impl From for std::io::Error { + fn from(value: JsonRpcMessageCodecError) -> Self { + match value { + JsonRpcMessageCodecError::MaxLineLengthExceeded => { + std::io::Error::new(std::io::ErrorKind::InvalidData, value) + } + JsonRpcMessageCodecError::Serde(e) => e.into(), + JsonRpcMessageCodecError::Io(e) => e, + } + } +} + +impl Decoder for JsonRpcMessageCodec { + type Item = T; + + type Error = JsonRpcMessageCodecError; + + fn decode( + &mut self, + buf: &mut BytesMut, + ) -> Result, JsonRpcMessageCodecError> { + loop { + // Determine how far into the buffer we'll search for a newline. If + // there's no max_length set, we'll read to the end of the buffer. + let read_to = std::cmp::min(self.max_length.saturating_add(1), buf.len()); + + let newline_offset = buf[self.next_index..read_to] + .iter() + .position(|b| *b == b'\n'); + + match (self.is_discarding, newline_offset) { + (true, Some(offset)) => { + // If we found a newline, discard up to that offset and + // then stop discarding. On the next iteration, we'll try + // to read a line normally. + buf.advance(offset + self.next_index + 1); + self.is_discarding = false; + self.next_index = 0; + } + (true, None) => { + // Otherwise, we didn't find a newline, so we'll discard + // everything we read. On the next iteration, we'll continue + // discarding up to max_len bytes unless we find a newline. + buf.advance(read_to); + self.next_index = 0; + if buf.is_empty() { + return Ok(None); + } + } + (false, Some(offset)) => { + // Found a line! + let newline_index = offset + self.next_index; + self.next_index = 0; + let line = buf.split_to(newline_index + 1); + let line = &line[..line.len() - 1]; + let line = without_carriage_return(line); + + // Use compatibility handling function + let item = match try_parse_with_compatibility(line, "decode")? { + Some(item) => item, + None => return Ok(None), // Skip non-standard message + }; + return Ok(Some(item)); + } + (false, None) if buf.len() > self.max_length => { + // Reached the maximum length without finding a + // newline, return an error and start discarding on the + // next call. + self.is_discarding = true; + return Err(JsonRpcMessageCodecError::MaxLineLengthExceeded); + } + (false, None) => { + // We didn't find a line or reach the length limit, so the next + // call will resume searching at the current offset. + self.next_index = read_to; + return Ok(None); + } + } + } + } + + fn decode_eof(&mut self, buf: &mut BytesMut) -> Result, JsonRpcMessageCodecError> { + Ok(match self.decode(buf)? { + Some(frame) => Some(frame), + None => { + self.next_index = 0; + // No terminating newline - return remaining data, if any + if buf.is_empty() || buf == &b"\r"[..] { + None + } else { + let line = buf.split_to(buf.len()); + let line = without_carriage_return(&line); + + // Use compatibility handling function + let item = match try_parse_with_compatibility(line, "decode_eof")? { + Some(item) => item, + None => return Ok(None), // Skip non-standard message + }; + Some(item) + } + } + }) + } +} + +impl Encoder for JsonRpcMessageCodec { + type Error = JsonRpcMessageCodecError; + + fn encode(&mut self, item: T, buf: &mut BytesMut) -> Result<(), JsonRpcMessageCodecError> { + serde_json::to_writer(buf.writer(), &item)?; + buf.put_u8(b'\n'); + Ok(()) + } +} + +#[cfg(test)] +mod test { + use futures::{Sink, Stream}; + + use super::*; + fn from_async_read(reader: R) -> impl Stream { + FramedRead::new(reader, JsonRpcMessageCodec::::default()).filter_map(|result| { + if let Err(e) = &result { + tracing::error!("Error reading from stream: {}", e); + } + futures::future::ready(result.ok()) + }) + } + + fn from_async_write( + writer: W, + ) -> impl Sink { + FramedWrite::new(writer, JsonRpcMessageCodec::::default()).sink_map_err(Into::into) + } + #[tokio::test] + async fn test_decode() { + use futures::StreamExt; + use tokio::io::BufReader; + + let data = r#"{"jsonrpc":"2.0","method":"subtract","params":[42,23],"id":1} + {"jsonrpc":"2.0","method":"subtract","params":[23,42],"id":2} + {"jsonrpc":"2.0","method":"subtract","params":[42,23],"id":3} + {"jsonrpc":"2.0","method":"subtract","params":[23,42],"id":4} + {"jsonrpc":"2.0","method":"subtract","params":[42,23],"id":5} + {"jsonrpc":"2.0","method":"subtract","params":[23,42],"id":6} + {"jsonrpc":"2.0","method":"subtract","params":[42,23],"id":7} + {"jsonrpc":"2.0","method":"subtract","params":[23,42],"id":8} + {"jsonrpc":"2.0","method":"subtract","params":[42,23],"id":9} + {"jsonrpc":"2.0","method":"subtract","params":[23,42],"id":10} + + "#; + + let mut cursor = BufReader::new(data.as_bytes()); + let mut stream = from_async_read::(&mut cursor); + + for i in 1..=10 { + let item = stream.next().await.unwrap(); + assert_eq!( + item, + serde_json::json!({ + "jsonrpc": "2.0", + "method": "subtract", + "params": if i % 2 != 0 { [42, 23] } else { [23, 42] }, + "id": i, + }) + ); + } + } + + #[tokio::test] + async fn test_encode() { + let test_messages = vec![ + serde_json::json!({ + "jsonrpc": "2.0", + "method": "subtract", + "params": [42, 23], + "id": 1, + }), + serde_json::json!({ + "jsonrpc": "2.0", + "method": "subtract", + "params": [23, 42], + "id": 2, + }), + ]; + + // Create a buffer to write to + let mut buffer = Vec::new(); + let mut writer = from_async_write(&mut buffer); + + // Write the test messages + for message in test_messages.iter() { + writer.send(message.clone()).await.unwrap(); + } + writer.close().await.unwrap(); + drop(writer); + // Parse the buffer back into lines and check each one + let output = String::from_utf8_lossy(&buffer); + let mut lines = output.lines(); + + for expected_message in test_messages { + let line = lines.next().unwrap(); + let parsed_message: serde_json::Value = serde_json::from_str(line).unwrap(); + assert_eq!(parsed_message, expected_message); + } + + // Make sure there are no extra lines + assert!(lines.next().is_none()); + } + + #[test] + fn test_standard_notification_check() { + // Test that all standard notifications are recognized + assert!(is_standard_notification("notifications/cancelled")); + assert!(is_standard_notification("notifications/initialized")); + assert!(is_standard_notification("notifications/progress")); + assert!(is_standard_notification( + "notifications/resources/list_changed" + )); + assert!(is_standard_notification("notifications/resources/updated")); + assert!(is_standard_notification( + "notifications/prompts/list_changed" + )); + assert!(is_standard_notification("notifications/tools/list_changed")); + assert!(is_standard_notification("notifications/message")); + assert!(is_standard_notification("notifications/roots/list_changed")); + + // Test that non-standard notifications are not recognized + assert!(!is_standard_notification("notifications/stderr")); + assert!(!is_standard_notification("notifications/custom")); + assert!(!is_standard_notification("notifications/debug")); + assert!(!is_standard_notification("some/other/method")); + } + + #[test] + fn test_compatibility_function() { + // Test the compatibility function directly + let stderr_message = + r#"{"method":"notifications/stderr","params":{"content":"stderr message"}}"#; + let custom_message = r#"{"method":"notifications/custom","params":{"data":"custom"}}"#; + let standard_message = + r#"{"method":"notifications/message","params":{"level":"info","data":"standard"}}"#; + let progress_message = r#"{"method":"notifications/progress","params":{"progressToken":"token","progress":50}}"#; + + // Test with valid JSON - all should parse successfully + let result1 = + try_parse_with_compatibility::(stderr_message.as_bytes(), "test"); + let result2 = + try_parse_with_compatibility::(custom_message.as_bytes(), "test"); + let result3 = + try_parse_with_compatibility::(standard_message.as_bytes(), "test"); + let result4 = + try_parse_with_compatibility::(progress_message.as_bytes(), "test"); + + // All should parse successfully since they're valid JSON + assert!(result1.is_ok()); + assert!(result2.is_ok()); + assert!(result3.is_ok()); + assert!(result4.is_ok()); + + // Standard notifications should return Some(value) + assert!(result3.unwrap().is_some()); + assert!(result4.unwrap().is_some()); + + println!("Standard notifications are preserved, non-standard are handled gracefully"); + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/transport/auth.rs b/code-rs/third_party/rmcp-0.8.3/src/transport/auth.rs new file mode 100644 index 00000000000..3b1aca34c92 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/transport/auth.rs @@ -0,0 +1,1280 @@ +use std::{ + collections::HashMap, + sync::Arc, + time::{Duration, Instant}, +}; + +use oauth2::{ + AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields, + PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError, Scope, + StandardTokenResponse, TokenResponse, TokenUrl, + basic::{BasicClient, BasicTokenType}, +}; +use reqwest::{ + Client as HttpClient, IntoUrl, StatusCode, Url, + header::{AUTHORIZATION, WWW_AUTHENTICATE}, +}; +use serde::{Deserialize, Serialize}; +use thiserror::Error; +use tokio::sync::{Mutex, RwLock}; +use tracing::{debug, error, warn}; + +const DEFAULT_EXCHANGE_URL: &str = "http://localhost"; + +/// sse client with oauth2 authorization +#[derive(Clone)] +pub struct AuthClient { + pub http_client: C, + pub auth_manager: Arc>, +} + +impl std::fmt::Debug for AuthClient { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AuthorizedClient") + .field("http_client", &self.http_client) + .field("auth_manager", &"...") + .finish() + } +} + +impl AuthClient { + /// create new authorized sse client + pub fn new(http_client: C, auth_manager: AuthorizationManager) -> Self { + Self { + http_client, + auth_manager: Arc::new(Mutex::new(auth_manager)), + } + } +} + +impl AuthClient { + pub fn get_access_token(&self) -> impl Future> + Send { + let auth_manager = self.auth_manager.clone(); + async move { auth_manager.lock().await.get_access_token().await } + } +} + +/// Auth error +#[derive(Debug, Error)] +pub enum AuthError { + #[error("OAuth authorization required")] + AuthorizationRequired, + + #[error("OAuth authorization failed: {0}")] + AuthorizationFailed(String), + + #[error("OAuth token exchange failed: {0}")] + TokenExchangeFailed(String), + + #[error("OAuth token refresh failed: {0}")] + TokenRefreshFailed(String), + + #[error("HTTP error: {0}")] + HttpError(#[from] reqwest::Error), + + #[error("OAuth error: {0}")] + OAuthError(String), + + #[error("Metadata error: {0}")] + MetadataError(String), + + #[error("URL parse error: {0}")] + UrlError(#[from] url::ParseError), + + #[error("No authorization support detected")] + NoAuthorizationSupport, + + #[error("Internal error: {0}")] + InternalError(String), + + #[error("Invalid token type: {0}")] + InvalidTokenType(String), + + #[error("Token expired")] + TokenExpired, + + #[error("Invalid scope: {0}")] + InvalidScope(String), + + #[error("Registration failed: {0}")] + RegistrationFailed(String), +} + +/// oauth2 metadata +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct AuthorizationMetadata { + pub authorization_endpoint: String, + pub token_endpoint: String, + pub registration_endpoint: Option, + pub issuer: Option, + pub jwks_uri: Option, + pub scopes_supported: Option>, + // allow additional fields + #[serde(flatten)] + pub additional_fields: HashMap, +} + +#[derive(Debug, Clone, Deserialize)] +struct ResourceServerMetadata { + authorization_server: Option, + authorization_servers: Option>, +} + +/// oauth2 client config +#[derive(Debug, Clone)] +pub struct OAuthClientConfig { + pub client_id: String, + pub client_secret: Option, + pub scopes: Vec, + pub redirect_uri: String, +} + +// add type aliases for oauth2 types +type OAuthErrorResponse = oauth2::StandardErrorResponse; +pub type OAuthTokenResponse = StandardTokenResponse; +type OAuthTokenIntrospection = + oauth2::StandardTokenIntrospectionResponse; +type OAuthRevocableToken = oauth2::StandardRevocableToken; +type OAuthRevocationError = oauth2::StandardErrorResponse; +type OAuthClient = oauth2::Client< + OAuthErrorResponse, + OAuthTokenResponse, + OAuthTokenIntrospection, + OAuthRevocableToken, + OAuthRevocationError, + oauth2::EndpointSet, + oauth2::EndpointNotSet, + oauth2::EndpointNotSet, + oauth2::EndpointNotSet, + oauth2::EndpointSet, +>; +type Credentials = (String, Option); + +/// oauth2 auth manager +pub struct AuthorizationManager { + http_client: HttpClient, + metadata: Option, + oauth_client: Option, + credentials: RwLock>, + state: RwLock>, + expires_at: RwLock>, + base_url: Url, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClientRegistrationRequest { + pub client_name: String, + pub redirect_uris: Vec, + pub grant_types: Vec, + pub token_endpoint_auth_method: String, + pub response_types: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClientRegistrationResponse { + pub client_id: String, + pub client_secret: Option, + pub client_name: Option, + pub redirect_uris: Vec, + // allow additional fields + #[serde(flatten)] + pub additional_fields: HashMap, +} + +#[derive(Debug)] +struct AuthorizationState { + pkce_verifier: PkceCodeVerifier, + csrf_token: CsrfToken, +} + +impl AuthorizationManager { + fn well_known_paths(base_path: &str, resource: &str) -> Vec { + let trimmed = base_path.trim_start_matches('/').trim_end_matches('/'); + let mut candidates = Vec::new(); + + let mut push_candidate = |candidate: String| { + if !candidates.contains(&candidate) { + candidates.push(candidate); + } + }; + + let canonical = format!("/.well-known/{resource}"); + + if trimmed.is_empty() { + push_candidate(canonical); + return candidates; + } + + // This follows the RFC 8414 recommendation for well-known URI discovery + push_candidate(format!("{canonical}/{trimmed}")); + // This is a common pattern used by some identity providers + push_candidate(format!("/{trimmed}/.well-known/{resource}")); + // The canonical path should always be the last fallback + push_candidate(canonical); + + candidates + } + + /// create new auth manager with base url + pub async fn new(base_url: U) -> Result { + let base_url = base_url.into_url()?; + let http_client = HttpClient::builder() + .timeout(Duration::from_secs(30)) + .build() + .map_err(|e| AuthError::InternalError(e.to_string()))?; + + let manager = Self { + http_client, + metadata: None, + oauth_client: None, + credentials: RwLock::new(None), + state: RwLock::new(None), + expires_at: RwLock::new(None), + base_url, + }; + + Ok(manager) + } + + pub fn with_client(&mut self, http_client: HttpClient) -> Result<(), AuthError> { + self.http_client = http_client; + Ok(()) + } + + /// discover oauth2 metadata + pub async fn discover_metadata(&self) -> Result { + if let Some(metadata) = self.try_discover_oauth_server(&self.base_url).await? { + return Ok(metadata); + } + + if let Some(metadata) = self.discover_oauth_server_via_resource_metadata().await? { + return Ok(metadata); + } + + warn!("No valid authorization metadata found, falling back to default endpoints"); + + // fallback to default endpoints + let mut auth_base = self.base_url.clone(); + // discard the path part, only keep scheme, host, port + auth_base.set_path(""); + + // Helper function to create endpoint URL + let create_endpoint = |path: &str| -> String { + let mut url = auth_base.clone(); + url.set_path(path); + url.to_string() + }; + + Ok(AuthorizationMetadata { + authorization_endpoint: create_endpoint("authorize"), + token_endpoint: create_endpoint("token"), + registration_endpoint: None, + issuer: None, + jwks_uri: None, + scopes_supported: None, + additional_fields: HashMap::new(), + }) + } + + /// get client id and credentials + pub async fn get_credentials(&self) -> Result { + let credentials = self.credentials.read().await; + let client_id = self + .oauth_client + .as_ref() + .ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))? + .client_id(); + Ok((client_id.to_string(), credentials.clone())) + } + + /// configure oauth2 client with client credentials + pub fn configure_client(&mut self, config: OAuthClientConfig) -> Result<(), AuthError> { + if self.metadata.is_none() { + return Err(AuthError::NoAuthorizationSupport); + } + + let metadata = self.metadata.as_ref().unwrap(); + + let auth_url = AuthUrl::new(metadata.authorization_endpoint.clone()) + .map_err(|e| AuthError::OAuthError(format!("Invalid authorization URL: {}", e)))?; + + let token_url = TokenUrl::new(metadata.token_endpoint.clone()) + .map_err(|e| AuthError::OAuthError(format!("Invalid token URL: {}", e)))?; + + let client_id = ClientId::new(config.client_id); + let redirect_url = RedirectUrl::new(config.redirect_uri.clone()) + .map_err(|e| AuthError::OAuthError(format!("Invalid re URL: {}", e)))?; + + let mut client_builder = BasicClient::new(client_id.clone()) + .set_auth_uri(auth_url) + .set_token_uri(token_url) + .set_redirect_uri(redirect_url); + + if let Some(secret) = config.client_secret { + client_builder = client_builder.set_client_secret(ClientSecret::new(secret)); + } + + self.oauth_client = Some(client_builder); + Ok(()) + } + + /// dynamic register oauth2 client + pub async fn register_client( + &mut self, + name: &str, + redirect_uri: &str, + ) -> Result { + if self.metadata.is_none() { + return Err(AuthError::NoAuthorizationSupport); + } + + let metadata = self.metadata.as_ref().unwrap(); + let Some(registration_url) = metadata.registration_endpoint.as_ref() else { + return Err(AuthError::RegistrationFailed( + "Dynamic client registration not supported".to_string(), + )); + }; + + // prepare registration request + let registration_request = ClientRegistrationRequest { + client_name: name.to_string(), + redirect_uris: vec![redirect_uri.to_string()], + grant_types: vec![ + "authorization_code".to_string(), + "refresh_token".to_string(), + ], + token_endpoint_auth_method: "none".to_string(), // public client + response_types: vec!["code".to_string()], + }; + + let response = match self + .http_client + .post(registration_url) + .json(®istration_request) + .send() + .await + { + Ok(response) => response, + Err(e) => { + return Err(AuthError::RegistrationFailed(format!( + "HTTP request error: {}", + e + ))); + } + }; + + if !response.status().is_success() { + let status = response.status(); + let error_text = match response.text().await { + Ok(text) => text, + Err(_) => "cannot get error details".to_string(), + }; + + return Err(AuthError::RegistrationFailed(format!( + "HTTP {}: {}", + status, error_text + ))); + } + + debug!("registration response: {:?}", response); + let reg_response = match response.json::().await { + Ok(response) => response, + Err(e) => { + return Err(AuthError::RegistrationFailed(format!( + "analyze response error: {}", + e + ))); + } + }; + + let config = OAuthClientConfig { + client_id: reg_response.client_id, + // Some IdP returns a response where the field 'client_secret' is present but with empty string value. + // In that case, the interpretation is that the client is a public client and does not have a secret during the + // registration phase here, e.g. dynamic client registrations. + // + // Even though whether or not the empty string is valid is outside of the scope of Oauth2 spec, + // we should treat it as no secret since otherwise we end up authenticating with a valid client_id with an empty client_secret + // as a password, which is not a goal of the client secret. + client_secret: reg_response.client_secret.filter(|s| !s.is_empty()), + redirect_uri: redirect_uri.to_string(), + scopes: vec![], + }; + + self.configure_client(config.clone())?; + Ok(config) + } + + /// use provided client id to configure oauth2 client instead of dynamic registration + /// this is useful when you have a stored client id from previous registration + pub fn configure_client_id(&mut self, client_id: &str) -> Result<(), AuthError> { + let config = OAuthClientConfig { + client_id: client_id.to_string(), + client_secret: None, + scopes: vec![], + redirect_uri: self.base_url.to_string(), + }; + self.configure_client(config) + } + + /// generate authorization url + pub async fn get_authorization_url(&self, scopes: &[&str]) -> Result { + let oauth_client = self + .oauth_client + .as_ref() + .ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?; + + // generate pkce challenge + let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + + // build authorization request + let mut auth_request = oauth_client + .authorize_url(CsrfToken::new_random) + .set_pkce_challenge(pkce_challenge); + + // add request scopes + for scope in scopes { + auth_request = auth_request.add_scope(Scope::new(scope.to_string())); + } + + let (auth_url, csrf_token) = auth_request.url(); + + // store pkce verifier for later use + *self.state.write().await = Some(AuthorizationState { + pkce_verifier, + csrf_token, + }); + + Ok(auth_url.to_string()) + } + + /// exchange authorization code for access token + pub async fn exchange_code_for_token( + &self, + code: &str, + csrf_token: &str, + ) -> Result, AuthError> { + debug!("start exchange code for token: {:?}", code); + let oauth_client = self + .oauth_client + .as_ref() + .ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?; + + let AuthorizationState { + pkce_verifier, + csrf_token: expected_csrf_token, + } = + self.state.write().await.take().ok_or_else(|| { + AuthError::InternalError("Authorization state not found".to_string()) + })?; + + if csrf_token != expected_csrf_token.secret() { + return Err(AuthError::InternalError("CSRF token mismatch".to_string())); + } + + let http_client = reqwest::ClientBuilder::new() + .redirect(reqwest::redirect::Policy::none()) + .build() + .map_err(|e| AuthError::InternalError(e.to_string()))?; + debug!("client_id: {:?}", oauth_client.client_id()); + + // exchange token + let token_result = match oauth_client + .exchange_code(AuthorizationCode::new(code.to_string())) + .set_pkce_verifier(pkce_verifier) + .request_async(&http_client) + .await + { + Ok(token) => token, + Err(RequestTokenError::Parse(_, body)) => { + match serde_json::from_slice::(&body) { + Ok(parsed) => { + warn!( + "token exchange failed to parse completely but included a valid token response. Accepting it." + ); + parsed + } + Err(parse_err) => { + return Err(AuthError::TokenExchangeFailed(parse_err.to_string())); + } + } + } + Err(e) => { + return Err(AuthError::TokenExchangeFailed(e.to_string())); + } + }; + + // get expires_in from token response + let expires_in = token_result.expires_in(); + if let Some(expires_in) = expires_in { + let expires_at = Instant::now() + expires_in; + *self.expires_at.write().await = Some(expires_at); + } + debug!("exchange token result: {:?}", token_result); + // store credentials + *self.credentials.write().await = Some(token_result.clone()); + + Ok(token_result) + } + + /// get access token, if expired, refresh it automatically + pub async fn get_access_token(&self) -> Result { + let credentials = self.credentials.read().await; + + if let Some(creds) = credentials.as_ref() { + // check if the token is expire + if let Some(expires_at) = *self.expires_at.read().await { + if expires_at < Instant::now() { + // token expired, try to refresh , release the lock + drop(credentials); + let new_creds = self.refresh_token().await?; + return Ok(new_creds.access_token().secret().to_string()); + } + } + + Ok(creds.access_token().secret().to_string()) + } else { + Err(AuthError::AuthorizationRequired) + } + } + + /// refresh access token + pub async fn refresh_token( + &self, + ) -> Result, AuthError> { + let oauth_client = self + .oauth_client + .as_ref() + .ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?; + + let current_credentials = self + .credentials + .read() + .await + .clone() + .ok_or_else(|| AuthError::AuthorizationRequired)?; + + let refresh_token = current_credentials.refresh_token().ok_or_else(|| { + AuthError::TokenRefreshFailed("No refresh token available".to_string()) + })?; + debug!("refresh token: {:?}", refresh_token); + // refresh token + let token_result = oauth_client + .exchange_refresh_token(&RefreshToken::new(refresh_token.secret().to_string())) + .request_async(&self.http_client) + .await + .map_err(|e| AuthError::TokenRefreshFailed(e.to_string()))?; + + // store new credentials + *self.credentials.write().await = Some(token_result.clone()); + + // get expires_in from token response + let expires_in = token_result.expires_in(); + if let Some(expires_in) = expires_in { + let expires_at = Instant::now() + expires_in; + *self.expires_at.write().await = Some(expires_at); + } + Ok(token_result) + } + + /// prepare request, add authorization header + pub async fn prepare_request( + &self, + request: reqwest::RequestBuilder, + ) -> Result { + let token = self.get_access_token().await?; + Ok(request.header(AUTHORIZATION, format!("Bearer {}", token))) + } + + /// handle response, check if need to re-authorize + pub async fn handle_response( + &self, + response: reqwest::Response, + ) -> Result { + if response.status() == StatusCode::UNAUTHORIZED { + // 401 Unauthorized, need to re-authorize + Err(AuthError::AuthorizationRequired) + } else { + Ok(response) + } + } + + async fn try_discover_oauth_server( + &self, + base_url: &Url, + ) -> Result, AuthError> { + let query = base_url.query().map(|q| q.to_string()); + for candidate_path in Self::well_known_paths(base_url.path(), "oauth-authorization-server") + { + let mut discovery_url = base_url.clone(); + discovery_url.set_fragment(None); + discovery_url.set_path(&candidate_path); + if let Some(query) = query.as_deref() { + discovery_url.set_query(Some(query)); + } else { + discovery_url.set_query(None); + } + if let Some(metadata) = self.fetch_authorization_metadata(&discovery_url).await? { + return Ok(Some(metadata)); + } + } + Ok(None) + } + + async fn fetch_authorization_metadata( + &self, + discovery_url: &Url, + ) -> Result, AuthError> { + debug!("discovery url: {:?}", discovery_url); + let response = match self + .http_client + .get(discovery_url.clone()) + .header("MCP-Protocol-Version", "2024-11-05") + .send() + .await + { + Ok(r) => r, + Err(e) => { + debug!("discovery request failed: {}", e); + return Ok(None); + } + }; + + if response.status() != StatusCode::OK { + debug!("discovery returned non-200: {}", response.status()); + return Ok(None); + } + + let metadata = response + .json::() + .await + .map_err(|e| AuthError::MetadataError(format!("Failed to parse metadata: {}", e)))?; + debug!("metadata: {:?}", metadata); + Ok(Some(metadata)) + } + + async fn discover_oauth_server_via_resource_metadata( + &self, + ) -> Result, AuthError> { + let Some(resource_metadata_url) = self.fetch_resource_metadata_url().await? else { + return Ok(None); + }; + + let Some(resource_metadata) = self + .fetch_resource_metadata_from_url(&resource_metadata_url) + .await? + else { + return Ok(None); + }; + + let mut candidates = Vec::new(); + + if let Some(single) = resource_metadata.authorization_server { + candidates.push(single); + } + if let Some(list) = resource_metadata.authorization_servers { + candidates.extend(list); + } + + for candidate in candidates { + let candidate = candidate.trim(); + if candidate.is_empty() { + continue; + } + + let candidate_url = match Url::parse(candidate) { + Ok(url) => url, + Err(_) => match resource_metadata_url.join(candidate) { + Ok(url) => url, + Err(e) => { + debug!("Failed to resolve authorization server URL `{candidate}`: {e}"); + continue; + } + }, + }; + + if candidate_url.path().contains("/.well-known/") { + if let Some(metadata) = self.fetch_authorization_metadata(&candidate_url).await? { + return Ok(Some(metadata)); + } + continue; + } + + if let Some(metadata) = self.try_discover_oauth_server(&candidate_url).await? { + return Ok(Some(metadata)); + } + } + + Ok(None) + } + + /// Extract the resource metadata url from the WWW-Authenticate header value. + /// https://www.rfc-editor.org/rfc/rfc9728.html#name-use-of-www-authenticate-for + async fn fetch_resource_metadata_url(&self) -> Result, AuthError> { + let response = match self + .http_client + .get(self.base_url.clone()) + .header("MCP-Protocol-Version", "2024-11-05") + .send() + .await + { + Ok(r) => r, + Err(e) => { + debug!("resource metadata probe failed: {}", e); + return Ok(None); + } + }; + + if response.status() != StatusCode::UNAUTHORIZED { + debug!( + "resource metadata probe returned unexpected status: {}", + response.status() + ); + return Ok(None); + } + + let mut parsed_url = None; + for value in response.headers().get_all(WWW_AUTHENTICATE).iter() { + let Ok(value_str) = value.to_str() else { + continue; + }; + if let Some(url) = + Self::extract_resource_metadata_url_from_header(value_str, &self.base_url) + { + parsed_url = Some(url); + break; + } + } + + Ok(parsed_url) + } + + async fn fetch_resource_metadata_from_url( + &self, + resource_metadata_url: &Url, + ) -> Result, AuthError> { + debug!( + "resource metadata discovery url: {:?}", + resource_metadata_url + ); + let response = match self + .http_client + .get(resource_metadata_url.clone()) + .header("MCP-Protocol-Version", "2024-11-05") + .send() + .await + { + Ok(r) => r, + Err(e) => { + debug!("resource metadata request failed: {}", e); + return Ok(None); + } + }; + + if response.status() != StatusCode::OK { + debug!( + "resource metadata request returned non-200: {}", + response.status() + ); + return Ok(None); + } + + let metadata = response + .json::() + .await + .map_err(|e| { + AuthError::MetadataError(format!("Failed to parse resource metadata: {}", e)) + })?; + Ok(Some(metadata)) + } + + /// Extracts a url following `resource_metadata=` in a header value + fn extract_resource_metadata_url_from_header(header: &str, base_url: &Url) -> Option { + let header_lowercase = header.to_ascii_lowercase(); + let fragment_key = "resource_metadata="; + let mut search_offset = 0; + + while let Some(pos) = header_lowercase[search_offset..].find(fragment_key) { + let global_pos = search_offset + pos + fragment_key.len(); + let value_slice = &header[global_pos..]; + if let Some((value, consumed)) = Self::parse_next_header_value(value_slice) { + if let Ok(url) = Url::parse(&value) { + return Some(url); + } + if let Ok(url) = base_url.join(&value) { + return Some(url); + } + debug!("failed to parse resource metadata value `{value}` as URL"); + search_offset = global_pos + consumed; + continue; + } else { + break; + } + } + + None + } + + /// Parses an authentication parameter value from a `WWW-Authenticate` header fragment. + /// The header fragment should start with the header value after the `=` character and then + /// reads until the value ends. + /// + /// Returns the extracted value together with the number of bytes consumed from the provided + /// fragment. Quoted values support escaped characters (e.g. `\"`). The parser skips leading + /// whitespace before reading either a quoted or token value. If no well-formed value is found, + /// `None` is returned. + fn parse_next_header_value(header_fragment: &str) -> Option<(String, usize)> { + let trimmed = header_fragment.trim_start(); + let leading_ws = header_fragment.len() - trimmed.len(); + + if let Some(stripped) = trimmed.strip_prefix('"') { + let mut escaped = false; + let mut result = String::new(); + #[allow(clippy::manual_strip)] + for (idx, ch) in stripped.char_indices() { + if escaped { + result.push(ch); + escaped = false; + continue; + } + match ch { + '\\' => escaped = true, + '"' => return Some((result, leading_ws + idx + 2)), + _ => result.push(ch), + } + } + None + } else { + let end = trimmed + .find(|c: char| c == ',' || c == ';' || c.is_whitespace()) + .unwrap_or(trimmed.len()); + Some((trimmed[..end].to_string(), leading_ws + end)) + } + } +} + +/// oauth2 authorization session, for guiding user to complete the authorization process +pub struct AuthorizationSession { + pub auth_manager: AuthorizationManager, + pub auth_url: String, + pub redirect_uri: String, +} + +impl AuthorizationSession { + /// create new authorization session + pub async fn new( + mut auth_manager: AuthorizationManager, + scopes: &[&str], + redirect_uri: &str, + client_name: Option<&str>, + ) -> Result { + // Default client config + let config = OAuthClientConfig { + client_id: "mcp-client".to_string(), + client_secret: None, + scopes: scopes.iter().map(|s| s.to_string()).collect(), + redirect_uri: redirect_uri.to_string(), + }; + + // try to dynamic register client + let config = match auth_manager + .register_client(client_name.unwrap_or("MCP Client"), redirect_uri) + .await + { + Ok(config) => config, + Err(e) => { + warn!( + "Dynamic registration failed: {}, fallback to default config", + e + ); + // fallback to default config + config + } + }; + // reset client config + auth_manager.configure_client(config)?; + let auth_url = auth_manager.get_authorization_url(scopes).await?; + + Ok(Self { + auth_manager, + auth_url, + redirect_uri: redirect_uri.to_string(), + }) + } + + /// get client_id and credentials + pub async fn get_credentials(&self) -> Result { + self.auth_manager.get_credentials().await + } + + /// get authorization url + pub fn get_authorization_url(&self) -> &str { + &self.auth_url + } + + /// handle authorization code callback + pub async fn handle_callback( + &self, + code: &str, + csrf_token: &str, + ) -> Result, AuthError> { + self.auth_manager + .exchange_code_for_token(code, csrf_token) + .await + } +} + +/// http client extension, automatically add authorization header +pub struct AuthorizedHttpClient { + auth_manager: Arc, + inner_client: HttpClient, +} + +impl AuthorizedHttpClient { + /// create new authorized http client + pub fn new(auth_manager: Arc, client: Option) -> Self { + let inner_client = client.unwrap_or_default(); + Self { + auth_manager, + inner_client, + } + } + + /// send authorized request + pub async fn request( + &self, + method: reqwest::Method, + url: U, + ) -> Result { + let request = self.inner_client.request(method, url); + self.auth_manager.prepare_request(request).await + } + + /// send get request + pub async fn get(&self, url: U) -> Result { + let request = self.request(reqwest::Method::GET, url).await?; + let response = request.send().await?; + self.auth_manager.handle_response(response).await + } + + /// send post request + pub async fn post(&self, url: U) -> Result { + self.request(reqwest::Method::POST, url).await + } +} + +/// OAuth state machine +/// Use the OAuthState to manage the OAuth client is more recommend +/// But also you can use the AuthorizationManager,AuthorizationSession,AuthorizedHttpClient directly +pub enum OAuthState { + /// the AuthorizationManager + Unauthorized(AuthorizationManager), + /// the AuthorizationSession + Session(AuthorizationSession), + /// the authd AuthorizationManager + Authorized(AuthorizationManager), + /// the authd http client + AuthorizedHttpClient(AuthorizedHttpClient), +} + +impl OAuthState { + /// Create new OAuth state machine + pub async fn new( + base_url: U, + client: Option, + ) -> Result { + let mut manager = AuthorizationManager::new(base_url).await?; + if let Some(client) = client { + manager.with_client(client)?; + } + + Ok(OAuthState::Unauthorized(manager)) + } + + /// Get client_id and OAuth credentials + pub async fn get_credentials(&self) -> Result { + // return client_id and credentials + match self { + OAuthState::Unauthorized(manager) | OAuthState::Authorized(manager) => { + manager.get_credentials().await + } + OAuthState::Session(session) => session.get_credentials().await, + OAuthState::AuthorizedHttpClient(client) => client.auth_manager.get_credentials().await, + } + } + + /// Manually set credentials and move into authorized state + /// Useful if you're caching credentials externally and wish to reuse them + pub async fn set_credentials( + &mut self, + client_id: &str, + credentials: OAuthTokenResponse, + ) -> Result<(), AuthError> { + if let OAuthState::Unauthorized(manager) = self { + let mut manager = std::mem::replace( + manager, + AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?, + ); + + // write credentials + *manager.credentials.write().await = Some(credentials); + + // discover metadata + let metadata = manager.discover_metadata().await?; + manager.metadata = Some(metadata); + + // set client id and secret + manager.configure_client_id(client_id)?; + + *self = OAuthState::Authorized(manager); + Ok(()) + } else { + Err(AuthError::InternalError( + "Cannot set credentials in this state".to_string(), + )) + } + } + + /// start authorization + pub async fn start_authorization( + &mut self, + scopes: &[&str], + redirect_uri: &str, + client_name: Option<&str>, + ) -> Result<(), AuthError> { + if let OAuthState::Unauthorized(mut manager) = std::mem::replace( + self, + OAuthState::Unauthorized(AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?), + ) { + debug!("start discovery"); + let metadata = manager.discover_metadata().await?; + manager.metadata = Some(metadata); + debug!("start session"); + let session = + AuthorizationSession::new(manager, scopes, redirect_uri, client_name).await?; + *self = OAuthState::Session(session); + Ok(()) + } else { + Err(AuthError::InternalError( + "Already in session state".to_string(), + )) + } + } + + /// complete authorization + pub async fn complete_authorization(&mut self) -> Result<(), AuthError> { + if let OAuthState::Session(session) = std::mem::replace( + self, + OAuthState::Unauthorized(AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?), + ) { + *self = OAuthState::Authorized(session.auth_manager); + Ok(()) + } else { + Err(AuthError::InternalError("Not in session state".to_string())) + } + } + /// covert to authorized http client + pub async fn to_authorized_http_client(&mut self) -> Result<(), AuthError> { + if let OAuthState::Authorized(manager) = std::mem::replace( + self, + OAuthState::Authorized(AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?), + ) { + *self = OAuthState::AuthorizedHttpClient(AuthorizedHttpClient::new( + Arc::new(manager), + None, + )); + Ok(()) + } else { + Err(AuthError::InternalError( + "Not in authorized state".to_string(), + )) + } + } + /// get current authorization url + pub async fn get_authorization_url(&self) -> Result { + match self { + OAuthState::Session(session) => Ok(session.get_authorization_url().to_string()), + OAuthState::Unauthorized(_) => { + Err(AuthError::InternalError("Not in session state".to_string())) + } + OAuthState::Authorized(_) => { + Err(AuthError::InternalError("Already authorized".to_string())) + } + OAuthState::AuthorizedHttpClient(_) => { + Err(AuthError::InternalError("Already authorized".to_string())) + } + } + } + + /// handle authorization callback + pub async fn handle_callback(&mut self, code: &str, csrf_token: &str) -> Result<(), AuthError> { + match self { + OAuthState::Session(session) => { + session.handle_callback(code, csrf_token).await?; + self.complete_authorization().await + } + OAuthState::Unauthorized(_) => { + Err(AuthError::InternalError("Not in session state".to_string())) + } + OAuthState::Authorized(_) => { + Err(AuthError::InternalError("Already authorized".to_string())) + } + OAuthState::AuthorizedHttpClient(_) => { + Err(AuthError::InternalError("Already authorized".to_string())) + } + } + } + + /// get access token + pub async fn get_access_token(&self) -> Result { + match self { + OAuthState::Unauthorized(manager) => manager.get_access_token().await, + OAuthState::Session(_) => { + Err(AuthError::InternalError("Not in manager state".to_string())) + } + OAuthState::Authorized(_) => { + Err(AuthError::InternalError("Already authorized".to_string())) + } + OAuthState::AuthorizedHttpClient(_) => { + Err(AuthError::InternalError("Already authorized".to_string())) + } + } + } + + /// refresh access token + pub async fn refresh_token(&self) -> Result<(), AuthError> { + match self { + OAuthState::Unauthorized(_) => { + Err(AuthError::InternalError("Not in manager state".to_string())) + } + OAuthState::Session(_) => { + Err(AuthError::InternalError("Not in manager state".to_string())) + } + OAuthState::Authorized(manager) => { + manager.refresh_token().await?; + Ok(()) + } + OAuthState::AuthorizedHttpClient(_) => { + Err(AuthError::InternalError("Already authorized".to_string())) + } + } + } + + pub fn into_authorization_manager(self) -> Option { + match self { + OAuthState::Authorized(manager) => Some(manager), + _ => None, + } + } +} + +#[cfg(test)] +mod tests { + use url::Url; + + use super::AuthorizationManager; + + #[test] + fn parses_resource_metadata_parameter() { + let header = r#"Bearer error="invalid_request", error_description="missing token", resource_metadata="https://example.com/.well-known/oauth-protected-resource/api""#; + let base = Url::parse("https://example.com/api").unwrap(); + let parsed = AuthorizationManager::extract_resource_metadata_url_from_header(header, &base); + assert_eq!( + parsed.unwrap().as_str(), + "https://example.com/.well-known/oauth-protected-resource/api" + ); + } + + #[test] + fn parses_relative_resource_metadata_parameter() { + let header = r#"Bearer error="invalid_request", resource_metadata="/.well-known/oauth-protected-resource/api""#; + let base = Url::parse("https://example.com/api").unwrap(); + let parsed = AuthorizationManager::extract_resource_metadata_url_from_header(header, &base); + assert_eq!( + parsed.unwrap().as_str(), + "https://example.com/.well-known/oauth-protected-resource/api" + ); + } + + #[test] + fn parse_auth_param_value_handles_quoted_string() { + let fragment = r#""example", realm="foo""#; + let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap(); + assert_eq!(parsed.0, "example"); + assert_eq!(parsed.1, 9); + } + + #[test] + fn parse_auth_param_value_handles_escaped_quotes_and_whitespace() { + let fragment = r#" "a\"b\\c" ,next=value"#; + let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap(); + assert_eq!(parsed.0, r#"a"b\c"#); + assert_eq!(parsed.1, 12); + } + + #[test] + fn parse_auth_param_value_handles_token_values() { + let fragment = " token,next"; + let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap(); + assert_eq!(parsed.0, "token"); + assert_eq!(parsed.1, 7); + } + + #[test] + fn parse_auth_param_value_handles_semicolon_separated_tokens() { + let fragment = r#" https://example.com/meta; error="invalid_token""#; + let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap(); + assert_eq!(parsed.0, "https://example.com/meta"); + assert_eq!(&fragment[..parsed.1], " https://example.com/meta"); + } + + #[test] + fn parse_auth_param_value_handles_semicolon_after_quoted_value() { + let fragment = r#" "https://example.com/meta"; error="invalid_token""#; + let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap(); + assert_eq!(parsed.0, "https://example.com/meta"); + assert_eq!(&fragment[..parsed.1], r#" "https://example.com/meta""#); + } + + #[test] + fn parse_auth_param_value_returns_none_for_unterminated_quotes() { + let fragment = r#""unterminated,value"#; + assert!(AuthorizationManager::parse_next_header_value(fragment).is_none()); + } + + #[test] + fn well_known_paths_root() { + let paths = AuthorizationManager::well_known_paths("/", "oauth-authorization-server"); + assert_eq!( + paths, + vec!["/.well-known/oauth-authorization-server".to_string()] + ); + } + + #[test] + fn well_known_paths_with_suffix() { + let paths = AuthorizationManager::well_known_paths("/mcp", "oauth-authorization-server"); + assert_eq!( + paths, + vec![ + "/.well-known/oauth-authorization-server/mcp".to_string(), + "/mcp/.well-known/oauth-authorization-server".to_string(), + "/.well-known/oauth-authorization-server".to_string(), + ] + ); + } + + #[test] + fn well_known_paths_trailing_slash() { + let paths = + AuthorizationManager::well_known_paths("/v1/mcp/", "oauth-authorization-server"); + assert_eq!( + paths, + vec![ + "/.well-known/oauth-authorization-server/v1/mcp".to_string(), + "/v1/mcp/.well-known/oauth-authorization-server".to_string(), + "/.well-known/oauth-authorization-server".to_string(), + ] + ); + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/transport/child_process.rs b/code-rs/third_party/rmcp-0.8.3/src/transport/child_process.rs new file mode 100644 index 00000000000..d117c09d1ef --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/transport/child_process.rs @@ -0,0 +1,309 @@ +use std::process::Stdio; + +use futures::future::Future; +use process_wrap::tokio::{TokioChildWrapper, TokioCommandWrap}; +use tokio::{ + io::AsyncRead, + process::{ChildStderr, ChildStdin, ChildStdout}, +}; + +use super::{RxJsonRpcMessage, Transport, TxJsonRpcMessage, async_rw::AsyncRwTransport}; +use crate::RoleClient; + +const MAX_WAIT_ON_DROP_SECS: u64 = 3; +/// The parts of a child process. +type ChildProcessParts = ( + Box, + ChildStdout, + ChildStdin, + Option, +); + +/// Extract the stdio handles from a spawned child. +/// Returns `(child, stdout, stdin, stderr)` where `stderr` is `Some` only +/// if the process was spawned with `Stdio::piped()`. +#[inline] +fn child_process(mut child: Box) -> std::io::Result { + let child_stdin = match child.inner_mut().stdin().take() { + Some(stdin) => stdin, + None => return Err(std::io::Error::other("stdin was already taken")), + }; + let child_stdout = match child.inner_mut().stdout().take() { + Some(stdout) => stdout, + None => return Err(std::io::Error::other("stdout was already taken")), + }; + let child_stderr = child.inner_mut().stderr().take(); + Ok((child, child_stdout, child_stdin, child_stderr)) +} + +pub struct TokioChildProcess { + child: ChildWithCleanup, + transport: AsyncRwTransport, +} + +pub struct ChildWithCleanup { + inner: Option>, +} + +impl Drop for ChildWithCleanup { + fn drop(&mut self) { + // We should not use start_kill(), instead we should use kill() to avoid zombies + if let Some(mut inner) = self.inner.take() { + // We don't care about the result, just try to kill it + tokio::spawn(async move { + if let Err(e) = Box::into_pin(inner.kill()).await { + tracing::warn!("Error killing child process: {}", e); + } + }); + } + } +} + +// we hold the child process with stdout, for it's easier to implement AsyncRead +pin_project_lite::pin_project! { + pub struct TokioChildProcessOut { + child: ChildWithCleanup, + #[pin] + child_stdout: ChildStdout, + } +} + +impl TokioChildProcessOut { + /// Get the process ID of the child process. + pub fn id(&self) -> Option { + self.child.inner.as_ref()?.id() + } +} + +impl AsyncRead for TokioChildProcessOut { + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + self.project().child_stdout.poll_read(cx, buf) + } +} + +impl TokioChildProcess { + /// Convenience: spawn with default `piped` stdio + pub fn new(command: impl Into) -> std::io::Result { + let (proc, _ignored) = TokioChildProcessBuilder::new(command).spawn()?; + Ok(proc) + } + + /// Builder entry-point allowing fine-grained stdio control. + pub fn builder(command: impl Into) -> TokioChildProcessBuilder { + TokioChildProcessBuilder::new(command) + } + + /// Get the process ID of the child process. + pub fn id(&self) -> Option { + self.child.inner.as_ref()?.id() + } + + /// Gracefully shutdown the child process + /// + /// This will first close the transport to the child process (the server), + /// and wait for the child process to exit normally with a timeout. + /// If the child process doesn't exit within the timeout, it will be killed. + pub async fn graceful_shutdown(&mut self) -> std::io::Result<()> { + if let Some(mut child) = self.child.inner.take() { + self.transport.close().await?; + + let wait_fut = Box::into_pin(child.wait()); + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS)) => { + if let Err(e) = Box::into_pin(child.kill()).await { + tracing::warn!("Error killing child: {e}"); + return Err(e); + } + }, + res = wait_fut => { + match res { + Ok(status) => { + tracing::info!("Child exited gracefully {}", status); + } + Err(e) => { + tracing::warn!("Error waiting for child: {e}"); + return Err(e); + } + } + } + } + } + Ok(()) + } + + /// Take ownership of the inner child process + pub fn into_inner(mut self) -> Option> { + self.child.inner.take() + } + + /// Split this helper into a reader (stdout) and writer (stdin). + #[deprecated( + since = "0.5.0", + note = "use the Transport trait implementation instead" + )] + pub fn split(self) -> (TokioChildProcessOut, ChildStdin) { + unimplemented!("This method is deprecated, use the Transport trait implementation instead"); + } +} + +/// Builder for `TokioChildProcess` allowing custom `Stdio` configuration. +pub struct TokioChildProcessBuilder { + cmd: TokioCommandWrap, + stdin: Stdio, + stdout: Stdio, + stderr: Stdio, +} + +impl TokioChildProcessBuilder { + fn new(cmd: impl Into) -> Self { + Self { + cmd: cmd.into(), + stdin: Stdio::piped(), + stdout: Stdio::piped(), + stderr: Stdio::inherit(), + } + } + + /// Override the child stdin configuration. + pub fn stdin(mut self, io: impl Into) -> Self { + self.stdin = io.into(); + self + } + /// Override the child stdout configuration. + pub fn stdout(mut self, io: impl Into) -> Self { + self.stdout = io.into(); + self + } + /// Override the child stderr configuration. + pub fn stderr(mut self, io: impl Into) -> Self { + self.stderr = io.into(); + self + } + + /// Spawn the child process. Returns the transport plus an optional captured stderr handle. + pub fn spawn(mut self) -> std::io::Result<(TokioChildProcess, Option)> { + self.cmd + .command_mut() + .stdin(self.stdin) + .stdout(self.stdout) + .stderr(self.stderr); + + let (child, stdout, stdin, stderr_opt) = child_process(self.cmd.spawn()?)?; + + let transport = AsyncRwTransport::new(stdout, stdin); + let proc = TokioChildProcess { + child: ChildWithCleanup { inner: Some(child) }, + transport, + }; + Ok((proc, stderr_opt)) + } +} + +impl Transport for TokioChildProcess { + type Error = std::io::Error; + + fn send( + &mut self, + item: TxJsonRpcMessage, + ) -> impl Future> + Send + 'static { + self.transport.send(item) + } + + fn receive(&mut self) -> impl Future>> + Send { + self.transport.receive() + } + + fn close(&mut self) -> impl Future> + Send { + self.graceful_shutdown() + } +} + +pub trait ConfigureCommandExt { + fn configure(self, f: impl FnOnce(&mut Self)) -> Self; +} + +impl ConfigureCommandExt for tokio::process::Command { + fn configure(mut self, f: impl FnOnce(&mut Self)) -> Self { + f(&mut self); + self + } +} + +#[cfg(unix)] +#[cfg(test)] +mod tests { + use tokio::process::Command; + + use super::*; + + #[tokio::test] + async fn test_tokio_child_process_drop() { + let r = TokioChildProcess::new(Command::new("sleep").configure(|cmd| { + cmd.arg("30"); + })); + assert!(r.is_ok()); + let child_process = r.unwrap(); + let id = child_process.id(); + assert!(id.is_some()); + let id = id.unwrap(); + // Drop the child process + drop(child_process); + // Wait a moment to allow the cleanup task to run + tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS + 1)).await; + // Check if the process is still running + let status = Command::new("ps") + .arg("-p") + .arg(id.to_string()) + .status() + .await; + match status { + Ok(status) => { + assert!( + !status.success(), + "Process with PID {} is still running", + id + ); + } + Err(e) => { + panic!("Failed to check process status: {}", e); + } + } + } + + #[tokio::test] + async fn test_tokio_child_process_graceful_shutdown() { + let r = TokioChildProcess::new(Command::new("sleep").configure(|cmd| { + cmd.arg("30"); + })); + assert!(r.is_ok()); + let mut child_process = r.unwrap(); + let id = child_process.id(); + assert!(id.is_some()); + let id = id.unwrap(); + child_process.graceful_shutdown().await.unwrap(); + // Wait a moment to allow the cleanup task to run + tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS + 1)).await; + // Check if the process is still running + let status = Command::new("ps") + .arg("-p") + .arg(id.to_string()) + .status() + .await; + match status { + Ok(status) => { + assert!( + !status.success(), + "Process with PID {} is still running", + id + ); + } + Err(e) => { + panic!("Failed to check process status: {}", e); + } + } + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/transport/common.rs b/code-rs/third_party/rmcp-0.8.3/src/transport/common.rs new file mode 100644 index 00000000000..401c6f2d510 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/transport/common.rs @@ -0,0 +1,19 @@ +#[cfg(any( + feature = "transport-streamable-http-server", + feature = "transport-sse-server" +))] +pub mod server_side_http; + +pub mod http_header; + +#[cfg(feature = "__reqwest")] +#[cfg_attr(docsrs, doc(cfg(feature = "reqwest")))] +mod reqwest; + +#[cfg(feature = "client-side-sse")] +#[cfg_attr(docsrs, doc(cfg(feature = "client-side-sse")))] +pub mod client_side_sse; + +#[cfg(feature = "auth")] +#[cfg_attr(docsrs, doc(cfg(feature = "auth")))] +pub mod auth; diff --git a/code-rs/third_party/rmcp-0.8.3/src/transport/common/auth.rs b/code-rs/third_party/rmcp-0.8.3/src/transport/common/auth.rs new file mode 100644 index 00000000000..5395d571e93 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/transport/common/auth.rs @@ -0,0 +1,7 @@ +#[cfg(feature = "transport-streamable-http-client")] +#[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-client")))] +mod streamable_http_client; + +#[cfg(feature = "transport-sse-client")] +#[cfg_attr(docsrs, doc(cfg(feature = "transport-sse-client")))] +mod sse_client; diff --git a/code-rs/third_party/rmcp-0.8.3/src/transport/common/auth/sse_client.rs b/code-rs/third_party/rmcp-0.8.3/src/transport/common/auth/sse_client.rs new file mode 100644 index 00000000000..009593e1370 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/transport/common/auth/sse_client.rs @@ -0,0 +1,45 @@ +use http::Uri; + +use crate::transport::{ + auth::AuthClient, + sse_client::{SseClient, SseTransportError}, +}; +impl SseClient for AuthClient +where + C: SseClient, +{ + type Error = SseTransportError; + + async fn post_message( + &self, + uri: Uri, + message: crate::model::ClientJsonRpcMessage, + mut auth_token: Option, + ) -> Result<(), SseTransportError> { + if auth_token.is_none() { + auth_token = Some(self.get_access_token().await?); + } + self.http_client + .post_message(uri, message, auth_token) + .await + .map_err(SseTransportError::Client) + } + + async fn get_stream( + &self, + uri: Uri, + last_event_id: Option, + mut auth_token: Option, + ) -> Result< + crate::transport::common::client_side_sse::BoxedSseResponse, + SseTransportError, + > { + if auth_token.is_none() { + auth_token = Some(self.get_access_token().await?); + } + self.http_client + .get_stream(uri, last_event_id, auth_token) + .await + .map_err(SseTransportError::Client) + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/transport/common/auth/streamable_http_client.rs b/code-rs/third_party/rmcp-0.8.3/src/transport/common/auth/streamable_http_client.rs new file mode 100644 index 00000000000..49ebefcd649 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/transport/common/auth/streamable_http_client.rs @@ -0,0 +1,61 @@ +use crate::transport::{ + auth::AuthClient, + streamable_http_client::{StreamableHttpClient, StreamableHttpError}, +}; +impl StreamableHttpClient for AuthClient +where + C: StreamableHttpClient + Send + Sync, +{ + type Error = C::Error; + + async fn delete_session( + &self, + uri: std::sync::Arc, + session_id: std::sync::Arc, + mut auth_token: Option, + ) -> Result<(), crate::transport::streamable_http_client::StreamableHttpError> + { + if auth_token.is_none() { + auth_token = Some(self.get_access_token().await?); + } + self.http_client + .delete_session(uri, session_id, auth_token) + .await + } + + async fn get_stream( + &self, + uri: std::sync::Arc, + session_id: std::sync::Arc, + last_event_id: Option, + mut auth_token: Option, + ) -> Result< + futures::stream::BoxStream<'static, Result>, + crate::transport::streamable_http_client::StreamableHttpError, + > { + if auth_token.is_none() { + auth_token = Some(self.get_access_token().await?); + } + self.http_client + .get_stream(uri, session_id, last_event_id, auth_token) + .await + } + + async fn post_message( + &self, + uri: std::sync::Arc, + message: crate::model::ClientJsonRpcMessage, + session_id: Option>, + mut auth_token: Option, + ) -> Result< + crate::transport::streamable_http_client::StreamableHttpPostResponse, + StreamableHttpError, + > { + if auth_token.is_none() { + auth_token = Some(self.get_access_token().await?); + } + self.http_client + .post_message(uri, message, session_id, auth_token) + .await + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/transport/common/client_side_sse.rs b/code-rs/third_party/rmcp-0.8.3/src/transport/common/client_side_sse.rs new file mode 100644 index 00000000000..4e01994fa45 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/transport/common/client_side_sse.rs @@ -0,0 +1,310 @@ +use std::{ + pin::Pin, + sync::Arc, + task::{Poll, ready}, + time::Duration, +}; + +use futures::{Stream, stream::BoxStream}; +use sse_stream::{Error as SseError, Sse}; + +use crate::model::ServerJsonRpcMessage; + +pub type BoxedSseResponse = BoxStream<'static, Result>; + +pub trait SseRetryPolicy: std::fmt::Debug + Send + Sync { + fn retry(&self, current_times: usize) -> Option; +} + +#[derive(Debug, Clone)] +pub struct FixedInterval { + pub max_times: Option, + pub duration: Duration, +} + +impl SseRetryPolicy for FixedInterval { + fn retry(&self, current_times: usize) -> Option { + if let Some(max_times) = self.max_times { + if current_times >= max_times { + return None; + } + } + Some(self.duration) + } +} + +impl FixedInterval { + pub const DEFAULT_MIN_DURATION: Duration = Duration::from_millis(1000); +} + +impl Default for FixedInterval { + fn default() -> Self { + Self { + max_times: None, + duration: Self::DEFAULT_MIN_DURATION, + } + } +} + +#[derive(Debug, Clone)] +pub struct ExponentialBackoff { + pub max_times: Option, + pub base_duration: Duration, +} + +impl ExponentialBackoff { + pub const DEFAULT_DURATION: Duration = Duration::from_millis(1000); +} + +impl Default for ExponentialBackoff { + fn default() -> Self { + Self { + max_times: None, + base_duration: Self::DEFAULT_DURATION, + } + } +} + +impl SseRetryPolicy for ExponentialBackoff { + fn retry(&self, current_times: usize) -> Option { + if let Some(max_times) = self.max_times { + if current_times >= max_times { + return None; + } + } + Some(self.base_duration * (2u32.pow(current_times as u32))) + } +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct NeverRetry; + +impl SseRetryPolicy for NeverRetry { + fn retry(&self, _current_times: usize) -> Option { + None + } +} + +#[derive(Debug, Default)] +pub struct NeverReconnect { + error: Option, +} + +impl SseStreamReconnect for NeverReconnect { + type Error = E; + type Future = futures::future::Ready>; + fn retry_connection(&mut self, _last_event_id: Option<&str>) -> Self::Future { + futures::future::ready(Err(self.error.take().expect("should not be called again"))) + } +} + +/// Abstraction for SSE reconnection logic. Implementors can hook into +/// [`handle_control_event`](Self::handle_control_event) to consume control +/// frames (e.g. `event: endpoint`) that arrive when a server restarts an SSE +/// stream. The default implementation is a no-op, keeping existing behaviour +/// intact. +pub(crate) trait SseStreamReconnect { + type Error: std::error::Error; + type Future: Future> + Send; + fn retry_connection(&mut self, last_event_id: Option<&str>) -> Self::Future; + fn handle_control_event(&mut self, _event: &Sse) -> Result<(), Self::Error> { + Ok(()) + } + fn handle_stream_error( + &mut self, + error: &(dyn std::error::Error + 'static), + last_event_id: Option<&str>, + ) { + if let Some(id) = last_event_id { + tracing::warn!(%id, "sse stream error: {error}"); + } else { + tracing::warn!("sse stream error: {error}"); + } + } +} + +pin_project_lite::pin_project! { + pub(crate) struct SseAutoReconnectStream + where R: SseStreamReconnect + { + retry_policy: Arc, + last_event_id: Option, + server_retry_interval: Option, + connector: R, + #[pin] + state: SseAutoReconnectStreamState, + } +} + +impl SseAutoReconnectStream { + pub fn new( + stream: BoxedSseResponse, + connector: R, + retry_policy: Arc, + ) -> Self { + Self { + retry_policy, + last_event_id: None, + server_retry_interval: None, + connector, + state: SseAutoReconnectStreamState::Connected { stream }, + } + } +} + +impl SseAutoReconnectStream> { + #[allow(dead_code)] + pub(crate) fn never_reconnect(stream: BoxedSseResponse, error_when_reconnect: E) -> Self { + Self { + retry_policy: Arc::new(NeverRetry), + last_event_id: None, + server_retry_interval: None, + connector: NeverReconnect { + error: Some(error_when_reconnect), + }, + state: SseAutoReconnectStreamState::Connected { stream }, + } + } +} + +pin_project_lite::pin_project! { + #[project = SseAutoReconnectStreamStateProj] + pub enum SseAutoReconnectStreamState { + Connected { + #[pin] + stream: BoxedSseResponse, + }, + Retrying { + retry_times: usize, + #[pin] + retrying: F, + }, + WaitingNextRetry { + #[pin] + sleep: tokio::time::Sleep, + retry_times: usize, + }, + Terminated, + } +} + +impl Stream for SseAutoReconnectStream +where + R: SseStreamReconnect, +{ + type Item = Result; + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let mut this = self.as_mut().project(); + // let this_state = this.state.as_mut().project() + let state = this.state.as_mut().project(); + let next_state = match state { + SseAutoReconnectStreamStateProj::Connected { stream } => { + match ready!(stream.poll_next(cx)) { + Some(Ok(sse)) => { + if let Some(new_server_retry) = sse.retry { + *this.server_retry_interval = + Some(Duration::from_millis(new_server_retry)); + } + if let Some(ref event_id) = sse.id { + *this.last_event_id = Some(event_id.clone()); + } + // Only treat blank/`message` events as JSON-RPC payloads. + // Other control frames (endpoint, ping, etc.) are passed to + // the reconnection handler. + let is_message_event = + matches!(sse.event.as_deref(), None | Some("") | Some("message")); + if !is_message_event { + match this.connector.handle_control_event(&sse) { + Ok(()) => return self.poll_next(cx), + Err(e) => { + this.state.set(SseAutoReconnectStreamState::Terminated); + return Poll::Ready(Some(Err(e))); + } + } + } + if let Some(data) = sse.data { + match serde_json::from_str::(&data) { + Err(e) => { + // Downgrade to debug to avoid noisy logs when servers emit + // non-JSON payloads as message frames. Include last_event_id + // to aid troubleshooting while keeping default behaviour. + let last_id = this.last_event_id.as_deref().unwrap_or(""); + tracing::debug!(last_event_id=%last_id, "failed to deserialize server message: {e}"); + return self.poll_next(cx); + } + Ok(message) => { + return Poll::Ready(Some(Ok(message))); + } + }; + } else { + return self.poll_next(cx); + } + } + Some(Err(e)) => { + this.connector + .handle_stream_error(&e, this.last_event_id.as_deref()); + let retrying = this + .connector + .retry_connection(this.last_event_id.as_deref()); + SseAutoReconnectStreamState::Retrying { + retry_times: 0, + retrying, + } + } + None => { + tracing::debug!("sse stream terminated"); + return Poll::Ready(None); + } + } + } + SseAutoReconnectStreamStateProj::Retrying { + retry_times, + retrying, + } => { + let retry_result = ready!(retrying.poll(cx)); + match retry_result { + Ok(new_stream) => SseAutoReconnectStreamState::Connected { stream: new_stream }, + Err(e) => { + tracing::debug!("retry sse stream error: {e}"); + *retry_times += 1; + if let Some(interval) = this.retry_policy.retry(*retry_times) { + let interval = this + .server_retry_interval + .map(|server_retry_interval| server_retry_interval.max(interval)) + .unwrap_or(interval); + let sleep = tokio::time::sleep(interval); + SseAutoReconnectStreamState::WaitingNextRetry { + sleep, + retry_times: *retry_times, + } + } else { + tracing::error!("sse stream error: {e}, max retry times reached"); + this.state.set(SseAutoReconnectStreamState::Terminated); + return Poll::Ready(Some(Err(e))); + } + } + } + } + SseAutoReconnectStreamStateProj::WaitingNextRetry { sleep, retry_times } => { + ready!(sleep.poll(cx)); + let retrying = this + .connector + .retry_connection(this.last_event_id.as_deref()); + let retry_times = *retry_times; + SseAutoReconnectStreamState::Retrying { + retry_times, + retrying, + } + } + SseAutoReconnectStreamStateProj::Terminated => { + return Poll::Ready(None); + } + }; + // update the state + this.state.set(next_state); + self.poll_next(cx) + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/transport/common/http_header.rs b/code-rs/third_party/rmcp-0.8.3/src/transport/common/http_header.rs new file mode 100644 index 00000000000..84bc7bfb2cd --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/transport/common/http_header.rs @@ -0,0 +1,4 @@ +pub const HEADER_SESSION_ID: &str = "Mcp-Session-Id"; +pub const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id"; +pub const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream"; +pub const JSON_MIME_TYPE: &str = "application/json"; diff --git a/code-rs/third_party/rmcp-0.8.3/src/transport/common/reqwest.rs b/code-rs/third_party/rmcp-0.8.3/src/transport/common/reqwest.rs new file mode 100644 index 00000000000..4f9dc0dc578 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/transport/common/reqwest.rs @@ -0,0 +1,7 @@ +#[cfg(feature = "transport-streamable-http-client-reqwest")] +#[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-client-reqwest")))] +mod streamable_http_client; + +#[cfg(feature = "transport-sse-client-reqwest")] +#[cfg_attr(docsrs, doc(cfg(feature = "transport-sse-client-reqwest")))] +mod sse_client; diff --git a/code-rs/third_party/rmcp-0.8.3/src/transport/common/reqwest/sse_client.rs b/code-rs/third_party/rmcp-0.8.3/src/transport/common/reqwest/sse_client.rs new file mode 100644 index 00000000000..a5362d79cfc --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/transport/common/reqwest/sse_client.rs @@ -0,0 +1,118 @@ +use std::sync::Arc; + +use futures::StreamExt; +use http::Uri; +use reqwest::header::ACCEPT; +use sse_stream::SseStream; + +use crate::transport::{ + SseClientTransport, + common::http_header::{EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID}, + sse_client::{SseClient, SseClientConfig, SseTransportError}, +}; + +impl From for SseTransportError { + fn from(e: reqwest::Error) -> Self { + SseTransportError::Client(e) + } +} + +impl SseClient for reqwest::Client { + type Error = reqwest::Error; + + async fn post_message( + &self, + uri: Uri, + message: crate::model::ClientJsonRpcMessage, + auth_token: Option, + ) -> Result<(), SseTransportError> { + let mut request_builder = self.post(uri.to_string()).json(&message); + if let Some(auth_header) = auth_token { + request_builder = request_builder.bearer_auth(auth_header); + } + request_builder + .send() + .await + .and_then(|resp| resp.error_for_status()) + .map_err(SseTransportError::from) + .map(drop) + } + + async fn get_stream( + &self, + uri: Uri, + last_event_id: Option, + auth_token: Option, + ) -> Result< + crate::transport::common::client_side_sse::BoxedSseResponse, + SseTransportError, + > { + let mut request_builder = self + .get(uri.to_string()) + .header(ACCEPT, EVENT_STREAM_MIME_TYPE); + if let Some(auth_header) = auth_token { + request_builder = request_builder.bearer_auth(auth_header); + } + if let Some(last_event_id) = last_event_id { + request_builder = request_builder.header(HEADER_LAST_EVENT_ID, last_event_id); + } + let response = request_builder.send().await?; + let response = response.error_for_status()?; + match response.headers().get(reqwest::header::CONTENT_TYPE) { + Some(ct) => { + if !ct.as_bytes().starts_with(EVENT_STREAM_MIME_TYPE.as_bytes()) { + return Err(SseTransportError::UnexpectedContentType(Some( + String::from_utf8_lossy(ct.as_bytes()).to_string(), + ))); + } + } + None => { + return Err(SseTransportError::UnexpectedContentType(None)); + } + } + let event_stream = SseStream::from_byte_stream(response.bytes_stream()).boxed(); + Ok(event_stream) + } +} + +impl SseClientTransport { + /// Creates a new transport using reqwest with the specified SSE endpoint. + /// + /// This is a convenience method that creates a transport using the default + /// reqwest client. This method is only available when the + /// `transport-sse-client-reqwest` feature is enabled. + /// + /// # Arguments + /// + /// * `uri` - The SSE endpoint to connect to + /// + /// # Example + /// + /// ```rust + /// use rmcp::transport::SseClientTransport; + /// + /// // Enable the reqwest feature in Cargo.toml: + /// // rmcp = { version = "0.5", features = ["transport-sse-client-reqwest"] } + /// + /// # async fn example() -> Result<(), Box> { + /// let transport = SseClientTransport::start("http://localhost:8000/sse").await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// # Feature requirement + /// + /// This method requires the `transport-sse-client-reqwest` feature. + pub async fn start( + uri: impl Into>, + ) -> Result> { + SseClientTransport::start_with_client( + reqwest::Client::default(), + SseClientConfig { + sse_endpoint: uri.into(), + ..Default::default() + }, + ) + .await + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/transport/common/reqwest/streamable_http_client.rs b/code-rs/third_party/rmcp-0.8.3/src/transport/common/reqwest/streamable_http_client.rs new file mode 100644 index 00000000000..6026e0564f0 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/transport/common/reqwest/streamable_http_client.rs @@ -0,0 +1,198 @@ +use std::{borrow::Cow, sync::Arc}; + +use futures::{StreamExt, stream::BoxStream}; +use http::header::WWW_AUTHENTICATE; +use reqwest::header::ACCEPT; +use sse_stream::{Sse, SseStream}; + +use crate::{ + model::{ClientJsonRpcMessage, ServerJsonRpcMessage}, + transport::{ + common::http_header::{ + EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_SESSION_ID, JSON_MIME_TYPE, + }, + streamable_http_client::*, + }, +}; + +impl From for StreamableHttpError { + fn from(e: reqwest::Error) -> Self { + StreamableHttpError::Client(e) + } +} + +impl StreamableHttpClient for reqwest::Client { + type Error = reqwest::Error; + + async fn get_stream( + &self, + uri: Arc, + session_id: Arc, + last_event_id: Option, + auth_token: Option, + ) -> Result>, StreamableHttpError> { + let mut request_builder = self + .get(uri.as_ref()) + .header(ACCEPT, EVENT_STREAM_MIME_TYPE) + .header(HEADER_SESSION_ID, session_id.as_ref()); + if let Some(last_event_id) = last_event_id { + request_builder = request_builder.header(HEADER_LAST_EVENT_ID, last_event_id); + } + if let Some(auth_header) = auth_token { + request_builder = request_builder.bearer_auth(auth_header); + } + let response = request_builder.send().await?; + if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED { + return Err(StreamableHttpError::ServerDoesNotSupportSse); + } + let response = response.error_for_status()?; + match response.headers().get(reqwest::header::CONTENT_TYPE) { + Some(ct) => { + if !ct.as_bytes().starts_with(EVENT_STREAM_MIME_TYPE.as_bytes()) { + return Err(StreamableHttpError::UnexpectedContentType(Some( + String::from_utf8_lossy(ct.as_bytes()).to_string(), + ))); + } + } + None => { + return Err(StreamableHttpError::UnexpectedContentType(None)); + } + } + let event_stream = SseStream::from_byte_stream(response.bytes_stream()).boxed(); + Ok(event_stream) + } + + async fn delete_session( + &self, + uri: Arc, + session: Arc, + auth_token: Option, + ) -> Result<(), StreamableHttpError> { + let mut request_builder = self.delete(uri.as_ref()); + if let Some(auth_header) = auth_token { + request_builder = request_builder.bearer_auth(auth_header); + } + let response = request_builder + .header(HEADER_SESSION_ID, session.as_ref()) + .send() + .await?; + + // if method no allowed + if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED { + tracing::debug!("this server doesn't support deleting session"); + return Ok(()); + } + let _response = response.error_for_status()?; + Ok(()) + } + + async fn post_message( + &self, + uri: Arc, + message: ClientJsonRpcMessage, + session_id: Option>, + auth_token: Option, + ) -> Result> { + let mut request = self + .post(uri.as_ref()) + .header(ACCEPT, [EVENT_STREAM_MIME_TYPE, JSON_MIME_TYPE].join(", ")); + if let Some(auth_header) = auth_token { + request = request.bearer_auth(auth_header); + } + if let Some(session_id) = session_id { + request = request.header(HEADER_SESSION_ID, session_id.as_ref()); + } + let response = request.json(&message).send().await?; + if response.status() == reqwest::StatusCode::UNAUTHORIZED { + if let Some(header) = response.headers().get(WWW_AUTHENTICATE) { + let header = header + .to_str() + .map_err(|_| { + StreamableHttpError::UnexpectedServerResponse(Cow::from( + "invalid www-authenticate header value", + )) + })? + .to_string(); + return Err(StreamableHttpError::AuthRequired(AuthRequiredError { + www_authenticate_header: header, + })); + } + } + let status = response.status(); + let response = response.error_for_status()?; + if matches!( + status, + reqwest::StatusCode::ACCEPTED | reqwest::StatusCode::NO_CONTENT + ) { + return Ok(StreamableHttpPostResponse::Accepted); + } + let content_type = response.headers().get(reqwest::header::CONTENT_TYPE); + let session_id = response.headers().get(HEADER_SESSION_ID); + let session_id = session_id + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + match content_type { + Some(ct) if ct.as_bytes().starts_with(EVENT_STREAM_MIME_TYPE.as_bytes()) => { + let event_stream = SseStream::from_byte_stream(response.bytes_stream()).boxed(); + Ok(StreamableHttpPostResponse::Sse(event_stream, session_id)) + } + Some(ct) if ct.as_bytes().starts_with(JSON_MIME_TYPE.as_bytes()) => { + let message: ServerJsonRpcMessage = response.json().await?; + Ok(StreamableHttpPostResponse::Json(message, session_id)) + } + _ => { + // unexpected content type + tracing::error!("unexpected content type: {:?}", content_type); + Err(StreamableHttpError::UnexpectedContentType( + content_type.map(|ct| String::from_utf8_lossy(ct.as_bytes()).to_string()), + )) + } + } + } +} + +impl StreamableHttpClientTransport { + /// Creates a new transport using reqwest with the specified URI. + /// + /// This is a convenience method that creates a transport using the default + /// reqwest client. This method is only available when the + /// `transport-streamable-http-client-reqwest` feature is enabled. + /// + /// # Arguments + /// + /// * `uri` - The server URI to connect to + /// + /// # Example + /// + /// ```rust,no_run + /// use rmcp::transport::StreamableHttpClientTransport; + /// + /// // Enable the reqwest feature in Cargo.toml: + /// // rmcp = { version = "0.5", features = ["transport-streamable-http-client-reqwest"] } + /// + /// let transport = StreamableHttpClientTransport::from_uri("http://localhost:8000/mcp"); + /// ``` + /// + /// # Feature requirement + /// + /// This method requires the `transport-streamable-http-client-reqwest` feature. + pub fn from_uri(uri: impl Into>) -> Self { + StreamableHttpClientTransport::with_client( + reqwest::Client::default(), + StreamableHttpClientTransportConfig { + uri: uri.into(), + auth_header: None, + ..Default::default() + }, + ) + } + + /// Build this transport form a config + /// + /// # Arguments + /// + /// * `config` - The config to use with this transport + pub fn from_config(config: StreamableHttpClientTransportConfig) -> Self { + StreamableHttpClientTransport::with_client(reqwest::Client::default(), config) + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/transport/common/server_side_http.rs b/code-rs/third_party/rmcp-0.8.3/src/transport/common/server_side_http.rs new file mode 100644 index 00000000000..693d8b34b7c --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/transport/common/server_side_http.rs @@ -0,0 +1,145 @@ +#![allow(dead_code)] +use std::{convert::Infallible, fmt::Display, sync::Arc, time::Duration}; + +use bytes::{Buf, Bytes}; +use http::Response; +use http_body::Body; +use http_body_util::{BodyExt, Empty, Full, combinators::BoxBody}; +use sse_stream::{KeepAlive, Sse, SseBody}; + +use super::http_header::EVENT_STREAM_MIME_TYPE; +use crate::model::{ClientJsonRpcMessage, ServerJsonRpcMessage}; + +pub type SessionId = Arc; + +pub fn session_id() -> SessionId { + uuid::Uuid::new_v4().to_string().into() +} + +pub const DEFAULT_AUTO_PING_INTERVAL: Duration = Duration::from_secs(15); + +pub(crate) type BoxResponse = Response>; + +pub(crate) fn accepted_response() -> Response> { + Response::builder() + .status(http::StatusCode::ACCEPTED) + .body(Empty::new().boxed()) + .expect("valid response") +} +pin_project_lite::pin_project! { + struct TokioTimer { + #[pin] + sleep: tokio::time::Sleep, + } +} +impl Future for TokioTimer { + type Output = (); + + fn poll( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + let this = self.project(); + this.sleep.poll(cx) + } +} +impl sse_stream::Timer for TokioTimer { + fn from_duration(duration: Duration) -> Self { + Self { + sleep: tokio::time::sleep(duration), + } + } + + fn reset(self: std::pin::Pin<&mut Self>, when: std::time::Instant) { + let this = self.project(); + this.sleep.reset(tokio::time::Instant::from_std(when)); + } +} + +#[derive(Debug, Clone)] +pub struct ServerSseMessage { + pub event_id: Option, + pub message: Arc, +} + +pub(crate) fn sse_stream_response( + stream: impl futures::Stream + Send + Sync + 'static, + keep_alive: Option, +) -> Response> { + use futures::StreamExt; + let stream = SseBody::new(stream.map(|message| { + let data = serde_json::to_string(&message.message).expect("valid message"); + let mut sse = Sse::default().data(data); + sse.id = message.event_id; + Result::::Ok(sse) + })); + let stream = match keep_alive { + Some(duration) => stream + .with_keep_alive::(KeepAlive::new().interval(duration)) + .boxed(), + None => stream.boxed(), + }; + Response::builder() + .status(http::StatusCode::OK) + .header(http::header::CONTENT_TYPE, EVENT_STREAM_MIME_TYPE) + .header(http::header::CACHE_CONTROL, "no-cache") + .body(stream) + .expect("valid response") +} + +pub(crate) const fn internal_error_response( + context: &str, +) -> impl FnOnce(E) -> Response> { + move |error| { + tracing::error!("Internal server error when {context}: {error}"); + Response::builder() + .status(http::StatusCode::INTERNAL_SERVER_ERROR) + .body( + Full::new(Bytes::from(format!( + "Encounter an error when {context}: {error}" + ))) + .boxed(), + ) + .expect("valid response") + } +} + +pub(crate) fn unexpected_message_response(expect: &str) -> Response> { + Response::builder() + .status(http::StatusCode::UNPROCESSABLE_ENTITY) + .body(Full::new(Bytes::from(format!("Unexpected message, expect {expect}"))).boxed()) + .expect("valid response") +} + +pub(crate) async fn expect_json( + body: B, +) -> Result>> +where + B: Body + Send + 'static, + B::Error: Display, +{ + match body.collect().await { + Ok(bytes) => { + match serde_json::from_reader::<_, ClientJsonRpcMessage>(bytes.aggregate().reader()) { + Ok(message) => Ok(message), + Err(e) => { + let response = Response::builder() + .status(http::StatusCode::UNSUPPORTED_MEDIA_TYPE) + .body( + Full::new(Bytes::from(format!("fail to deserialize request body {e}"))) + .boxed(), + ) + .expect("valid response"); + Err(response) + } + } + } + Err(e) => { + let response = Response::builder() + .status(http::StatusCode::INTERNAL_SERVER_ERROR) + .body(Full::new(Bytes::from(format!("Failed to read request body: {e}"))).boxed()) + .expect("valid response"); + Err(response) + } + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/transport/io.rs b/code-rs/third_party/rmcp-0.8.3/src/transport/io.rs new file mode 100644 index 00000000000..adc2dba0ef6 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/transport/io.rs @@ -0,0 +1,6 @@ +/// # StdIO Transport +/// +/// Create a pair of [`tokio::io::Stdin`] and [`tokio::io::Stdout`]. +pub fn stdio() -> (tokio::io::Stdin, tokio::io::Stdout) { + (tokio::io::stdin(), tokio::io::stdout()) +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/transport/sink_stream.rs b/code-rs/third_party/rmcp-0.8.3/src/transport/sink_stream.rs new file mode 100644 index 00000000000..f31743922ba --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/transport/sink_stream.rs @@ -0,0 +1,78 @@ +use std::sync::Arc; + +use futures::{Sink, Stream}; +use tokio::sync::Mutex; + +use super::{IntoTransport, Transport}; +use crate::service::{RxJsonRpcMessage, ServiceRole, TxJsonRpcMessage}; + +pub struct SinkStreamTransport { + stream: St, + sink: Arc>, +} + +impl SinkStreamTransport { + pub fn new(sink: Si, stream: St) -> Self { + Self { + stream, + sink: Arc::new(Mutex::new(sink)), + } + } +} + +impl Transport for SinkStreamTransport +where + St: Send + Stream> + Unpin, + Si: Send + Sink> + Unpin + 'static, + Si::Error: std::error::Error + Send + Sync + 'static, +{ + type Error = Si::Error; + + fn send( + &mut self, + item: TxJsonRpcMessage, + ) -> impl Future> + Send + 'static { + use futures::SinkExt; + let lock = self.sink.clone(); + async move { + let mut write = lock.lock().await; + write.send(item).await + } + } + + fn receive(&mut self) -> impl Future>> { + use futures::StreamExt; + self.stream.next() + } + + async fn close(&mut self) -> Result<(), Self::Error> { + Ok(()) + } +} + +pub enum TransportAdapterSinkStream {} + +impl IntoTransport for (Si, St) +where + Role: ServiceRole, + Si: Send + Sink> + Unpin + 'static, + St: Send + Stream> + Unpin + 'static, + Si::Error: std::error::Error + Send + Sync + 'static, +{ + fn into_transport(self) -> impl Transport + 'static { + SinkStreamTransport::new(self.0, self.1) + } +} + +pub enum TransportAdapterAsyncCombinedRW {} +impl IntoTransport for S +where + Role: ServiceRole, + S: Sink> + Stream> + Send + 'static, + S::Error: std::error::Error + Send + Sync + 'static, +{ + fn into_transport(self) -> impl Transport + 'static { + use futures::StreamExt; + IntoTransport::::into_transport(self.split()) + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/transport/sse_client.rs b/code-rs/third_party/rmcp-0.8.3/src/transport/sse_client.rs new file mode 100644 index 00000000000..7b61e3d8657 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/transport/sse_client.rs @@ -0,0 +1,458 @@ +//! Reference: +use std::{ + pin::Pin, + sync::{Arc, RwLock}, +}; + +use futures::{StreamExt, future::BoxFuture}; +use http::Uri; +use sse_stream::{Error as SseError, Sse}; +use thiserror::Error; + +use super::{ + Transport, + common::client_side_sse::{BoxedSseResponse, SseRetryPolicy, SseStreamReconnect}, +}; +use crate::{ + RoleClient, + model::{ClientJsonRpcMessage, ServerJsonRpcMessage}, + transport::common::client_side_sse::SseAutoReconnectStream, +}; + +#[derive(Error, Debug)] +pub enum SseTransportError { + #[error("SSE error: {0}")] + Sse(#[from] SseError), + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + #[error("Client error: {0}")] + Client(E), + #[error("unexpected end of stream")] + UnexpectedEndOfStream, + #[error("Unexpected content type: {0:?}")] + UnexpectedContentType(Option), + #[cfg(feature = "auth")] + #[cfg_attr(docsrs, doc(cfg(feature = "auth")))] + #[error("Auth error: {0}")] + Auth(#[from] crate::transport::auth::AuthError), + #[error("Invalid uri: {0}")] + InvalidUri(#[from] http::uri::InvalidUri), + #[error("Invalid uri parts: {0}")] + InvalidUriParts(#[from] http::uri::InvalidUriParts), +} + +pub trait SseClient: Clone + Send + Sync + 'static { + type Error: std::error::Error + Send + Sync + 'static; + fn post_message( + &self, + uri: Uri, + message: ClientJsonRpcMessage, + auth_token: Option, + ) -> impl Future>> + Send + '_; + fn get_stream( + &self, + uri: Uri, + last_event_id: Option, + auth_token: Option, + ) -> impl Future>> + Send + '_; +} + +/// Helper that refreshes the POST endpoint whenever the server emits +/// control frames during SSE reconnect; used together with +/// [`SseAutoReconnectStream`]. +struct SseClientReconnect { + pub client: C, + pub uri: Uri, + pub message_endpoint: Arc>, +} + +impl SseStreamReconnect for SseClientReconnect { + type Error = SseTransportError; + type Future = BoxFuture<'static, Result>; + fn retry_connection(&mut self, last_event_id: Option<&str>) -> Self::Future { + let client = self.client.clone(); + let uri = self.uri.clone(); + let last_event_id = last_event_id.map(|s| s.to_owned()); + Box::pin(async move { client.get_stream(uri, last_event_id, None).await }) + } + + fn handle_control_event(&mut self, event: &Sse) -> Result<(), Self::Error> { + if event.event.as_deref() != Some("endpoint") { + return Ok(()); + } + let Some(data) = event.data.as_ref() else { + return Ok(()); + }; + // Servers typically resend the message POST endpoint (often with a new + // sessionId) when a stream reconnects. Reuse `message_endpoint` helper + // to resolve it and update the shared URI. + let new_endpoint = message_endpoint(self.uri.clone(), data.clone()) + .map_err(SseTransportError::InvalidUri)?; + *self + .message_endpoint + .write() + .expect("message endpoint lock poisoned") = new_endpoint; + Ok(()) + } + + fn handle_stream_error( + &mut self, + error: &(dyn std::error::Error + 'static), + last_event_id: Option<&str>, + ) { + tracing::warn!( + uri = %self.uri, + last_event_id = last_event_id.unwrap_or(""), + "sse stream error: {error}" + ); + } +} +type ServerMessageStream = Pin>>>; + +/// A client-agnostic SSE transport for RMCP that supports Server-Sent Events. +/// +/// This transport allows you to choose your preferred HTTP client implementation +/// by implementing the [`SseClient`] trait. The transport handles SSE streaming +/// and automatic reconnection. +/// +/// # Usage +/// +/// ## Using reqwest +/// +/// ```rust,ignore +/// use rmcp::transport::SseClientTransport; +/// +/// // Enable the reqwest feature in Cargo.toml: +/// // rmcp = { version = "0.5", features = ["transport-sse-client-reqwest"] } +/// +/// # async fn example() -> Result<(), Box> { +/// let transport = SseClientTransport::start("http://localhost:8000/sse").await?; +/// # Ok(()) +/// # } +/// ``` +/// +/// ## Using a custom HTTP client +/// +/// ```rust,ignore +/// use rmcp::transport::sse_client::{SseClient, SseClientTransport, SseClientConfig}; +/// use std::sync::Arc; +/// use futures::stream::BoxStream; +/// use rmcp::model::ClientJsonRpcMessage; +/// use sse_stream::{Sse, Error as SseError}; +/// use http::Uri; +/// +/// #[derive(Clone)] +/// struct MyHttpClient; +/// +/// #[derive(Debug, thiserror::Error)] +/// struct MyError; +/// +/// impl std::fmt::Display for MyError { +/// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +/// write!(f, "MyError") +/// } +/// } +/// +/// impl SseClient for MyHttpClient { +/// type Error = MyError; +/// +/// async fn post_message( +/// &self, +/// _uri: Uri, +/// _message: ClientJsonRpcMessage, +/// _auth_token: Option, +/// ) -> Result<(), rmcp::transport::sse_client::SseTransportError> { +/// todo!() +/// } +/// +/// async fn get_stream( +/// &self, +/// _uri: Uri, +/// _last_event_id: Option, +/// _auth_token: Option, +/// ) -> Result>, rmcp::transport::sse_client::SseTransportError> { +/// todo!() +/// } +/// } +/// +/// # async fn example() -> Result<(), Box> { +/// let config = SseClientConfig { +/// sse_endpoint: "http://localhost:8000/sse".into(), +/// ..Default::default() +/// }; +/// let transport = SseClientTransport::start_with_client(MyHttpClient, config).await?; +/// # Ok(()) +/// # } +/// ``` +/// +/// # Feature Flags +/// +/// - `transport-sse-client`: Base feature providing the generic transport infrastructure +/// - `transport-sse-client-reqwest`: Includes reqwest HTTP client support with convenience methods +pub struct SseClientTransport { + client: C, + config: SseClientConfig, + /// Current POST endpoint; refreshed when the server sends new endpoint + /// control frames. + message_endpoint: Arc>, + stream: Option>, +} + +impl Transport for SseClientTransport { + type Error = SseTransportError; + async fn receive(&mut self) -> Option { + self.stream.as_mut()?.next().await?.ok() + } + fn send( + &mut self, + item: crate::service::TxJsonRpcMessage, + ) -> impl Future> + Send + 'static { + let client = self.client.clone(); + let message_endpoint = self.message_endpoint.clone(); + async move { + let uri = { + let guard = message_endpoint + .read() + .expect("message endpoint lock poisoned"); + guard.clone() + }; + client.post_message(uri, item, None).await + } + } + async fn close(&mut self) -> Result<(), Self::Error> { + self.stream.take(); + Ok(()) + } +} + +impl std::fmt::Debug for SseClientTransport { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SseClientWorker") + .field("client", &self.client) + .field("config", &self.config) + .finish() + } +} + +impl SseClientTransport { + pub async fn start_with_client( + client: C, + config: SseClientConfig, + ) -> Result> { + let sse_endpoint = config.sse_endpoint.as_ref().parse::()?; + + let mut sse_stream = client.get_stream(sse_endpoint.clone(), None, None).await?; + let initial_message_endpoint = if let Some(endpoint) = config.use_message_endpoint.clone() { + let ep = endpoint.parse::()?; + let mut sse_endpoint_parts = sse_endpoint.clone().into_parts(); + sse_endpoint_parts.path_and_query = ep.into_parts().path_and_query; + Uri::from_parts(sse_endpoint_parts)? + } else { + // wait the endpoint event + loop { + let sse = sse_stream + .next() + .await + .ok_or(SseTransportError::UnexpectedEndOfStream)??; + let Some("endpoint") = sse.event.as_deref() else { + continue; + }; + let ep = sse.data.unwrap_or_default(); + + break message_endpoint(sse_endpoint.clone(), ep)?; + } + }; + let message_endpoint = Arc::new(RwLock::new(initial_message_endpoint)); + + let stream = Box::pin(SseAutoReconnectStream::new( + sse_stream, + SseClientReconnect { + client: client.clone(), + uri: sse_endpoint.clone(), + message_endpoint: message_endpoint.clone(), + }, + config.retry_policy.clone(), + )); + Ok(Self { + client, + config, + message_endpoint, + stream: Some(stream), + }) + } +} + +fn message_endpoint(base: http::Uri, endpoint: String) -> Result { + // If endpoint is a full URL, parse and return it directly + if endpoint.starts_with("http://") || endpoint.starts_with("https://") { + return endpoint.parse::(); + } + + let mut base_parts = base.into_parts(); + let endpoint_clone = endpoint.clone(); + + if endpoint.starts_with("?") { + // Query only - keep base path and append query + if let Some(base_path_and_query) = &base_parts.path_and_query { + let base_path = base_path_and_query.path(); + base_parts.path_and_query = Some(format!("{}{}", base_path, endpoint).parse()?); + } else { + base_parts.path_and_query = Some(format!("/{}", endpoint).parse()?); + } + } else { + // Path (with optional query) - replace entire path_and_query + let path_to_use = if endpoint.starts_with("/") { + endpoint // Use absolute path as-is + } else { + format!("/{}", endpoint) // Make relative path absolute + }; + base_parts.path_and_query = Some(path_to_use.parse()?); + } + + http::Uri::from_parts(base_parts).map_err(|_| endpoint_clone.parse::().unwrap_err()) +} + +#[derive(Debug, Clone)] +pub struct SseClientConfig { + /// client sse endpoint + /// + /// # How this client resolve the message endpoint + /// if sse_endpoint has this format: ``, + /// then the message endpoint will be ``. + /// + /// For example, if you config the sse_endpoint as `http://example.com/some_path/sse`, + /// and the server send the message endpoint event as `message?session_id=123`, + /// then the message endpoint will be `http://example.com/message`. + /// + /// This follows the rules of JavaScript's [`new URL(url, base)`](https://developer.mozilla.org/en-US/docs/Web/API/URL/URL) + pub sse_endpoint: Arc, + pub retry_policy: Arc, + /// if this is settled, the client will use this endpoint to send message and skip get the endpoint event + pub use_message_endpoint: Option, +} + +impl Default for SseClientConfig { + fn default() -> Self { + Self { + sse_endpoint: "".into(), + retry_policy: Arc::new(super::common::client_side_sse::FixedInterval::default()), + use_message_endpoint: None, + } + } +} + +#[cfg(test)] +mod tests { + use futures::StreamExt; + use serde_json::{Value, json}; + + use super::*; + + #[derive(Clone)] + struct DummyClient; + + #[derive(Debug, thiserror::Error)] + #[error("dummy error")] + struct DummyError; + + impl SseClient for DummyClient { + type Error = DummyError; + + async fn post_message( + &self, + _uri: Uri, + _message: ClientJsonRpcMessage, + _auth_token: Option, + ) -> Result<(), SseTransportError> { + Ok(()) + } + + async fn get_stream( + &self, + _uri: Uri, + _last_event_id: Option, + _auth_token: Option, + ) -> Result> { + unreachable!("get_stream should not be called in this test") + } + } + + #[test] + fn test_message_endpoint() { + let base_url = "https://localhost/sse".parse::().unwrap(); + + // Query only + let result = message_endpoint(base_url.clone(), "?sessionId=x".to_string()).unwrap(); + assert_eq!(result.to_string(), "https://localhost/sse?sessionId=x"); + + // Relative path with query + let result = message_endpoint(base_url.clone(), "mypath?sessionId=x".to_string()).unwrap(); + assert_eq!(result.to_string(), "https://localhost/mypath?sessionId=x"); + + // Absolute path with query + let result = message_endpoint(base_url.clone(), "/xxx?sessionId=x".to_string()).unwrap(); + assert_eq!(result.to_string(), "https://localhost/xxx?sessionId=x"); + + // Full URL + let result = message_endpoint( + base_url.clone(), + "http://example.com/xxx?sessionId=x".to_string(), + ) + .unwrap(); + assert_eq!(result.to_string(), "http://example.com/xxx?sessionId=x"); + } + + #[test] + fn handle_endpoint_control_event_updates_uri() { + let initial_endpoint = "https://example.com/message?sessionId=old" + .parse::() + .unwrap(); + let shared_endpoint = Arc::new(RwLock::new(initial_endpoint)); + let mut reconnect = SseClientReconnect { + client: DummyClient, + uri: "https://example.com/sse".parse::().unwrap(), + message_endpoint: shared_endpoint.clone(), + }; + + let control_event = Sse::default() + .event("endpoint") + .data("/message?sessionId=new"); + + reconnect.handle_control_event(&control_event).unwrap(); + + let guard = shared_endpoint.read().expect("lock poisoned"); + assert_eq!( + guard.to_string(), + "https://example.com/message?sessionId=new" + ); + } + + #[tokio::test] + async fn control_event_frames_are_skipped() { + let payload = json!({ + "jsonrpc": "2.0", + "id": 1, + "result": {"ok": true} + }) + .to_string(); + + let events = vec![ + Ok(Sse::default() + .event("endpoint") + .data("/message?sessionId=reconnect")), + Ok(Sse::default().event("message").data(payload.clone())), + ]; + + let sse_src: BoxedSseResponse = futures::stream::iter(events).boxed(); + let reconn_stream = SseAutoReconnectStream::never_reconnect(sse_src, DummyError); + futures::pin_mut!(reconn_stream); + + let message = reconn_stream.next().await.expect("stream item").unwrap(); + let actual: Value = serde_json::to_value(message).expect("serialize actual message"); + // We only need to assert that a valid JSON-RPC response came through after + // skipping control frames. The exact `result` shape depends on the SDK's + // typed result enums and is not asserted here. + assert_eq!(actual.get("jsonrpc"), Some(&Value::String("2.0".into()))); + assert_eq!(actual.get("id"), Some(&Value::Number(1u64.into()))); + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/transport/sse_server.rs b/code-rs/third_party/rmcp-0.8.3/src/transport/sse_server.rs new file mode 100644 index 00000000000..15a65cb52ac --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/transport/sse_server.rs @@ -0,0 +1,343 @@ +use std::{collections::HashMap, io, net::SocketAddr, sync::Arc, time::Duration}; + +use axum::{ + Extension, Json, Router, + extract::{NestedPath, Query, State}, + http::{StatusCode, request::Parts}, + response::{ + Response, + sse::{Event, KeepAlive, Sse}, + }, + routing::{get, post}, +}; +use futures::{Sink, SinkExt, Stream}; +use tokio_stream::wrappers::ReceiverStream; +use tokio_util::sync::{CancellationToken, PollSender}; +use tracing::Instrument; + +use crate::{ + RoleServer, Service, + model::ClientJsonRpcMessage, + service::{RxJsonRpcMessage, TxJsonRpcMessage, serve_directly_with_ct}, + transport::common::server_side_http::{DEFAULT_AUTO_PING_INTERVAL, SessionId, session_id}, +}; + +type TxStore = + Arc>>>; +pub type TransportReceiver = ReceiverStream>; + +#[derive(Clone)] +struct App { + txs: TxStore, + transport_tx: tokio::sync::mpsc::UnboundedSender, + post_path: Arc, + sse_ping_interval: Duration, +} + +impl App { + pub fn new( + post_path: String, + sse_ping_interval: Duration, + ) -> ( + Self, + tokio::sync::mpsc::UnboundedReceiver, + ) { + let (transport_tx, transport_rx) = tokio::sync::mpsc::unbounded_channel(); + ( + Self { + txs: Default::default(), + transport_tx, + post_path: post_path.into(), + sse_ping_interval, + }, + transport_rx, + ) + } +} + +#[derive(Debug, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PostEventQuery { + pub session_id: String, +} + +async fn post_event_handler( + State(app): State, + Query(PostEventQuery { session_id }): Query, + parts: Parts, + Json(mut message): Json, +) -> Result { + tracing::debug!(session_id, ?parts, ?message, "new client message"); + let tx = { + let rg = app.txs.read().await; + rg.get(session_id.as_str()) + .ok_or(StatusCode::NOT_FOUND)? + .clone() + }; + message.insert_extension(parts); + if tx.send(message).await.is_err() { + tracing::error!("send message error"); + return Err(StatusCode::GONE); + } + Ok(StatusCode::ACCEPTED) +} + +async fn sse_handler( + State(app): State, + nested_path: Option>, + parts: Parts, +) -> Result>>, Response> { + let session = session_id(); + tracing::info!(%session, ?parts, "sse connection"); + use tokio_stream::{StreamExt, wrappers::ReceiverStream}; + use tokio_util::sync::PollSender; + let (from_client_tx, from_client_rx) = tokio::sync::mpsc::channel(64); + let (to_client_tx, to_client_rx) = tokio::sync::mpsc::channel(64); + let to_client_tx_clone = to_client_tx.clone(); + + app.txs + .write() + .await + .insert(session.clone(), from_client_tx); + let session = session.clone(); + let stream = ReceiverStream::new(from_client_rx); + let sink = PollSender::new(to_client_tx); + let transport = SseServerTransport { + stream, + sink, + session_id: session.clone(), + tx_store: app.txs.clone(), + }; + let transport_send_result = app.transport_tx.send(transport); + if transport_send_result.is_err() { + tracing::warn!("send transport out error"); + let mut response = + Response::new("fail to send out transport, it seems server is closed".to_string()); + *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + return Err(response); + } + let nested_path = nested_path.as_deref().map(NestedPath::as_str).unwrap_or(""); + let post_path = app.post_path.as_ref(); + let ping_interval = app.sse_ping_interval; + let stream = futures::stream::once(futures::future::ok( + Event::default() + .event("endpoint") + .data(format!("{nested_path}{post_path}?sessionId={session}")), + )) + .chain(ReceiverStream::new(to_client_rx).map(|message| { + match serde_json::to_string(&message) { + Ok(bytes) => Ok(Event::default().event("message").data(&bytes)), + Err(e) => Err(io::Error::new(io::ErrorKind::InvalidData, e)), + } + })); + + tokio::spawn(async move { + // Wait for connection closure + to_client_tx_clone.closed().await; + + // Clean up session + let session_id = session.clone(); + let tx_store = app.txs.clone(); + let mut txs = tx_store.write().await; + txs.remove(&session_id); + tracing::debug!(%session_id, "Closed session and cleaned up resources"); + }); + + Ok(Sse::new(stream).keep_alive(KeepAlive::new().interval(ping_interval))) +} + +pub struct SseServerTransport { + stream: ReceiverStream>, + sink: PollSender>, + session_id: SessionId, + tx_store: TxStore, +} + +impl Sink> for SseServerTransport { + type Error = io::Error; + + fn poll_ready( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.sink + .poll_ready_unpin(cx) + .map_err(std::io::Error::other) + } + + fn start_send( + mut self: std::pin::Pin<&mut Self>, + item: TxJsonRpcMessage, + ) -> Result<(), Self::Error> { + self.sink + .start_send_unpin(item) + .map_err(std::io::Error::other) + } + + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.sink + .poll_flush_unpin(cx) + .map_err(std::io::Error::other) + } + + fn poll_close( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let inner_close_result = self + .sink + .poll_close_unpin(cx) + .map_err(std::io::Error::other); + if inner_close_result.is_ready() { + let session_id = self.session_id.clone(); + let tx_store = self.tx_store.clone(); + tokio::spawn(async move { + tx_store.write().await.remove(&session_id); + }); + } + inner_close_result + } +} + +impl Stream for SseServerTransport { + type Item = RxJsonRpcMessage; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + use futures::StreamExt; + self.stream.poll_next_unpin(cx) + } +} + +#[derive(Debug, Clone)] +pub struct SseServerConfig { + pub bind: SocketAddr, + pub sse_path: String, + pub post_path: String, + pub ct: CancellationToken, + pub sse_keep_alive: Option, +} + +#[derive(Debug)] +pub struct SseServer { + transport_rx: tokio::sync::mpsc::UnboundedReceiver, + pub config: SseServerConfig, +} + +impl SseServer { + pub async fn serve(bind: SocketAddr) -> io::Result { + Self::serve_with_config(SseServerConfig { + bind, + sse_path: "/sse".to_string(), + post_path: "/message".to_string(), + ct: CancellationToken::new(), + sse_keep_alive: None, + }) + .await + } + pub async fn serve_with_config(config: SseServerConfig) -> io::Result { + let (sse_server, service) = Self::new(config); + let listener = tokio::net::TcpListener::bind(sse_server.config.bind).await?; + let ct = sse_server.config.ct.child_token(); + let server = axum::serve(listener, service).with_graceful_shutdown(async move { + ct.cancelled().await; + tracing::info!("sse server cancelled"); + }); + tokio::spawn( + async move { + if let Err(e) = server.await { + tracing::error!(error = %e, "sse server shutdown with error"); + } + } + .instrument(tracing::info_span!("sse-server", bind_address = %sse_server.config.bind)), + ); + Ok(sse_server) + } + + pub fn new(config: SseServerConfig) -> (SseServer, Router) { + let (app, transport_rx) = App::new( + config.post_path.clone(), + config.sse_keep_alive.unwrap_or(DEFAULT_AUTO_PING_INTERVAL), + ); + let router = Router::new() + .route(&config.sse_path, get(sse_handler)) + .route(&config.post_path, post(post_event_handler)) + .with_state(app); + + let server = SseServer { + transport_rx, + config, + }; + + (server, router) + } + + pub fn with_service(mut self, service_provider: F) -> CancellationToken + where + S: Service, + F: Fn() -> S + Send + 'static, + { + use crate::service::ServiceExt; + let ct = self.config.ct.clone(); + tokio::spawn(async move { + while let Some(transport) = self.next_transport().await { + let service = service_provider(); + let ct = self.config.ct.child_token(); + tokio::spawn(async move { + let server = service + .serve_with_ct(transport, ct) + .await + .map_err(std::io::Error::other)?; + server.waiting().await?; + tokio::io::Result::Ok(()) + }); + } + }); + ct + } + + /// This allows you to skip the initialization steps for incoming request. + pub fn with_service_directly(mut self, service_provider: F) -> CancellationToken + where + S: Service, + F: Fn() -> S + Send + 'static, + { + let ct = self.config.ct.clone(); + tokio::spawn(async move { + while let Some(transport) = self.next_transport().await { + let service = service_provider(); + let ct = self.config.ct.child_token(); + tokio::spawn(async move { + let server = serve_directly_with_ct(service, transport, None, ct); + server.waiting().await?; + tokio::io::Result::Ok(()) + }); + } + }); + ct + } + + pub fn cancel(&self) { + self.config.ct.cancel(); + } + + pub async fn next_transport(&mut self) -> Option { + self.transport_rx.recv().await + } +} + +impl Stream for SseServer { + type Item = SseServerTransport; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.transport_rx.poll_recv(cx) + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/transport/streamable_http_client.rs b/code-rs/third_party/rmcp-0.8.3/src/transport/streamable_http_client.rs new file mode 100644 index 00000000000..8d076c71f1e --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/transport/streamable_http_client.rs @@ -0,0 +1,749 @@ +use std::{borrow::Cow, sync::Arc, time::Duration}; + +use futures::{Stream, StreamExt, future::BoxFuture, stream::BoxStream}; +pub use sse_stream::Error as SseError; +use sse_stream::Sse; +use thiserror::Error; +use tokio_util::sync::CancellationToken; +use tracing::debug; + +use super::common::client_side_sse::{ExponentialBackoff, SseRetryPolicy, SseStreamReconnect}; +use crate::{ + RoleClient, + model::{ClientJsonRpcMessage, ServerJsonRpcMessage}, + transport::{ + common::client_side_sse::SseAutoReconnectStream, + worker::{Worker, WorkerQuitReason, WorkerSendRequest, WorkerTransport}, + }, +}; + +type BoxedSseStream = BoxStream<'static, Result>; + +#[derive(Debug)] +pub struct AuthRequiredError { + pub www_authenticate_header: String, +} + +#[derive(Error, Debug)] +pub enum StreamableHttpError { + #[error("SSE error: {0}")] + Sse(#[from] SseError), + #[error("Io error: {0}")] + Io(#[from] std::io::Error), + #[error("Client error: {0}")] + Client(E), + #[error("unexpected end of stream")] + UnexpectedEndOfStream, + #[error("unexpected server response: {0}")] + UnexpectedServerResponse(Cow<'static, str>), + #[error("Unexpected content type: {0:?}")] + UnexpectedContentType(Option), + #[error("Server does not support SSE")] + ServerDoesNotSupportSse, + #[error("Server does not support delete session")] + ServerDoesNotSupportDeleteSession, + #[error("Tokio join error: {0}")] + TokioJoinError(#[from] tokio::task::JoinError), + #[error("Deserialize error: {0}")] + Deserialize(#[from] serde_json::Error), + #[error("Transport channel closed")] + TransportChannelClosed, + #[error("Missing session id in HTTP response")] + MissingSessionIdInResponse, + #[cfg(feature = "auth")] + #[cfg_attr(docsrs, doc(cfg(feature = "auth")))] + #[error("Auth error: {0}")] + Auth(#[from] crate::transport::auth::AuthError), + #[error("Auth required")] + AuthRequired(AuthRequiredError), +} + +#[derive(Debug, Clone, Error)] +pub enum StreamableHttpProtocolError { + #[error("Missing session id in response")] + MissingSessionIdInResponse, +} + +#[allow(clippy::large_enum_variant)] +pub enum StreamableHttpPostResponse { + Accepted, + Json(ServerJsonRpcMessage, Option), + Sse(BoxedSseStream, Option), +} + +impl std::fmt::Debug for StreamableHttpPostResponse { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Accepted => write!(f, "Accepted"), + Self::Json(arg0, arg1) => f.debug_tuple("Json").field(arg0).field(arg1).finish(), + Self::Sse(_, arg1) => f.debug_tuple("Sse").field(arg1).finish(), + } + } +} + +impl StreamableHttpPostResponse { + pub async fn expect_initialized( + self, + ) -> Result<(ServerJsonRpcMessage, Option), StreamableHttpError> + where + E: std::error::Error + Send + Sync + 'static, + { + match self { + Self::Json(message, session_id) => Ok((message, session_id)), + Self::Sse(mut stream, session_id) => { + while let Some(event) = stream.next().await { + let event = event?; + let payload = event.data.unwrap_or_default(); + if payload.trim().is_empty() { + continue; + } + + let message: ServerJsonRpcMessage = serde_json::from_str(&payload)?; + + if matches!(message, ServerJsonRpcMessage::Response(_)) { + return Ok((message, session_id)); + } + + debug!( + ?message, + "received message before initialize response; continuing to drain stream" + ); + } + + Err(StreamableHttpError::UnexpectedServerResponse( + "empty sse stream".into(), + )) + } + _ => Err(StreamableHttpError::UnexpectedServerResponse( + "expect initialized, accepted".into(), + )), + } + } + + pub fn expect_json(self) -> Result> + where + E: std::error::Error + Send + Sync + 'static, + { + match self { + Self::Json(message, ..) => Ok(message), + got => Err(StreamableHttpError::UnexpectedServerResponse( + format!("expect json, got {got:?}").into(), + )), + } + } + + pub fn expect_accepted(self) -> Result<(), StreamableHttpError> + where + E: std::error::Error + Send + Sync + 'static, + { + match self { + Self::Accepted => Ok(()), + got => Err(StreamableHttpError::UnexpectedServerResponse( + format!("expect accepted, got {got:?}").into(), + )), + } + } +} + +pub trait StreamableHttpClient: Clone + Send + 'static { + type Error: std::error::Error + Send + Sync + 'static; + fn post_message( + &self, + uri: Arc, + message: ClientJsonRpcMessage, + session_id: Option>, + auth_header: Option, + ) -> impl Future>> + + Send + + '_; + fn delete_session( + &self, + uri: Arc, + session_id: Arc, + auth_header: Option, + ) -> impl Future>> + Send + '_; + fn get_stream( + &self, + uri: Arc, + session_id: Arc, + last_event_id: Option, + auth_header: Option, + ) -> impl Future< + Output = Result< + BoxStream<'static, Result>, + StreamableHttpError, + >, + > + Send + + '_; +} + +pub struct RetryConfig { + pub max_times: Option, + pub min_duration: Duration, +} + +struct StreamableHttpClientReconnect { + pub client: C, + pub session_id: Arc, + pub uri: Arc, + pub auth_header: Option, +} + +impl SseStreamReconnect for StreamableHttpClientReconnect { + type Error = StreamableHttpError; + type Future = BoxFuture<'static, Result>; + fn retry_connection(&mut self, last_event_id: Option<&str>) -> Self::Future { + let client = self.client.clone(); + let uri = self.uri.clone(); + let session_id = self.session_id.clone(); + let auth_header = self.auth_header.clone(); + let last_event_id = last_event_id.map(|s| s.to_owned()); + Box::pin(async move { + client + .get_stream(uri, session_id, last_event_id, auth_header) + .await + }) + } +} + +#[derive(Debug, Clone, Default)] +pub struct StreamableHttpClientWorker { + pub client: C, + pub config: StreamableHttpClientTransportConfig, +} + +impl StreamableHttpClientWorker { + pub fn new_simple(url: impl Into>) -> Self { + Self { + client: C::default(), + config: StreamableHttpClientTransportConfig { + uri: url.into(), + ..Default::default() + }, + } + } +} + +impl StreamableHttpClientWorker { + pub fn new(client: C, config: StreamableHttpClientTransportConfig) -> Self { + Self { client, config } + } +} + +impl StreamableHttpClientWorker { + async fn execute_sse_stream( + sse_stream: impl Stream>> + + Send + + 'static, + sse_worker_tx: tokio::sync::mpsc::Sender, + close_on_response: bool, + ct: CancellationToken, + ) -> Result<(), StreamableHttpError> { + let mut sse_stream = std::pin::pin!(sse_stream); + loop { + let message = tokio::select! { + event = sse_stream.next() => { + event + } + _ = ct.cancelled() => { + tracing::debug!("cancelled"); + break; + } + }; + let Some(message) = message.transpose()? else { + break; + }; + let is_response = matches!(message, ServerJsonRpcMessage::Response(_)); + let yield_result = sse_worker_tx.send(message).await; + if yield_result.is_err() { + tracing::trace!("streamable http transport worker dropped, exiting"); + break; + } + if close_on_response && is_response { + tracing::debug!("got response, closing sse stream"); + break; + } + } + Ok(()) + } +} + +impl Worker for StreamableHttpClientWorker { + type Role = RoleClient; + type Error = StreamableHttpError; + fn err_closed() -> Self::Error { + StreamableHttpError::TransportChannelClosed + } + fn err_join(e: tokio::task::JoinError) -> Self::Error { + StreamableHttpError::TokioJoinError(e) + } + fn config(&self) -> super::worker::WorkerConfig { + super::worker::WorkerConfig { + name: Some("StreamableHttpClientWorker".into()), + channel_buffer_capacity: self.config.channel_buffer_capacity, + } + } + async fn run( + self, + mut context: super::worker::WorkerContext, + ) -> Result<(), WorkerQuitReason> { + let channel_buffer_capacity = self.config.channel_buffer_capacity; + let (sse_worker_tx, mut sse_worker_rx) = + tokio::sync::mpsc::channel::(channel_buffer_capacity); + let config = self.config.clone(); + let transport_task_ct = context.cancellation_token.clone(); + let _drop_guard = transport_task_ct.clone().drop_guard(); + let WorkerSendRequest { + responder, + message: initialize_request, + } = context.recv_from_handler().await?; + let (message, session_id) = match self + .client + .post_message( + config.uri.clone(), + initialize_request, + None, + self.config.auth_header, + ) + .await + { + Ok(res) => { + let _ = responder.send(Ok(())); + res.expect_initialized::().await.map_err( + WorkerQuitReason::fatal_context("process initialize response"), + )? + } + Err(err) => { + let msg = format!("{:?}", err); + let _ = responder.send(Err(err)); + return Err(WorkerQuitReason::fatal( + StreamableHttpError::TransportChannelClosed, + msg, + )); + } + }; + let session_id: Option> = if let Some(session_id) = session_id { + Some(session_id.into()) + } else { + if !self.config.allow_stateless { + return Err(WorkerQuitReason::fatal( + StreamableHttpError::::MissingSessionIdInResponse, + "process initialize response", + )); + } + None + }; + // delete session when drop guard is dropped + if let Some(session_id) = &session_id { + let ct = transport_task_ct.clone(); + let client = self.client.clone(); + let session_id = session_id.clone(); + let url = config.uri.clone(); + let auth_header = config.auth_header.clone(); + tokio::spawn(async move { + ct.cancelled().await; + let delete_session_result = client + .delete_session(url, session_id.clone(), auth_header.clone()) + .await; + match delete_session_result { + Ok(_) => { + tracing::info!(session_id = session_id.as_ref(), "delete session success") + } + Err(StreamableHttpError::ServerDoesNotSupportDeleteSession) => { + tracing::info!( + session_id = session_id.as_ref(), + "server doesn't support delete session" + ) + } + Err(e) => { + tracing::error!( + session_id = session_id.as_ref(), + "fail to delete session: {e}" + ); + } + }; + }); + } + + context.send_to_handler(message).await?; + let initialized_notification = context.recv_from_handler().await?; + // expect a initialized response + self.client + .post_message( + config.uri.clone(), + initialized_notification.message, + session_id.clone(), + config.auth_header.clone(), + ) + .await + .map_err(WorkerQuitReason::fatal_context( + "send initialized notification", + ))? + .expect_accepted::() + .map_err(WorkerQuitReason::fatal_context( + "process initialized notification response", + ))?; + let _ = initialized_notification.responder.send(Ok(())); + enum Event { + ClientMessage(WorkerSendRequest), + ServerMessage(ServerJsonRpcMessage), + StreamResult(Result<(), StreamableHttpError>), + } + let mut streams = tokio::task::JoinSet::new(); + if let Some(session_id) = &session_id { + match self + .client + .get_stream( + config.uri.clone(), + session_id.clone(), + None, + config.auth_header.clone(), + ) + .await + { + Ok(stream) => { + let sse_stream = SseAutoReconnectStream::new( + stream, + StreamableHttpClientReconnect { + client: self.client.clone(), + session_id: session_id.clone(), + uri: config.uri.clone(), + auth_header: config.auth_header.clone(), + }, + self.config.retry_config.clone(), + ); + streams.spawn(Self::execute_sse_stream( + sse_stream, + sse_worker_tx.clone(), + false, + transport_task_ct.child_token(), + )); + tracing::debug!("got common stream"); + } + Err(StreamableHttpError::ServerDoesNotSupportSse) => { + tracing::debug!("server doesn't support sse, skip common stream"); + } + Err(e) => { + // fail to get common stream + tracing::error!("fail to get common stream: {e}"); + return Err(WorkerQuitReason::fatal( + e, + "get general purpose event stream", + )); + } + } + } + loop { + let event = tokio::select! { + _ = transport_task_ct.cancelled() => { + tracing::debug!("cancelled"); + return Err(WorkerQuitReason::Cancelled); + } + message = context.recv_from_handler() => { + let message = message?; + Event::ClientMessage(message) + }, + message = sse_worker_rx.recv() => { + let Some(message) = message else { + tracing::trace!("transport dropped, exiting"); + return Err(WorkerQuitReason::HandlerTerminated); + }; + Event::ServerMessage(message) + }, + terminated_stream = streams.join_next(), if !streams.is_empty() => { + match terminated_stream { + Some(result) => { + Event::StreamResult(result.map_err(StreamableHttpError::TokioJoinError).and_then(std::convert::identity)) + } + None => { + continue + } + } + } + }; + match event { + Event::ClientMessage(send_request) => { + let WorkerSendRequest { message, responder } = send_request; + let response = self + .client + .post_message( + config.uri.clone(), + message, + session_id.clone(), + config.auth_header.clone(), + ) + .await; + let send_result = match response { + Err(e) => Err(e), + Ok(StreamableHttpPostResponse::Accepted) => { + tracing::trace!("client message accepted"); + Ok(()) + } + Ok(StreamableHttpPostResponse::Json(message, ..)) => { + context.send_to_handler(message).await?; + Ok(()) + } + Ok(StreamableHttpPostResponse::Sse(stream, ..)) => { + if let Some(session_id) = &session_id { + let sse_stream = SseAutoReconnectStream::new( + stream, + StreamableHttpClientReconnect { + client: self.client.clone(), + session_id: session_id.clone(), + uri: config.uri.clone(), + auth_header: config.auth_header.clone(), + }, + self.config.retry_config.clone(), + ); + streams.spawn(Self::execute_sse_stream( + sse_stream, + sse_worker_tx.clone(), + true, + transport_task_ct.child_token(), + )); + } else { + let sse_stream = SseAutoReconnectStream::never_reconnect( + stream, + StreamableHttpError::::UnexpectedEndOfStream, + ); + streams.spawn(Self::execute_sse_stream( + sse_stream, + sse_worker_tx.clone(), + true, + transport_task_ct.child_token(), + )); + } + tracing::trace!("got new sse stream"); + Ok(()) + } + }; + let _ = responder.send(send_result); + } + Event::ServerMessage(json_rpc_message) => { + // send the message to the handler + context.send_to_handler(json_rpc_message).await?; + } + Event::StreamResult(result) => { + if result.is_err() { + tracing::warn!( + "sse client event stream terminated with error: {:?}", + result + ); + } + } + } + } + } +} + +/// A client-agnostic HTTP transport for RMCP that supports streaming responses. +/// +/// This transport allows you to choose your preferred HTTP client implementation +/// by implementing the [`StreamableHttpClient`] trait. The transport handles +/// session management, SSE streaming, and automatic reconnection. +/// +/// # Usage +/// +/// ## Using reqwest +/// +/// ```rust,no_run +/// use rmcp::transport::StreamableHttpClientTransport; +/// +/// // Enable the reqwest feature in Cargo.toml: +/// // rmcp = { version = "0.5", features = ["transport-streamable-http-client-reqwest"] } +/// +/// let transport = StreamableHttpClientTransport::from_uri("http://localhost:8000/mcp"); +/// ``` +/// +/// ## Using a custom HTTP client +/// +/// ```rust,no_run +/// use rmcp::transport::streamable_http_client::{ +/// StreamableHttpClient, +/// StreamableHttpClientTransport, +/// StreamableHttpClientTransportConfig +/// }; +/// use std::sync::Arc; +/// use futures::stream::BoxStream; +/// use rmcp::model::ClientJsonRpcMessage; +/// use sse_stream::{Sse, Error as SseError}; +/// +/// #[derive(Clone)] +/// struct MyHttpClient; +/// +/// #[derive(Debug, thiserror::Error)] +/// struct MyError; +/// +/// impl std::fmt::Display for MyError { +/// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +/// write!(f, "MyError") +/// } +/// } +/// +/// impl StreamableHttpClient for MyHttpClient { +/// type Error = MyError; +/// +/// async fn post_message( +/// &self, +/// _uri: Arc, +/// _message: ClientJsonRpcMessage, +/// _session_id: Option>, +/// _auth_header: Option, +/// ) -> Result> { +/// todo!() +/// } +/// +/// async fn delete_session( +/// &self, +/// _uri: Arc, +/// _session_id: Arc, +/// _auth_header: Option, +/// ) -> Result<(), rmcp::transport::streamable_http_client::StreamableHttpError> { +/// todo!() +/// } +/// +/// async fn get_stream( +/// &self, +/// _uri: Arc, +/// _session_id: Arc, +/// _last_event_id: Option, +/// _auth_header: Option, +/// ) -> Result>, rmcp::transport::streamable_http_client::StreamableHttpError> { +/// todo!() +/// } +/// } +/// +/// let transport = StreamableHttpClientTransport::with_client( +/// MyHttpClient, +/// StreamableHttpClientTransportConfig::with_uri("http://localhost:8000/mcp") +/// ); +/// ``` +/// +/// # Feature Flags +/// +/// - `transport-streamable-http-client`: Base feature providing the generic transport infrastructure +/// - `transport-streamable-http-client-reqwest`: Includes reqwest HTTP client support with convenience methods +pub type StreamableHttpClientTransport = WorkerTransport>; + +impl StreamableHttpClientTransport { + /// Creates a new transport with a custom HTTP client implementation. + /// + /// This method allows you to use any HTTP client that implements the [`StreamableHttpClient`] trait. + /// Use this when you want to use a custom HTTP client or when the reqwest feature is not enabled. + /// + /// # Arguments + /// + /// * `client` - Your HTTP client implementation + /// * `config` - Transport configuration including the server URI + /// + /// # Example + /// + /// ```rust,no_run + /// use rmcp::transport::streamable_http_client::{ + /// StreamableHttpClient, + /// StreamableHttpClientTransport, + /// StreamableHttpClientTransportConfig + /// }; + /// use std::sync::Arc; + /// use futures::stream::BoxStream; + /// use rmcp::model::ClientJsonRpcMessage; + /// use sse_stream::{Sse, Error as SseError}; + /// + /// // Define your custom client + /// #[derive(Clone)] + /// struct MyHttpClient; + /// + /// #[derive(Debug, thiserror::Error)] + /// struct MyError; + /// + /// impl std::fmt::Display for MyError { + /// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + /// write!(f, "MyError") + /// } + /// } + /// + /// impl StreamableHttpClient for MyHttpClient { + /// type Error = MyError; + /// + /// async fn post_message( + /// &self, + /// _uri: Arc, + /// _message: ClientJsonRpcMessage, + /// _session_id: Option>, + /// _auth_header: Option, + /// ) -> Result> { + /// todo!() + /// } + /// + /// async fn delete_session( + /// &self, + /// _uri: Arc, + /// _session_id: Arc, + /// _auth_header: Option, + /// ) -> Result<(), rmcp::transport::streamable_http_client::StreamableHttpError> { + /// todo!() + /// } + /// + /// async fn get_stream( + /// &self, + /// _uri: Arc, + /// _session_id: Arc, + /// _last_event_id: Option, + /// _auth_header: Option, + /// ) -> Result>, rmcp::transport::streamable_http_client::StreamableHttpError> { + /// todo!() + /// } + /// } + /// + /// let transport = StreamableHttpClientTransport::with_client( + /// MyHttpClient, + /// StreamableHttpClientTransportConfig::with_uri("http://localhost:8000/mcp") + /// ); + /// ``` + pub fn with_client(client: C, config: StreamableHttpClientTransportConfig) -> Self { + let worker = StreamableHttpClientWorker::new(client, config); + WorkerTransport::spawn(worker) + } +} +#[derive(Debug, Clone)] +pub struct StreamableHttpClientTransportConfig { + pub uri: Arc, + pub retry_config: Arc, + pub channel_buffer_capacity: usize, + /// if true, the transport will not require a session to be established + pub allow_stateless: bool, + /// The value to send in the authorization header + pub auth_header: Option, +} + +impl StreamableHttpClientTransportConfig { + pub fn with_uri(uri: impl Into>) -> Self { + Self { + uri: uri.into(), + ..Default::default() + } + } + + /// Set the authorization header to send with requests + /// + /// # Arguments + /// + /// * `value` - A bearer token without the `Bearer ` prefix + pub fn auth_header>(mut self, value: T) -> Self { + // set our authorization header + self.auth_header = Some(value.into()); + self + } +} + +impl Default for StreamableHttpClientTransportConfig { + fn default() -> Self { + Self { + uri: "localhost".into(), + retry_config: Arc::new(ExponentialBackoff::default()), + channel_buffer_capacity: 16, + allow_stateless: true, + auth_header: None, + } + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/transport/streamable_http_server.rs b/code-rs/third_party/rmcp-0.8.3/src/transport/streamable_http_server.rs new file mode 100644 index 00000000000..733fc5e518a --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/transport/streamable_http_server.rs @@ -0,0 +1,8 @@ +pub mod session; +#[cfg(feature = "transport-streamable-http-server")] +#[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-server")))] +pub mod tower; +pub use session::{SessionId, SessionManager}; +#[cfg(feature = "transport-streamable-http-server")] +#[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-server")))] +pub use tower::{StreamableHttpServerConfig, StreamableHttpService}; diff --git a/code-rs/third_party/rmcp-0.8.3/src/transport/streamable_http_server/session.rs b/code-rs/third_party/rmcp-0.8.3/src/transport/streamable_http_server/session.rs new file mode 100644 index 00000000000..a4a5fe43e55 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/transport/streamable_http_server/session.rs @@ -0,0 +1,54 @@ +use futures::Stream; + +pub use crate::transport::common::server_side_http::SessionId; +use crate::{ + RoleServer, + model::{ClientJsonRpcMessage, ServerJsonRpcMessage}, + transport::common::server_side_http::ServerSseMessage, +}; + +pub mod local; +pub mod never; + +pub trait SessionManager: Send + Sync + 'static { + type Error: std::error::Error + Send + 'static; + type Transport: crate::transport::Transport; + /// Create a new session with the given id and configuration. + fn create_session( + &self, + ) -> impl Future> + Send; + fn initialize_session( + &self, + id: &SessionId, + message: ClientJsonRpcMessage, + ) -> impl Future> + Send; + fn has_session(&self, id: &SessionId) + -> impl Future> + Send; + fn close_session(&self, id: &SessionId) + -> impl Future> + Send; + fn create_stream( + &self, + id: &SessionId, + message: ClientJsonRpcMessage, + ) -> impl Future< + Output = Result + Send + Sync + 'static, Self::Error>, + > + Send; + fn accept_message( + &self, + id: &SessionId, + message: ClientJsonRpcMessage, + ) -> impl Future> + Send; + fn create_standalone_stream( + &self, + id: &SessionId, + ) -> impl Future< + Output = Result + Send + Sync + 'static, Self::Error>, + > + Send; + fn resume( + &self, + id: &SessionId, + last_event_id: String, + ) -> impl Future< + Output = Result + Send + Sync + 'static, Self::Error>, + > + Send; +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/transport/streamable_http_server/session/local.rs b/code-rs/third_party/rmcp-0.8.3/src/transport/streamable_http_server/session/local.rs new file mode 100644 index 00000000000..1dca3fafab0 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/transport/streamable_http_server/session/local.rs @@ -0,0 +1,908 @@ +use std::{ + collections::{HashMap, HashSet, VecDeque}, + num::ParseIntError, + sync::Arc, + time::Duration, +}; + +use futures::Stream; +use thiserror::Error; +use tokio::sync::{ + mpsc::{Receiver, Sender}, + oneshot, +}; +use tokio_stream::wrappers::ReceiverStream; +use tracing::instrument; + +use crate::{ + RoleServer, + model::{ + CancelledNotificationParam, ClientJsonRpcMessage, ClientNotification, ClientRequest, + JsonRpcNotification, JsonRpcRequest, Notification, ProgressNotificationParam, + ProgressToken, RequestId, ServerJsonRpcMessage, ServerNotification, + }, + transport::{ + WorkerTransport, + common::server_side_http::{SessionId, session_id}, + worker::{Worker, WorkerContext, WorkerQuitReason, WorkerSendRequest}, + }, +}; + +#[derive(Debug, Default)] +pub struct LocalSessionManager { + pub sessions: tokio::sync::RwLock>, + pub session_config: SessionConfig, +} + +#[derive(Debug, Error)] +pub enum LocalSessionManagerError { + #[error("Session not found: {0}")] + SessionNotFound(SessionId), + #[error("Session error: {0}")] + SessionError(#[from] SessionError), + #[error("Invalid event id: {0}")] + InvalidEventId(#[from] EventIdParseError), +} +impl SessionManager for LocalSessionManager { + type Error = LocalSessionManagerError; + type Transport = WorkerTransport; + async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> { + let id = session_id(); + let (handle, worker) = create_local_session(id.clone(), self.session_config.clone()); + self.sessions.write().await.insert(id.clone(), handle); + Ok((id, WorkerTransport::spawn(worker))) + } + async fn initialize_session( + &self, + id: &SessionId, + message: ClientJsonRpcMessage, + ) -> Result { + let sessions = self.sessions.read().await; + let handle = sessions + .get(id) + .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?; + let response = handle.initialize(message).await?; + Ok(response) + } + async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> { + let mut sessions = self.sessions.write().await; + if let Some(handle) = sessions.remove(id) { + handle.close().await?; + } + Ok(()) + } + async fn has_session(&self, id: &SessionId) -> Result { + let sessions = self.sessions.read().await; + Ok(sessions.contains_key(id)) + } + async fn create_stream( + &self, + id: &SessionId, + message: ClientJsonRpcMessage, + ) -> Result + Send + 'static, Self::Error> { + let sessions = self.sessions.read().await; + let handle = sessions + .get(id) + .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?; + let receiver = handle.establish_request_wise_channel().await?; + handle + .push_message(message, receiver.http_request_id) + .await?; + Ok(ReceiverStream::new(receiver.inner)) + } + + async fn create_standalone_stream( + &self, + id: &SessionId, + ) -> Result + Send + 'static, Self::Error> { + let sessions = self.sessions.read().await; + let handle = sessions + .get(id) + .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?; + let receiver = handle.establish_common_channel().await?; + Ok(ReceiverStream::new(receiver.inner)) + } + + async fn resume( + &self, + id: &SessionId, + last_event_id: String, + ) -> Result + Send + 'static, Self::Error> { + let sessions = self.sessions.read().await; + let handle = sessions + .get(id) + .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?; + let receiver = handle.resume(last_event_id.parse()?).await?; + Ok(ReceiverStream::new(receiver.inner)) + } + + async fn accept_message( + &self, + id: &SessionId, + message: ClientJsonRpcMessage, + ) -> Result<(), Self::Error> { + let sessions = self.sessions.read().await; + let handle = sessions + .get(id) + .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?; + handle.push_message(message, None).await?; + Ok(()) + } +} + +/// `/request_id>` +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct EventId { + http_request_id: Option, + index: usize, +} + +impl std::fmt::Display for EventId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.index)?; + match &self.http_request_id { + Some(http_request_id) => write!(f, "/{http_request_id}"), + None => write!(f, ""), + } + } +} + +#[derive(Debug, Clone, Error)] +pub enum EventIdParseError { + #[error("Invalid index: {0}")] + InvalidIndex(ParseIntError), + #[error("Invalid numeric request id: {0}")] + InvalidNumericRequestId(ParseIntError), + #[error("Missing request id type")] + InvalidRequestIdType, + #[error("Missing request id")] + MissingRequestId, +} + +impl std::str::FromStr for EventId { + type Err = EventIdParseError; + fn from_str(s: &str) -> Result { + if let Some((index, request_id)) = s.split_once("/") { + let index = usize::from_str(index).map_err(EventIdParseError::InvalidIndex)?; + let request_id = u64::from_str(request_id).map_err(EventIdParseError::InvalidIndex)?; + Ok(EventId { + http_request_id: Some(request_id), + index, + }) + } else { + let index = usize::from_str(s).map_err(EventIdParseError::InvalidIndex)?; + Ok(EventId { + http_request_id: None, + index, + }) + } + } +} + +use super::{ServerSseMessage, SessionManager}; + +struct CachedTx { + tx: Sender, + cache: VecDeque, + http_request_id: Option, + capacity: usize, +} + +impl CachedTx { + fn new(tx: Sender, http_request_id: Option) -> Self { + Self { + cache: VecDeque::with_capacity(tx.capacity()), + capacity: tx.capacity(), + tx, + http_request_id, + } + } + fn new_common(tx: Sender) -> Self { + Self::new(tx, None) + } + + async fn send(&mut self, message: ServerJsonRpcMessage) { + let index = self.cache.back().map_or(0, |m| { + m.event_id + .as_deref() + .unwrap_or_default() + .parse::() + .expect("valid event id") + .index + + 1 + }); + let event_id = EventId { + http_request_id: self.http_request_id, + index, + }; + let message = ServerSseMessage { + event_id: Some(event_id.to_string()), + message: Arc::new(message), + }; + if self.cache.len() >= self.capacity { + self.cache.pop_front(); + self.cache.push_back(message.clone()); + } else { + self.cache.push_back(message.clone()); + } + let _ = self.tx.send(message).await.inspect_err(|e| { + let event_id = &e.0.event_id; + tracing::trace!(?event_id, "trying to send message in a closed session") + }); + } + + async fn sync(&mut self, index: usize) -> Result<(), SessionError> { + let Some(front) = self.cache.front() else { + return Ok(()); + }; + let front_event_id = front + .event_id + .as_deref() + .unwrap_or_default() + .parse::()?; + let sync_index = index.saturating_sub(front_event_id.index); + if sync_index > self.cache.len() { + // invalid index + return Err(SessionError::InvalidEventId); + } + for message in self.cache.iter().skip(sync_index) { + let send_result = self.tx.send(message.clone()).await; + if send_result.is_err() { + let event_id: EventId = message.event_id.as_deref().unwrap_or_default().parse()?; + return Err(SessionError::ChannelClosed(Some(event_id.index as u64))); + } + } + Ok(()) + } +} + +struct HttpRequestWise { + resources: HashSet, + tx: CachedTx, +} + +type HttpRequestId = u64; +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +enum ResourceKey { + McpRequestId(RequestId), + ProgressToken(ProgressToken), +} + +pub struct LocalSessionWorker { + id: SessionId, + next_http_request_id: HttpRequestId, + tx_router: HashMap, + resource_router: HashMap, + common: CachedTx, + event_rx: Receiver, + session_config: SessionConfig, +} + +impl LocalSessionWorker { + pub fn id(&self) -> &SessionId { + &self.id + } +} + +#[derive(Debug, Error)] +pub enum SessionError { + #[error("Invalid request id: {0}")] + DuplicatedRequestId(HttpRequestId), + #[error("Channel closed: {0:?}")] + ChannelClosed(Option), + #[error("Cannot parse event id: {0}")] + EventIdParseError(#[from] EventIdParseError), + #[error("Session service terminated")] + SessionServiceTerminated, + #[error("Invalid event id")] + InvalidEventId, + #[error("IO error: {0}")] + Io(#[from] std::io::Error), +} + +impl From for std::io::Error { + fn from(value: SessionError) -> Self { + match value { + SessionError::Io(io) => io, + _ => std::io::Error::other(format!("Session error: {value}")), + } + } +} + +enum OutboundChannel { + RequestWise { id: HttpRequestId, close: bool }, + Common, +} +#[derive(Debug)] +pub struct StreamableHttpMessageReceiver { + pub http_request_id: Option, + pub inner: Receiver, +} + +impl LocalSessionWorker { + fn unregister_resource(&mut self, resource: &ResourceKey) { + if let Some(http_request_id) = self.resource_router.remove(resource) { + tracing::trace!(?resource, http_request_id, "unregister resource"); + if let Some(channel) = self.tx_router.get_mut(&http_request_id) { + // It's okey to do so, since we don't handle batch json rpc request anymore + // and this can be refactored after the batch request is removed in the coming version. + if channel.resources.is_empty() || matches!(resource, ResourceKey::McpRequestId(_)) + { + tracing::debug!(http_request_id, "close http request wise channel"); + if let Some(channel) = self.tx_router.remove(&http_request_id) { + for resource in channel.resources { + self.resource_router.remove(&resource); + } + } + } + } else { + tracing::warn!(http_request_id, "http request wise channel not found"); + } + } + } + fn register_resource(&mut self, resource: ResourceKey, http_request_id: HttpRequestId) { + tracing::trace!(?resource, http_request_id, "register resource"); + if let Some(channel) = self.tx_router.get_mut(&http_request_id) { + channel.resources.insert(resource.clone()); + self.resource_router.insert(resource, http_request_id); + } + } + fn register_request( + &mut self, + request: &JsonRpcRequest, + http_request_id: HttpRequestId, + ) { + use crate::model::GetMeta; + self.register_resource( + ResourceKey::McpRequestId(request.id.clone()), + http_request_id, + ); + if let Some(progress_token) = request.request.get_meta().get_progress_token() { + self.register_resource( + ResourceKey::ProgressToken(progress_token.clone()), + http_request_id, + ); + } + } + fn catch_cancellation_notification( + &mut self, + notification: &JsonRpcNotification, + ) { + if let ClientNotification::CancelledNotification(n) = ¬ification.notification { + let request_id = n.params.request_id.clone(); + let resource = ResourceKey::McpRequestId(request_id); + self.unregister_resource(&resource); + } + } + fn next_http_request_id(&mut self) -> HttpRequestId { + let id = self.next_http_request_id; + self.next_http_request_id = self.next_http_request_id.wrapping_add(1); + id + } + async fn establish_request_wise_channel( + &mut self, + ) -> Result { + let http_request_id = self.next_http_request_id(); + let (tx, rx) = tokio::sync::mpsc::channel(self.session_config.channel_capacity); + self.tx_router.insert( + http_request_id, + HttpRequestWise { + resources: Default::default(), + tx: CachedTx::new(tx, Some(http_request_id)), + }, + ); + tracing::debug!(http_request_id, "establish new request wise channel"); + Ok(StreamableHttpMessageReceiver { + http_request_id: Some(http_request_id), + inner: rx, + }) + } + fn resolve_outbound_channel(&self, message: &ServerJsonRpcMessage) -> OutboundChannel { + match &message { + ServerJsonRpcMessage::Request(_) => OutboundChannel::Common, + ServerJsonRpcMessage::Notification(JsonRpcNotification { + notification: + ServerNotification::ProgressNotification(Notification { + params: ProgressNotificationParam { progress_token, .. }, + .. + }), + .. + }) => { + let id = self + .resource_router + .get(&ResourceKey::ProgressToken(progress_token.clone())); + + if let Some(id) = id { + OutboundChannel::RequestWise { + id: *id, + close: false, + } + } else { + OutboundChannel::Common + } + } + ServerJsonRpcMessage::Notification(JsonRpcNotification { + notification: + ServerNotification::CancelledNotification(Notification { + params: CancelledNotificationParam { request_id, .. }, + .. + }), + .. + }) => { + if let Some(id) = self + .resource_router + .get(&ResourceKey::McpRequestId(request_id.clone())) + { + OutboundChannel::RequestWise { + id: *id, + close: false, + } + } else { + OutboundChannel::Common + } + } + ServerJsonRpcMessage::Notification(_) => OutboundChannel::Common, + ServerJsonRpcMessage::Response(json_rpc_response) => { + if let Some(id) = self + .resource_router + .get(&ResourceKey::McpRequestId(json_rpc_response.id.clone())) + { + OutboundChannel::RequestWise { + id: *id, + close: false, + } + } else { + OutboundChannel::Common + } + } + ServerJsonRpcMessage::Error(json_rpc_error) => { + if let Some(id) = self + .resource_router + .get(&ResourceKey::McpRequestId(json_rpc_error.id.clone())) + { + OutboundChannel::RequestWise { + id: *id, + close: false, + } + } else { + OutboundChannel::Common + } + } + } + } + async fn handle_server_message( + &mut self, + message: ServerJsonRpcMessage, + ) -> Result<(), SessionError> { + let outbound_channel = self.resolve_outbound_channel(&message); + match outbound_channel { + OutboundChannel::RequestWise { id, close } => { + if let Some(request_wise) = self.tx_router.get_mut(&id) { + request_wise.tx.send(message).await; + if close { + self.tx_router.remove(&id); + } + } else { + return Err(SessionError::ChannelClosed(Some(id))); + } + } + OutboundChannel::Common => self.common.send(message).await, + } + Ok(()) + } + async fn resume( + &mut self, + last_event_id: EventId, + ) -> Result { + match last_event_id.http_request_id { + Some(http_request_id) => { + let request_wise = self + .tx_router + .get_mut(&http_request_id) + .ok_or(SessionError::ChannelClosed(Some(http_request_id)))?; + let channel = tokio::sync::mpsc::channel(self.session_config.channel_capacity); + let (tx, rx) = channel; + request_wise.tx.tx = tx; + let index = last_event_id.index; + // sync messages after index + request_wise.tx.sync(index).await?; + Ok(StreamableHttpMessageReceiver { + http_request_id: Some(http_request_id), + inner: rx, + }) + } + None => { + let channel = tokio::sync::mpsc::channel(self.session_config.channel_capacity); + let (tx, rx) = channel; + self.common.tx = tx; + let index = last_event_id.index; + // sync messages after index + self.common.sync(index).await?; + Ok(StreamableHttpMessageReceiver { + http_request_id: None, + inner: rx, + }) + } + } + } +} +#[derive(Debug)] +pub enum SessionEvent { + ClientMessage { + message: ClientJsonRpcMessage, + http_request_id: Option, + }, + EstablishRequestWiseChannel { + responder: oneshot::Sender>, + }, + CloseRequestWiseChannel { + id: HttpRequestId, + responder: oneshot::Sender>, + }, + Resume { + last_event_id: EventId, + responder: oneshot::Sender>, + }, + InitializeRequest { + request: ClientJsonRpcMessage, + responder: oneshot::Sender>, + }, + Close, +} + +#[derive(Debug, Clone)] +pub enum SessionQuitReason { + ServiceTerminated, + ClientTerminated, + ExpectInitializeRequest, + ExpectInitializeResponse, + Cancelled, +} + +#[derive(Debug, Clone)] +pub struct LocalSessionHandle { + id: SessionId, + // after all event_tx drop, inner task will be terminated + event_tx: Sender, +} + +impl LocalSessionHandle { + /// Get the session id + pub fn id(&self) -> &SessionId { + &self.id + } + + /// Close the session + pub async fn close(&self) -> Result<(), SessionError> { + self.event_tx + .send(SessionEvent::Close) + .await + .map_err(|_| SessionError::SessionServiceTerminated)?; + Ok(()) + } + + /// Send a message to the session + pub async fn push_message( + &self, + message: ClientJsonRpcMessage, + http_request_id: Option, + ) -> Result<(), SessionError> { + self.event_tx + .send(SessionEvent::ClientMessage { + message, + http_request_id, + }) + .await + .map_err(|_| SessionError::SessionServiceTerminated)?; + Ok(()) + } + + /// establish a channel for a http-request, the corresponded message from server will be + /// sent through this channel. The channel will be closed when the request is completed, + /// or you can close it manually by calling [`LocalSessionHandle::close_request_wise_channel`]. + pub async fn establish_request_wise_channel( + &self, + ) -> Result { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.event_tx + .send(SessionEvent::EstablishRequestWiseChannel { responder: tx }) + .await + .map_err(|_| SessionError::SessionServiceTerminated)?; + rx.await + .map_err(|_| SessionError::SessionServiceTerminated)? + } + + /// close the http-request wise channel. + pub async fn close_request_wise_channel( + &self, + request_id: HttpRequestId, + ) -> Result<(), SessionError> { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.event_tx + .send(SessionEvent::CloseRequestWiseChannel { + id: request_id, + responder: tx, + }) + .await + .map_err(|_| SessionError::SessionServiceTerminated)?; + rx.await + .map_err(|_| SessionError::SessionServiceTerminated)? + } + + /// Establish a common channel for general purpose messages. + pub async fn establish_common_channel( + &self, + ) -> Result { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.event_tx + .send(SessionEvent::Resume { + last_event_id: EventId { + http_request_id: None, + index: 0, + }, + responder: tx, + }) + .await + .map_err(|_| SessionError::SessionServiceTerminated)?; + rx.await + .map_err(|_| SessionError::SessionServiceTerminated)? + } + + /// Resume streaming response by the last event id. This is suitable for both request wise and common channel. + pub async fn resume( + &self, + last_event_id: EventId, + ) -> Result { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.event_tx + .send(SessionEvent::Resume { + last_event_id, + responder: tx, + }) + .await + .map_err(|_| SessionError::SessionServiceTerminated)?; + rx.await + .map_err(|_| SessionError::SessionServiceTerminated)? + } + + /// Send an initialize request to the session. And wait for the initialized response. + /// + /// This is used to establish a session with the server. + pub async fn initialize( + &self, + request: ClientJsonRpcMessage, + ) -> Result { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.event_tx + .send(SessionEvent::InitializeRequest { + request, + responder: tx, + }) + .await + .map_err(|_| SessionError::SessionServiceTerminated)?; + rx.await + .map_err(|_| SessionError::SessionServiceTerminated)? + } +} + +pub type SessionTransport = WorkerTransport; + +#[allow(clippy::large_enum_variant)] +#[derive(Debug, Error)] +pub enum LocalSessionWorkerError { + #[error("transport terminated")] + TransportTerminated, + #[error("unexpected message: {0:?}")] + UnexpectedEvent(SessionEvent), + #[error("fail to send initialize request {0}")] + FailToSendInitializeRequest(SessionError), + #[error("fail to handle message: {0}")] + FailToHandleMessage(SessionError), + #[error("keep alive timeout after {}ms", _0.as_millis())] + KeepAliveTimeout(Duration), + #[error("Transport closed")] + TransportClosed, + #[error("Tokio join error {0}")] + TokioJoinError(#[from] tokio::task::JoinError), +} +impl Worker for LocalSessionWorker { + type Error = LocalSessionWorkerError; + type Role = RoleServer; + fn err_closed() -> Self::Error { + LocalSessionWorkerError::TransportClosed + } + fn err_join(e: tokio::task::JoinError) -> Self::Error { + LocalSessionWorkerError::TokioJoinError(e) + } + fn config(&self) -> crate::transport::worker::WorkerConfig { + crate::transport::worker::WorkerConfig { + name: Some(format!("streamable-http-session-{}", self.id)), + channel_buffer_capacity: self.session_config.channel_capacity, + } + } + #[instrument(name = "streamable_http_session", skip_all, fields(id = self.id.as_ref()))] + async fn run( + mut self, + mut context: WorkerContext, + ) -> Result<(), WorkerQuitReason> { + enum InnerEvent { + FromHttpService(SessionEvent), + FromHandler(WorkerSendRequest), + } + // waiting for initialize request + let evt = self.event_rx.recv().await.ok_or_else(|| { + WorkerQuitReason::fatal( + LocalSessionWorkerError::TransportTerminated, + "get initialize request", + ) + })?; + let SessionEvent::InitializeRequest { request, responder } = evt else { + return Err(WorkerQuitReason::fatal( + LocalSessionWorkerError::UnexpectedEvent(evt), + "get initialize request", + )); + }; + context.send_to_handler(request).await?; + let send_initialize_response = context.recv_from_handler().await?; + responder + .send(Ok(send_initialize_response.message)) + .map_err(|_| { + WorkerQuitReason::fatal( + LocalSessionWorkerError::FailToSendInitializeRequest( + SessionError::SessionServiceTerminated, + ), + "send initialize response", + ) + })?; + send_initialize_response + .responder + .send(Ok(())) + .map_err(|_| WorkerQuitReason::HandlerTerminated)?; + let ct = context.cancellation_token.clone(); + let keep_alive = self.session_config.keep_alive.unwrap_or(Duration::MAX); + loop { + let keep_alive_timeout = tokio::time::sleep(keep_alive); + let event = tokio::select! { + event = self.event_rx.recv() => { + if let Some(event) = event { + InnerEvent::FromHttpService(event) + } else { + return Err(WorkerQuitReason::fatal(LocalSessionWorkerError::TransportTerminated, "waiting next session event")) + } + }, + from_handler = context.recv_from_handler() => { + InnerEvent::FromHandler(from_handler?) + } + _ = ct.cancelled() => { + return Err(WorkerQuitReason::Cancelled) + } + _ = keep_alive_timeout => { + return Err(WorkerQuitReason::fatal(LocalSessionWorkerError::KeepAliveTimeout(keep_alive), "poll next session event")) + } + }; + match event { + InnerEvent::FromHandler(WorkerSendRequest { message, responder }) => { + // catch response + let to_unregister = match &message { + crate::model::JsonRpcMessage::Response(json_rpc_response) => { + let request_id = json_rpc_response.id.clone(); + Some(ResourceKey::McpRequestId(request_id)) + } + crate::model::JsonRpcMessage::Error(json_rpc_error) => { + let request_id = json_rpc_error.id.clone(); + Some(ResourceKey::McpRequestId(request_id)) + } + _ => { + None + // no need to unregister resource + } + }; + let handle_result = self + .handle_server_message(message) + .await + .map_err(LocalSessionWorkerError::FailToHandleMessage); + let _ = responder.send(handle_result).inspect_err(|error| { + tracing::warn!(?error, "failed to send message to http service handler"); + }); + if let Some(to_unregister) = to_unregister { + self.unregister_resource(&to_unregister); + } + } + InnerEvent::FromHttpService(SessionEvent::ClientMessage { + message: json_rpc_message, + http_request_id, + }) => { + match &json_rpc_message { + crate::model::JsonRpcMessage::Request(request) => { + if let Some(http_request_id) = http_request_id { + self.register_request(request, http_request_id) + } + } + crate::model::JsonRpcMessage::Notification(notification) => { + self.catch_cancellation_notification(notification) + } + _ => {} + } + context.send_to_handler(json_rpc_message).await?; + } + InnerEvent::FromHttpService(SessionEvent::EstablishRequestWiseChannel { + responder, + }) => { + let handle_result = self.establish_request_wise_channel().await; + let _ = responder.send(handle_result); + } + InnerEvent::FromHttpService(SessionEvent::CloseRequestWiseChannel { + id, + responder, + }) => { + let _handle_result = self.tx_router.remove(&id); + let _ = responder.send(Ok(())); + } + InnerEvent::FromHttpService(SessionEvent::Resume { + last_event_id, + responder, + }) => { + let handle_result = self.resume(last_event_id).await; + let _ = responder.send(handle_result); + } + InnerEvent::FromHttpService(SessionEvent::Close) => { + return Err(WorkerQuitReason::TransportClosed); + } + _ => { + // ignore + } + } + } + } +} + +#[derive(Debug, Clone)] +pub struct SessionConfig { + /// the capacity of the channel for the session. Default is 16. + pub channel_capacity: usize, + /// if set, the session will be closed after this duration of inactivity. + pub keep_alive: Option, +} + +impl SessionConfig { + pub const DEFAULT_CHANNEL_CAPACITY: usize = 16; +} + +impl Default for SessionConfig { + fn default() -> Self { + Self { + channel_capacity: Self::DEFAULT_CHANNEL_CAPACITY, + keep_alive: None, + } + } +} + +/// Create a new session with the given id and configuration. +/// +/// This function will return a pair of [`LocalSessionHandle`] and [`LocalSessionWorker`]. +/// +/// You can run the [`LocalSessionWorker`] as a transport for mcp server. And use the [`LocalSessionHandle`] operate the session. +pub fn create_local_session( + id: impl Into, + config: SessionConfig, +) -> (LocalSessionHandle, LocalSessionWorker) { + let id = id.into(); + let (event_tx, event_rx) = tokio::sync::mpsc::channel(config.channel_capacity); + let (common_tx, _) = tokio::sync::mpsc::channel(config.channel_capacity); + let common = CachedTx::new_common(common_tx); + tracing::info!(session_id = ?id, "create new session"); + let handle = LocalSessionHandle { + event_tx, + id: id.clone(), + }; + let session_worker = LocalSessionWorker { + next_http_request_id: 0, + id, + tx_router: HashMap::new(), + resource_router: HashMap::new(), + common, + event_rx, + session_config: config.clone(), + }; + (handle, session_worker) +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/transport/streamable_http_server/session/never.rs b/code-rs/third_party/rmcp-0.8.3/src/transport/streamable_http_server/session/never.rs new file mode 100644 index 00000000000..436d4cfce2e --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/transport/streamable_http_server/session/never.rs @@ -0,0 +1,107 @@ +use futures::Stream; +use thiserror::Error; + +use super::{ServerSseMessage, SessionId, SessionManager}; +use crate::{ + RoleServer, + model::{ClientJsonRpcMessage, ServerJsonRpcMessage}, + transport::Transport, +}; + +#[derive(Debug, Clone, Error)] +#[error("Session management is not supported")] +pub struct ErrorSessionManagementNotSupported; +#[derive(Debug, Clone, Default)] +pub struct NeverSessionManager {} +pub enum NeverTransport {} +impl Transport for NeverTransport { + type Error = ErrorSessionManagementNotSupported; + + fn send( + &mut self, + _item: ServerJsonRpcMessage, + ) -> impl Future> + Send + 'static { + futures::future::ready(Err(ErrorSessionManagementNotSupported)) + } + + fn receive(&mut self) -> impl Future> { + futures::future::ready(None) + } + + async fn close(&mut self) -> Result<(), Self::Error> { + Err(ErrorSessionManagementNotSupported) + } +} + +impl SessionManager for NeverSessionManager { + type Error = ErrorSessionManagementNotSupported; + type Transport = NeverTransport; + + fn create_session( + &self, + ) -> impl Future> + Send { + futures::future::ready(Err(ErrorSessionManagementNotSupported)) + } + + fn initialize_session( + &self, + _id: &SessionId, + _message: ClientJsonRpcMessage, + ) -> impl Future> + Send { + futures::future::ready(Err(ErrorSessionManagementNotSupported)) + } + + fn has_session( + &self, + _id: &SessionId, + ) -> impl Future> + Send { + futures::future::ready(Err(ErrorSessionManagementNotSupported)) + } + + fn close_session( + &self, + _id: &SessionId, + ) -> impl Future> + Send { + futures::future::ready(Err(ErrorSessionManagementNotSupported)) + } + + fn create_stream( + &self, + _id: &SessionId, + _message: ClientJsonRpcMessage, + ) -> impl Future< + Output = Result + Send + 'static, Self::Error>, + > + Send { + futures::future::ready(Result::, _>::Err( + ErrorSessionManagementNotSupported, + )) + } + fn create_standalone_stream( + &self, + _id: &SessionId, + ) -> impl Future< + Output = Result + Send + 'static, Self::Error>, + > + Send { + futures::future::ready(Result::, _>::Err( + ErrorSessionManagementNotSupported, + )) + } + fn resume( + &self, + _id: &SessionId, + _last_event_id: String, + ) -> impl Future< + Output = Result + Send + 'static, Self::Error>, + > + Send { + futures::future::ready(Result::, _>::Err( + ErrorSessionManagementNotSupported, + )) + } + fn accept_message( + &self, + _id: &SessionId, + _message: ClientJsonRpcMessage, + ) -> impl Future> + Send { + futures::future::ready(Err(ErrorSessionManagementNotSupported)) + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/transport/streamable_http_server/tower.rs b/code-rs/third_party/rmcp-0.8.3/src/transport/streamable_http_server/tower.rs new file mode 100644 index 00000000000..ba373d487c9 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/transport/streamable_http_server/tower.rs @@ -0,0 +1,453 @@ +use std::{convert::Infallible, fmt::Display, sync::Arc, time::Duration}; + +use bytes::Bytes; +use futures::{StreamExt, future::BoxFuture}; +use http::{Method, Request, Response, header::ALLOW}; +use http_body::Body; +use http_body_util::{BodyExt, Full, combinators::BoxBody}; +use tokio_stream::wrappers::ReceiverStream; + +use super::session::SessionManager; +use crate::{ + RoleServer, + model::{ClientJsonRpcMessage, ClientRequest, GetExtensions}, + serve_server, + service::serve_directly, + transport::{ + OneshotTransport, TransportAdapterIdentity, + common::{ + http_header::{ + EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_SESSION_ID, JSON_MIME_TYPE, + }, + server_side_http::{ + BoxResponse, ServerSseMessage, accepted_response, expect_json, + internal_error_response, sse_stream_response, unexpected_message_response, + }, + }, + }, +}; + +#[derive(Debug, Clone)] +pub struct StreamableHttpServerConfig { + /// The ping message duration for SSE connections. + pub sse_keep_alive: Option, + /// If true, the server will create a session for each request and keep it alive. + pub stateful_mode: bool, +} + +impl Default for StreamableHttpServerConfig { + fn default() -> Self { + Self { + sse_keep_alive: Some(Duration::from_secs(15)), + stateful_mode: true, + } + } +} + +/// # Streamable Http Server +/// +/// ## Extract information from raw http request +/// +/// The http service will consume the request body, however the rest part will be remain and injected into [`crate::model::Extensions`], +/// which you can get from [`crate::service::RequestContext`]. +/// ```rust +/// use rmcp::handler::server::tool::Extension; +/// use http::request::Parts; +/// async fn my_tool(Extension(parts): Extension) { +/// tracing::info!("http parts:{parts:?}") +/// } +/// ``` +pub struct StreamableHttpService { + pub config: StreamableHttpServerConfig, + session_manager: Arc, + service_factory: Arc Result + Send + Sync>, +} + +impl Clone for StreamableHttpService { + fn clone(&self) -> Self { + Self { + config: self.config.clone(), + session_manager: self.session_manager.clone(), + service_factory: self.service_factory.clone(), + } + } +} + +impl tower_service::Service> for StreamableHttpService +where + RequestBody: Body + Send + 'static, + S: crate::Service, + M: SessionManager, + RequestBody::Error: Display, + RequestBody::Data: Send + 'static, +{ + type Response = BoxResponse; + type Error = Infallible; + type Future = BoxFuture<'static, Result>; + fn call(&mut self, req: http::Request) -> Self::Future { + let service = self.clone(); + Box::pin(async move { + let response = service.handle(req).await; + Ok(response) + }) + } + fn poll_ready( + &mut self, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } +} + +impl StreamableHttpService +where + S: crate::Service + Send + 'static, + M: SessionManager, +{ + pub fn new( + service_factory: impl Fn() -> Result + Send + Sync + 'static, + session_manager: Arc, + config: StreamableHttpServerConfig, + ) -> Self { + Self { + config, + session_manager, + service_factory: Arc::new(service_factory), + } + } + fn get_service(&self) -> Result { + (self.service_factory)() + } + pub async fn handle(&self, request: Request) -> Response> + where + B: Body + Send + 'static, + B::Error: Display, + { + let method = request.method().clone(); + let allowed_methods = match self.config.stateful_mode { + true => "GET, POST, DELETE", + false => "POST", + }; + let result = match (method, self.config.stateful_mode) { + (Method::POST, _) => self.handle_post(request).await, + // if we're not in stateful mode, we don't support GET or DELETE because there is no session + (Method::GET, true) => self.handle_get(request).await, + (Method::DELETE, true) => self.handle_delete(request).await, + _ => { + // Handle other methods or return an error + let response = Response::builder() + .status(http::StatusCode::METHOD_NOT_ALLOWED) + .header(ALLOW, allowed_methods) + .body(Full::new(Bytes::from("Method Not Allowed")).boxed()) + .expect("valid response"); + return response; + } + }; + match result { + Ok(response) => response, + Err(response) => response, + } + } + async fn handle_get(&self, request: Request) -> Result + where + B: Body + Send + 'static, + B::Error: Display, + { + // check accept header + if !request + .headers() + .get(http::header::ACCEPT) + .and_then(|header| header.to_str().ok()) + .is_some_and(|header| header.contains(EVENT_STREAM_MIME_TYPE)) + { + return Ok(Response::builder() + .status(http::StatusCode::NOT_ACCEPTABLE) + .body( + Full::new(Bytes::from( + "Not Acceptable: Client must accept text/event-stream", + )) + .boxed(), + ) + .expect("valid response")); + } + // check session id + let session_id = request + .headers() + .get(HEADER_SESSION_ID) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_owned().into()); + let Some(session_id) = session_id else { + // unauthorized + return Ok(Response::builder() + .status(http::StatusCode::UNAUTHORIZED) + .body(Full::new(Bytes::from("Unauthorized: Session ID is required")).boxed()) + .expect("valid response")); + }; + // check if session exists + let has_session = self + .session_manager + .has_session(&session_id) + .await + .map_err(internal_error_response("check session"))?; + if !has_session { + // unauthorized + return Ok(Response::builder() + .status(http::StatusCode::UNAUTHORIZED) + .body(Full::new(Bytes::from("Unauthorized: Session not found")).boxed()) + .expect("valid response")); + } + // check if last event id is provided + let last_event_id = request + .headers() + .get(HEADER_LAST_EVENT_ID) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_owned()); + if let Some(last_event_id) = last_event_id { + // check if session has this event id + let stream = self + .session_manager + .resume(&session_id, last_event_id) + .await + .map_err(internal_error_response("resume session"))?; + Ok(sse_stream_response(stream, self.config.sse_keep_alive)) + } else { + // create standalone stream + let stream = self + .session_manager + .create_standalone_stream(&session_id) + .await + .map_err(internal_error_response("create standalone stream"))?; + Ok(sse_stream_response(stream, self.config.sse_keep_alive)) + } + } + + async fn handle_post(&self, request: Request) -> Result + where + B: Body + Send + 'static, + B::Error: Display, + { + // check accept header + if !request + .headers() + .get(http::header::ACCEPT) + .and_then(|header| header.to_str().ok()) + .is_some_and(|header| { + header.contains(JSON_MIME_TYPE) && header.contains(EVENT_STREAM_MIME_TYPE) + }) + { + return Ok(Response::builder() + .status(http::StatusCode::NOT_ACCEPTABLE) + .body(Full::new(Bytes::from("Not Acceptable: Client must accept both application/json and text/event-stream")).boxed()) + .expect("valid response")); + } + + // check content type + if !request + .headers() + .get(http::header::CONTENT_TYPE) + .and_then(|header| header.to_str().ok()) + .is_some_and(|header| header.starts_with(JSON_MIME_TYPE)) + { + return Ok(Response::builder() + .status(http::StatusCode::UNSUPPORTED_MEDIA_TYPE) + .body( + Full::new(Bytes::from( + "Unsupported Media Type: Content-Type must be application/json", + )) + .boxed(), + ) + .expect("valid response")); + } + + // json deserialize request body + let (part, body) = request.into_parts(); + let mut message = match expect_json(body).await { + Ok(message) => message, + Err(response) => return Ok(response), + }; + + if self.config.stateful_mode { + // do we have a session id? + let session_id = part + .headers + .get(HEADER_SESSION_ID) + .and_then(|v| v.to_str().ok()); + if let Some(session_id) = session_id { + let session_id = session_id.to_owned().into(); + let has_session = self + .session_manager + .has_session(&session_id) + .await + .map_err(internal_error_response("check session"))?; + if !has_session { + // unauthorized + return Ok(Response::builder() + .status(http::StatusCode::UNAUTHORIZED) + .body(Full::new(Bytes::from("Unauthorized: Session not found")).boxed()) + .expect("valid response")); + } + + // inject request part to extensions + match &mut message { + ClientJsonRpcMessage::Request(req) => { + req.request.extensions_mut().insert(part); + } + ClientJsonRpcMessage::Notification(not) => { + not.notification.extensions_mut().insert(part); + } + _ => { + // skip + } + } + + match message { + ClientJsonRpcMessage::Request(_) => { + let stream = self + .session_manager + .create_stream(&session_id, message) + .await + .map_err(internal_error_response("get session"))?; + Ok(sse_stream_response(stream, self.config.sse_keep_alive)) + } + ClientJsonRpcMessage::Notification(_) + | ClientJsonRpcMessage::Response(_) + | ClientJsonRpcMessage::Error(_) => { + // handle notification + self.session_manager + .accept_message(&session_id, message) + .await + .map_err(internal_error_response("accept message"))?; + Ok(accepted_response()) + } + } + } else { + let (session_id, transport) = self + .session_manager + .create_session() + .await + .map_err(internal_error_response("create session"))?; + if let ClientJsonRpcMessage::Request(req) = &mut message { + if !matches!(req.request, ClientRequest::InitializeRequest(_)) { + return Err(unexpected_message_response("initialize request")); + } + // inject request part to extensions + req.request.extensions_mut().insert(part); + } else { + return Err(unexpected_message_response("initialize request")); + } + let service = self + .get_service() + .map_err(internal_error_response("get service"))?; + // spawn a task to serve the session + tokio::spawn({ + let session_manager = self.session_manager.clone(); + let session_id = session_id.clone(); + async move { + let service = serve_server::( + service, transport, + ) + .await; + match service { + Ok(service) => { + // on service created + let _ = service.waiting().await; + } + Err(e) => { + tracing::error!("Failed to create service: {e}"); + } + } + let _ = session_manager + .close_session(&session_id) + .await + .inspect_err(|e| { + tracing::error!("Failed to close session {session_id}: {e}"); + }); + } + }); + // get initialize response + let response = self + .session_manager + .initialize_session(&session_id, message) + .await + .map_err(internal_error_response("create stream"))?; + let mut response = sse_stream_response( + futures::stream::once({ + async move { + ServerSseMessage { + event_id: None, + message: response.into(), + } + } + }), + self.config.sse_keep_alive, + ); + + response.headers_mut().insert( + HEADER_SESSION_ID, + session_id + .parse() + .map_err(internal_error_response("create session id header"))?, + ); + Ok(response) + } + } else { + let service = self + .get_service() + .map_err(internal_error_response("get service"))?; + match message { + ClientJsonRpcMessage::Request(mut request) => { + request.request.extensions_mut().insert(part); + let (transport, receiver) = + OneshotTransport::::new(ClientJsonRpcMessage::Request(request)); + let service = serve_directly(service, transport, None); + tokio::spawn(async move { + // on service created + let _ = service.waiting().await; + }); + Ok(sse_stream_response( + ReceiverStream::new(receiver).map(|message| { + tracing::info!(?message); + ServerSseMessage { + event_id: None, + message: message.into(), + } + }), + self.config.sse_keep_alive, + )) + } + ClientJsonRpcMessage::Notification(_notification) => { + // ignore + Ok(accepted_response()) + } + ClientJsonRpcMessage::Response(_json_rpc_response) => Ok(accepted_response()), + ClientJsonRpcMessage::Error(_json_rpc_error) => Ok(accepted_response()), + } + } + } + + async fn handle_delete(&self, request: Request) -> Result + where + B: Body + Send + 'static, + B::Error: Display, + { + // check session id + let session_id = request + .headers() + .get(HEADER_SESSION_ID) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_owned().into()); + let Some(session_id) = session_id else { + // unauthorized + return Ok(Response::builder() + .status(http::StatusCode::UNAUTHORIZED) + .body(Full::new(Bytes::from("Unauthorized: Session ID is required")).boxed()) + .expect("valid response")); + }; + // close session + self.session_manager + .close_session(&session_id) + .await + .map_err(internal_error_response("close session"))?; + Ok(accepted_response()) + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/transport/worker.rs b/code-rs/third_party/rmcp-0.8.3/src/transport/worker.rs new file mode 100644 index 00000000000..769d448a51b --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/transport/worker.rs @@ -0,0 +1,208 @@ +use std::borrow::Cow; + +use tokio_util::sync::CancellationToken; +use tracing::{Instrument, Level}; + +use super::{IntoTransport, Transport}; +use crate::service::{RxJsonRpcMessage, ServiceRole, TxJsonRpcMessage}; + +#[derive(Debug, thiserror::Error)] +pub enum WorkerQuitReason { + #[error("Join error {0}")] + Join(#[from] tokio::task::JoinError), + #[error("Transport fatal {error}, when {context}")] + Fatal { + error: E, + context: Cow<'static, str>, + }, + #[error("Transport cancelled")] + Cancelled, + #[error("Transport closed")] + TransportClosed, + #[error("Handler terminated")] + HandlerTerminated, +} + +impl WorkerQuitReason { + pub fn fatal(error: E, context: impl Into>) -> Self { + Self::Fatal { + error, + context: context.into(), + } + } + pub fn fatal_context(context: impl Into>) -> impl FnOnce(E) -> Self { + |e| Self::Fatal { + error: e, + context: context.into(), + } + } +} + +pub trait Worker: Sized + Send + 'static { + type Error: std::error::Error + Send + Sync + 'static; + type Role: ServiceRole; + fn err_closed() -> Self::Error; + fn err_join(e: tokio::task::JoinError) -> Self::Error; + fn run( + self, + context: WorkerContext, + ) -> impl Future>> + Send; + fn config(&self) -> WorkerConfig { + WorkerConfig::default() + } +} + +pub struct WorkerSendRequest { + pub message: TxJsonRpcMessage, + pub responder: tokio::sync::oneshot::Sender>, +} + +pub struct WorkerTransport { + rx: tokio::sync::mpsc::Receiver>, + send_service: tokio::sync::mpsc::Sender>, + join_handle: Option>>>, + _drop_guard: tokio_util::sync::DropGuard, + ct: CancellationToken, +} + +pub struct WorkerConfig { + pub name: Option, + pub channel_buffer_capacity: usize, +} + +impl Default for WorkerConfig { + fn default() -> Self { + Self { + name: None, + channel_buffer_capacity: 16, + } + } +} +pub enum WorkerAdapter {} + +impl IntoTransport for W { + fn into_transport(self) -> impl Transport + 'static { + WorkerTransport::spawn(self) + } +} + +impl WorkerTransport { + pub fn cancel_token(&self) -> CancellationToken { + self.ct.clone() + } + pub fn spawn(worker: W) -> Self { + Self::spawn_with_ct(worker, CancellationToken::new()) + } + pub fn spawn_with_ct(worker: W, transport_task_ct: CancellationToken) -> Self { + let config = worker.config(); + let worker_name = config.name; + let (to_transport_tx, from_handler_rx) = + tokio::sync::mpsc::channel::>(config.channel_buffer_capacity); + let (to_handler_tx, from_transport_rx) = + tokio::sync::mpsc::channel::>(config.channel_buffer_capacity); + let context = WorkerContext { + to_handler_tx, + from_handler_rx, + cancellation_token: transport_task_ct.clone(), + }; + + let join_handle = tokio::spawn(async move { + worker + .run(context) + .instrument(tracing::span!( + Level::TRACE, + "transport_worker", + name = worker_name, + )) + .await + .inspect_err(|e| match e { + WorkerQuitReason::Cancelled + | WorkerQuitReason::TransportClosed + | WorkerQuitReason::HandlerTerminated => { + tracing::debug!("worker quit with reason: {:?}", e); + } + WorkerQuitReason::Join(e) => { + tracing::error!("worker quit with join error: {:?}", e); + } + WorkerQuitReason::Fatal { error, context } => { + tracing::error!("worker quit with fatal: {error}, when {context}"); + } + }) + .inspect(|_| { + tracing::debug!("worker quit"); + }) + }); + Self { + rx: from_transport_rx, + send_service: to_transport_tx, + join_handle: Some(join_handle), + ct: transport_task_ct.clone(), + _drop_guard: transport_task_ct.drop_guard(), + } + } +} + +pub struct SendRequest { + pub message: TxJsonRpcMessage, + pub responder: tokio::sync::oneshot::Sender>, +} + +pub struct WorkerContext { + pub to_handler_tx: tokio::sync::mpsc::Sender>, + pub from_handler_rx: tokio::sync::mpsc::Receiver>, + pub cancellation_token: CancellationToken, +} + +impl WorkerContext { + pub async fn send_to_handler( + &mut self, + item: RxJsonRpcMessage, + ) -> Result<(), WorkerQuitReason> { + self.to_handler_tx + .send(item) + .await + .map_err(|_| WorkerQuitReason::HandlerTerminated) + } + + pub async fn recv_from_handler( + &mut self, + ) -> Result, WorkerQuitReason> { + self.from_handler_rx + .recv() + .await + .ok_or(WorkerQuitReason::HandlerTerminated) + } +} + +impl Transport for WorkerTransport { + type Error = W::Error; + + fn send( + &mut self, + item: TxJsonRpcMessage, + ) -> impl Future> + Send + 'static { + let tx = self.send_service.clone(); + let (responder, receiver) = tokio::sync::oneshot::channel(); + let request = WorkerSendRequest { + message: item, + responder, + }; + async move { + tx.send(request).await.map_err(|_| W::err_closed())?; + receiver.await.map_err(|_| W::err_closed())??; + Ok(()) + } + } + async fn receive(&mut self) -> Option> { + self.rx.recv().await + } + async fn close(&mut self) -> Result<(), Self::Error> { + if let Some(handle) = self.join_handle.take() { + self.ct.cancel(); + let _quit_reason = handle.await.map_err(W::err_join)?; + Ok(()) + } else { + Ok(()) + } + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/src/transport/ws.rs b/code-rs/third_party/rmcp-0.8.3/src/transport/ws.rs new file mode 100644 index 00000000000..5ec9cbac9e6 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/src/transport/ws.rs @@ -0,0 +1 @@ +// Maybe we don't really need a ws implementation? diff --git a/code-rs/third_party/rmcp-0.8.3/tests/common/calculator.rs b/code-rs/third_party/rmcp-0.8.3/tests/common/calculator.rs new file mode 100644 index 00000000000..5b8cebf7aee --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/common/calculator.rs @@ -0,0 +1,62 @@ +#![allow(dead_code)] +use rmcp::{ + ServerHandler, + handler::server::{router::tool::ToolRouter, wrapper::Parameters}, + model::{ServerCapabilities, ServerInfo}, + schemars, tool, tool_router, +}; +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] +pub struct SumRequest { + #[schemars(description = "the left hand side number")] + pub a: i32, + pub b: i32, +} + +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] +pub struct SubRequest { + #[schemars(description = "the left hand side number")] + pub a: i32, + #[schemars(description = "the right hand side number")] + pub b: i32, +} +#[derive(Debug, Clone)] +pub struct Calculator { + tool_router: ToolRouter, +} + +impl Calculator { + pub fn new() -> Self { + Self { + tool_router: Self::tool_router(), + } + } +} + +impl Default for Calculator { + fn default() -> Self { + Self::new() + } +} + +#[tool_router] +impl Calculator { + #[tool(description = "Calculate the sum of two numbers")] + fn sum(&self, Parameters(SumRequest { a, b }): Parameters) -> String { + (a + b).to_string() + } + + #[tool(description = "Calculate the sub of two numbers")] + fn sub(&self, Parameters(SubRequest { a, b }): Parameters) -> String { + (a - b).to_string() + } +} + +impl ServerHandler for Calculator { + fn get_info(&self) -> ServerInfo { + ServerInfo { + instructions: Some("A simple calculator".into()), + capabilities: ServerCapabilities::builder().enable_tools().build(), + ..Default::default() + } + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/common/handlers.rs b/code-rs/third_party/rmcp-0.8.3/tests/common/handlers.rs new file mode 100644 index 00000000000..db82d47152b --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/common/handlers.rs @@ -0,0 +1,183 @@ +use std::{ + future::Future, + sync::{Arc, Mutex}, +}; + +use rmcp::{ + ClientHandler, ErrorData as McpError, RoleClient, RoleServer, ServerHandler, + model::*, + service::{NotificationContext, RequestContext}, +}; +use serde_json::json; +use tokio::sync::Notify; + +#[derive(Clone)] +pub struct TestClientHandler { + pub honor_this_server: bool, + pub honor_all_servers: bool, + pub receive_signal: Arc, + pub received_messages: Arc>>, +} + +impl TestClientHandler { + #[allow(dead_code)] + pub fn new(honor_this_server: bool, honor_all_servers: bool) -> Self { + Self { + honor_this_server, + honor_all_servers, + receive_signal: Arc::new(Notify::new()), + received_messages: Arc::new(Mutex::new(Vec::new())), + } + } + + #[allow(dead_code)] + pub fn with_notification( + honor_this_server: bool, + honor_all_servers: bool, + receive_signal: Arc, + received_messages: Arc>>, + ) -> Self { + Self { + honor_this_server, + honor_all_servers, + receive_signal, + received_messages, + } + } +} + +impl ClientHandler for TestClientHandler { + async fn create_message( + &self, + params: CreateMessageRequestParam, + _context: RequestContext, + ) -> Result { + // First validate that there's at least one User message + if !params.messages.iter().any(|msg| msg.role == Role::User) { + return Err(McpError::invalid_request( + "Message sequence must contain at least one user message", + Some(json!({"messages": params.messages})), + )); + } + + // Create response based on context inclusion + let response = match params.include_context { + Some(ContextInclusion::ThisServer) if self.honor_this_server => { + "Test response with context: test context" + } + Some(ContextInclusion::AllServers) if self.honor_all_servers => { + "Test response with context: test context" + } + _ => "Test response without context", + }; + + Ok(CreateMessageResult { + message: SamplingMessage { + role: Role::Assistant, + content: Content::text(response.to_string()), + }, + model: "test-model".to_string(), + stop_reason: Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()), + }) + } + + fn on_logging_message( + &self, + params: LoggingMessageNotificationParam, + _context: NotificationContext, + ) -> impl Future + Send + '_ { + let receive_signal = self.receive_signal.clone(); + let received_messages = self.received_messages.clone(); + + async move { + println!("Client: Received log message: {:?}", params); + let mut messages = received_messages.lock().unwrap(); + messages.push(params); + receive_signal.notify_one(); + } + } +} + +pub struct TestServer {} + +impl TestServer { + #[allow(dead_code)] + pub fn new() -> Self { + Self {} + } +} + +impl ServerHandler for TestServer { + fn get_info(&self) -> ServerInfo { + ServerInfo { + capabilities: ServerCapabilities::builder().enable_logging().build(), + ..Default::default() + } + } + + fn set_level( + &self, + request: SetLevelRequestParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + let peer = context.peer; + async move { + let (data, logger) = match request.level { + LoggingLevel::Error => ( + serde_json::json!({ + "message": "Failed to process request", + "error_code": "E1001", + "error_details": "Connection timeout", + "timestamp": chrono::Utc::now().to_rfc3339(), + }), + Some("error_handler".to_string()), + ), + LoggingLevel::Debug => ( + serde_json::json!({ + "message": "Processing request", + "function": "handle_request", + "line": 42, + "context": { + "request_id": "req-123", + "user_id": "user-456" + }, + "timestamp": chrono::Utc::now().to_rfc3339(), + }), + Some("debug_logger".to_string()), + ), + LoggingLevel::Info => ( + serde_json::json!({ + "message": "System status update", + "status": "healthy", + "metrics": { + "requests_per_second": 150, + "average_latency_ms": 45, + "error_rate": 0.01 + }, + "timestamp": chrono::Utc::now().to_rfc3339(), + }), + Some("monitoring".to_string()), + ), + _ => ( + serde_json::json!({ + "message": format!("Message at level {:?}", request.level), + "timestamp": chrono::Utc::now().to_rfc3339(), + }), + None, + ), + }; + + if let Err(e) = peer + .notify_logging_message(LoggingMessageNotificationParam { + level: request.level, + data, + logger, + }) + .await + { + panic!("Failed to send notification: {}", e); + } + Ok(()) + } + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/common/mod.rs b/code-rs/third_party/rmcp-0.8.3/tests/common/mod.rs new file mode 100644 index 00000000000..491960651d8 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/common/mod.rs @@ -0,0 +1,2 @@ +pub mod calculator; +pub mod handlers; diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_completion.rs b/code-rs/third_party/rmcp-0.8.3/tests/test_completion.rs new file mode 100644 index 00000000000..ea9f632fe97 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_completion.rs @@ -0,0 +1,218 @@ +use std::collections::HashMap; + +use rmcp::model::*; +use serde_json::json; + +#[test] +fn test_completion_context_serialization() { + let mut args = HashMap::new(); + args.insert("key1".to_string(), "value1".to_string()); + args.insert("key2".to_string(), "value2".to_string()); + + let context = CompletionContext::with_arguments(args); + + // Test serialization + let json = serde_json::to_value(&context).unwrap(); + let expected = json!({ + "arguments": { + "key1": "value1", + "key2": "value2" + } + }); + assert_eq!(json, expected); + + // Test deserialization + let deserialized: CompletionContext = serde_json::from_value(expected).unwrap(); + assert_eq!(deserialized, context); +} + +#[test] +fn test_completion_context_methods() { + let mut args = HashMap::new(); + args.insert("city".to_string(), "San Francisco".to_string()); + args.insert("country".to_string(), "USA".to_string()); + + let context = CompletionContext::with_arguments(args); + + assert!(context.has_arguments()); + assert_eq!( + context.get_argument("city"), + Some(&"San Francisco".to_string()) + ); + assert_eq!(context.get_argument("missing"), None); + + let names: Vec<&str> = context.argument_names().collect(); + assert!(names.contains(&"city")); + assert!(names.contains(&"country")); + assert_eq!(names.len(), 2); +} + +#[test] +fn test_complete_request_param_serialization() { + let mut args = HashMap::new(); + args.insert("previous_input".to_string(), "test".to_string()); + + let request = CompleteRequestParam { + r#ref: Reference::for_prompt("weather_prompt"), + argument: ArgumentInfo { + name: "location".to_string(), + value: "San".to_string(), + }, + context: Some(CompletionContext::with_arguments(args)), + }; + + let json = serde_json::to_value(&request).unwrap(); + assert!(json["ref"]["name"].as_str().unwrap() == "weather_prompt"); + assert!(json["argument"]["name"].as_str().unwrap() == "location"); + assert!(json["argument"]["value"].as_str().unwrap() == "San"); + assert!( + json["context"]["arguments"]["previous_input"] + .as_str() + .unwrap() + == "test" + ); +} + +#[test] +fn test_completion_info_validation() { + // Valid completion with less than max values + let values = vec!["option1".to_string(), "option2".to_string()]; + let completion = CompletionInfo::new(values.clone()).unwrap(); + assert_eq!(completion.values, values); + assert!(completion.validate().is_ok()); + + // Test max values limit + let many_values: Vec = (0..=CompletionInfo::MAX_VALUES) + .map(|i| format!("option_{}", i)) + .collect(); + let result = CompletionInfo::new(many_values); + assert!(result.is_err()); +} + +#[test] +fn test_completion_info_helper_methods() { + let values = vec!["test1".to_string(), "test2".to_string()]; + + // Test with_all_values + let completion = CompletionInfo::with_all_values(values.clone()).unwrap(); + assert_eq!(completion.values, values); + assert_eq!(completion.total, Some(2)); + assert_eq!(completion.has_more, Some(false)); + assert!(!completion.has_more_results()); + assert_eq!(completion.total_available(), Some(2)); + + // Test with_pagination + let paginated = CompletionInfo::with_pagination(values.clone(), Some(10), true).unwrap(); + assert_eq!(paginated.values, values); + assert_eq!(paginated.total, Some(10)); + assert_eq!(paginated.has_more, Some(true)); + assert!(paginated.has_more_results()); + assert_eq!(paginated.total_available(), Some(10)); +} + +#[test] +fn test_completion_info_bounds() { + // Test exactly at the limit + let max_values: Vec = (0..CompletionInfo::MAX_VALUES) + .map(|i| format!("value_{}", i)) + .collect(); + assert!(CompletionInfo::new(max_values).is_ok()); + + // Test over the limit + let over_limit: Vec = (0..=CompletionInfo::MAX_VALUES) + .map(|i| format!("value_{}", i)) + .collect(); + assert!(CompletionInfo::new(over_limit).is_err()); +} + +#[test] +fn test_reference_convenience_methods() { + let prompt_ref = Reference::for_prompt("test_prompt"); + assert_eq!(prompt_ref.reference_type(), "ref/prompt"); + assert_eq!(prompt_ref.as_prompt_name(), Some("test_prompt")); + assert_eq!(prompt_ref.as_resource_uri(), None); + + let resource_ref = Reference::for_resource("file://path/to/resource"); + assert_eq!(resource_ref.reference_type(), "ref/resource"); + assert_eq!( + resource_ref.as_resource_uri(), + Some("file://path/to/resource") + ); + assert_eq!(resource_ref.as_prompt_name(), None); +} + +#[test] +fn test_completion_serialization_format() { + // Test that completion follows MCP 2025-06-18 specification format + let completion = CompletionInfo { + values: vec!["value1".to_string(), "value2".to_string()], + total: Some(2), + has_more: Some(false), + }; + + let json = serde_json::to_value(&completion).unwrap(); + + // Verify JSON structure matches specification + assert!(json.is_object()); + assert!(json["values"].is_array()); + assert_eq!(json["values"].as_array().unwrap().len(), 2); + assert_eq!(json["total"].as_u64().unwrap(), 2); + assert!(!json["hasMore"].as_bool().unwrap()); +} + +#[test] +fn test_resource_reference() { + // Test that ResourceReference works correctly + let resource_ref = ResourceReference { + uri: "test://uri".to_string(), + }; + + // Test that ResourceReference works correctly + let another_ref = ResourceReference { + uri: "test://uri".to_string(), + }; + + // They should be equivalent + assert_eq!(resource_ref.uri, another_ref.uri); +} + +#[test] +fn test_complete_result_default() { + let result = CompleteResult::default(); + assert!(result.completion.values.is_empty()); + assert_eq!(result.completion.total, None); + assert_eq!(result.completion.has_more, None); +} + +#[test] +fn test_completion_context_empty() { + let context = CompletionContext::new(); + assert!(!context.has_arguments()); + assert_eq!(context.get_argument("any"), None); + assert!(context.argument_names().count() == 0); +} + +#[test] +fn test_mcp_schema_compliance() { + // Test that our types serialize correctly according to MCP specification + let request = CompleteRequestParam { + r#ref: Reference::for_resource("file://{path}"), + argument: ArgumentInfo { + name: "path".to_string(), + value: "src/".to_string(), + }, + context: None, + }; + + let json_str = serde_json::to_string(&request).unwrap(); + let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap(); + + // Verify key structure matches MCP spec + assert!(parsed["ref"].is_object()); + assert!(parsed["argument"].is_object()); + assert!(parsed["argument"]["name"].is_string()); + assert!(parsed["argument"]["value"].is_string()); + + // Verify type tag is correct + assert_eq!(parsed["ref"]["type"].as_str().unwrap(), "ref/resource"); +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_complex_schema.rs b/code-rs/third_party/rmcp-0.8.3/tests/test_complex_schema.rs new file mode 100644 index 00000000000..e3a49e8df02 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_complex_schema.rs @@ -0,0 +1,66 @@ +use rmcp::{ + ErrorData as McpError, handler::server::wrapper::Parameters, model::*, schemars, tool, + tool_router, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)] +pub enum ChatRole { + System, + User, + Assistant, + Tool, +} + +#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)] +pub struct ChatMessage { + pub role: ChatRole, + pub content: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)] +pub struct ChatRequest { + pub system: Option, + pub messages: Vec, +} + +#[derive(Clone, Default)] +pub struct Demo; + +#[tool_router] +impl Demo { + pub fn new() -> Self { + Self + } + + #[tool(description = "LLM")] + async fn chat( + &self, + chat_request: Parameters, + ) -> Result { + let content = Content::json(chat_request.0)?; + Ok(CallToolResult::success(vec![content])) + } +} + +#[test] +fn test_complex_schema() { + let attr = Demo::chat_tool_attr(); + let input_schema = attr.input_schema; + let enum_number = input_schema + .get("definitions") + .unwrap() + .as_object() + .unwrap() + .get("ChatRole") + .unwrap() + .as_object() + .unwrap() + .get("enum") + .unwrap() + .as_array() + .unwrap() + .len(); + assert_eq!(enum_number, 4); + println!("{}", serde_json::to_string_pretty(&input_schema).unwrap()); +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_deserialization.rs b/code-rs/third_party/rmcp-0.8.3/tests/test_deserialization.rs new file mode 100644 index 00000000000..73621f4873e --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_deserialization.rs @@ -0,0 +1,15 @@ +use rmcp::model::{JsonRpcResponse, ServerJsonRpcMessage, ServerResult}; +#[test] +fn test_tool_list_result() { + let json = std::fs::read("tests/test_deserialization/tool_list_result.json").unwrap(); + let result: ServerJsonRpcMessage = serde_json::from_slice(&json).unwrap(); + println!("{result:#?}"); + + assert!(matches!( + result, + ServerJsonRpcMessage::Response(JsonRpcResponse { + result: ServerResult::ListToolsResult(_), + .. + }) + )); +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_deserialization/tool_list_result.json b/code-rs/third_party/rmcp-0.8.3/tests/test_deserialization/tool_list_result.json new file mode 100644 index 00000000000..674fdc0583a --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_deserialization/tool_list_result.json @@ -0,0 +1,28 @@ +{ + "result": { + "tools": [ + { + "name": "add", + "inputSchema": { + "type": "object", + "properties": { + "a": { + "type": "number" + }, + "b": { + "type": "number" + } + }, + "required": [ + "a", + "b" + ], + "additionalProperties": false, + "$schema": "http://json-schema.org/draft-07/schema#" + } + } + ] + }, + "jsonrpc": "2.0", + "id": 2 +} \ No newline at end of file diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_elicitation.rs b/code-rs/third_party/rmcp-0.8.3/tests/test_elicitation.rs new file mode 100644 index 00000000000..8dc6f160828 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_elicitation.rs @@ -0,0 +1,1604 @@ +//cargo test --test test_elicitation --features "client server" + +use rmcp::{model::*, service::*}; +// For typed elicitation tests +#[cfg(feature = "schemars")] +use schemars::JsonSchema; +#[cfg(feature = "schemars")] +use serde::{Deserialize, Serialize}; +use serde_json::json; + +/// Test that elicitation data structures can be serialized and deserialized correctly +/// This ensures JSON-RPC compatibility with MCP 2025-06-18 specification +#[tokio::test] +async fn test_elicitation_serialization() { + // Test ElicitationAction enum serialization + let accept = ElicitationAction::Accept; + let decline = ElicitationAction::Decline; + let cancel = ElicitationAction::Cancel; + + assert_eq!(serde_json::to_string(&accept).unwrap(), "\"accept\""); + assert_eq!(serde_json::to_string(&decline).unwrap(), "\"decline\""); + assert_eq!(serde_json::to_string(&cancel).unwrap(), "\"cancel\""); + + // Test deserialization + assert_eq!( + serde_json::from_str::("\"accept\"").unwrap(), + ElicitationAction::Accept + ); + assert_eq!( + serde_json::from_str::("\"decline\"").unwrap(), + ElicitationAction::Decline + ); + assert_eq!( + serde_json::from_str::("\"cancel\"").unwrap(), + ElicitationAction::Cancel + ); +} + +/// Test CreateElicitationRequestParam structure serialization/deserialization +#[tokio::test] +async fn test_elicitation_request_param_serialization() { + let schema = ElicitationSchema::builder() + .required_property("email", PrimitiveSchema::String(StringSchema::email())) + .build() + .unwrap(); + + let request_param = CreateElicitationRequestParam { + message: "Please provide your email address".to_string(), + requested_schema: schema, + }; + + // Test serialization + let json = serde_json::to_value(&request_param).unwrap(); + let expected = json!({ + "message": "Please provide your email address", + "requestedSchema": { + "type": "object", + "properties": { + "email": { + "type": "string", + "format": "email" + } + }, + "required": ["email"] + } + }); + + assert_eq!(json, expected); + + // Test deserialization + let deserialized: CreateElicitationRequestParam = serde_json::from_value(expected).unwrap(); + assert_eq!(deserialized.message, request_param.message); + assert_eq!( + deserialized.requested_schema, + request_param.requested_schema + ); +} + +/// Test CreateElicitationResult structure with different action types +#[tokio::test] +async fn test_elicitation_result_serialization() { + // Test Accept with content + let accept_result = CreateElicitationResult { + action: ElicitationAction::Accept, + content: Some(json!({"email": "user@example.com"})), + }; + + let json = serde_json::to_value(&accept_result).unwrap(); + let expected = json!({ + "action": "accept", + "content": {"email": "user@example.com"} + }); + assert_eq!(json, expected); + + // Test Decline without content + let decline_result = CreateElicitationResult { + action: ElicitationAction::Decline, + content: None, + }; + + let json = serde_json::to_value(&decline_result).unwrap(); + let expected = json!({ + "action": "decline" + // content should be omitted when None due to skip_serializing_if + }); + assert_eq!(json, expected); + + // Test deserialization + let deserialized: CreateElicitationResult = serde_json::from_value(expected).unwrap(); + assert_eq!(deserialized.action, ElicitationAction::Decline); + assert_eq!(deserialized.content, None); +} + +/// Test that elicitation requests can be created and handled through the JSON-RPC protocol +#[tokio::test] +async fn test_elicitation_json_rpc_protocol() { + let schema = ElicitationSchema::builder() + .required_property( + "confirmation", + PrimitiveSchema::Boolean(BooleanSchema::new()), + ) + .build() + .unwrap(); + + // Create a complete JSON-RPC request for elicitation + let request = JsonRpcRequest { + jsonrpc: JsonRpcVersion2_0, + id: RequestId::Number(1), + request: CreateElicitationRequest { + method: ElicitationCreateRequestMethod, + params: CreateElicitationRequestParam { + message: "Do you want to continue?".to_string(), + requested_schema: schema, + }, + extensions: Default::default(), + }, + }; + + // Test serialization of complete request + let json = serde_json::to_value(&request).unwrap(); + assert_eq!(json["jsonrpc"], "2.0"); + assert_eq!(json["id"], 1); + assert_eq!(json["method"], "elicitation/create"); + assert_eq!(json["params"]["message"], "Do you want to continue?"); + + // Test deserialization + let deserialized: JsonRpcRequest = + serde_json::from_value(json).unwrap(); + assert_eq!(deserialized.id, RequestId::Number(1)); + assert_eq!( + deserialized.request.params.message, + "Do you want to continue?" + ); +} + +/// Test elicitation action types and their expected behavior +#[tokio::test] +async fn test_elicitation_action_types() { + // Test all three action types + let actions = [ + ElicitationAction::Accept, + ElicitationAction::Decline, + ElicitationAction::Cancel, + ]; + + // Each action should have a unique string representation + let serialized: Vec = actions + .iter() + .map(|action| serde_json::to_string(action).unwrap()) + .collect(); + + assert_eq!(serialized.len(), 3); + assert!(serialized.contains(&"\"accept\"".to_string())); + assert!(serialized.contains(&"\"decline\"".to_string())); + assert!(serialized.contains(&"\"cancel\"".to_string())); + + // Test round-trip serialization + for action in actions { + let json = serde_json::to_string(&action).unwrap(); + let deserialized: ElicitationAction = serde_json::from_str(&json).unwrap(); + assert_eq!(action, deserialized); + } +} + +/// Test MCP 2025-06-18 specification compliance +/// Ensures our implementation matches the latest MCP spec +#[tokio::test] +async fn test_elicitation_spec_compliance() { + // Test that method names match the specification + assert_eq!(ElicitationCreateRequestMethod::VALUE, "elicitation/create"); + assert_eq!( + ElicitationResponseNotificationMethod::VALUE, + "notifications/elicitation/response" + ); + + // Test that enum values match specification + let actions = [ + ElicitationAction::Accept, + ElicitationAction::Decline, + ElicitationAction::Cancel, + ]; + + let serialized: Vec = actions + .iter() + .map(|a| serde_json::to_string(a).unwrap()) + .collect(); + + assert_eq!(serialized, vec!["\"accept\"", "\"decline\"", "\"cancel\""]); +} + +/// Test error handling and edge cases for elicitation +#[tokio::test] +async fn test_elicitation_error_handling() { + // Test minimal schema handling (empty properties is technically valid) + let minimal_schema_request = CreateElicitationRequestParam { + message: "Test message".to_string(), + requested_schema: ElicitationSchema::builder().build().unwrap(), + }; + + // Should serialize without error + let _json = serde_json::to_value(&minimal_schema_request).unwrap(); + + // Test empty message + let empty_message_request = CreateElicitationRequestParam { + message: "".to_string(), + requested_schema: ElicitationSchema::builder() + .property("value", PrimitiveSchema::String(StringSchema::new())) + .build() + .unwrap(), + }; + + // Should serialize without error (validation is up to the implementation) + let _json = serde_json::to_value(&empty_message_request).unwrap(); + + // Test that we can deserialize invalid action types (should fail) + let invalid_action_json = json!("invalid_action"); + let result = serde_json::from_value::(invalid_action_json); + assert!(result.is_err()); +} + +/// Benchmark-style test for elicitation performance +#[tokio::test] +async fn test_elicitation_performance() { + let schema = ElicitationSchema::builder() + .property("data", PrimitiveSchema::String(StringSchema::new())) + .build() + .unwrap(); + + let request = CreateElicitationRequestParam { + message: "Performance test message".to_string(), + requested_schema: schema, + }; + + let start = std::time::Instant::now(); + + // Serialize/deserialize 1000 times + for _ in 0..1000 { + let json = serde_json::to_value(&request).unwrap(); + let _deserialized: CreateElicitationRequestParam = serde_json::from_value(json).unwrap(); + } + + let duration = start.elapsed(); + println!( + "1000 elicitation serialization/deserialization cycles took: {:?}", + duration + ); + + // Should complete in reasonable time (less than 100ms on modern hardware) + assert!( + duration.as_millis() < 1000, + "Performance test took too long: {:?}", + duration + ); +} + +/// Test elicitation capabilities integration +/// Ensures that elicitation capability can be properly configured and serialized +#[tokio::test] +async fn test_elicitation_capabilities() { + use rmcp::model::{ClientCapabilities, ElicitationCapability}; + + // Test basic elicitation capability + let mut elicitation_cap = ElicitationCapability::default(); + assert_eq!(elicitation_cap.schema_validation, None); + + // Test with schema validation enabled + elicitation_cap.schema_validation = Some(true); + + // Test serialization + let json = serde_json::to_value(&elicitation_cap).unwrap(); + let expected = json!({"schemaValidation": true}); + assert_eq!(json, expected); + + // Test deserialization + let deserialized: ElicitationCapability = serde_json::from_value(expected).unwrap(); + assert_eq!(deserialized.schema_validation, Some(true)); + + // Test ClientCapabilities builder with elicitation + let client_caps = ClientCapabilities::builder() + .enable_elicitation() + .enable_elicitation_schema_validation() + .build(); + + assert!(client_caps.elicitation.is_some()); + assert_eq!( + client_caps.elicitation.as_ref().unwrap().schema_validation, + Some(true) + ); + + // Test full client capabilities serialization + let json = serde_json::to_value(&client_caps).unwrap(); + assert!( + json["elicitation"]["schemaValidation"] + .as_bool() + .unwrap_or(false) + ); +} + +/// Test convenience methods for common elicitation scenarios +/// This ensures the helper methods create proper requests with expected schemas +#[tokio::test] +async fn test_elicitation_convenience_methods() { + // Test that convenience methods produce the expected request parameters + + // Test confirmation schema + let confirmation_schema = serde_json::json!({ + "type": "boolean", + "description": "User confirmation (true for yes, false for no)" + }); + + // Verify the schema structure for boolean confirmation + assert_eq!(confirmation_schema["type"], "boolean"); + assert!(confirmation_schema["description"].is_string()); + + // Test text input schema (non-required) + let text_schema = serde_json::json!({ + "type": "string", + "description": "User text input" + }); + + assert_eq!(text_schema["type"], "string"); + assert!(text_schema.get("minLength").is_none()); + + // Test text input schema (required) + let required_text_schema = serde_json::json!({ + "type": "string", + "description": "User text input", + "minLength": 1 + }); + + assert_eq!(required_text_schema["minLength"], 1); + + // Test choice schema + let options = ["Option A", "Option B", "Option C"]; + let choice_schema = serde_json::json!({ + "type": "integer", + "minimum": 0, + "maximum": options.len() - 1, + "description": format!("Choose an option: {}", options.join(", ")) + }); + + assert_eq!(choice_schema["type"], "integer"); + assert_eq!(choice_schema["minimum"], 0); + assert_eq!(choice_schema["maximum"], 2); + assert!( + choice_schema["description"] + .as_str() + .unwrap() + .contains("Option A") + ); + + // Test that CreateElicitationRequestParam can be created with type-safe schemas + let confirmation_request = CreateElicitationRequestParam { + message: "Test confirmation".to_string(), + requested_schema: ElicitationSchema::builder() + .property( + "confirmed", + PrimitiveSchema::Boolean( + BooleanSchema::new() + .description("User confirmation (true for yes, false for no)"), + ), + ) + .build() + .unwrap(), + }; + + // Test serialization of convenience method request + let json = serde_json::to_value(&confirmation_request).unwrap(); + assert_eq!(json["message"], "Test confirmation"); + assert_eq!(json["requestedSchema"]["type"], "object"); + assert_eq!( + json["requestedSchema"]["properties"]["confirmed"]["type"], + "boolean" + ); +} + +/// Test structured input with multiple primitive properties +/// Ensures that schemas with multiple primitive properties work correctly with elicitation +#[tokio::test] +async fn test_elicitation_structured_schemas() { + // Test schema with multiple primitive properties + let schema = ElicitationSchema::builder() + .required_string_with("name", |s| s.length(1, 100)) + .required_email("email") + .required_integer("age", 0, 150) + .optional_bool("newsletter", false) + .required_enum( + "country", + vec!["US".to_string(), "UK".to_string(), "CA".to_string()], + ) + .description("User registration information") + .build() + .unwrap(); + + let request = CreateElicitationRequestParam { + message: "Please provide your user information".to_string(), + requested_schema: schema, + }; + + // Test that complex schemas serialize/deserialize correctly + let json = serde_json::to_value(&request).unwrap(); + let deserialized: CreateElicitationRequestParam = serde_json::from_value(json).unwrap(); + + assert_eq!(deserialized.message, "Please provide your user information"); + assert_eq!(deserialized.requested_schema.properties.len(), 5); + assert!( + deserialized + .requested_schema + .properties + .contains_key("name") + ); + assert!( + deserialized + .requested_schema + .properties + .contains_key("email") + ); + assert!(deserialized.requested_schema.properties.contains_key("age")); + assert!( + deserialized + .requested_schema + .properties + .contains_key("newsletter") + ); + assert!( + deserialized + .requested_schema + .properties + .contains_key("country") + ); + assert_eq!( + deserialized.requested_schema.required, + Some(vec![ + "name".to_string(), + "email".to_string(), + "age".to_string(), + "country".to_string() + ]) + ); +} + +// Typed elicitation tests using the API with schemars +#[cfg(feature = "schemars")] +mod typed_elicitation_tests { + use super::*; + + /// Simple user confirmation with reason + #[derive(Debug, Serialize, Deserialize, JsonSchema, PartialEq)] + #[schemars(description = "User confirmation with optional reasoning")] + struct UserConfirmation { + #[schemars(description = "User's decision (true for yes, false for no)")] + confirmed: bool, + + #[schemars(description = "Optional reason for the decision")] + reason: Option, + } + + /// User profile with validation constraints + #[derive(Debug, Serialize, Deserialize, JsonSchema, PartialEq)] + #[schemars(description = "Complete user profile information")] + struct UserProfile { + #[schemars(description = "Full name")] + name: String, + + #[schemars(description = "Email address")] + email: String, + + #[schemars(description = "Age in years")] + age: u8, + + #[schemars(description = "User preferences")] + preferences: UserPreferences, + } + + /// User preferences + #[derive(Debug, Serialize, Deserialize, JsonSchema, PartialEq)] + struct UserPreferences { + #[schemars(description = "UI theme preference")] + theme: Theme, + + #[schemars(description = "Enable notifications")] + notifications: bool, + + #[schemars(description = "Language preference")] + language: String, + } + + /// UI theme options + #[derive(Debug, Serialize, Deserialize, JsonSchema, PartialEq)] + #[schemars(description = "Available UI themes")] + enum Theme { + #[schemars(description = "Light theme")] + Light, + #[schemars(description = "Dark theme")] + Dark, + #[schemars(description = "Auto-detect based on system")] + Auto, + } + + // Mark types as safe for elicitation (they generate object schemas) + rmcp::elicit_safe!(UserConfirmation, UserProfile, UserPreferences); + + /// Test automatic schema generation for simple types + #[tokio::test] + async fn test_typed_elicitation_simple_schema() { + // Test that schema generation works for simple types + let schema = rmcp::handler::server::tool::schema_for_type::(); + + // Verify schema contains expected fields + assert!(schema.contains_key("type")); + assert_eq!(schema.get("type"), Some(&json!("object"))); + assert!(schema.contains_key("properties")); + + if let Some(properties) = schema.get("properties") { + assert!(properties.is_object()); + let props = properties.as_object().unwrap(); + assert!(props.contains_key("confirmed")); + assert!(props.contains_key("reason")); + + // Check confirmed field is boolean + if let Some(confirmed_schema) = props.get("confirmed") { + let confirmed_obj = confirmed_schema.as_object().unwrap(); + assert_eq!(confirmed_obj.get("type"), Some(&json!("boolean"))); + } + + // Check reason field is optional string + if let Some(reason_schema) = props.get("reason") { + assert!(reason_schema.is_object()); + } + } + } + + /// Test automatic schema generation for complex nested types + #[tokio::test] + async fn test_typed_elicitation_complex_schema() { + // Test complex nested structure schema generation + let schema = rmcp::handler::server::tool::schema_for_type::(); + + // Verify schema structure + assert!(schema.contains_key("type")); + assert_eq!(schema.get("type"), Some(&json!("object"))); + + if let Some(properties) = schema.get("properties") { + let props = properties.as_object().unwrap(); + + // Check required fields exist + assert!(props.contains_key("name")); + assert!(props.contains_key("email")); + assert!(props.contains_key("age")); + assert!(props.contains_key("preferences")); + + // Check validation constraints for name + if let Some(name_schema) = props.get("name") { + let name_obj = name_schema.as_object().unwrap(); + assert_eq!(name_obj.get("type"), Some(&json!("string"))); + // Note: schemars might generate constraints differently + // The exact structure depends on schemars version + } + + // Check email format constraint + if let Some(email_schema) = props.get("email") { + let email_obj = email_schema.as_object().unwrap(); + assert_eq!(email_obj.get("type"), Some(&json!("string"))); + } + + // Check age numeric constraints + if let Some(age_schema) = props.get("age") { + let age_obj = age_schema.as_object().unwrap(); + assert_eq!(age_obj.get("type"), Some(&json!("integer"))); + } + } + } + + /// Test enum schema generation + #[tokio::test] + async fn test_enum_schema_generation() { + // Test enum schema generation + let schema = rmcp::handler::server::tool::schema_for_type::(); + + // Verify enum schema structure - schemars might use oneOf or enum depending on version + assert!( + schema.contains_key("type") + || schema.contains_key("oneOf") + || schema.contains_key("enum") + ); + + // The exact structure depends on schemars configuration, but it should be valid + let json = serde_json::to_string(&schema).unwrap(); + assert!(!json.is_empty()); + } + + /// Test that the schema generation for nested structures works + #[tokio::test] + async fn test_nested_structure_schema() { + // Test that nested structures generate proper schemas + let preferences_schema = rmcp::handler::server::tool::schema_for_type::(); + + // Verify basic structure + assert!(preferences_schema.contains_key("type")); + assert_eq!(preferences_schema.get("type"), Some(&json!("object"))); + + if let Some(properties) = preferences_schema.get("properties") { + let props = properties.as_object().unwrap(); + assert!(props.contains_key("theme")); + assert!(props.contains_key("notifications")); + assert!(props.contains_key("language")); + } + } +} + +// ============================================================================= +// ELICITATION DIRECTION TESTS (MCP 2025-06-18 COMPLIANCE) +// ============================================================================= + +/// Test that elicitation requests flow from server to client (not client to server) +/// This verifies compliance with MCP 2025-06-18 specification +#[cfg(all(feature = "client", feature = "server"))] +#[tokio::test] +async fn test_elicitation_direction_server_to_client() { + use rmcp::model::*; + use serde_json::json; + + // Test that server can create elicitation requests + let schema = ElicitationSchema::builder() + .property( + "name", + PrimitiveSchema::String(StringSchema::new().description("Enter your name")), + ) + .build() + .unwrap(); + + let elicitation_request = CreateElicitationRequestParam { + message: "Please enter your name".to_string(), + requested_schema: schema, + }; + + // Verify request can be serialized + let serialized = serde_json::to_value(&elicitation_request).unwrap(); + assert_eq!(serialized["message"], "Please enter your name"); + assert_eq!(serialized["requestedSchema"]["type"], "object"); + + // Test that elicitation requests are part of ServerRequest + let _server_request = ServerRequest::CreateElicitationRequest(CreateElicitationRequest { + method: ElicitationCreateRequestMethod, + params: elicitation_request, + extensions: Default::default(), + }); + + // Test that client can respond with elicitation results + let client_result = ClientResult::CreateElicitationResult(CreateElicitationResult { + action: ElicitationAction::Accept, + content: Some(json!("John Doe")), + }); + + // Verify client result can be serialized + match client_result { + ClientResult::CreateElicitationResult(result) => { + assert_eq!(result.action, ElicitationAction::Accept); + assert_eq!(result.content, Some(json!("John Doe"))); + } + _ => panic!("CreateElicitationResult should be part of ClientResult"), + } +} + +/// Test complete JSON-RPC message flow: Server → Client → Server +#[cfg(all(feature = "client", feature = "server"))] +#[tokio::test] +async fn test_elicitation_json_rpc_direction() { + use rmcp::model::*; + use serde_json::json; + + let schema = ElicitationSchema::builder() + .property( + "continue", + PrimitiveSchema::Boolean(BooleanSchema::new().description("Do you want to continue?")), + ) + .build() + .unwrap(); + + // 1. Server creates elicitation request + let server_request = ServerJsonRpcMessage::request( + ServerRequest::CreateElicitationRequest(CreateElicitationRequest { + method: ElicitationCreateRequestMethod, + params: CreateElicitationRequestParam { + message: "Do you want to continue?".to_string(), + requested_schema: schema, + }, + extensions: Default::default(), + }), + RequestId::Number(1), + ); + + // Serialize server request + let server_json = serde_json::to_value(&server_request).unwrap(); + assert_eq!(server_json["method"], "elicitation/create"); + assert_eq!(server_json["id"], 1); + assert_eq!(server_json["params"]["message"], "Do you want to continue?"); + + // 2. Client responds with elicitation result + let client_response = ClientJsonRpcMessage::response( + ClientResult::CreateElicitationResult(CreateElicitationResult { + action: ElicitationAction::Accept, + content: Some(json!(true)), + }), + RequestId::Number(1), + ); + + // Serialize client response + let client_json = serde_json::to_value(&client_response).unwrap(); + assert_eq!(client_json["id"], 1); + if let Some(result) = client_json["result"].as_object() { + assert_eq!(result["action"], "accept"); + assert_eq!(result["content"], true); + } else { + panic!("Client response should contain result"); + } +} + +/// Test all three elicitation actions according to MCP spec +#[cfg(all(feature = "client", feature = "server"))] +#[tokio::test] +async fn test_elicitation_actions_compliance() { + use rmcp::model::*; + + // Test all three elicitation actions according to MCP spec + let actions = [ + ElicitationAction::Accept, + ElicitationAction::Decline, + ElicitationAction::Cancel, + ]; + + for action in actions { + let result = CreateElicitationResult { + action: action.clone(), + content: match action { + ElicitationAction::Accept => Some(serde_json::json!("some data")), + _ => None, + }, + }; + + let json = serde_json::to_value(&result).unwrap(); + + match action { + ElicitationAction::Accept => { + assert_eq!(json["action"], "accept"); + assert!(json["content"].is_string()); + } + ElicitationAction::Decline => { + assert_eq!(json["action"], "decline"); + assert!(json.get("content").is_none() || json["content"].is_null()); + } + ElicitationAction::Cancel => { + assert_eq!(json["action"], "cancel"); + assert!(json.get("content").is_none() || json["content"].is_null()); + } + } + } +} + +/// Test that CreateElicitationResult IS in ClientResult (response compliance) +#[tokio::test] +async fn test_elicitation_result_in_client_result() { + use rmcp::model::*; + + // Test that clients can return elicitation results + let result = ClientResult::CreateElicitationResult(CreateElicitationResult { + action: ElicitationAction::Decline, + content: None, + }); + + match result { + ClientResult::CreateElicitationResult(elicit_result) => { + assert_eq!(elicit_result.action, ElicitationAction::Decline); + assert_eq!(elicit_result.content, None); + } + _ => panic!("CreateElicitationResult should be part of ClientResult"), + } +} + +// ============================================================================= +// ELICITATION CAPABILITIES TESTS +// ============================================================================= + +/// Test ElicitationCapability structure and serialization +#[tokio::test] +async fn test_elicitation_capability_structure() { + // Test default ElicitationCapability + let default_cap = ElicitationCapability::default(); + assert!(default_cap.schema_validation.is_none()); + + // Test ElicitationCapability with schema validation enabled + let cap_with_validation = ElicitationCapability { + schema_validation: Some(true), + }; + assert_eq!(cap_with_validation.schema_validation, Some(true)); + + // Test ElicitationCapability with schema validation disabled + let cap_without_validation = ElicitationCapability { + schema_validation: Some(false), + }; + assert_eq!(cap_without_validation.schema_validation, Some(false)); + + // Test JSON serialization + let json = serde_json::to_value(&cap_with_validation).unwrap(); + assert_eq!( + json, + serde_json::json!({ + "schemaValidation": true + }) + ); + + // Test JSON deserialization + let deserialized: ElicitationCapability = serde_json::from_value(json).unwrap(); + assert_eq!(deserialized.schema_validation, Some(true)); +} + +/// Test ClientCapabilities with elicitation capability +#[tokio::test] +async fn test_client_capabilities_with_elicitation() { + // Test ClientCapabilities with elicitation capability + let capabilities = ClientCapabilities { + elicitation: Some(ElicitationCapability { + schema_validation: Some(true), + }), + ..Default::default() + }; + + // Verify elicitation capability is present + assert!(capabilities.elicitation.is_some()); + assert_eq!( + capabilities.elicitation.as_ref().unwrap().schema_validation, + Some(true) + ); + + // Test JSON serialization + let json = serde_json::to_value(&capabilities).unwrap(); + assert!( + json["elicitation"]["schemaValidation"] + .as_bool() + .unwrap_or(false) + ); + + // Test ClientCapabilities without elicitation + let capabilities_without = ClientCapabilities { + elicitation: None, + ..Default::default() + }; + + assert!(capabilities_without.elicitation.is_none()); +} + +/// Test InitializeRequestParam with elicitation capability +#[tokio::test] +async fn test_initialize_request_with_elicitation() { + // Test InitializeRequestParam with elicitation capability + let init_param = InitializeRequestParam { + protocol_version: ProtocolVersion::LATEST, + capabilities: ClientCapabilities { + elicitation: Some(ElicitationCapability { + schema_validation: Some(true), + }), + ..Default::default() + }, + client_info: Implementation { + name: "test-client".to_string(), + version: "1.0.0".to_string(), + title: None, + website_url: None, + icons: None, + }, + }; + + // Verify the structure + assert!(init_param.capabilities.elicitation.is_some()); + assert_eq!( + init_param + .capabilities + .elicitation + .as_ref() + .unwrap() + .schema_validation, + Some(true) + ); + + // Test JSON serialization + let json = serde_json::to_value(&init_param).unwrap(); + assert!( + json["capabilities"]["elicitation"]["schemaValidation"] + .as_bool() + .unwrap_or(false) + ); +} + +/// Test capability checking logic (simulated) +#[tokio::test] +async fn test_capability_checking_logic() { + // Simulate the logic that would be used in supports_elicitation() + + // Case 1: Client with elicitation capability + let client_with_capability = InitializeRequestParam { + protocol_version: ProtocolVersion::LATEST, + capabilities: ClientCapabilities { + elicitation: Some(ElicitationCapability { + schema_validation: Some(true), + }), + ..Default::default() + }, + client_info: Implementation { + name: "test-client".to_string(), + version: "1.0.0".to_string(), + title: None, + website_url: None, + icons: None, + }, + }; + + // Simulate supports_elicitation() logic + let supports_elicitation = client_with_capability.capabilities.elicitation.is_some(); + assert!(supports_elicitation); + + // Case 2: Client without elicitation capability + let client_without_capability = InitializeRequestParam { + protocol_version: ProtocolVersion::LATEST, + capabilities: ClientCapabilities { + elicitation: None, + ..Default::default() + }, + client_info: Implementation { + name: "test-client".to_string(), + version: "1.0.0".to_string(), + title: None, + website_url: None, + icons: None, + }, + }; + let supports_elicitation = client_without_capability.capabilities.elicitation.is_some(); + assert!(!supports_elicitation); +} + +/// Test CapabilityNotSupported error message formatting +#[tokio::test] +async fn test_capability_not_supported_error_message() { + let error = ElicitationError::CapabilityNotSupported; + let message = format!("{}", error); + + assert_eq!( + message, + "Client does not support elicitation - capability not declared during initialization" + ); +} + +/// Test all ElicitationError variants and their messages +#[tokio::test] +async fn test_elicitation_error_variants() { + // Test CapabilityNotSupported + let capability_error = ElicitationError::CapabilityNotSupported; + assert_eq!( + format!("{}", capability_error), + "Client does not support elicitation - capability not declared during initialization" + ); + + // Test UserDeclined + let user_declined = ElicitationError::UserDeclined; + assert_eq!( + format!("{}", user_declined), + "User explicitly declined the request" + ); + + // Test UserCancelled + let user_cancelled = ElicitationError::UserCancelled; + assert_eq!( + format!("{}", user_cancelled), + "User cancelled/dismissed the request" + ); + + // Test NoContent + let no_content = ElicitationError::NoContent; + assert_eq!(format!("{}", no_content), "No response content provided"); + + // Test Service error + let service_error = ElicitationError::Service(ServiceError::UnexpectedResponse); + let message = format!("{}", service_error); + assert!(message.starts_with("Service error:")); + + // Test ParseError + let json_error = serde_json::from_str::("\"not_an_integer\"").unwrap_err(); + let data = serde_json::json!({"key": "value"}); + let parse_error = ElicitationError::ParseError { + error: json_error, + data: data.clone(), + }; + let message = format!("{}", parse_error); + assert!(message.starts_with("Failed to parse response data:")); + assert!(message.contains("Received data:")); + + // Test error matching + match capability_error { + ElicitationError::CapabilityNotSupported => {} // Expected + _ => panic!("Should match CapabilityNotSupported"), + } + + match user_declined { + ElicitationError::UserDeclined => {} // Expected + _ => panic!("Should match UserDeclined"), + } + + match user_cancelled { + ElicitationError::UserCancelled => {} // Expected + _ => panic!("Should match UserCancelled"), + } + + match no_content { + ElicitationError::NoContent => {} // Expected + _ => panic!("Should match NoContent"), + } +} + +/// Test ElicitationCapability serialization with schema validation +#[tokio::test] +async fn test_elicitation_capability_serialization() { + use rmcp::model::ElicitationCapability; + + // Test default capability (no schema validation) + let default_cap = ElicitationCapability::default(); + let json = serde_json::to_value(&default_cap).unwrap(); + + // Should serialize to empty object when no fields are set + assert_eq!(json, serde_json::json!({})); + + // Test capability with schema validation enabled + let cap_with_validation = ElicitationCapability { + schema_validation: Some(true), + }; + let json = serde_json::to_value(&cap_with_validation).unwrap(); + + assert_eq!( + json, + serde_json::json!({ + "schemaValidation": true + }) + ); + + // Test capability with schema validation disabled + let cap_without_validation = ElicitationCapability { + schema_validation: Some(false), + }; + let json = serde_json::to_value(&cap_without_validation).unwrap(); + + assert_eq!( + json, + serde_json::json!({ + "schemaValidation": false + }) + ); + + // Test deserialization + let deserialized: ElicitationCapability = serde_json::from_value(serde_json::json!({ + "schemaValidation": true + })) + .unwrap(); + + assert_eq!(deserialized.schema_validation, Some(true)); +} + +/// Test ClientCapabilities builder with elicitation capability methods +#[tokio::test] +async fn test_client_capabilities_elicitation_builder() { + use rmcp::model::{ClientCapabilities, ElicitationCapability}; + + // Test enabling elicitation capability + let caps = ClientCapabilities::builder().enable_elicitation().build(); + + assert!(caps.elicitation.is_some()); + assert_eq!(caps.elicitation.as_ref().unwrap().schema_validation, None); + + // Test enabling elicitation with schema validation + let caps_with_validation = ClientCapabilities::builder() + .enable_elicitation() + .enable_elicitation_schema_validation() + .build(); + + assert!(caps_with_validation.elicitation.is_some()); + assert_eq!( + caps_with_validation + .elicitation + .as_ref() + .unwrap() + .schema_validation, + Some(true) + ); + + // Test enabling elicitation with custom capability + let custom_elicitation = ElicitationCapability { + schema_validation: Some(false), + }; + + let caps_custom = ClientCapabilities::builder() + .enable_elicitation_with(custom_elicitation.clone()) + .build(); + + assert!(caps_custom.elicitation.is_some()); + assert_eq!( + caps_custom.elicitation.as_ref().unwrap(), + &custom_elicitation + ); +} + +// ============================================================================= +// TIMEOUT TESTS +// ============================================================================= + +/// Test basic timeout functionality for create_elicitation_with_timeout +#[tokio::test] +async fn test_create_elicitation_with_timeout_basic() { + use std::time::Duration; + + // This test verifies that the method accepts timeout parameter + let schema = ElicitationSchema::builder() + .required_property("name", PrimitiveSchema::String(StringSchema::new())) + .required_property("email", PrimitiveSchema::String(StringSchema::new())) + .build() + .unwrap(); + + let _params = CreateElicitationRequestParam { + message: "Enter your details".to_string(), + requested_schema: schema, + }; + + // Test different timeout values + let timeout_short = Duration::from_millis(100); + let timeout_long = Duration::from_secs(30); + let timeout_none: Option = None; + + // Verify timeout parameter types are correct + assert!(!timeout_short.is_zero()); + assert!(!timeout_long.is_zero()); + assert!(timeout_none.is_none()); + + // Verify timeout values are reasonable + assert_eq!(timeout_short.as_millis(), 100); + assert_eq!(timeout_long.as_secs(), 30); +} + +/// Test timeout behavior with elicit_with_timeout method +#[tokio::test] +async fn test_elicit_with_timeout_method_signature() { + use std::time::Duration; + + // Test that method signature works with different timeout values + let timeout_values = vec![ + None, + Some(Duration::from_millis(500)), + Some(Duration::from_secs(1)), + Some(Duration::from_secs(30)), + Some(Duration::from_secs(60)), + ]; + + for timeout in timeout_values { + // Verify timeout value is properly handled + match timeout { + None => assert!(timeout.is_none()), + Some(duration) => { + assert!(duration > Duration::from_millis(0)); + assert!(duration <= Duration::from_secs(300)); // Max 5 minutes + } + } + } +} + +/// Test timeout value validation +#[tokio::test] +async fn test_timeout_value_validation() { + use std::time::Duration; + + // Test valid timeout ranges + let valid_timeouts = vec![ + Duration::from_millis(1), // Minimum + Duration::from_millis(100), // Short + Duration::from_secs(1), // 1 second + Duration::from_secs(30), // 30 seconds + Duration::from_secs(300), // 5 minutes + ]; + + for timeout in valid_timeouts { + assert!(timeout >= Duration::from_millis(1)); + assert!(timeout <= Duration::from_secs(300)); + } + + // Test edge cases + let zero_timeout = Duration::from_millis(0); + let very_long_timeout = Duration::from_secs(3600); // 1 hour + + // Zero timeout should be handled gracefully + assert_eq!(zero_timeout, Duration::from_millis(0)); + + // Very long timeout should work but may not be practical + assert!(very_long_timeout > Duration::from_secs(300)); +} + +/// Test timeout error message formatting +#[tokio::test] +async fn test_timeout_error_formatting() { + use std::time::Duration; + + let timeout = Duration::from_secs(30); + + // Simulate a timeout error + let timeout_error = ServiceError::Timeout { timeout }; + + // Verify error contains timeout information + let error_string = format!("{}", timeout_error); + assert!(error_string.contains("timeout")); + assert!(error_string.contains("30")); +} + +/// Test elicitation error handling with timeout +#[tokio::test] +async fn test_elicitation_timeout_error_conversion() { + use std::time::Duration; + + let timeout = Duration::from_millis(500); + let service_timeout_error = ServiceError::Timeout { timeout }; + let elicitation_error = ElicitationError::Service(service_timeout_error); + + // Verify error chain is preserved + match elicitation_error { + ElicitationError::Service(ServiceError::Timeout { timeout: t }) => { + assert_eq!(t, timeout); + } + _ => panic!("Expected timeout error"), + } +} + +/// Test timeout parameter pass-through in PeerRequestOptions +#[tokio::test] +async fn test_peer_request_options_timeout() { + use std::time::Duration; + + let timeout = Some(Duration::from_secs(15)); + + let options = PeerRequestOptions { + timeout, + meta: None, + }; + + // Verify timeout is properly stored + assert_eq!(options.timeout, timeout); + assert!(options.meta.is_none()); + + // Test with no timeout + let options_no_timeout = PeerRequestOptions { + timeout: None, + meta: None, + }; + + assert!(options_no_timeout.timeout.is_none()); +} + +/// Test realistic timeout scenarios +#[tokio::test] +async fn test_realistic_timeout_scenarios() { + use std::time::Duration; + + // Test common timeout scenarios users might encounter + + // Quick response (5 seconds) + let quick_timeout = Duration::from_secs(5); + assert!(quick_timeout >= Duration::from_secs(1)); + assert!(quick_timeout <= Duration::from_secs(10)); + + // Normal interaction (30 seconds) + let normal_timeout = Duration::from_secs(30); + assert!(normal_timeout >= Duration::from_secs(10)); + assert!(normal_timeout <= Duration::from_secs(60)); + + // Long form input (2 minutes) + let long_timeout = Duration::from_secs(120); + assert!(long_timeout >= Duration::from_secs(60)); + assert!(long_timeout <= Duration::from_secs(300)); +} + +/// Test that different ElicitationAction values map to correct error types +#[tokio::test] +async fn test_elicitation_action_error_mapping() { + use rmcp::{model::ElicitationAction, service::ElicitationError}; + + // Test that each action type produces the expected error + let test_cases = vec![ + (ElicitationAction::Decline, "UserDeclined"), + (ElicitationAction::Cancel, "UserCancelled"), + ]; + + for (action, _expected_error_type) in test_cases { + // Verify that the action exists and has the expected semantics + match action { + ElicitationAction::Accept => { + // Accept should not produce an error (it provides content) + } + ElicitationAction::Decline => { + // Should map to UserDeclined error + let error = ElicitationError::UserDeclined; + assert!(format!("{}", error).contains("explicitly declined")); + } + ElicitationAction::Cancel => { + // Should map to UserCancelled error + let error = ElicitationError::UserCancelled; + assert!(format!("{}", error).contains("cancelled/dismissed")); + } + } + } +} + +/// Test elicitation action semantics according to MCP specification +#[tokio::test] +async fn test_elicitation_action_semantics() { + use rmcp::model::ElicitationAction; + + // According to MCP spec: + // - Accept: User explicitly approved and submitted with data + // - Decline: User explicitly declined the request + // - Cancel: User dismissed without making an explicit choice + + // Test that all three actions are available + let actions = vec![ + ElicitationAction::Accept, + ElicitationAction::Decline, + ElicitationAction::Cancel, + ]; + + assert_eq!(actions.len(), 3); + + // Test serialization/deserialization + for action in actions { + let serialized = serde_json::to_string(&action).expect("Should serialize"); + let deserialized: ElicitationAction = + serde_json::from_str(&serialized).expect("Should deserialize"); + + // Actions should round-trip correctly + match (action, deserialized) { + (ElicitationAction::Accept, ElicitationAction::Accept) => {} + (ElicitationAction::Decline, ElicitationAction::Decline) => {} + (ElicitationAction::Cancel, ElicitationAction::Cancel) => {} + _ => panic!("Action serialization round-trip failed"), + } + } +} + +/// Test compile-time type safety for elicitation +#[tokio::test] +async fn test_elicitation_type_safety() { + use rmcp::service::ElicitationSafe; + use schemars::JsonSchema; + + // Test that our types implement ElicitationSafe + #[derive(serde::Serialize, serde::Deserialize, JsonSchema)] + struct SafeType { + name: String, + value: i32, + } + + rmcp::elicit_safe!(SafeType); + + // Verify that SafeType implements the required traits + fn assert_elicitation_safe() {} + assert_elicitation_safe::(); + + // Test that SafeType can generate schema (compile-time check) + let _schema = schemars::schema_for!(SafeType); +} + +/// Test that elicit_safe! macro works with multiple types +#[tokio::test] +async fn test_elicit_safe_macro() { + use schemars::JsonSchema; + + #[derive(serde::Serialize, serde::Deserialize, JsonSchema)] + struct TypeA { + field_a: String, + } + + #[derive(serde::Serialize, serde::Deserialize, JsonSchema)] + struct TypeB { + field_b: i32, + } + + #[derive(serde::Serialize, serde::Deserialize, JsonSchema)] + struct TypeC { + field_c: bool, + } + + // Test macro with multiple types + rmcp::elicit_safe!(TypeA, TypeB, TypeC); + + // All should implement ElicitationSafe + fn assert_all_safe() {} + assert_all_safe::(); + assert_all_safe::(); + assert_all_safe::(); +} + +/// Test ElicitationSafe trait behavior +#[tokio::test] +async fn test_elicitation_safe_trait() { + use schemars::JsonSchema; + + // Test object type validation + #[derive(serde::Serialize, serde::Deserialize, JsonSchema)] + struct ObjectType { + name: String, + count: usize, + active: bool, + } + + rmcp::elicit_safe!(ObjectType); + + // Test that ObjectType can generate schema (compile-time check) + let _schema = schemars::schema_for!(ObjectType); +} + +/// Test documentation examples compile correctly +#[tokio::test] +async fn test_elicitation_examples_compile() { + use schemars::JsonSchema; + use serde::{Deserialize, Serialize}; + + // Example from trait documentation + #[allow(dead_code)] + #[derive(Serialize, Deserialize, JsonSchema)] + struct UserProfile { + name: String, + email: String, + } + + rmcp::elicit_safe!(UserProfile); + + // This should compile and work + fn _example_usage() { + fn _assert_safe() {} + _assert_safe::(); + } +} + +// ============================================================================= +// BUILD-TIME VALIDATION TESTS +// ============================================================================= + +/// Test that build() validates required fields exist in properties +#[tokio::test] +async fn test_build_validation_required_field_not_in_properties() { + // Try to mark a field as required that doesn't exist in properties + let result = ElicitationSchema::builder() + .property("email", PrimitiveSchema::String(StringSchema::email())) + .mark_required("nonexistent_field") + .build(); + + // Should return an error + assert!(result.is_err()); + assert_eq!( + result.unwrap_err(), + "Required field does not exist in properties" + ); +} + +/// Test that build() succeeds when all required fields exist +#[tokio::test] +async fn test_build_validation_required_field_exists() { + let result = ElicitationSchema::builder() + .property("email", PrimitiveSchema::String(StringSchema::email())) + .property("name", PrimitiveSchema::String(StringSchema::new())) + .mark_required("email") + .mark_required("name") + .build(); + + // Should succeed + assert!(result.is_ok()); + let schema = result.unwrap(); + assert_eq!(schema.properties.len(), 2); + assert_eq!( + schema.required, + Some(vec!["email".to_string(), "name".to_string()]) + ); +} + +/// Test that build_unchecked() panics on validation errors +#[tokio::test] +#[should_panic(expected = "Invalid elicitation schema")] +async fn test_build_unchecked_panics_on_invalid() { + // build_unchecked validates but panics instead of returning Result + let _schema = ElicitationSchema::builder() + .property("email", PrimitiveSchema::String(StringSchema::email())) + .mark_required("nonexistent_field") + .build_unchecked(); +} + +/// Test convenience methods handle validation correctly +#[tokio::test] +async fn test_convenience_methods_validation() { + // required_string_property should add both property and mark as required + let result = ElicitationSchema::builder() + .required_string_property("name", |s| s) + .required_email("email") + .build(); + + assert!(result.is_ok()); + let schema = result.unwrap(); + assert_eq!(schema.properties.len(), 2); + assert!( + schema + .required + .as_ref() + .unwrap() + .contains(&"name".to_string()) + ); + assert!( + schema + .required + .as_ref() + .unwrap() + .contains(&"email".to_string()) + ); +} + +/// Test typed property methods work correctly +#[tokio::test] +async fn test_typed_property_methods() { + let result = ElicitationSchema::builder() + .string_property("name", |s| s.length(1, 100)) + .number_property("price", |n| n.range(0.0, 1000.0)) + .integer_property("quantity", |i| i.range(1, 100)) + .bool_property("in_stock", |b| b.with_default(true)) + .build(); + + assert!(result.is_ok()); + let schema = result.unwrap(); + assert_eq!(schema.properties.len(), 4); + + // Verify types are correct + if let Some(PrimitiveSchema::String(_)) = schema.properties.get("name") { + // Expected + } else { + panic!("name should be StringSchema"); + } + + if let Some(PrimitiveSchema::Number(_)) = schema.properties.get("price") { + // Expected + } else { + panic!("price should be NumberSchema"); + } + + if let Some(PrimitiveSchema::Integer(_)) = schema.properties.get("quantity") { + // Expected + } else { + panic!("quantity should be IntegerSchema"); + } + + if let Some(PrimitiveSchema::Boolean(_)) = schema.properties.get("in_stock") { + // Expected + } else { + panic!("in_stock should be BooleanSchema"); + } +} + +/// Test required typed property methods +#[tokio::test] +async fn test_required_typed_property_methods() { + let result = ElicitationSchema::builder() + .required_string_property("name", |s| s) + .required_number_property("price", |n| n) + .required_integer_property("age", |i| i) + .required_bool_property("active", |b| b) + .build(); + + assert!(result.is_ok()); + let schema = result.unwrap(); + assert_eq!(schema.properties.len(), 4); + assert_eq!(schema.required.as_ref().unwrap().len(), 4); + + // All should be marked as required + let required = schema.required.as_ref().unwrap(); + assert!(required.contains(&"name".to_string())); + assert!(required.contains(&"price".to_string())); + assert!(required.contains(&"age".to_string())); + assert!(required.contains(&"active".to_string())); +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_embedded_resource_meta.rs b/code-rs/third_party/rmcp-0.8.3/tests/test_embedded_resource_meta.rs new file mode 100644 index 00000000000..7535e358f94 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_embedded_resource_meta.rs @@ -0,0 +1,122 @@ +use rmcp::model::{AnnotateAble, Content, Meta, RawContent, ResourceContents}; +use serde_json::json; + +#[test] +fn serialize_embedded_text_resource_with_meta() { + // Inner contents meta + let mut resource_content_meta = Meta::new(); + resource_content_meta.insert("inner".to_string(), json!(2)); + + // Top-level embedded resource meta + let mut resource_meta = Meta::new(); + resource_meta.insert("top".to_string(), json!(1)); + + let content: Content = RawContent::Resource(rmcp::model::RawEmbeddedResource { + meta: Some(resource_meta), + resource: ResourceContents::TextResourceContents { + uri: "str://example".to_string(), + mime_type: Some("text/plain".to_string()), + text: "hello".to_string(), + meta: Some(resource_content_meta), + }, + }) + .no_annotation(); + + let v = serde_json::to_value(&content).unwrap(); + + let expected = json!({ + "type": "resource", + "_meta": {"top": 1}, + "resource": { + "uri": "str://example", + "mimeType": "text/plain", + "text": "hello", + "_meta": {"inner": 2} + } + }); + + assert_eq!(v, expected); +} + +#[test] +fn serialize_embedded_text_resource_without_meta_omits_fields() { + let content: Content = RawContent::Resource(rmcp::model::RawEmbeddedResource { + meta: None, + resource: ResourceContents::TextResourceContents { + uri: "str://no-meta".to_string(), + mime_type: Some("text/plain".to_string()), + text: "hi".to_string(), + meta: None, + }, + }) + .no_annotation(); + + let v = serde_json::to_value(&content).unwrap(); + + assert_eq!(v.get("_meta"), None); + let inner = v.get("resource").and_then(|r| r.as_object()).unwrap(); + assert_eq!(inner.get("_meta"), None); +} + +#[test] +fn deserialize_embedded_text_resource_with_meta() { + let raw = json!({ + "type": "resource", + "_meta": {"x": true}, + "resource": { + "uri": "str://from-json", + "text": "ok", + "_meta": {"y": 42} + } + }); + + let content: Content = serde_json::from_value(raw).unwrap(); + + let raw = match &content.raw { + RawContent::Resource(er) => er, + _ => panic!("expected resource"), + }; + + // top-level _meta + let top = raw.meta.as_ref().expect("top-level meta missing"); + assert_eq!(top.get("x").unwrap(), &json!(true)); + + // inner contents _meta + match &raw.resource { + ResourceContents::TextResourceContents { + meta, uri, text, .. + } => { + assert_eq!(uri, "str://from-json"); + assert_eq!(text, "ok"); + let inner = meta.as_ref().expect("inner meta missing"); + assert_eq!(inner.get("y").unwrap(), &json!(42)); + } + _ => panic!("expected text resource contents"), + } +} + +#[test] +fn serialize_embedded_blob_resource_with_meta() { + let mut resource_content_meta = Meta::new(); + resource_content_meta.insert("blob_inner".to_string(), json!(true)); + + let mut resource_meta = Meta::new(); + resource_meta.insert("blob_top".to_string(), json!("t")); + + let content: Content = RawContent::Resource(rmcp::model::RawEmbeddedResource { + meta: Some(resource_meta), + resource: ResourceContents::BlobResourceContents { + uri: "str://blob".to_string(), + mime_type: Some("application/octet-stream".to_string()), + blob: "Zm9v".to_string(), + meta: Some(resource_content_meta), + }, + }) + .no_annotation(); + + let v = serde_json::to_value(&content).unwrap(); + + assert_eq!(v.get("_meta").unwrap(), &json!({"blob_top": "t"})); + let inner = v.get("resource").unwrap(); + assert_eq!(inner.get("_meta").unwrap(), &json!({"blob_inner": true})); +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_json_schema_detection.rs b/code-rs/third_party/rmcp-0.8.3/tests/test_json_schema_detection.rs new file mode 100644 index 00000000000..89dd858692e --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_json_schema_detection.rs @@ -0,0 +1,114 @@ +//cargo test --test test_json_schema_detection --features "client server macros" +use rmcp::{ + Json, ServerHandler, handler::server::router::tool::ToolRouter, tool, tool_handler, tool_router, +}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, JsonSchema)] +pub struct TestData { + pub value: String, +} + +#[tool_handler(router = self.tool_router)] +impl ServerHandler for TestServer {} + +#[derive(Debug, Clone)] +pub struct TestServer { + tool_router: ToolRouter, +} + +impl Default for TestServer { + fn default() -> Self { + Self::new() + } +} + +#[tool_router(router = tool_router)] +impl TestServer { + pub fn new() -> Self { + Self { + tool_router: Self::tool_router(), + } + } + + /// Tool that returns Json - should have output schema + #[tool(name = "with-json")] + pub async fn with_json(&self) -> Result, String> { + Ok(Json(TestData { + value: "test".to_string(), + })) + } + + /// Tool that returns regular type - should NOT have output schema + #[tool(name = "without-json")] + pub async fn without_json(&self) -> Result { + Ok("test".to_string()) + } + + /// Tool that returns Result with inner Json - should have output schema + #[tool(name = "result-with-json")] + pub async fn result_with_json(&self) -> Result, rmcp::ErrorData> { + Ok(Json(TestData { + value: "test".to_string(), + })) + } + + /// Tool with explicit output_schema attribute - should have output schema + #[tool(name = "explicit-schema", output_schema = rmcp::handler::server::tool::cached_schema_for_type::())] + pub async fn explicit_schema(&self) -> Result { + Ok("test".to_string()) + } +} + +#[tokio::test] +async fn test_json_type_generates_schema() { + let server = TestServer::new(); + let tools = server.tool_router.list_all(); + + // Find the with-json tool + let json_tool = tools.iter().find(|t| t.name == "with-json").unwrap(); + assert!( + json_tool.output_schema.is_some(), + "Json return type should generate output schema" + ); +} + +#[tokio::test] +async fn test_non_json_type_no_schema() { + let server = TestServer::new(); + let tools = server.tool_router.list_all(); + + // Find the without-json tool + let non_json_tool = tools.iter().find(|t| t.name == "without-json").unwrap(); + assert!( + non_json_tool.output_schema.is_none(), + "Regular return type should NOT generate output schema" + ); +} + +#[tokio::test] +async fn test_result_with_json_generates_schema() { + let server = TestServer::new(); + let tools = server.tool_router.list_all(); + + // Find the result-with-json tool + let result_json_tool = tools.iter().find(|t| t.name == "result-with-json").unwrap(); + assert!( + result_json_tool.output_schema.is_some(), + "Result, E> return type should generate output schema" + ); +} + +#[tokio::test] +async fn test_explicit_schema_override() { + let server = TestServer::new(); + let tools = server.tool_router.list_all(); + + // Find the explicit-schema tool + let explicit_tool = tools.iter().find(|t| t.name == "explicit-schema").unwrap(); + assert!( + explicit_tool.output_schema.is_some(), + "Explicit output_schema attribute should work" + ); +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_logging.rs b/code-rs/third_party/rmcp-0.8.3/tests/test_logging.rs new file mode 100644 index 00000000000..eb63773fcc7 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_logging.rs @@ -0,0 +1,351 @@ +// cargo test --features "server client" --package rmcp test_logging +mod common; + +use std::sync::{Arc, Mutex}; + +use common::handlers::{TestClientHandler, TestServer}; +use rmcp::{ + ServiceExt, + model::{LoggingLevel, LoggingMessageNotificationParam, SetLevelRequestParam}, +}; +use serde_json::json; +use tokio::sync::Notify; + +#[tokio::test] +async fn test_logging_spec_compliance() -> anyhow::Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + let receive_signal = Arc::new(Notify::new()); + let received_messages = Arc::new(Mutex::new(Vec::::new())); + + // Start server in a separate task + let server_handle = tokio::spawn(async move { + let server = TestServer::new().serve(server_transport).await?; + + // Test server can send messages before level is set + server + .peer() + .notify_logging_message(LoggingMessageNotificationParam { + level: LoggingLevel::Info, + data: serde_json::json!({ + "message": "Server initiated message", + "timestamp": chrono::Utc::now().to_rfc3339(), + }), + logger: Some("test_server".to_string()), + }) + .await?; + + server.waiting().await?; + anyhow::Ok(()) + }); + + let client = TestClientHandler::with_notification( + true, + true, + receive_signal.clone(), + received_messages.clone(), + ) + .serve(client_transport) + .await?; + + // Wait for the initial server message + receive_signal.notified().await; + { + let mut messages = received_messages.lock().unwrap(); + assert_eq!(messages.len(), 1, "Should receive server-initiated message"); + messages.clear(); + } + + // Test level filtering and message format + for level in [ + LoggingLevel::Emergency, + LoggingLevel::Warning, + LoggingLevel::Debug, + ] { + client + .peer() + .set_level(SetLevelRequestParam { level }) + .await?; + + // Wait for each message response + receive_signal.notified().await; + + let mut messages = received_messages.lock().unwrap(); + let msg = messages.last().unwrap(); + + // Verify required fields + assert_eq!(msg.level, level); + assert!(msg.data.is_object()); + + // Verify data format + let data = msg.data.as_object().unwrap(); + assert!(data.contains_key("message")); + assert!(data.contains_key("timestamp")); + + // Verify timestamp + let timestamp = data["timestamp"].as_str().unwrap(); + chrono::DateTime::parse_from_rfc3339(timestamp).expect("RFC3339 timestamp"); + + messages.clear(); + } + + // Important: Cancel the client before ending the test + client.cancel().await?; + + // Wait for server to complete + server_handle.await??; + + Ok(()) +} + +#[tokio::test] +async fn test_logging_user_scenarios() -> anyhow::Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + let receive_signal = Arc::new(Notify::new()); + let received_messages = Arc::new(Mutex::new(Vec::::new())); + + let server_handle = tokio::spawn(async move { + let server = TestServer::new().serve(server_transport).await?; + server.waiting().await?; + anyhow::Ok(()) + }); + + let client = TestClientHandler::with_notification( + true, + true, + receive_signal.clone(), + received_messages.clone(), + ) + .serve(client_transport) + .await?; + + // Test 1: Error reporting scenario + client + .peer() + .set_level(SetLevelRequestParam { + level: LoggingLevel::Error, + }) + .await?; + receive_signal.notified().await; // Wait for response + { + let messages = received_messages.lock().unwrap(); + let msg = &messages[0]; + let data = msg.data.as_object().unwrap(); + assert!( + data.contains_key("error_code"), + "Error should have an error code" + ); + assert!( + data.contains_key("error_details"), + "Error should have details" + ); + assert!( + data.contains_key("timestamp"), + "Should know when error occurred" + ); + } + + // Test 2: Debug scenario + client + .peer() + .set_level(SetLevelRequestParam { + level: LoggingLevel::Debug, + }) + .await?; + receive_signal.notified().await; // Wait for response + { + let messages = received_messages.lock().unwrap(); + let msg = messages.last().unwrap(); + let data = msg.data.as_object().unwrap(); + assert!( + data.contains_key("function"), + "Debug should show function name" + ); + assert!(data.contains_key("line"), "Debug should show line number"); + assert!( + data.contains_key("context"), + "Debug should show execution context" + ); + } + + // Test 3: Production monitoring scenario + client + .peer() + .set_level(SetLevelRequestParam { + level: LoggingLevel::Info, + }) + .await?; + receive_signal.notified().await; // Wait for response + { + let messages = received_messages.lock().unwrap(); + let msg = messages.last().unwrap(); + let data = msg.data.as_object().unwrap(); + assert!(data.contains_key("status"), "Should show system status"); + assert!(data.contains_key("metrics"), "Should include metrics"); + } + + // Important: Cancel client and wait for server before ending + client.cancel().await?; + server_handle.await??; + + Ok(()) +} + +#[test] +fn test_logging_level_serialization() { + // Test all levels match spec exactly + let test_cases = [ + (LoggingLevel::Alert, "alert"), + (LoggingLevel::Critical, "critical"), + (LoggingLevel::Debug, "debug"), + (LoggingLevel::Emergency, "emergency"), + (LoggingLevel::Error, "error"), + (LoggingLevel::Info, "info"), + (LoggingLevel::Notice, "notice"), + (LoggingLevel::Warning, "warning"), + ]; + + for (level, expected) in test_cases { + let serialized = serde_json::to_string(&level).unwrap(); + // Remove quotes from serialized string + let serialized = serialized.trim_matches('"'); + assert_eq!( + serialized, expected, + "LoggingLevel::{:?} should serialize to \"{}\"", + level, expected + ); + } + + // Test deserialization from spec strings + for (level, spec_string) in test_cases { + let deserialized: LoggingLevel = + serde_json::from_str(&format!("\"{}\"", spec_string)).unwrap(); + assert_eq!( + deserialized, level, + "\"{}\" should deserialize to LoggingLevel::{:?}", + spec_string, level + ); + } +} + +#[tokio::test] +async fn test_logging_edge_cases() -> anyhow::Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + let receive_signal = Arc::new(Notify::new()); + let received_messages = Arc::new(Mutex::new(Vec::::new())); + + let server_handle = tokio::spawn(async move { + let server = TestServer::new().serve(server_transport).await?; + server.waiting().await?; + anyhow::Ok(()) + }); + + let client = TestClientHandler::with_notification( + true, + true, + receive_signal.clone(), + received_messages.clone(), + ) + .serve(client_transport) + .await?; + + // Test all logging levels from spec + for level in [ + LoggingLevel::Alert, + LoggingLevel::Critical, + LoggingLevel::Notice, // These weren't tested before + ] { + client + .peer() + .set_level(SetLevelRequestParam { level }) + .await?; + receive_signal.notified().await; + + let messages = received_messages.lock().unwrap(); + let msg = messages.last().unwrap(); + assert_eq!(msg.level, level); + } + + client.cancel().await?; + server_handle.await??; + Ok(()) +} + +#[tokio::test] +async fn test_logging_optional_fields() -> anyhow::Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + let receive_signal = Arc::new(Notify::new()); + let received_messages = Arc::new(Mutex::new(Vec::::new())); + + let server_handle = tokio::spawn(async move { + let server = TestServer::new().serve(server_transport).await?; + + // Test message with and without optional logger field + for (level, has_logger) in [(LoggingLevel::Info, true), (LoggingLevel::Debug, false)] { + server + .peer() + .notify_logging_message(LoggingMessageNotificationParam { + level, + data: json!({"test": "data"}), + logger: has_logger.then(|| "test_logger".to_string()), + }) + .await?; + } + + server.waiting().await?; + anyhow::Ok(()) + }); + + let client = TestClientHandler::with_notification( + true, + true, + receive_signal.clone(), + received_messages.clone(), + ) + .serve(client_transport) + .await?; + + // Wait for the initial server message + receive_signal.notified().await; + { + let mut messages = received_messages.lock().unwrap(); + assert_eq!(messages.len(), 2, "Should receive two messages"); + messages.clear(); + } + + // Test level filtering and message format + for level in [LoggingLevel::Info, LoggingLevel::Debug] { + client + .peer() + .set_level(SetLevelRequestParam { level }) + .await?; + + // Wait for each message response + receive_signal.notified().await; + + let mut messages = received_messages.lock().unwrap(); + let msg = messages.last().unwrap(); + + // Verify required fields + assert_eq!(msg.level, level); + assert!(msg.data.is_object()); + + // Verify data format + let data = msg.data.as_object().unwrap(); + assert!(data.contains_key("message")); + assert!(data.contains_key("timestamp")); + + // Verify timestamp + let timestamp = data["timestamp"].as_str().unwrap(); + chrono::DateTime::parse_from_rfc3339(timestamp).expect("RFC3339 timestamp"); + + messages.clear(); + } + + // Important: Cancel the client before ending the test + client.cancel().await?; + + // Wait for server to complete + server_handle.await??; + + Ok(()) +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_message_protocol.rs b/code-rs/third_party/rmcp-0.8.3/tests/test_message_protocol.rs new file mode 100644 index 00000000000..602f93daba3 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_message_protocol.rs @@ -0,0 +1,559 @@ +//cargo test --test test_message_protocol --features "client server" + +mod common; +use common::handlers::{TestClientHandler, TestServer}; +use rmcp::{ + ServiceExt, + model::*, + service::{RequestContext, Service}, +}; +use tokio_util::sync::CancellationToken; + +// Tests start here +#[tokio::test] +async fn test_message_roles() { + let messages = vec![ + SamplingMessage { + role: Role::User, + content: Content::text("user message"), + }, + SamplingMessage { + role: Role::Assistant, + content: Content::text("assistant message"), + }, + ]; + + // Verify all roles can be serialized/deserialized correctly + let json = serde_json::to_string(&messages).unwrap(); + let deserialized: Vec = serde_json::from_str(&json).unwrap(); + assert_eq!(messages, deserialized); +} + +#[tokio::test] +async fn test_context_inclusion_integration() -> anyhow::Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + + // Start server + let server_handle = tokio::spawn(async move { + let server = TestServer::new().serve(server_transport).await?; + server.waiting().await?; + anyhow::Ok(()) + }); + + // Start client that honors context requests + let handler = TestClientHandler::new(true, true); + let client = handler.clone().serve(client_transport).await?; + + // Test ThisServer context inclusion + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { + method: Default::default(), + params: CreateMessageRequestParam { + messages: vec![SamplingMessage { + role: Role::User, + content: Content::text("test message"), + }], + include_context: Some(ContextInclusion::ThisServer), + model_preferences: None, + system_prompt: None, + temperature: None, + max_tokens: 100, + stop_sequences: None, + metadata: None, + }, + extensions: Default::default(), + }); + + let result = handler + .handle_request( + request.clone(), + RequestContext { + peer: client.peer().clone(), + ct: CancellationToken::new(), + id: NumberOrString::Number(1), + meta: Default::default(), + extensions: Default::default(), + }, + ) + .await?; + + if let ClientResult::CreateMessageResult(result) = result { + let text = result.message.content.as_text().unwrap().text.as_str(); + assert!( + text.contains("test context"), + "Response should include context for ThisServer" + ); + } else { + panic!("Expected CreateMessageResult"); + } + + // Test AllServers context inclusion + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { + method: Default::default(), + params: CreateMessageRequestParam { + messages: vec![SamplingMessage { + role: Role::User, + content: Content::text("test message"), + }], + include_context: Some(ContextInclusion::AllServers), + model_preferences: None, + system_prompt: None, + temperature: None, + max_tokens: 100, + stop_sequences: None, + metadata: None, + }, + extensions: Default::default(), + }); + + let result = handler + .handle_request( + request.clone(), + RequestContext { + peer: client.peer().clone(), + ct: CancellationToken::new(), + id: NumberOrString::Number(2), + meta: Default::default(), + extensions: Default::default(), + }, + ) + .await?; + + if let ClientResult::CreateMessageResult(result) = result { + let text = result.message.content.as_text().unwrap().text.as_str(); + assert!( + text.contains("test context"), + "Response should include context for AllServers" + ); + } else { + panic!("Expected CreateMessageResult"); + } + + // Test No context inclusion + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { + method: Default::default(), + params: CreateMessageRequestParam { + messages: vec![SamplingMessage { + role: Role::User, + content: Content::text("test message"), + }], + include_context: Some(ContextInclusion::None), + model_preferences: None, + system_prompt: None, + temperature: None, + max_tokens: 100, + stop_sequences: None, + metadata: None, + }, + extensions: Default::default(), + }); + + let result = handler + .handle_request( + request.clone(), + RequestContext { + peer: client.peer().clone(), + ct: CancellationToken::new(), + id: NumberOrString::Number(3), + meta: Default::default(), + extensions: Default::default(), + }, + ) + .await?; + + if let ClientResult::CreateMessageResult(result) = result { + let text = result.message.content.as_text().unwrap().text.as_str(); + assert!( + !text.contains("test context"), + "Response should not include context for None" + ); + } else { + panic!("Expected CreateMessageResult"); + } + + client.cancel().await?; + server_handle.await??; + Ok(()) +} + +#[tokio::test] +async fn test_context_inclusion_ignored_integration() -> anyhow::Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + + // Start server + let server_handle = tokio::spawn(async move { + let server = TestServer::new().serve(server_transport).await?; + server.waiting().await?; + anyhow::Ok(()) + }); + + // Start client that ignores context requests + let handler = TestClientHandler::new(false, false); + let client = handler.clone().serve(client_transport).await?; + + // Test that context requests are ignored + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { + method: Default::default(), + params: CreateMessageRequestParam { + messages: vec![SamplingMessage { + role: Role::User, + content: Content::text("test message"), + }], + include_context: Some(ContextInclusion::ThisServer), + model_preferences: None, + system_prompt: None, + temperature: None, + max_tokens: 100, + stop_sequences: None, + metadata: None, + }, + extensions: Default::default(), + }); + + let result = handler + .handle_request( + request.clone(), + RequestContext { + peer: client.peer().clone(), + ct: CancellationToken::new(), + id: NumberOrString::Number(1), + meta: Meta::default(), + extensions: Default::default(), + }, + ) + .await?; + + if let ClientResult::CreateMessageResult(result) = result { + let text = result.message.content.as_text().unwrap().text.as_str(); + assert!( + !text.contains("test context"), + "Context should be ignored when client chooses not to honor requests" + ); + } else { + panic!("Expected CreateMessageResult"); + } + + client.cancel().await?; + server_handle.await??; + Ok(()) +} + +#[tokio::test] +async fn test_message_sequence_integration() -> anyhow::Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + + // Start server + let server_handle = tokio::spawn(async move { + let server = TestServer::new().serve(server_transport).await?; + server.waiting().await?; + anyhow::Ok(()) + }); + + // Start client + let handler = TestClientHandler::new(true, true); + let client = handler.clone().serve(client_transport).await?; + + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { + method: Default::default(), + params: CreateMessageRequestParam { + messages: vec![ + SamplingMessage { + role: Role::User, + content: Content::text("first message"), + }, + SamplingMessage { + role: Role::Assistant, + content: Content::text("second message"), + }, + ], + include_context: Some(ContextInclusion::ThisServer), + model_preferences: None, + system_prompt: None, + temperature: None, + max_tokens: 100, + stop_sequences: None, + metadata: None, + }, + extensions: Default::default(), + }); + + let result = handler + .handle_request( + request.clone(), + RequestContext { + peer: client.peer().clone(), + ct: CancellationToken::new(), + id: NumberOrString::Number(1), + meta: Meta::default(), + extensions: Default::default(), + }, + ) + .await?; + + if let ClientResult::CreateMessageResult(result) = result { + let text = result.message.content.as_text().unwrap().text.as_str(); + assert!( + text.contains("test context"), + "Response should include context when ThisServer is specified" + ); + assert_eq!(result.model, "test-model"); + assert_eq!( + result.stop_reason, + Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()) + ); + } else { + panic!("Expected CreateMessageResult"); + } + + client.cancel().await?; + server_handle.await??; + Ok(()) +} + +#[tokio::test] +async fn test_message_sequence_validation_integration() -> anyhow::Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + + let server_handle = tokio::spawn(async move { + let server = TestServer::new().serve(server_transport).await?; + server.waiting().await?; + anyhow::Ok(()) + }); + + let handler = TestClientHandler::new(true, true); + let client = handler.clone().serve(client_transport).await?; + + // Test valid sequence: User -> Assistant -> User + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { + method: Default::default(), + params: CreateMessageRequestParam { + messages: vec![ + SamplingMessage { + role: Role::User, + content: Content::text("first user message"), + }, + SamplingMessage { + role: Role::Assistant, + content: Content::text("first assistant response"), + }, + SamplingMessage { + role: Role::User, + content: Content::text("second user message"), + }, + ], + include_context: None, + model_preferences: None, + system_prompt: None, + temperature: None, + max_tokens: 100, + stop_sequences: None, + metadata: None, + }, + extensions: Default::default(), + }); + + let result = handler + .handle_request( + request.clone(), + RequestContext { + peer: client.peer().clone(), + ct: CancellationToken::new(), + id: NumberOrString::Number(1), + meta: Meta::default(), + extensions: Default::default(), + }, + ) + .await?; + + assert!(matches!(result, ClientResult::CreateMessageResult(_))); + + // Test invalid: No user message + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { + method: Default::default(), + params: CreateMessageRequestParam { + messages: vec![SamplingMessage { + role: Role::Assistant, + content: Content::text("assistant message"), + }], + include_context: None, + model_preferences: None, + system_prompt: None, + temperature: None, + max_tokens: 100, + stop_sequences: None, + metadata: None, + }, + extensions: Default::default(), + }); + + let result = handler + .handle_request( + request.clone(), + RequestContext { + peer: client.peer().clone(), + ct: CancellationToken::new(), + id: NumberOrString::Number(2), + meta: Meta::default(), + extensions: Default::default(), + }, + ) + .await; + + assert!(result.is_err()); + + client.cancel().await?; + server_handle.await??; + Ok(()) +} + +#[tokio::test] +async fn test_selective_context_handling_integration() -> anyhow::Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + + let server_handle = tokio::spawn(async move { + let server = TestServer::new().serve(server_transport).await?; + server.waiting().await?; + anyhow::Ok(()) + }); + + // Client that only honors ThisServer but ignores AllServers + let handler = TestClientHandler::new(true, false); + let client = handler.clone().serve(client_transport).await?; + + // Test ThisServer is honored + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { + method: Default::default(), + params: CreateMessageRequestParam { + messages: vec![SamplingMessage { + role: Role::User, + content: Content::text("test message"), + }], + include_context: Some(ContextInclusion::ThisServer), + model_preferences: None, + system_prompt: None, + temperature: None, + max_tokens: 100, + stop_sequences: None, + metadata: None, + }, + extensions: Default::default(), + }); + + let result = handler + .handle_request( + request.clone(), + RequestContext { + peer: client.peer().clone(), + ct: CancellationToken::new(), + id: NumberOrString::Number(1), + meta: Meta::default(), + extensions: Default::default(), + }, + ) + .await?; + + if let ClientResult::CreateMessageResult(result) = result { + let text = result.message.content.as_text().unwrap().text.as_str(); + assert!( + text.contains("test context"), + "ThisServer context request should be honored" + ); + } + + // Test AllServers is ignored + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { + method: Default::default(), + params: CreateMessageRequestParam { + messages: vec![SamplingMessage { + role: Role::User, + content: Content::text("test message"), + }], + include_context: Some(ContextInclusion::AllServers), + model_preferences: None, + system_prompt: None, + temperature: None, + max_tokens: 100, + stop_sequences: None, + metadata: None, + }, + extensions: Default::default(), + }); + + let result = handler + .handle_request( + request.clone(), + RequestContext { + peer: client.peer().clone(), + ct: CancellationToken::new(), + id: NumberOrString::Number(2), + meta: Meta::default(), + extensions: Default::default(), + }, + ) + .await?; + + if let ClientResult::CreateMessageResult(result) = result { + let text = result.message.content.as_text().unwrap().text.as_str(); + assert!( + !text.contains("test context"), + "AllServers context request should be ignored" + ); + } + + client.cancel().await?; + server_handle.await??; + Ok(()) +} + +#[tokio::test] +async fn test_context_inclusion() -> anyhow::Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + let server_handle = tokio::spawn(async move { + let server = TestServer::new().serve(server_transport).await?; + server.waiting().await?; + anyhow::Ok(()) + }); + + let handler = TestClientHandler::new(true, true); + let client = handler.clone().serve(client_transport).await?; + + // Test context handling + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { + method: Default::default(), + params: CreateMessageRequestParam { + messages: vec![SamplingMessage { + role: Role::User, + content: Content::text("test"), + }], + include_context: Some(ContextInclusion::ThisServer), + model_preferences: None, + system_prompt: None, + temperature: None, + max_tokens: 100, + stop_sequences: None, + metadata: None, + }, + extensions: Default::default(), + }); + + let result = handler + .handle_request( + request.clone(), + RequestContext { + peer: client.peer().clone(), + ct: CancellationToken::new(), + id: NumberOrString::Number(1), + meta: Meta::default(), + extensions: Default::default(), + }, + ) + .await?; + + if let ClientResult::CreateMessageResult(result) = result { + let text = result.message.content.as_text().unwrap().text.as_str(); + assert!(text.contains("test context")); + } + + client.cancel().await?; + server_handle.await??; + Ok(()) +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_message_schema.rs b/code-rs/third_party/rmcp-0.8.3/tests/test_message_schema.rs new file mode 100644 index 00000000000..f05c263cc6d --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_message_schema.rs @@ -0,0 +1,78 @@ +mod tests { + use rmcp::model::{ClientJsonRpcMessage, ServerJsonRpcMessage}; + use schemars::generate::SchemaSettings; + + fn compare_schemas(name: &str, actual: &str, expected_file: &str) { + let expected = match std::fs::read_to_string(expected_file) { + Ok(content) => content, + Err(e) => { + panic!( + "Failed to read expected schema file {}: {}", + expected_file, e + ); + } + }; + + let actual_json: serde_json::Value = + serde_json::from_str(actual).expect("Failed to parse actual schema as JSON"); + let expected_json: serde_json::Value = + serde_json::from_str(&expected).expect("Failed to parse expected schema as JSON"); + + if actual_json == expected_json { + println!("{} schema matches expected", name); + return; + } + + // Write current schema to file for comparison + let current_file = expected_file.replace(".json", "_current.json"); + std::fs::write(¤t_file, actual).expect("Failed to write current schema"); + + println!("{} schema differs from expected", name); + println!("Expected: {}", expected_file); + println!("Current: {}", current_file); + println!( + "Run 'diff {} {}' to see differences", + expected_file, current_file + ); + + // UPDATE_SCHEMA=1 cargo test -p rmcp --test test_message_schema --features="server client schemars" + if std::env::var("UPDATE_SCHEMA").is_ok() { + println!("UPDATE_SCHEMA is set, updating expected file"); + std::fs::write(expected_file, actual).expect("Failed to update expected schema file"); + println!("Updated {}", expected_file); + } else { + println!("Set UPDATE_SCHEMA=1 to auto-update expected schemas"); + panic!("Schema validation failed"); + } + } + + #[test] + fn test_client_json_rpc_message_schema() { + let settings = SchemaSettings::draft07(); + let schema = settings + .into_generator() + .into_root_schema_for::(); + let schema_str = serde_json::to_string_pretty(&schema).expect("Failed to serialize schema"); + + compare_schemas( + "ClientJsonRpcMessage", + &schema_str, + "tests/test_message_schema/client_json_rpc_message_schema.json", + ); + } + + #[test] + fn test_server_json_rpc_message_schema() { + let settings = SchemaSettings::draft07(); + let schema = settings + .into_generator() + .into_root_schema_for::(); + let schema_str = serde_json::to_string_pretty(&schema).expect("Failed to serialize schema"); + + compare_schemas( + "ServerJsonRpcMessage", + &schema_str, + "tests/test_message_schema/server_json_rpc_message_schema.json", + ); + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_message_schema/client_json_rpc_message_schema.json b/code-rs/third_party/rmcp-0.8.3/tests/test_message_schema/client_json_rpc_message_schema.json new file mode 100644 index 00000000000..8fa565c0b17 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_message_schema/client_json_rpc_message_schema.json @@ -0,0 +1,1509 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "JsonRpcMessage", + "description": "Represents any JSON-RPC message that can be sent or received.\n\nThis enum covers all possible message types in the JSON-RPC protocol:\nindividual requests/responses, notifications, and errors.\nIt serves as the top-level message container for MCP communication.", + "anyOf": [ + { + "description": "A single request expecting a response", + "allOf": [ + { + "$ref": "#/definitions/JsonRpcRequest" + } + ] + }, + { + "description": "A response to a previous request", + "allOf": [ + { + "$ref": "#/definitions/JsonRpcResponse" + } + ] + }, + { + "description": "A one-way notification (no response expected)", + "allOf": [ + { + "$ref": "#/definitions/JsonRpcNotification" + } + ] + }, + { + "description": "An error response", + "allOf": [ + { + "$ref": "#/definitions/JsonRpcError" + } + ] + } + ], + "definitions": { + "Annotated": { + "type": "object", + "properties": { + "annotations": { + "anyOf": [ + { + "$ref": "#/definitions/Annotations" + }, + { + "type": "null" + } + ] + } + }, + "oneOf": [ + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "text" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawTextContent" + } + ], + "required": [ + "type" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "image" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawImageContent" + } + ], + "required": [ + "type" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "resource" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawEmbeddedResource" + } + ], + "required": [ + "type" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "audio" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawAudioContent" + } + ], + "required": [ + "type" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "resource_link" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawResource" + } + ], + "required": [ + "type" + ] + } + ] + }, + "Annotations": { + "type": "object", + "properties": { + "audience": { + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/Role" + } + }, + "lastModified": { + "type": [ + "string", + "null" + ], + "format": "date-time" + }, + "priority": { + "type": [ + "number", + "null" + ], + "format": "float" + } + } + }, + "ArgumentInfo": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "value": { + "type": "string" + } + }, + "required": [ + "name", + "value" + ] + }, + "CallToolRequestMethod": { + "type": "string", + "format": "const", + "const": "tools/call" + }, + "CallToolRequestParam": { + "description": "Parameters for calling a tool provided by an MCP server.\n\nContains the tool name and optional arguments needed to execute\nthe tool operation.", + "type": "object", + "properties": { + "arguments": { + "description": "Arguments to pass to the tool (must match the tool's input schema)", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "name": { + "description": "The name of the tool to call", + "type": "string" + } + }, + "required": [ + "name" + ] + }, + "CancelledNotificationMethod": { + "type": "string", + "format": "const", + "const": "notifications/cancelled" + }, + "CancelledNotificationParam": { + "type": "object", + "properties": { + "reason": { + "type": [ + "string", + "null" + ] + }, + "requestId": { + "$ref": "#/definitions/NumberOrString" + } + }, + "required": [ + "requestId" + ] + }, + "ClientCapabilities": { + "title": "Builder", + "description": "```rust\n# use rmcp::model::ClientCapabilities;\nlet cap = ClientCapabilities::builder()\n .enable_experimental()\n .enable_roots()\n .enable_roots_list_changed()\n .build();\n```", + "type": "object", + "properties": { + "elicitation": { + "description": "Capability to handle elicitation requests from servers for interactive user input", + "anyOf": [ + { + "$ref": "#/definitions/ElicitationCapability" + }, + { + "type": "null" + } + ] + }, + "experimental": { + "type": [ + "object", + "null" + ], + "additionalProperties": { + "type": "object", + "additionalProperties": true + } + }, + "roots": { + "anyOf": [ + { + "$ref": "#/definitions/RootsCapabilities" + }, + { + "type": "null" + } + ] + }, + "sampling": { + "type": [ + "object", + "null" + ], + "additionalProperties": true + } + } + }, + "ClientResult": { + "anyOf": [ + { + "$ref": "#/definitions/CreateMessageResult" + }, + { + "$ref": "#/definitions/ListRootsResult" + }, + { + "$ref": "#/definitions/CreateElicitationResult" + }, + { + "$ref": "#/definitions/EmptyObject" + } + ] + }, + "CompleteRequestMethod": { + "type": "string", + "format": "const", + "const": "completion/complete" + }, + "CompleteRequestParam": { + "type": "object", + "properties": { + "argument": { + "$ref": "#/definitions/ArgumentInfo" + }, + "context": { + "description": "Optional context containing previously resolved argument values", + "anyOf": [ + { + "$ref": "#/definitions/CompletionContext" + }, + { + "type": "null" + } + ] + }, + "ref": { + "$ref": "#/definitions/Reference" + } + }, + "required": [ + "ref", + "argument" + ] + }, + "CompletionContext": { + "description": "Context for completion requests providing previously resolved arguments.\n\nThis enables context-aware completion where subsequent argument completions\ncan take into account the values of previously resolved arguments.", + "type": "object", + "properties": { + "arguments": { + "description": "Previously resolved argument values that can inform completion suggestions", + "type": [ + "object", + "null" + ], + "additionalProperties": { + "type": "string" + } + } + } + }, + "CreateElicitationResult": { + "description": "The result returned by a client in response to an elicitation request.\n\nContains the user's decision (accept/decline/cancel) and optionally their input data\nif they chose to accept the request.", + "type": "object", + "properties": { + "action": { + "description": "The user's decision on how to handle the elicitation request", + "allOf": [ + { + "$ref": "#/definitions/ElicitationAction" + } + ] + }, + "content": { + "description": "The actual data provided by the user, if they accepted the request.\nMust conform to the JSON schema specified in the original request.\nOnly present when action is Accept." + } + }, + "required": [ + "action" + ] + }, + "CreateMessageResult": { + "description": "The result of a sampling/createMessage request containing the generated response.\n\nThis structure contains the generated message along with metadata about\nhow the generation was performed and why it stopped.", + "type": "object", + "properties": { + "content": { + "description": "The actual content of the message (text, image, etc.)", + "allOf": [ + { + "$ref": "#/definitions/Annotated" + } + ] + }, + "model": { + "description": "The identifier of the model that generated the response", + "type": "string" + }, + "role": { + "description": "The role of the message sender (User or Assistant)", + "allOf": [ + { + "$ref": "#/definitions/Role" + } + ] + }, + "stopReason": { + "description": "The reason why generation stopped (e.g., \"endTurn\", \"maxTokens\")", + "type": [ + "string", + "null" + ] + } + }, + "required": [ + "model", + "role", + "content" + ] + }, + "ElicitationAction": { + "description": "Represents the possible actions a user can take in response to an elicitation request.\n\nWhen a server requests user input through elicitation, the user can:\n- Accept: Provide the requested information and continue\n- Decline: Refuse to provide the information but continue the operation\n- Cancel: Stop the entire operation", + "oneOf": [ + { + "description": "User accepts the request and provides the requested information", + "type": "string", + "const": "accept" + }, + { + "description": "User declines to provide the information but allows the operation to continue", + "type": "string", + "const": "decline" + }, + { + "description": "User cancels the entire operation", + "type": "string", + "const": "cancel" + } + ] + }, + "ElicitationCapability": { + "description": "Capability for handling elicitation requests from servers.\n\nElicitation allows servers to request interactive input from users during tool execution.\nThis capability indicates that a client can handle elicitation requests and present\nappropriate UI to users for collecting the requested information.", + "type": "object", + "properties": { + "schemaValidation": { + "description": "Whether the client supports JSON Schema validation for elicitation responses.\nWhen true, the client will validate user input against the requested_schema\nbefore sending the response back to the server.", + "type": [ + "boolean", + "null" + ] + } + } + }, + "EmptyObject": { + "description": "This is commonly used for representing empty objects in MCP messages.\n\nwithout returning any specific data.", + "type": "object" + }, + "ErrorCode": { + "description": "Standard JSON-RPC error codes used throughout the MCP protocol.\n\nThese codes follow the JSON-RPC 2.0 specification and provide\nstandardized error reporting across all MCP implementations.", + "type": "integer", + "format": "int32" + }, + "ErrorData": { + "description": "Error information for JSON-RPC error responses.\n\nThis structure follows the JSON-RPC 2.0 specification for error reporting,\nproviding a standardized way to communicate errors between clients and servers.", + "type": "object", + "properties": { + "code": { + "description": "The error type that occurred (using standard JSON-RPC error codes)", + "allOf": [ + { + "$ref": "#/definitions/ErrorCode" + } + ] + }, + "data": { + "description": "Additional information about the error. The value of this member is defined by the\nsender (e.g. detailed error information, nested errors etc.)." + }, + "message": { + "description": "A short description of the error. The message SHOULD be limited to a concise single sentence.", + "type": "string" + } + }, + "required": [ + "code", + "message" + ] + }, + "GetPromptRequestMethod": { + "type": "string", + "format": "const", + "const": "prompts/get" + }, + "GetPromptRequestParam": { + "description": "Parameters for retrieving a specific prompt", + "type": "object", + "properties": { + "arguments": { + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "name": { + "type": "string" + } + }, + "required": [ + "name" + ] + }, + "Icon": { + "description": "A URL pointing to an icon resource or a base64-encoded data URI.\n\nClients that support rendering icons MUST support at least the following MIME types:\n- image/png - PNG images (safe, universal compatibility)\n- image/jpeg (and image/jpg) - JPEG images (safe, universal compatibility)\n\nClients that support rendering icons SHOULD also support:\n- image/svg+xml - SVG images (scalable but requires security precautions)\n- image/webp - WebP images (modern, efficient format)", + "type": "object", + "properties": { + "mimeType": { + "description": "Optional override if the server's MIME type is missing or generic", + "type": [ + "string", + "null" + ] + }, + "sizes": { + "description": "Size specification, each string should be in WxH format (e.g., `\\\"48x48\\\"`, `\\\"96x96\\\"`) or `\\\"any\\\"` for scalable formats like SVG", + "type": [ + "array", + "null" + ], + "items": { + "type": "string" + } + }, + "src": { + "description": "A standard URI pointing to an icon resource", + "type": "string" + } + }, + "required": [ + "src" + ] + }, + "Implementation": { + "type": "object", + "properties": { + "icons": { + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/Icon" + } + }, + "name": { + "type": "string" + }, + "title": { + "type": [ + "string", + "null" + ] + }, + "version": { + "type": "string" + }, + "websiteUrl": { + "type": [ + "string", + "null" + ] + } + }, + "required": [ + "name", + "version" + ] + }, + "InitializeRequestParam": { + "description": "Parameters sent by a client when initializing a connection to an MCP server.\n\nThis contains the client's protocol version, capabilities, and implementation\ninformation, allowing the server to understand what the client supports.", + "type": "object", + "properties": { + "capabilities": { + "description": "The capabilities this client supports (sampling, roots, etc.)", + "allOf": [ + { + "$ref": "#/definitions/ClientCapabilities" + } + ] + }, + "clientInfo": { + "description": "Information about the client implementation", + "allOf": [ + { + "$ref": "#/definitions/Implementation" + } + ] + }, + "protocolVersion": { + "description": "The MCP protocol version this client supports", + "allOf": [ + { + "$ref": "#/definitions/ProtocolVersion" + } + ] + } + }, + "required": [ + "protocolVersion", + "capabilities", + "clientInfo" + ] + }, + "InitializeResultMethod": { + "type": "string", + "format": "const", + "const": "initialize" + }, + "InitializedNotificationMethod": { + "type": "string", + "format": "const", + "const": "notifications/initialized" + }, + "JsonRpcError": { + "type": "object", + "properties": { + "error": { + "$ref": "#/definitions/ErrorData" + }, + "id": { + "$ref": "#/definitions/NumberOrString" + }, + "jsonrpc": { + "$ref": "#/definitions/JsonRpcVersion2_0" + } + }, + "required": [ + "jsonrpc", + "id", + "error" + ] + }, + "JsonRpcNotification": { + "type": "object", + "properties": { + "jsonrpc": { + "$ref": "#/definitions/JsonRpcVersion2_0" + } + }, + "anyOf": [ + { + "$ref": "#/definitions/Notification" + }, + { + "$ref": "#/definitions/Notification2" + }, + { + "$ref": "#/definitions/NotificationNoParam" + }, + { + "$ref": "#/definitions/NotificationNoParam2" + } + ], + "required": [ + "jsonrpc" + ] + }, + "JsonRpcRequest": { + "type": "object", + "properties": { + "id": { + "$ref": "#/definitions/NumberOrString" + }, + "jsonrpc": { + "$ref": "#/definitions/JsonRpcVersion2_0" + } + }, + "anyOf": [ + { + "$ref": "#/definitions/RequestNoParam" + }, + { + "$ref": "#/definitions/Request" + }, + { + "$ref": "#/definitions/Request2" + }, + { + "$ref": "#/definitions/Request3" + }, + { + "$ref": "#/definitions/Request4" + }, + { + "$ref": "#/definitions/RequestOptionalParam" + }, + { + "$ref": "#/definitions/RequestOptionalParam2" + }, + { + "$ref": "#/definitions/RequestOptionalParam3" + }, + { + "$ref": "#/definitions/Request5" + }, + { + "$ref": "#/definitions/Request6" + }, + { + "$ref": "#/definitions/Request7" + }, + { + "$ref": "#/definitions/Request8" + }, + { + "$ref": "#/definitions/RequestOptionalParam4" + } + ], + "required": [ + "jsonrpc", + "id" + ] + }, + "JsonRpcResponse": { + "type": "object", + "properties": { + "id": { + "$ref": "#/definitions/NumberOrString" + }, + "jsonrpc": { + "$ref": "#/definitions/JsonRpcVersion2_0" + }, + "result": { + "$ref": "#/definitions/ClientResult" + } + }, + "required": [ + "jsonrpc", + "id", + "result" + ] + }, + "JsonRpcVersion2_0": { + "type": "string", + "format": "const", + "const": "2.0" + }, + "ListPromptsRequestMethod": { + "type": "string", + "format": "const", + "const": "prompts/list" + }, + "ListResourceTemplatesRequestMethod": { + "type": "string", + "format": "const", + "const": "resources/templates/list" + }, + "ListResourcesRequestMethod": { + "type": "string", + "format": "const", + "const": "resources/list" + }, + "ListRootsResult": { + "type": "object", + "properties": { + "roots": { + "type": "array", + "items": { + "$ref": "#/definitions/Root" + } + } + }, + "required": [ + "roots" + ] + }, + "ListToolsRequestMethod": { + "type": "string", + "format": "const", + "const": "tools/list" + }, + "LoggingLevel": { + "description": "Logging levels supported by the MCP protocol", + "type": "string", + "enum": [ + "debug", + "info", + "notice", + "warning", + "error", + "critical", + "alert", + "emergency" + ] + }, + "Notification": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/CancelledNotificationMethod" + }, + "params": { + "$ref": "#/definitions/CancelledNotificationParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "Notification2": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/ProgressNotificationMethod" + }, + "params": { + "$ref": "#/definitions/ProgressNotificationParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "NotificationNoParam": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/InitializedNotificationMethod" + } + }, + "required": [ + "method" + ] + }, + "NotificationNoParam2": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/RootsListChangedNotificationMethod" + } + }, + "required": [ + "method" + ] + }, + "NumberOrString": { + "oneOf": [ + { + "type": "number" + }, + { + "type": "string" + } + ] + }, + "PaginatedRequestParam": { + "type": "object", + "properties": { + "cursor": { + "type": [ + "string", + "null" + ] + } + } + }, + "PingRequestMethod": { + "type": "string", + "format": "const", + "const": "ping" + }, + "ProgressNotificationMethod": { + "type": "string", + "format": "const", + "const": "notifications/progress" + }, + "ProgressNotificationParam": { + "type": "object", + "properties": { + "message": { + "description": "An optional message describing the current progress.", + "type": [ + "string", + "null" + ] + }, + "progress": { + "description": "The progress thus far. This should increase every time progress is made, even if the total is unknown.", + "type": "number", + "format": "double" + }, + "progressToken": { + "$ref": "#/definitions/ProgressToken" + }, + "total": { + "description": "Total number of items to process (or total progress required), if known", + "type": [ + "number", + "null" + ], + "format": "double" + } + }, + "required": [ + "progressToken", + "progress" + ] + }, + "ProgressToken": { + "description": "A token used to track the progress of long-running operations.\n\nProgress tokens allow clients and servers to associate progress notifications\nwith specific requests, enabling real-time updates on operation status.", + "allOf": [ + { + "$ref": "#/definitions/NumberOrString" + } + ] + }, + "PromptReference": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "title": { + "type": [ + "string", + "null" + ] + } + }, + "required": [ + "name" + ] + }, + "ProtocolVersion": { + "description": "Represents the MCP protocol version used for communication.\n\nThis ensures compatibility between clients and servers by specifying\nwhich version of the Model Context Protocol is being used.", + "type": "string" + }, + "RawAudioContent": { + "type": "object", + "properties": { + "data": { + "type": "string" + }, + "mimeType": { + "type": "string" + } + }, + "required": [ + "data", + "mimeType" + ] + }, + "RawEmbeddedResource": { + "type": "object", + "properties": { + "_meta": { + "description": "Optional protocol-level metadata for this content block", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "resource": { + "$ref": "#/definitions/ResourceContents" + } + }, + "required": [ + "resource" + ] + }, + "RawImageContent": { + "type": "object", + "properties": { + "_meta": { + "description": "Optional protocol-level metadata for this content block", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "data": { + "description": "The base64-encoded image", + "type": "string" + }, + "mimeType": { + "type": "string" + } + }, + "required": [ + "data", + "mimeType" + ] + }, + "RawResource": { + "description": "Represents a resource in the extension with metadata", + "type": "object", + "properties": { + "description": { + "description": "Optional description of the resource", + "type": [ + "string", + "null" + ] + }, + "icons": { + "description": "Optional list of icons for the resource", + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/Icon" + } + }, + "mimeType": { + "description": "MIME type of the resource content (\"text\" or \"blob\")", + "type": [ + "string", + "null" + ] + }, + "name": { + "description": "Name of the resource", + "type": "string" + }, + "size": { + "description": "The size of the raw resource content, in bytes (i.e., before base64 encoding or any tokenization), if known.\n\nThis can be used by Hosts to display file sizes and estimate context window us", + "type": [ + "integer", + "null" + ], + "format": "uint32", + "minimum": 0 + }, + "title": { + "description": "Human-readable title of the resource", + "type": [ + "string", + "null" + ] + }, + "uri": { + "description": "URI representing the resource location (e.g., \"file:///path/to/file\" or \"str:///content\")", + "type": "string" + } + }, + "required": [ + "uri", + "name" + ] + }, + "RawTextContent": { + "type": "object", + "properties": { + "_meta": { + "description": "Optional protocol-level metadata for this content block", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "text": { + "type": "string" + } + }, + "required": [ + "text" + ] + }, + "ReadResourceRequestMethod": { + "type": "string", + "format": "const", + "const": "resources/read" + }, + "ReadResourceRequestParam": { + "description": "Parameters for reading a specific resource", + "type": "object", + "properties": { + "uri": { + "description": "The URI of the resource to read", + "type": "string" + } + }, + "required": [ + "uri" + ] + }, + "Reference": { + "oneOf": [ + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "ref/resource" + } + }, + "allOf": [ + { + "$ref": "#/definitions/ResourceReference" + } + ], + "required": [ + "type" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "ref/prompt" + } + }, + "allOf": [ + { + "$ref": "#/definitions/PromptReference" + } + ], + "required": [ + "type" + ] + } + ] + }, + "Request": { + "description": "Represents a JSON-RPC request with method, parameters, and extensions.\n\nThis is the core structure for all MCP requests, containing:\n- `method`: The name of the method being called\n- `params`: The parameters for the method\n- `extensions`: Additional context data (similar to HTTP headers)", + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/InitializeResultMethod" + }, + "params": { + "$ref": "#/definitions/InitializeRequestParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "Request2": { + "description": "Represents a JSON-RPC request with method, parameters, and extensions.\n\nThis is the core structure for all MCP requests, containing:\n- `method`: The name of the method being called\n- `params`: The parameters for the method\n- `extensions`: Additional context data (similar to HTTP headers)", + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/CompleteRequestMethod" + }, + "params": { + "$ref": "#/definitions/CompleteRequestParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "Request3": { + "description": "Represents a JSON-RPC request with method, parameters, and extensions.\n\nThis is the core structure for all MCP requests, containing:\n- `method`: The name of the method being called\n- `params`: The parameters for the method\n- `extensions`: Additional context data (similar to HTTP headers)", + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/SetLevelRequestMethod" + }, + "params": { + "$ref": "#/definitions/SetLevelRequestParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "Request4": { + "description": "Represents a JSON-RPC request with method, parameters, and extensions.\n\nThis is the core structure for all MCP requests, containing:\n- `method`: The name of the method being called\n- `params`: The parameters for the method\n- `extensions`: Additional context data (similar to HTTP headers)", + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/GetPromptRequestMethod" + }, + "params": { + "$ref": "#/definitions/GetPromptRequestParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "Request5": { + "description": "Represents a JSON-RPC request with method, parameters, and extensions.\n\nThis is the core structure for all MCP requests, containing:\n- `method`: The name of the method being called\n- `params`: The parameters for the method\n- `extensions`: Additional context data (similar to HTTP headers)", + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/ReadResourceRequestMethod" + }, + "params": { + "$ref": "#/definitions/ReadResourceRequestParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "Request6": { + "description": "Represents a JSON-RPC request with method, parameters, and extensions.\n\nThis is the core structure for all MCP requests, containing:\n- `method`: The name of the method being called\n- `params`: The parameters for the method\n- `extensions`: Additional context data (similar to HTTP headers)", + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/SubscribeRequestMethod" + }, + "params": { + "$ref": "#/definitions/SubscribeRequestParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "Request7": { + "description": "Represents a JSON-RPC request with method, parameters, and extensions.\n\nThis is the core structure for all MCP requests, containing:\n- `method`: The name of the method being called\n- `params`: The parameters for the method\n- `extensions`: Additional context data (similar to HTTP headers)", + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/UnsubscribeRequestMethod" + }, + "params": { + "$ref": "#/definitions/UnsubscribeRequestParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "Request8": { + "description": "Represents a JSON-RPC request with method, parameters, and extensions.\n\nThis is the core structure for all MCP requests, containing:\n- `method`: The name of the method being called\n- `params`: The parameters for the method\n- `extensions`: Additional context data (similar to HTTP headers)", + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/CallToolRequestMethod" + }, + "params": { + "$ref": "#/definitions/CallToolRequestParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "RequestNoParam": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/PingRequestMethod" + } + }, + "required": [ + "method" + ] + }, + "RequestOptionalParam": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/ListPromptsRequestMethod" + }, + "params": { + "anyOf": [ + { + "$ref": "#/definitions/PaginatedRequestParam" + }, + { + "type": "null" + } + ] + } + }, + "required": [ + "method" + ] + }, + "RequestOptionalParam2": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/ListResourcesRequestMethod" + }, + "params": { + "anyOf": [ + { + "$ref": "#/definitions/PaginatedRequestParam" + }, + { + "type": "null" + } + ] + } + }, + "required": [ + "method" + ] + }, + "RequestOptionalParam3": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/ListResourceTemplatesRequestMethod" + }, + "params": { + "anyOf": [ + { + "$ref": "#/definitions/PaginatedRequestParam" + }, + { + "type": "null" + } + ] + } + }, + "required": [ + "method" + ] + }, + "RequestOptionalParam4": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/ListToolsRequestMethod" + }, + "params": { + "anyOf": [ + { + "$ref": "#/definitions/PaginatedRequestParam" + }, + { + "type": "null" + } + ] + } + }, + "required": [ + "method" + ] + }, + "ResourceContents": { + "anyOf": [ + { + "type": "object", + "properties": { + "_meta": { + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "mimeType": { + "type": [ + "string", + "null" + ] + }, + "text": { + "type": "string" + }, + "uri": { + "type": "string" + } + }, + "required": [ + "uri", + "text" + ] + }, + { + "type": "object", + "properties": { + "_meta": { + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "blob": { + "type": "string" + }, + "mimeType": { + "type": [ + "string", + "null" + ] + }, + "uri": { + "type": "string" + } + }, + "required": [ + "uri", + "blob" + ] + } + ] + }, + "ResourceReference": { + "type": "object", + "properties": { + "uri": { + "type": "string" + } + }, + "required": [ + "uri" + ] + }, + "Role": { + "description": "Represents the role of a participant in a conversation or message exchange.\n\nUsed in sampling and chat contexts to distinguish between different\ntypes of message senders in the conversation flow.", + "oneOf": [ + { + "description": "A human user or client making a request", + "type": "string", + "const": "user" + }, + { + "description": "An AI assistant or server providing a response", + "type": "string", + "const": "assistant" + } + ] + }, + "Root": { + "type": "object", + "properties": { + "name": { + "type": [ + "string", + "null" + ] + }, + "uri": { + "type": "string" + } + }, + "required": [ + "uri" + ] + }, + "RootsCapabilities": { + "type": "object", + "properties": { + "listChanged": { + "type": [ + "boolean", + "null" + ] + } + } + }, + "RootsListChangedNotificationMethod": { + "type": "string", + "format": "const", + "const": "notifications/roots/list_changed" + }, + "SetLevelRequestMethod": { + "type": "string", + "format": "const", + "const": "logging/setLevel" + }, + "SetLevelRequestParam": { + "description": "Parameters for setting the logging level", + "type": "object", + "properties": { + "level": { + "description": "The desired logging level", + "allOf": [ + { + "$ref": "#/definitions/LoggingLevel" + } + ] + } + }, + "required": [ + "level" + ] + }, + "SubscribeRequestMethod": { + "type": "string", + "format": "const", + "const": "resources/subscribe" + }, + "SubscribeRequestParam": { + "description": "Parameters for subscribing to resource updates", + "type": "object", + "properties": { + "uri": { + "description": "The URI of the resource to subscribe to", + "type": "string" + } + }, + "required": [ + "uri" + ] + }, + "UnsubscribeRequestMethod": { + "type": "string", + "format": "const", + "const": "resources/unsubscribe" + }, + "UnsubscribeRequestParam": { + "description": "Parameters for unsubscribing from resource updates", + "type": "object", + "properties": { + "uri": { + "description": "The URI of the resource to unsubscribe from", + "type": "string" + } + }, + "required": [ + "uri" + ] + } + } +} \ No newline at end of file diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_message_schema/client_json_rpc_message_schema_current.json b/code-rs/third_party/rmcp-0.8.3/tests/test_message_schema/client_json_rpc_message_schema_current.json new file mode 100644 index 00000000000..8fa565c0b17 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_message_schema/client_json_rpc_message_schema_current.json @@ -0,0 +1,1509 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "JsonRpcMessage", + "description": "Represents any JSON-RPC message that can be sent or received.\n\nThis enum covers all possible message types in the JSON-RPC protocol:\nindividual requests/responses, notifications, and errors.\nIt serves as the top-level message container for MCP communication.", + "anyOf": [ + { + "description": "A single request expecting a response", + "allOf": [ + { + "$ref": "#/definitions/JsonRpcRequest" + } + ] + }, + { + "description": "A response to a previous request", + "allOf": [ + { + "$ref": "#/definitions/JsonRpcResponse" + } + ] + }, + { + "description": "A one-way notification (no response expected)", + "allOf": [ + { + "$ref": "#/definitions/JsonRpcNotification" + } + ] + }, + { + "description": "An error response", + "allOf": [ + { + "$ref": "#/definitions/JsonRpcError" + } + ] + } + ], + "definitions": { + "Annotated": { + "type": "object", + "properties": { + "annotations": { + "anyOf": [ + { + "$ref": "#/definitions/Annotations" + }, + { + "type": "null" + } + ] + } + }, + "oneOf": [ + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "text" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawTextContent" + } + ], + "required": [ + "type" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "image" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawImageContent" + } + ], + "required": [ + "type" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "resource" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawEmbeddedResource" + } + ], + "required": [ + "type" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "audio" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawAudioContent" + } + ], + "required": [ + "type" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "resource_link" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawResource" + } + ], + "required": [ + "type" + ] + } + ] + }, + "Annotations": { + "type": "object", + "properties": { + "audience": { + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/Role" + } + }, + "lastModified": { + "type": [ + "string", + "null" + ], + "format": "date-time" + }, + "priority": { + "type": [ + "number", + "null" + ], + "format": "float" + } + } + }, + "ArgumentInfo": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "value": { + "type": "string" + } + }, + "required": [ + "name", + "value" + ] + }, + "CallToolRequestMethod": { + "type": "string", + "format": "const", + "const": "tools/call" + }, + "CallToolRequestParam": { + "description": "Parameters for calling a tool provided by an MCP server.\n\nContains the tool name and optional arguments needed to execute\nthe tool operation.", + "type": "object", + "properties": { + "arguments": { + "description": "Arguments to pass to the tool (must match the tool's input schema)", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "name": { + "description": "The name of the tool to call", + "type": "string" + } + }, + "required": [ + "name" + ] + }, + "CancelledNotificationMethod": { + "type": "string", + "format": "const", + "const": "notifications/cancelled" + }, + "CancelledNotificationParam": { + "type": "object", + "properties": { + "reason": { + "type": [ + "string", + "null" + ] + }, + "requestId": { + "$ref": "#/definitions/NumberOrString" + } + }, + "required": [ + "requestId" + ] + }, + "ClientCapabilities": { + "title": "Builder", + "description": "```rust\n# use rmcp::model::ClientCapabilities;\nlet cap = ClientCapabilities::builder()\n .enable_experimental()\n .enable_roots()\n .enable_roots_list_changed()\n .build();\n```", + "type": "object", + "properties": { + "elicitation": { + "description": "Capability to handle elicitation requests from servers for interactive user input", + "anyOf": [ + { + "$ref": "#/definitions/ElicitationCapability" + }, + { + "type": "null" + } + ] + }, + "experimental": { + "type": [ + "object", + "null" + ], + "additionalProperties": { + "type": "object", + "additionalProperties": true + } + }, + "roots": { + "anyOf": [ + { + "$ref": "#/definitions/RootsCapabilities" + }, + { + "type": "null" + } + ] + }, + "sampling": { + "type": [ + "object", + "null" + ], + "additionalProperties": true + } + } + }, + "ClientResult": { + "anyOf": [ + { + "$ref": "#/definitions/CreateMessageResult" + }, + { + "$ref": "#/definitions/ListRootsResult" + }, + { + "$ref": "#/definitions/CreateElicitationResult" + }, + { + "$ref": "#/definitions/EmptyObject" + } + ] + }, + "CompleteRequestMethod": { + "type": "string", + "format": "const", + "const": "completion/complete" + }, + "CompleteRequestParam": { + "type": "object", + "properties": { + "argument": { + "$ref": "#/definitions/ArgumentInfo" + }, + "context": { + "description": "Optional context containing previously resolved argument values", + "anyOf": [ + { + "$ref": "#/definitions/CompletionContext" + }, + { + "type": "null" + } + ] + }, + "ref": { + "$ref": "#/definitions/Reference" + } + }, + "required": [ + "ref", + "argument" + ] + }, + "CompletionContext": { + "description": "Context for completion requests providing previously resolved arguments.\n\nThis enables context-aware completion where subsequent argument completions\ncan take into account the values of previously resolved arguments.", + "type": "object", + "properties": { + "arguments": { + "description": "Previously resolved argument values that can inform completion suggestions", + "type": [ + "object", + "null" + ], + "additionalProperties": { + "type": "string" + } + } + } + }, + "CreateElicitationResult": { + "description": "The result returned by a client in response to an elicitation request.\n\nContains the user's decision (accept/decline/cancel) and optionally their input data\nif they chose to accept the request.", + "type": "object", + "properties": { + "action": { + "description": "The user's decision on how to handle the elicitation request", + "allOf": [ + { + "$ref": "#/definitions/ElicitationAction" + } + ] + }, + "content": { + "description": "The actual data provided by the user, if they accepted the request.\nMust conform to the JSON schema specified in the original request.\nOnly present when action is Accept." + } + }, + "required": [ + "action" + ] + }, + "CreateMessageResult": { + "description": "The result of a sampling/createMessage request containing the generated response.\n\nThis structure contains the generated message along with metadata about\nhow the generation was performed and why it stopped.", + "type": "object", + "properties": { + "content": { + "description": "The actual content of the message (text, image, etc.)", + "allOf": [ + { + "$ref": "#/definitions/Annotated" + } + ] + }, + "model": { + "description": "The identifier of the model that generated the response", + "type": "string" + }, + "role": { + "description": "The role of the message sender (User or Assistant)", + "allOf": [ + { + "$ref": "#/definitions/Role" + } + ] + }, + "stopReason": { + "description": "The reason why generation stopped (e.g., \"endTurn\", \"maxTokens\")", + "type": [ + "string", + "null" + ] + } + }, + "required": [ + "model", + "role", + "content" + ] + }, + "ElicitationAction": { + "description": "Represents the possible actions a user can take in response to an elicitation request.\n\nWhen a server requests user input through elicitation, the user can:\n- Accept: Provide the requested information and continue\n- Decline: Refuse to provide the information but continue the operation\n- Cancel: Stop the entire operation", + "oneOf": [ + { + "description": "User accepts the request and provides the requested information", + "type": "string", + "const": "accept" + }, + { + "description": "User declines to provide the information but allows the operation to continue", + "type": "string", + "const": "decline" + }, + { + "description": "User cancels the entire operation", + "type": "string", + "const": "cancel" + } + ] + }, + "ElicitationCapability": { + "description": "Capability for handling elicitation requests from servers.\n\nElicitation allows servers to request interactive input from users during tool execution.\nThis capability indicates that a client can handle elicitation requests and present\nappropriate UI to users for collecting the requested information.", + "type": "object", + "properties": { + "schemaValidation": { + "description": "Whether the client supports JSON Schema validation for elicitation responses.\nWhen true, the client will validate user input against the requested_schema\nbefore sending the response back to the server.", + "type": [ + "boolean", + "null" + ] + } + } + }, + "EmptyObject": { + "description": "This is commonly used for representing empty objects in MCP messages.\n\nwithout returning any specific data.", + "type": "object" + }, + "ErrorCode": { + "description": "Standard JSON-RPC error codes used throughout the MCP protocol.\n\nThese codes follow the JSON-RPC 2.0 specification and provide\nstandardized error reporting across all MCP implementations.", + "type": "integer", + "format": "int32" + }, + "ErrorData": { + "description": "Error information for JSON-RPC error responses.\n\nThis structure follows the JSON-RPC 2.0 specification for error reporting,\nproviding a standardized way to communicate errors between clients and servers.", + "type": "object", + "properties": { + "code": { + "description": "The error type that occurred (using standard JSON-RPC error codes)", + "allOf": [ + { + "$ref": "#/definitions/ErrorCode" + } + ] + }, + "data": { + "description": "Additional information about the error. The value of this member is defined by the\nsender (e.g. detailed error information, nested errors etc.)." + }, + "message": { + "description": "A short description of the error. The message SHOULD be limited to a concise single sentence.", + "type": "string" + } + }, + "required": [ + "code", + "message" + ] + }, + "GetPromptRequestMethod": { + "type": "string", + "format": "const", + "const": "prompts/get" + }, + "GetPromptRequestParam": { + "description": "Parameters for retrieving a specific prompt", + "type": "object", + "properties": { + "arguments": { + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "name": { + "type": "string" + } + }, + "required": [ + "name" + ] + }, + "Icon": { + "description": "A URL pointing to an icon resource or a base64-encoded data URI.\n\nClients that support rendering icons MUST support at least the following MIME types:\n- image/png - PNG images (safe, universal compatibility)\n- image/jpeg (and image/jpg) - JPEG images (safe, universal compatibility)\n\nClients that support rendering icons SHOULD also support:\n- image/svg+xml - SVG images (scalable but requires security precautions)\n- image/webp - WebP images (modern, efficient format)", + "type": "object", + "properties": { + "mimeType": { + "description": "Optional override if the server's MIME type is missing or generic", + "type": [ + "string", + "null" + ] + }, + "sizes": { + "description": "Size specification, each string should be in WxH format (e.g., `\\\"48x48\\\"`, `\\\"96x96\\\"`) or `\\\"any\\\"` for scalable formats like SVG", + "type": [ + "array", + "null" + ], + "items": { + "type": "string" + } + }, + "src": { + "description": "A standard URI pointing to an icon resource", + "type": "string" + } + }, + "required": [ + "src" + ] + }, + "Implementation": { + "type": "object", + "properties": { + "icons": { + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/Icon" + } + }, + "name": { + "type": "string" + }, + "title": { + "type": [ + "string", + "null" + ] + }, + "version": { + "type": "string" + }, + "websiteUrl": { + "type": [ + "string", + "null" + ] + } + }, + "required": [ + "name", + "version" + ] + }, + "InitializeRequestParam": { + "description": "Parameters sent by a client when initializing a connection to an MCP server.\n\nThis contains the client's protocol version, capabilities, and implementation\ninformation, allowing the server to understand what the client supports.", + "type": "object", + "properties": { + "capabilities": { + "description": "The capabilities this client supports (sampling, roots, etc.)", + "allOf": [ + { + "$ref": "#/definitions/ClientCapabilities" + } + ] + }, + "clientInfo": { + "description": "Information about the client implementation", + "allOf": [ + { + "$ref": "#/definitions/Implementation" + } + ] + }, + "protocolVersion": { + "description": "The MCP protocol version this client supports", + "allOf": [ + { + "$ref": "#/definitions/ProtocolVersion" + } + ] + } + }, + "required": [ + "protocolVersion", + "capabilities", + "clientInfo" + ] + }, + "InitializeResultMethod": { + "type": "string", + "format": "const", + "const": "initialize" + }, + "InitializedNotificationMethod": { + "type": "string", + "format": "const", + "const": "notifications/initialized" + }, + "JsonRpcError": { + "type": "object", + "properties": { + "error": { + "$ref": "#/definitions/ErrorData" + }, + "id": { + "$ref": "#/definitions/NumberOrString" + }, + "jsonrpc": { + "$ref": "#/definitions/JsonRpcVersion2_0" + } + }, + "required": [ + "jsonrpc", + "id", + "error" + ] + }, + "JsonRpcNotification": { + "type": "object", + "properties": { + "jsonrpc": { + "$ref": "#/definitions/JsonRpcVersion2_0" + } + }, + "anyOf": [ + { + "$ref": "#/definitions/Notification" + }, + { + "$ref": "#/definitions/Notification2" + }, + { + "$ref": "#/definitions/NotificationNoParam" + }, + { + "$ref": "#/definitions/NotificationNoParam2" + } + ], + "required": [ + "jsonrpc" + ] + }, + "JsonRpcRequest": { + "type": "object", + "properties": { + "id": { + "$ref": "#/definitions/NumberOrString" + }, + "jsonrpc": { + "$ref": "#/definitions/JsonRpcVersion2_0" + } + }, + "anyOf": [ + { + "$ref": "#/definitions/RequestNoParam" + }, + { + "$ref": "#/definitions/Request" + }, + { + "$ref": "#/definitions/Request2" + }, + { + "$ref": "#/definitions/Request3" + }, + { + "$ref": "#/definitions/Request4" + }, + { + "$ref": "#/definitions/RequestOptionalParam" + }, + { + "$ref": "#/definitions/RequestOptionalParam2" + }, + { + "$ref": "#/definitions/RequestOptionalParam3" + }, + { + "$ref": "#/definitions/Request5" + }, + { + "$ref": "#/definitions/Request6" + }, + { + "$ref": "#/definitions/Request7" + }, + { + "$ref": "#/definitions/Request8" + }, + { + "$ref": "#/definitions/RequestOptionalParam4" + } + ], + "required": [ + "jsonrpc", + "id" + ] + }, + "JsonRpcResponse": { + "type": "object", + "properties": { + "id": { + "$ref": "#/definitions/NumberOrString" + }, + "jsonrpc": { + "$ref": "#/definitions/JsonRpcVersion2_0" + }, + "result": { + "$ref": "#/definitions/ClientResult" + } + }, + "required": [ + "jsonrpc", + "id", + "result" + ] + }, + "JsonRpcVersion2_0": { + "type": "string", + "format": "const", + "const": "2.0" + }, + "ListPromptsRequestMethod": { + "type": "string", + "format": "const", + "const": "prompts/list" + }, + "ListResourceTemplatesRequestMethod": { + "type": "string", + "format": "const", + "const": "resources/templates/list" + }, + "ListResourcesRequestMethod": { + "type": "string", + "format": "const", + "const": "resources/list" + }, + "ListRootsResult": { + "type": "object", + "properties": { + "roots": { + "type": "array", + "items": { + "$ref": "#/definitions/Root" + } + } + }, + "required": [ + "roots" + ] + }, + "ListToolsRequestMethod": { + "type": "string", + "format": "const", + "const": "tools/list" + }, + "LoggingLevel": { + "description": "Logging levels supported by the MCP protocol", + "type": "string", + "enum": [ + "debug", + "info", + "notice", + "warning", + "error", + "critical", + "alert", + "emergency" + ] + }, + "Notification": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/CancelledNotificationMethod" + }, + "params": { + "$ref": "#/definitions/CancelledNotificationParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "Notification2": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/ProgressNotificationMethod" + }, + "params": { + "$ref": "#/definitions/ProgressNotificationParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "NotificationNoParam": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/InitializedNotificationMethod" + } + }, + "required": [ + "method" + ] + }, + "NotificationNoParam2": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/RootsListChangedNotificationMethod" + } + }, + "required": [ + "method" + ] + }, + "NumberOrString": { + "oneOf": [ + { + "type": "number" + }, + { + "type": "string" + } + ] + }, + "PaginatedRequestParam": { + "type": "object", + "properties": { + "cursor": { + "type": [ + "string", + "null" + ] + } + } + }, + "PingRequestMethod": { + "type": "string", + "format": "const", + "const": "ping" + }, + "ProgressNotificationMethod": { + "type": "string", + "format": "const", + "const": "notifications/progress" + }, + "ProgressNotificationParam": { + "type": "object", + "properties": { + "message": { + "description": "An optional message describing the current progress.", + "type": [ + "string", + "null" + ] + }, + "progress": { + "description": "The progress thus far. This should increase every time progress is made, even if the total is unknown.", + "type": "number", + "format": "double" + }, + "progressToken": { + "$ref": "#/definitions/ProgressToken" + }, + "total": { + "description": "Total number of items to process (or total progress required), if known", + "type": [ + "number", + "null" + ], + "format": "double" + } + }, + "required": [ + "progressToken", + "progress" + ] + }, + "ProgressToken": { + "description": "A token used to track the progress of long-running operations.\n\nProgress tokens allow clients and servers to associate progress notifications\nwith specific requests, enabling real-time updates on operation status.", + "allOf": [ + { + "$ref": "#/definitions/NumberOrString" + } + ] + }, + "PromptReference": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "title": { + "type": [ + "string", + "null" + ] + } + }, + "required": [ + "name" + ] + }, + "ProtocolVersion": { + "description": "Represents the MCP protocol version used for communication.\n\nThis ensures compatibility between clients and servers by specifying\nwhich version of the Model Context Protocol is being used.", + "type": "string" + }, + "RawAudioContent": { + "type": "object", + "properties": { + "data": { + "type": "string" + }, + "mimeType": { + "type": "string" + } + }, + "required": [ + "data", + "mimeType" + ] + }, + "RawEmbeddedResource": { + "type": "object", + "properties": { + "_meta": { + "description": "Optional protocol-level metadata for this content block", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "resource": { + "$ref": "#/definitions/ResourceContents" + } + }, + "required": [ + "resource" + ] + }, + "RawImageContent": { + "type": "object", + "properties": { + "_meta": { + "description": "Optional protocol-level metadata for this content block", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "data": { + "description": "The base64-encoded image", + "type": "string" + }, + "mimeType": { + "type": "string" + } + }, + "required": [ + "data", + "mimeType" + ] + }, + "RawResource": { + "description": "Represents a resource in the extension with metadata", + "type": "object", + "properties": { + "description": { + "description": "Optional description of the resource", + "type": [ + "string", + "null" + ] + }, + "icons": { + "description": "Optional list of icons for the resource", + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/Icon" + } + }, + "mimeType": { + "description": "MIME type of the resource content (\"text\" or \"blob\")", + "type": [ + "string", + "null" + ] + }, + "name": { + "description": "Name of the resource", + "type": "string" + }, + "size": { + "description": "The size of the raw resource content, in bytes (i.e., before base64 encoding or any tokenization), if known.\n\nThis can be used by Hosts to display file sizes and estimate context window us", + "type": [ + "integer", + "null" + ], + "format": "uint32", + "minimum": 0 + }, + "title": { + "description": "Human-readable title of the resource", + "type": [ + "string", + "null" + ] + }, + "uri": { + "description": "URI representing the resource location (e.g., \"file:///path/to/file\" or \"str:///content\")", + "type": "string" + } + }, + "required": [ + "uri", + "name" + ] + }, + "RawTextContent": { + "type": "object", + "properties": { + "_meta": { + "description": "Optional protocol-level metadata for this content block", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "text": { + "type": "string" + } + }, + "required": [ + "text" + ] + }, + "ReadResourceRequestMethod": { + "type": "string", + "format": "const", + "const": "resources/read" + }, + "ReadResourceRequestParam": { + "description": "Parameters for reading a specific resource", + "type": "object", + "properties": { + "uri": { + "description": "The URI of the resource to read", + "type": "string" + } + }, + "required": [ + "uri" + ] + }, + "Reference": { + "oneOf": [ + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "ref/resource" + } + }, + "allOf": [ + { + "$ref": "#/definitions/ResourceReference" + } + ], + "required": [ + "type" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "ref/prompt" + } + }, + "allOf": [ + { + "$ref": "#/definitions/PromptReference" + } + ], + "required": [ + "type" + ] + } + ] + }, + "Request": { + "description": "Represents a JSON-RPC request with method, parameters, and extensions.\n\nThis is the core structure for all MCP requests, containing:\n- `method`: The name of the method being called\n- `params`: The parameters for the method\n- `extensions`: Additional context data (similar to HTTP headers)", + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/InitializeResultMethod" + }, + "params": { + "$ref": "#/definitions/InitializeRequestParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "Request2": { + "description": "Represents a JSON-RPC request with method, parameters, and extensions.\n\nThis is the core structure for all MCP requests, containing:\n- `method`: The name of the method being called\n- `params`: The parameters for the method\n- `extensions`: Additional context data (similar to HTTP headers)", + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/CompleteRequestMethod" + }, + "params": { + "$ref": "#/definitions/CompleteRequestParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "Request3": { + "description": "Represents a JSON-RPC request with method, parameters, and extensions.\n\nThis is the core structure for all MCP requests, containing:\n- `method`: The name of the method being called\n- `params`: The parameters for the method\n- `extensions`: Additional context data (similar to HTTP headers)", + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/SetLevelRequestMethod" + }, + "params": { + "$ref": "#/definitions/SetLevelRequestParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "Request4": { + "description": "Represents a JSON-RPC request with method, parameters, and extensions.\n\nThis is the core structure for all MCP requests, containing:\n- `method`: The name of the method being called\n- `params`: The parameters for the method\n- `extensions`: Additional context data (similar to HTTP headers)", + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/GetPromptRequestMethod" + }, + "params": { + "$ref": "#/definitions/GetPromptRequestParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "Request5": { + "description": "Represents a JSON-RPC request with method, parameters, and extensions.\n\nThis is the core structure for all MCP requests, containing:\n- `method`: The name of the method being called\n- `params`: The parameters for the method\n- `extensions`: Additional context data (similar to HTTP headers)", + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/ReadResourceRequestMethod" + }, + "params": { + "$ref": "#/definitions/ReadResourceRequestParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "Request6": { + "description": "Represents a JSON-RPC request with method, parameters, and extensions.\n\nThis is the core structure for all MCP requests, containing:\n- `method`: The name of the method being called\n- `params`: The parameters for the method\n- `extensions`: Additional context data (similar to HTTP headers)", + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/SubscribeRequestMethod" + }, + "params": { + "$ref": "#/definitions/SubscribeRequestParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "Request7": { + "description": "Represents a JSON-RPC request with method, parameters, and extensions.\n\nThis is the core structure for all MCP requests, containing:\n- `method`: The name of the method being called\n- `params`: The parameters for the method\n- `extensions`: Additional context data (similar to HTTP headers)", + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/UnsubscribeRequestMethod" + }, + "params": { + "$ref": "#/definitions/UnsubscribeRequestParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "Request8": { + "description": "Represents a JSON-RPC request with method, parameters, and extensions.\n\nThis is the core structure for all MCP requests, containing:\n- `method`: The name of the method being called\n- `params`: The parameters for the method\n- `extensions`: Additional context data (similar to HTTP headers)", + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/CallToolRequestMethod" + }, + "params": { + "$ref": "#/definitions/CallToolRequestParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "RequestNoParam": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/PingRequestMethod" + } + }, + "required": [ + "method" + ] + }, + "RequestOptionalParam": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/ListPromptsRequestMethod" + }, + "params": { + "anyOf": [ + { + "$ref": "#/definitions/PaginatedRequestParam" + }, + { + "type": "null" + } + ] + } + }, + "required": [ + "method" + ] + }, + "RequestOptionalParam2": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/ListResourcesRequestMethod" + }, + "params": { + "anyOf": [ + { + "$ref": "#/definitions/PaginatedRequestParam" + }, + { + "type": "null" + } + ] + } + }, + "required": [ + "method" + ] + }, + "RequestOptionalParam3": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/ListResourceTemplatesRequestMethod" + }, + "params": { + "anyOf": [ + { + "$ref": "#/definitions/PaginatedRequestParam" + }, + { + "type": "null" + } + ] + } + }, + "required": [ + "method" + ] + }, + "RequestOptionalParam4": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/ListToolsRequestMethod" + }, + "params": { + "anyOf": [ + { + "$ref": "#/definitions/PaginatedRequestParam" + }, + { + "type": "null" + } + ] + } + }, + "required": [ + "method" + ] + }, + "ResourceContents": { + "anyOf": [ + { + "type": "object", + "properties": { + "_meta": { + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "mimeType": { + "type": [ + "string", + "null" + ] + }, + "text": { + "type": "string" + }, + "uri": { + "type": "string" + } + }, + "required": [ + "uri", + "text" + ] + }, + { + "type": "object", + "properties": { + "_meta": { + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "blob": { + "type": "string" + }, + "mimeType": { + "type": [ + "string", + "null" + ] + }, + "uri": { + "type": "string" + } + }, + "required": [ + "uri", + "blob" + ] + } + ] + }, + "ResourceReference": { + "type": "object", + "properties": { + "uri": { + "type": "string" + } + }, + "required": [ + "uri" + ] + }, + "Role": { + "description": "Represents the role of a participant in a conversation or message exchange.\n\nUsed in sampling and chat contexts to distinguish between different\ntypes of message senders in the conversation flow.", + "oneOf": [ + { + "description": "A human user or client making a request", + "type": "string", + "const": "user" + }, + { + "description": "An AI assistant or server providing a response", + "type": "string", + "const": "assistant" + } + ] + }, + "Root": { + "type": "object", + "properties": { + "name": { + "type": [ + "string", + "null" + ] + }, + "uri": { + "type": "string" + } + }, + "required": [ + "uri" + ] + }, + "RootsCapabilities": { + "type": "object", + "properties": { + "listChanged": { + "type": [ + "boolean", + "null" + ] + } + } + }, + "RootsListChangedNotificationMethod": { + "type": "string", + "format": "const", + "const": "notifications/roots/list_changed" + }, + "SetLevelRequestMethod": { + "type": "string", + "format": "const", + "const": "logging/setLevel" + }, + "SetLevelRequestParam": { + "description": "Parameters for setting the logging level", + "type": "object", + "properties": { + "level": { + "description": "The desired logging level", + "allOf": [ + { + "$ref": "#/definitions/LoggingLevel" + } + ] + } + }, + "required": [ + "level" + ] + }, + "SubscribeRequestMethod": { + "type": "string", + "format": "const", + "const": "resources/subscribe" + }, + "SubscribeRequestParam": { + "description": "Parameters for subscribing to resource updates", + "type": "object", + "properties": { + "uri": { + "description": "The URI of the resource to subscribe to", + "type": "string" + } + }, + "required": [ + "uri" + ] + }, + "UnsubscribeRequestMethod": { + "type": "string", + "format": "const", + "const": "resources/unsubscribe" + }, + "UnsubscribeRequestParam": { + "description": "Parameters for unsubscribing from resource updates", + "type": "object", + "properties": { + "uri": { + "description": "The URI of the resource to unsubscribe from", + "type": "string" + } + }, + "required": [ + "uri" + ] + } + } +} \ No newline at end of file diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_message_schema/server_json_rpc_message_schema.json b/code-rs/third_party/rmcp-0.8.3/tests/test_message_schema/server_json_rpc_message_schema.json new file mode 100644 index 00000000000..663a68941cb --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_message_schema/server_json_rpc_message_schema.json @@ -0,0 +1,2411 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "JsonRpcMessage", + "description": "Represents any JSON-RPC message that can be sent or received.\n\nThis enum covers all possible message types in the JSON-RPC protocol:\nindividual requests/responses, notifications, and errors.\nIt serves as the top-level message container for MCP communication.", + "anyOf": [ + { + "description": "A single request expecting a response", + "allOf": [ + { + "$ref": "#/definitions/JsonRpcRequest" + } + ] + }, + { + "description": "A response to a previous request", + "allOf": [ + { + "$ref": "#/definitions/JsonRpcResponse" + } + ] + }, + { + "description": "A one-way notification (no response expected)", + "allOf": [ + { + "$ref": "#/definitions/JsonRpcNotification" + } + ] + }, + { + "description": "An error response", + "allOf": [ + { + "$ref": "#/definitions/JsonRpcError" + } + ] + } + ], + "definitions": { + "Annotated": { + "type": "object", + "properties": { + "annotations": { + "anyOf": [ + { + "$ref": "#/definitions/Annotations" + }, + { + "type": "null" + } + ] + } + }, + "oneOf": [ + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "text" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawTextContent" + } + ], + "required": [ + "type" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "image" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawImageContent" + } + ], + "required": [ + "type" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "resource" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawEmbeddedResource" + } + ], + "required": [ + "type" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "audio" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawAudioContent" + } + ], + "required": [ + "type" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "resource_link" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawResource" + } + ], + "required": [ + "type" + ] + } + ] + }, + "Annotated2": { + "type": "object", + "properties": { + "_meta": { + "description": "Optional protocol-level metadata for this content block", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "annotations": { + "anyOf": [ + { + "$ref": "#/definitions/Annotations" + }, + { + "type": "null" + } + ] + }, + "resource": { + "$ref": "#/definitions/ResourceContents" + } + }, + "required": [ + "resource" + ] + }, + "Annotated3": { + "description": "Represents a resource in the extension with metadata", + "type": "object", + "properties": { + "annotations": { + "anyOf": [ + { + "$ref": "#/definitions/Annotations" + }, + { + "type": "null" + } + ] + }, + "description": { + "description": "Optional description of the resource", + "type": [ + "string", + "null" + ] + }, + "icons": { + "description": "Optional list of icons for the resource", + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/Icon" + } + }, + "mimeType": { + "description": "MIME type of the resource content (\"text\" or \"blob\")", + "type": [ + "string", + "null" + ] + }, + "name": { + "description": "Name of the resource", + "type": "string" + }, + "size": { + "description": "The size of the raw resource content, in bytes (i.e., before base64 encoding or any tokenization), if known.\n\nThis can be used by Hosts to display file sizes and estimate context window us", + "type": [ + "integer", + "null" + ], + "format": "uint32", + "minimum": 0 + }, + "title": { + "description": "Human-readable title of the resource", + "type": [ + "string", + "null" + ] + }, + "uri": { + "description": "URI representing the resource location (e.g., \"file:///path/to/file\" or \"str:///content\")", + "type": "string" + } + }, + "required": [ + "uri", + "name" + ] + }, + "Annotated4": { + "type": "object", + "properties": { + "annotations": { + "anyOf": [ + { + "$ref": "#/definitions/Annotations" + }, + { + "type": "null" + } + ] + }, + "description": { + "type": [ + "string", + "null" + ] + }, + "mimeType": { + "type": [ + "string", + "null" + ] + }, + "name": { + "type": "string" + }, + "title": { + "type": [ + "string", + "null" + ] + }, + "uriTemplate": { + "type": "string" + } + }, + "required": [ + "uriTemplate", + "name" + ] + }, + "Annotations": { + "type": "object", + "properties": { + "audience": { + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/Role" + } + }, + "lastModified": { + "type": [ + "string", + "null" + ], + "format": "date-time" + }, + "priority": { + "type": [ + "number", + "null" + ], + "format": "float" + } + } + }, + "BooleanSchema": { + "description": "Schema definition for boolean properties.", + "type": "object", + "properties": { + "default": { + "description": "Default value", + "type": [ + "boolean", + "null" + ] + }, + "description": { + "description": "Human-readable description", + "type": [ + "string", + "null" + ] + }, + "title": { + "description": "Optional title for the schema", + "type": [ + "string", + "null" + ] + }, + "type": { + "description": "Type discriminator", + "allOf": [ + { + "$ref": "#/definitions/BooleanTypeConst" + } + ] + } + }, + "required": [ + "type" + ] + }, + "BooleanTypeConst": { + "type": "string", + "format": "const", + "const": "boolean" + }, + "CallToolResult": { + "description": "The result of a tool call operation.\n\nContains the content returned by the tool execution and an optional\nflag indicating whether the operation resulted in an error.", + "type": "object", + "properties": { + "_meta": { + "description": "Optional protocol-level metadata for this result", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "content": { + "description": "The content returned by the tool (text, images, etc.)", + "type": "array", + "items": { + "$ref": "#/definitions/Annotated" + } + }, + "isError": { + "description": "Whether this result represents an error condition", + "type": [ + "boolean", + "null" + ] + }, + "structuredContent": { + "description": "An optional JSON object that represents the structured result of the tool call" + } + }, + "required": [ + "content" + ] + }, + "CancelledNotificationMethod": { + "type": "string", + "format": "const", + "const": "notifications/cancelled" + }, + "CancelledNotificationParam": { + "type": "object", + "properties": { + "reason": { + "type": [ + "string", + "null" + ] + }, + "requestId": { + "$ref": "#/definitions/NumberOrString" + } + }, + "required": [ + "requestId" + ] + }, + "CompleteResult": { + "type": "object", + "properties": { + "completion": { + "$ref": "#/definitions/CompletionInfo" + } + }, + "required": [ + "completion" + ] + }, + "CompletionInfo": { + "type": "object", + "properties": { + "hasMore": { + "type": [ + "boolean", + "null" + ] + }, + "total": { + "type": [ + "integer", + "null" + ], + "format": "uint32", + "minimum": 0 + }, + "values": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "required": [ + "values" + ] + }, + "ContextInclusion": { + "description": "Specifies how much context should be included in sampling requests.\n\nThis allows clients to control what additional context information\nshould be provided to the LLM when processing sampling requests.", + "oneOf": [ + { + "description": "Include context from all connected MCP servers", + "type": "string", + "const": "allServers" + }, + { + "description": "Include no additional context", + "type": "string", + "const": "none" + }, + { + "description": "Include context only from the requesting server", + "type": "string", + "const": "thisServer" + } + ] + }, + "CreateElicitationRequestParam": { + "description": "Parameters for creating an elicitation request to gather user input.\n\nThis structure contains everything needed to request interactive input from a user:\n- A human-readable message explaining what information is needed\n- A type-safe schema defining the expected structure of the response\n\n# Example\n\n```rust\nuse rmcp::model::*;\n\nlet params = CreateElicitationRequestParam {\n message: \"Please provide your email\".to_string(),\n requested_schema: ElicitationSchema::builder()\n .required_email(\"email\")\n .build()\n .unwrap(),\n};\n```", + "type": "object", + "properties": { + "message": { + "description": "Human-readable message explaining what input is needed from the user.\nThis should be clear and provide sufficient context for the user to understand\nwhat information they need to provide.", + "type": "string" + }, + "requestedSchema": { + "description": "Type-safe schema defining the expected structure and validation rules for the user's response.\nThis enforces the MCP 2025-06-18 specification that elicitation schemas must be objects\nwith primitive-typed properties.", + "allOf": [ + { + "$ref": "#/definitions/ElicitationSchema" + } + ] + } + }, + "required": [ + "message", + "requestedSchema" + ] + }, + "CreateElicitationResult": { + "description": "The result returned by a client in response to an elicitation request.\n\nContains the user's decision (accept/decline/cancel) and optionally their input data\nif they chose to accept the request.", + "type": "object", + "properties": { + "action": { + "description": "The user's decision on how to handle the elicitation request", + "allOf": [ + { + "$ref": "#/definitions/ElicitationAction" + } + ] + }, + "content": { + "description": "The actual data provided by the user, if they accepted the request.\nMust conform to the JSON schema specified in the original request.\nOnly present when action is Accept." + } + }, + "required": [ + "action" + ] + }, + "CreateMessageRequestMethod": { + "type": "string", + "format": "const", + "const": "sampling/createMessage" + }, + "CreateMessageRequestParam": { + "description": "Parameters for creating a message through LLM sampling.\n\nThis structure contains all the necessary information for a client to\ngenerate an LLM response, including conversation history, model preferences,\nand generation parameters.", + "type": "object", + "properties": { + "includeContext": { + "description": "How much context to include from MCP servers", + "anyOf": [ + { + "$ref": "#/definitions/ContextInclusion" + }, + { + "type": "null" + } + ] + }, + "maxTokens": { + "description": "Maximum number of tokens to generate", + "type": "integer", + "format": "uint32", + "minimum": 0 + }, + "messages": { + "description": "The conversation history and current messages", + "type": "array", + "items": { + "$ref": "#/definitions/SamplingMessage" + } + }, + "metadata": { + "description": "Additional metadata for the request" + }, + "modelPreferences": { + "description": "Preferences for model selection and behavior", + "anyOf": [ + { + "$ref": "#/definitions/ModelPreferences" + }, + { + "type": "null" + } + ] + }, + "stopSequences": { + "description": "Sequences that should stop generation", + "type": [ + "array", + "null" + ], + "items": { + "type": "string" + } + }, + "systemPrompt": { + "description": "System prompt to guide the model's behavior", + "type": [ + "string", + "null" + ] + }, + "temperature": { + "description": "Temperature for controlling randomness (0.0 to 1.0)", + "type": [ + "number", + "null" + ], + "format": "float" + } + }, + "required": [ + "messages", + "maxTokens" + ] + }, + "ElicitationAction": { + "description": "Represents the possible actions a user can take in response to an elicitation request.\n\nWhen a server requests user input through elicitation, the user can:\n- Accept: Provide the requested information and continue\n- Decline: Refuse to provide the information but continue the operation\n- Cancel: Stop the entire operation", + "oneOf": [ + { + "description": "User accepts the request and provides the requested information", + "type": "string", + "const": "accept" + }, + { + "description": "User declines to provide the information but allows the operation to continue", + "type": "string", + "const": "decline" + }, + { + "description": "User cancels the entire operation", + "type": "string", + "const": "cancel" + } + ] + }, + "ElicitationCreateRequestMethod": { + "type": "string", + "format": "const", + "const": "elicitation/create" + }, + "ElicitationSchema": { + "description": "Type-safe elicitation schema for requesting structured user input.\n\nThis enforces the MCP 2025-06-18 specification that elicitation schemas\nmust be objects with primitive-typed properties.\n\n# Example\n\n```rust\nuse rmcp::model::*;\n\nlet schema = ElicitationSchema::builder()\n .required_email(\"email\")\n .required_integer(\"age\", 0, 150)\n .optional_bool(\"newsletter\", false)\n .build();\n```", + "type": "object", + "properties": { + "description": { + "description": "Optional description of what this schema represents", + "type": [ + "string", + "null" + ] + }, + "properties": { + "description": "Property definitions (must be primitive types)", + "type": "object", + "additionalProperties": { + "$ref": "#/definitions/PrimitiveSchema" + } + }, + "required": { + "description": "List of required property names", + "type": [ + "array", + "null" + ], + "items": { + "type": "string" + } + }, + "title": { + "description": "Optional title for the schema", + "type": [ + "string", + "null" + ] + }, + "type": { + "description": "Always \"object\" for elicitation schemas", + "allOf": [ + { + "$ref": "#/definitions/ObjectTypeConst" + } + ] + } + }, + "required": [ + "type", + "properties" + ] + }, + "EmptyObject": { + "description": "This is commonly used for representing empty objects in MCP messages.\n\nwithout returning any specific data.", + "type": "object" + }, + "EnumSchema": { + "description": "Schema definition for enum properties.\n\nCompliant with MCP 2025-06-18 specification for elicitation schemas.\nEnums must have string type and can optionally include human-readable names.", + "type": "object", + "properties": { + "description": { + "description": "Human-readable description", + "type": [ + "string", + "null" + ] + }, + "enum": { + "description": "Allowed enum values (string values only per MCP spec)", + "type": "array", + "items": { + "type": "string" + } + }, + "enumNames": { + "description": "Optional human-readable names for each enum value", + "type": [ + "array", + "null" + ], + "items": { + "type": "string" + } + }, + "title": { + "description": "Optional title for the schema", + "type": [ + "string", + "null" + ] + }, + "type": { + "description": "Type discriminator (always \"string\" for enums)", + "allOf": [ + { + "$ref": "#/definitions/StringTypeConst" + } + ] + } + }, + "required": [ + "type", + "enum" + ] + }, + "ErrorCode": { + "description": "Standard JSON-RPC error codes used throughout the MCP protocol.\n\nThese codes follow the JSON-RPC 2.0 specification and provide\nstandardized error reporting across all MCP implementations.", + "type": "integer", + "format": "int32" + }, + "ErrorData": { + "description": "Error information for JSON-RPC error responses.\n\nThis structure follows the JSON-RPC 2.0 specification for error reporting,\nproviding a standardized way to communicate errors between clients and servers.", + "type": "object", + "properties": { + "code": { + "description": "The error type that occurred (using standard JSON-RPC error codes)", + "allOf": [ + { + "$ref": "#/definitions/ErrorCode" + } + ] + }, + "data": { + "description": "Additional information about the error. The value of this member is defined by the\nsender (e.g. detailed error information, nested errors etc.)." + }, + "message": { + "description": "A short description of the error. The message SHOULD be limited to a concise single sentence.", + "type": "string" + } + }, + "required": [ + "code", + "message" + ] + }, + "GetPromptResult": { + "type": "object", + "properties": { + "description": { + "type": [ + "string", + "null" + ] + }, + "messages": { + "type": "array", + "items": { + "$ref": "#/definitions/PromptMessage" + } + } + }, + "required": [ + "messages" + ] + }, + "Icon": { + "description": "A URL pointing to an icon resource or a base64-encoded data URI.\n\nClients that support rendering icons MUST support at least the following MIME types:\n- image/png - PNG images (safe, universal compatibility)\n- image/jpeg (and image/jpg) - JPEG images (safe, universal compatibility)\n\nClients that support rendering icons SHOULD also support:\n- image/svg+xml - SVG images (scalable but requires security precautions)\n- image/webp - WebP images (modern, efficient format)", + "type": "object", + "properties": { + "mimeType": { + "description": "Optional override if the server's MIME type is missing or generic", + "type": [ + "string", + "null" + ] + }, + "sizes": { + "description": "Size specification, each string should be in WxH format (e.g., `\\\"48x48\\\"`, `\\\"96x96\\\"`) or `\\\"any\\\"` for scalable formats like SVG", + "type": [ + "array", + "null" + ], + "items": { + "type": "string" + } + }, + "src": { + "description": "A standard URI pointing to an icon resource", + "type": "string" + } + }, + "required": [ + "src" + ] + }, + "Implementation": { + "type": "object", + "properties": { + "icons": { + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/Icon" + } + }, + "name": { + "type": "string" + }, + "title": { + "type": [ + "string", + "null" + ] + }, + "version": { + "type": "string" + }, + "websiteUrl": { + "type": [ + "string", + "null" + ] + } + }, + "required": [ + "name", + "version" + ] + }, + "InitializeResult": { + "description": "The server's response to an initialization request.\n\nContains the server's protocol version, capabilities, and implementation\ninformation, along with optional instructions for the client.", + "type": "object", + "properties": { + "capabilities": { + "description": "The capabilities this server provides (tools, resources, prompts, etc.)", + "allOf": [ + { + "$ref": "#/definitions/ServerCapabilities" + } + ] + }, + "instructions": { + "description": "Optional human-readable instructions about using this server", + "type": [ + "string", + "null" + ] + }, + "protocolVersion": { + "description": "The MCP protocol version this server supports", + "allOf": [ + { + "$ref": "#/definitions/ProtocolVersion" + } + ] + }, + "serverInfo": { + "description": "Information about the server implementation", + "allOf": [ + { + "$ref": "#/definitions/Implementation" + } + ] + } + }, + "required": [ + "protocolVersion", + "capabilities", + "serverInfo" + ] + }, + "IntegerSchema": { + "description": "Schema definition for integer properties.\n\nCompliant with MCP 2025-06-18 specification for elicitation schemas.\nSupports only the fields allowed by the MCP spec.", + "type": "object", + "properties": { + "description": { + "description": "Human-readable description", + "type": [ + "string", + "null" + ] + }, + "maximum": { + "description": "Maximum value (inclusive)", + "type": [ + "integer", + "null" + ], + "format": "int64" + }, + "minimum": { + "description": "Minimum value (inclusive)", + "type": [ + "integer", + "null" + ], + "format": "int64" + }, + "title": { + "description": "Optional title for the schema", + "type": [ + "string", + "null" + ] + }, + "type": { + "description": "Type discriminator", + "allOf": [ + { + "$ref": "#/definitions/IntegerTypeConst" + } + ] + } + }, + "required": [ + "type" + ] + }, + "IntegerTypeConst": { + "type": "string", + "format": "const", + "const": "integer" + }, + "JsonRpcError": { + "type": "object", + "properties": { + "error": { + "$ref": "#/definitions/ErrorData" + }, + "id": { + "$ref": "#/definitions/NumberOrString" + }, + "jsonrpc": { + "$ref": "#/definitions/JsonRpcVersion2_0" + } + }, + "required": [ + "jsonrpc", + "id", + "error" + ] + }, + "JsonRpcNotification": { + "type": "object", + "properties": { + "jsonrpc": { + "$ref": "#/definitions/JsonRpcVersion2_0" + } + }, + "anyOf": [ + { + "$ref": "#/definitions/Notification" + }, + { + "$ref": "#/definitions/Notification2" + }, + { + "$ref": "#/definitions/Notification3" + }, + { + "$ref": "#/definitions/Notification4" + }, + { + "$ref": "#/definitions/NotificationNoParam" + }, + { + "$ref": "#/definitions/NotificationNoParam2" + }, + { + "$ref": "#/definitions/NotificationNoParam3" + } + ], + "required": [ + "jsonrpc" + ] + }, + "JsonRpcRequest": { + "type": "object", + "properties": { + "id": { + "$ref": "#/definitions/NumberOrString" + }, + "jsonrpc": { + "$ref": "#/definitions/JsonRpcVersion2_0" + } + }, + "anyOf": [ + { + "$ref": "#/definitions/RequestNoParam" + }, + { + "$ref": "#/definitions/Request" + }, + { + "$ref": "#/definitions/RequestNoParam2" + }, + { + "$ref": "#/definitions/Request2" + } + ], + "required": [ + "jsonrpc", + "id" + ] + }, + "JsonRpcResponse": { + "type": "object", + "properties": { + "id": { + "$ref": "#/definitions/NumberOrString" + }, + "jsonrpc": { + "$ref": "#/definitions/JsonRpcVersion2_0" + }, + "result": { + "$ref": "#/definitions/ServerResult" + } + }, + "required": [ + "jsonrpc", + "id", + "result" + ] + }, + "JsonRpcVersion2_0": { + "type": "string", + "format": "const", + "const": "2.0" + }, + "ListPromptsResult": { + "type": "object", + "properties": { + "nextCursor": { + "type": [ + "string", + "null" + ] + }, + "prompts": { + "type": "array", + "items": { + "$ref": "#/definitions/Prompt" + } + } + }, + "required": [ + "prompts" + ] + }, + "ListResourceTemplatesResult": { + "type": "object", + "properties": { + "nextCursor": { + "type": [ + "string", + "null" + ] + }, + "resourceTemplates": { + "type": "array", + "items": { + "$ref": "#/definitions/Annotated4" + } + } + }, + "required": [ + "resourceTemplates" + ] + }, + "ListResourcesResult": { + "type": "object", + "properties": { + "nextCursor": { + "type": [ + "string", + "null" + ] + }, + "resources": { + "type": "array", + "items": { + "$ref": "#/definitions/Annotated3" + } + } + }, + "required": [ + "resources" + ] + }, + "ListRootsRequestMethod": { + "type": "string", + "format": "const", + "const": "roots/list" + }, + "ListToolsResult": { + "type": "object", + "properties": { + "nextCursor": { + "type": [ + "string", + "null" + ] + }, + "tools": { + "type": "array", + "items": { + "$ref": "#/definitions/Tool" + } + } + }, + "required": [ + "tools" + ] + }, + "LoggingLevel": { + "description": "Logging levels supported by the MCP protocol", + "type": "string", + "enum": [ + "debug", + "info", + "notice", + "warning", + "error", + "critical", + "alert", + "emergency" + ] + }, + "LoggingMessageNotificationMethod": { + "type": "string", + "format": "const", + "const": "notifications/message" + }, + "LoggingMessageNotificationParam": { + "description": "Parameters for a logging message notification", + "type": "object", + "properties": { + "data": { + "description": "The actual log data" + }, + "level": { + "description": "The severity level of this log message", + "allOf": [ + { + "$ref": "#/definitions/LoggingLevel" + } + ] + }, + "logger": { + "description": "Optional logger name that generated this message", + "type": [ + "string", + "null" + ] + } + }, + "required": [ + "level", + "data" + ] + }, + "ModelHint": { + "description": "A hint suggesting a preferred model name or family.\n\nModel hints are advisory suggestions that help clients choose appropriate\nmodels. They can be specific model names or general families like \"claude\" or \"gpt\".", + "type": "object", + "properties": { + "name": { + "description": "The suggested model name or family identifier", + "type": [ + "string", + "null" + ] + } + } + }, + "ModelPreferences": { + "description": "Preferences for model selection and behavior in sampling requests.\n\nThis allows servers to express their preferences for which model to use\nand how to balance different priorities when the client has multiple\nmodel options available.", + "type": "object", + "properties": { + "costPriority": { + "description": "Priority for cost optimization (0.0 to 1.0, higher = prefer cheaper models)", + "type": [ + "number", + "null" + ], + "format": "float" + }, + "hints": { + "description": "Specific model names or families to prefer (e.g., \"claude\", \"gpt\")", + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/ModelHint" + } + }, + "intelligencePriority": { + "description": "Priority for intelligence/capability (0.0 to 1.0, higher = prefer more capable models)", + "type": [ + "number", + "null" + ], + "format": "float" + }, + "speedPriority": { + "description": "Priority for speed/latency (0.0 to 1.0, higher = prefer faster models)", + "type": [ + "number", + "null" + ], + "format": "float" + } + } + }, + "Notification": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/CancelledNotificationMethod" + }, + "params": { + "$ref": "#/definitions/CancelledNotificationParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "Notification2": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/ProgressNotificationMethod" + }, + "params": { + "$ref": "#/definitions/ProgressNotificationParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "Notification3": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/LoggingMessageNotificationMethod" + }, + "params": { + "$ref": "#/definitions/LoggingMessageNotificationParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "Notification4": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/ResourceUpdatedNotificationMethod" + }, + "params": { + "$ref": "#/definitions/ResourceUpdatedNotificationParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "NotificationNoParam": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/ResourceListChangedNotificationMethod" + } + }, + "required": [ + "method" + ] + }, + "NotificationNoParam2": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/ToolListChangedNotificationMethod" + } + }, + "required": [ + "method" + ] + }, + "NotificationNoParam3": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/PromptListChangedNotificationMethod" + } + }, + "required": [ + "method" + ] + }, + "NumberOrString": { + "oneOf": [ + { + "type": "number" + }, + { + "type": "string" + } + ] + }, + "NumberSchema": { + "description": "Schema definition for number properties (floating-point).\n\nCompliant with MCP 2025-06-18 specification for elicitation schemas.\nSupports only the fields allowed by the MCP spec.", + "type": "object", + "properties": { + "description": { + "description": "Human-readable description", + "type": [ + "string", + "null" + ] + }, + "maximum": { + "description": "Maximum value (inclusive)", + "type": [ + "number", + "null" + ], + "format": "double" + }, + "minimum": { + "description": "Minimum value (inclusive)", + "type": [ + "number", + "null" + ], + "format": "double" + }, + "title": { + "description": "Optional title for the schema", + "type": [ + "string", + "null" + ] + }, + "type": { + "description": "Type discriminator", + "allOf": [ + { + "$ref": "#/definitions/NumberTypeConst" + } + ] + } + }, + "required": [ + "type" + ] + }, + "NumberTypeConst": { + "type": "string", + "format": "const", + "const": "number" + }, + "ObjectTypeConst": { + "type": "string", + "format": "const", + "const": "object" + }, + "PingRequestMethod": { + "type": "string", + "format": "const", + "const": "ping" + }, + "PrimitiveSchema": { + "description": "Primitive schema definition for elicitation properties.\n\nAccording to MCP 2025-06-18 specification, elicitation schemas must have\nproperties of primitive types only (string, number, integer, boolean, enum).", + "anyOf": [ + { + "description": "String property (with optional enum constraint)", + "allOf": [ + { + "$ref": "#/definitions/StringSchema" + } + ] + }, + { + "description": "Number property (with optional enum constraint)", + "allOf": [ + { + "$ref": "#/definitions/NumberSchema" + } + ] + }, + { + "description": "Integer property (with optional enum constraint)", + "allOf": [ + { + "$ref": "#/definitions/IntegerSchema" + } + ] + }, + { + "description": "Boolean property", + "allOf": [ + { + "$ref": "#/definitions/BooleanSchema" + } + ] + }, + { + "description": "Enum property (explicit enum schema)", + "allOf": [ + { + "$ref": "#/definitions/EnumSchema" + } + ] + } + ] + }, + "ProgressNotificationMethod": { + "type": "string", + "format": "const", + "const": "notifications/progress" + }, + "ProgressNotificationParam": { + "type": "object", + "properties": { + "message": { + "description": "An optional message describing the current progress.", + "type": [ + "string", + "null" + ] + }, + "progress": { + "description": "The progress thus far. This should increase every time progress is made, even if the total is unknown.", + "type": "number", + "format": "double" + }, + "progressToken": { + "$ref": "#/definitions/ProgressToken" + }, + "total": { + "description": "Total number of items to process (or total progress required), if known", + "type": [ + "number", + "null" + ], + "format": "double" + } + }, + "required": [ + "progressToken", + "progress" + ] + }, + "ProgressToken": { + "description": "A token used to track the progress of long-running operations.\n\nProgress tokens allow clients and servers to associate progress notifications\nwith specific requests, enabling real-time updates on operation status.", + "allOf": [ + { + "$ref": "#/definitions/NumberOrString" + } + ] + }, + "Prompt": { + "description": "A prompt that can be used to generate text from a model", + "type": "object", + "properties": { + "arguments": { + "description": "Optional arguments that can be passed to customize the prompt", + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/PromptArgument" + } + }, + "description": { + "description": "Optional description of what the prompt does", + "type": [ + "string", + "null" + ] + }, + "icons": { + "description": "Optional list of icons for the prompt", + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/Icon" + } + }, + "name": { + "description": "The name of the prompt", + "type": "string" + }, + "title": { + "type": [ + "string", + "null" + ] + } + }, + "required": [ + "name" + ] + }, + "PromptArgument": { + "description": "Represents a prompt argument that can be passed to customize the prompt", + "type": "object", + "properties": { + "description": { + "description": "A description of what the argument is used for", + "type": [ + "string", + "null" + ] + }, + "name": { + "description": "The name of the argument", + "type": "string" + }, + "required": { + "description": "Whether this argument is required", + "type": [ + "boolean", + "null" + ] + }, + "title": { + "description": "A human-readable title for the argument", + "type": [ + "string", + "null" + ] + } + }, + "required": [ + "name" + ] + }, + "PromptListChangedNotificationMethod": { + "type": "string", + "format": "const", + "const": "notifications/prompts/list_changed" + }, + "PromptMessage": { + "description": "A message in a prompt conversation", + "type": "object", + "properties": { + "content": { + "description": "The content of the message", + "allOf": [ + { + "$ref": "#/definitions/PromptMessageContent" + } + ] + }, + "role": { + "description": "The role of the message sender", + "allOf": [ + { + "$ref": "#/definitions/PromptMessageRole" + } + ] + } + }, + "required": [ + "role", + "content" + ] + }, + "PromptMessageContent": { + "description": "Content types that can be included in prompt messages", + "oneOf": [ + { + "description": "Plain text content", + "type": "object", + "properties": { + "text": { + "type": "string" + }, + "type": { + "type": "string", + "const": "text" + } + }, + "required": [ + "type", + "text" + ] + }, + { + "description": "Image content with base64-encoded data", + "type": "object", + "properties": { + "_meta": { + "description": "Optional protocol-level metadata for this content block", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "annotations": { + "anyOf": [ + { + "$ref": "#/definitions/Annotations" + }, + { + "type": "null" + } + ] + }, + "data": { + "description": "The base64-encoded image", + "type": "string" + }, + "mimeType": { + "type": "string" + }, + "type": { + "type": "string", + "const": "image" + } + }, + "required": [ + "type", + "data", + "mimeType" + ] + }, + { + "description": "Embedded server-side resource", + "type": "object", + "properties": { + "resource": { + "$ref": "#/definitions/Annotated2" + }, + "type": { + "type": "string", + "const": "resource" + } + }, + "required": [ + "type", + "resource" + ] + }, + { + "description": "A link to a resource that can be fetched separately", + "type": "object", + "properties": { + "annotations": { + "anyOf": [ + { + "$ref": "#/definitions/Annotations" + }, + { + "type": "null" + } + ] + }, + "description": { + "description": "Optional description of the resource", + "type": [ + "string", + "null" + ] + }, + "icons": { + "description": "Optional list of icons for the resource", + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/Icon" + } + }, + "mimeType": { + "description": "MIME type of the resource content (\"text\" or \"blob\")", + "type": [ + "string", + "null" + ] + }, + "name": { + "description": "Name of the resource", + "type": "string" + }, + "size": { + "description": "The size of the raw resource content, in bytes (i.e., before base64 encoding or any tokenization), if known.\n\nThis can be used by Hosts to display file sizes and estimate context window us", + "type": [ + "integer", + "null" + ], + "format": "uint32", + "minimum": 0 + }, + "title": { + "description": "Human-readable title of the resource", + "type": [ + "string", + "null" + ] + }, + "type": { + "type": "string", + "const": "resource_link" + }, + "uri": { + "description": "URI representing the resource location (e.g., \"file:///path/to/file\" or \"str:///content\")", + "type": "string" + } + }, + "required": [ + "type", + "uri", + "name" + ] + } + ] + }, + "PromptMessageRole": { + "description": "Represents the role of a message sender in a prompt conversation", + "type": "string", + "enum": [ + "user", + "assistant" + ] + }, + "PromptsCapability": { + "type": "object", + "properties": { + "listChanged": { + "type": [ + "boolean", + "null" + ] + } + } + }, + "ProtocolVersion": { + "description": "Represents the MCP protocol version used for communication.\n\nThis ensures compatibility between clients and servers by specifying\nwhich version of the Model Context Protocol is being used.", + "type": "string" + }, + "RawAudioContent": { + "type": "object", + "properties": { + "data": { + "type": "string" + }, + "mimeType": { + "type": "string" + } + }, + "required": [ + "data", + "mimeType" + ] + }, + "RawEmbeddedResource": { + "type": "object", + "properties": { + "_meta": { + "description": "Optional protocol-level metadata for this content block", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "resource": { + "$ref": "#/definitions/ResourceContents" + } + }, + "required": [ + "resource" + ] + }, + "RawImageContent": { + "type": "object", + "properties": { + "_meta": { + "description": "Optional protocol-level metadata for this content block", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "data": { + "description": "The base64-encoded image", + "type": "string" + }, + "mimeType": { + "type": "string" + } + }, + "required": [ + "data", + "mimeType" + ] + }, + "RawResource": { + "description": "Represents a resource in the extension with metadata", + "type": "object", + "properties": { + "description": { + "description": "Optional description of the resource", + "type": [ + "string", + "null" + ] + }, + "icons": { + "description": "Optional list of icons for the resource", + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/Icon" + } + }, + "mimeType": { + "description": "MIME type of the resource content (\"text\" or \"blob\")", + "type": [ + "string", + "null" + ] + }, + "name": { + "description": "Name of the resource", + "type": "string" + }, + "size": { + "description": "The size of the raw resource content, in bytes (i.e., before base64 encoding or any tokenization), if known.\n\nThis can be used by Hosts to display file sizes and estimate context window us", + "type": [ + "integer", + "null" + ], + "format": "uint32", + "minimum": 0 + }, + "title": { + "description": "Human-readable title of the resource", + "type": [ + "string", + "null" + ] + }, + "uri": { + "description": "URI representing the resource location (e.g., \"file:///path/to/file\" or \"str:///content\")", + "type": "string" + } + }, + "required": [ + "uri", + "name" + ] + }, + "RawTextContent": { + "type": "object", + "properties": { + "_meta": { + "description": "Optional protocol-level metadata for this content block", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "text": { + "type": "string" + } + }, + "required": [ + "text" + ] + }, + "ReadResourceResult": { + "description": "Result containing the contents of a read resource", + "type": "object", + "properties": { + "contents": { + "description": "The actual content of the resource", + "type": "array", + "items": { + "$ref": "#/definitions/ResourceContents" + } + } + }, + "required": [ + "contents" + ] + }, + "Request": { + "description": "Represents a JSON-RPC request with method, parameters, and extensions.\n\nThis is the core structure for all MCP requests, containing:\n- `method`: The name of the method being called\n- `params`: The parameters for the method\n- `extensions`: Additional context data (similar to HTTP headers)", + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/CreateMessageRequestMethod" + }, + "params": { + "$ref": "#/definitions/CreateMessageRequestParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "Request2": { + "description": "Represents a JSON-RPC request with method, parameters, and extensions.\n\nThis is the core structure for all MCP requests, containing:\n- `method`: The name of the method being called\n- `params`: The parameters for the method\n- `extensions`: Additional context data (similar to HTTP headers)", + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/ElicitationCreateRequestMethod" + }, + "params": { + "$ref": "#/definitions/CreateElicitationRequestParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "RequestNoParam": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/PingRequestMethod" + } + }, + "required": [ + "method" + ] + }, + "RequestNoParam2": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/ListRootsRequestMethod" + } + }, + "required": [ + "method" + ] + }, + "ResourceContents": { + "anyOf": [ + { + "type": "object", + "properties": { + "_meta": { + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "mimeType": { + "type": [ + "string", + "null" + ] + }, + "text": { + "type": "string" + }, + "uri": { + "type": "string" + } + }, + "required": [ + "uri", + "text" + ] + }, + { + "type": "object", + "properties": { + "_meta": { + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "blob": { + "type": "string" + }, + "mimeType": { + "type": [ + "string", + "null" + ] + }, + "uri": { + "type": "string" + } + }, + "required": [ + "uri", + "blob" + ] + } + ] + }, + "ResourceListChangedNotificationMethod": { + "type": "string", + "format": "const", + "const": "notifications/resources/list_changed" + }, + "ResourceUpdatedNotificationMethod": { + "type": "string", + "format": "const", + "const": "notifications/resources/updated" + }, + "ResourceUpdatedNotificationParam": { + "description": "Parameters for a resource update notification", + "type": "object", + "properties": { + "uri": { + "description": "The URI of the resource that was updated", + "type": "string" + } + }, + "required": [ + "uri" + ] + }, + "ResourcesCapability": { + "type": "object", + "properties": { + "listChanged": { + "type": [ + "boolean", + "null" + ] + }, + "subscribe": { + "type": [ + "boolean", + "null" + ] + } + } + }, + "Role": { + "description": "Represents the role of a participant in a conversation or message exchange.\n\nUsed in sampling and chat contexts to distinguish between different\ntypes of message senders in the conversation flow.", + "oneOf": [ + { + "description": "A human user or client making a request", + "type": "string", + "const": "user" + }, + { + "description": "An AI assistant or server providing a response", + "type": "string", + "const": "assistant" + } + ] + }, + "SamplingMessage": { + "description": "A message in a sampling conversation, containing a role and content.\n\nThis represents a single message in a conversation flow, used primarily\nin LLM sampling requests where the conversation history is important\nfor generating appropriate responses.", + "type": "object", + "properties": { + "content": { + "description": "The actual content of the message (text, image, etc.)", + "allOf": [ + { + "$ref": "#/definitions/Annotated" + } + ] + }, + "role": { + "description": "The role of the message sender (User or Assistant)", + "allOf": [ + { + "$ref": "#/definitions/Role" + } + ] + } + }, + "required": [ + "role", + "content" + ] + }, + "ServerCapabilities": { + "title": "Builder", + "description": "```rust\n# use rmcp::model::ServerCapabilities;\nlet cap = ServerCapabilities::builder()\n .enable_logging()\n .enable_experimental()\n .enable_prompts()\n .enable_resources()\n .enable_tools()\n .enable_tool_list_changed()\n .build();\n```", + "type": "object", + "properties": { + "completions": { + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "experimental": { + "type": [ + "object", + "null" + ], + "additionalProperties": { + "type": "object", + "additionalProperties": true + } + }, + "logging": { + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "prompts": { + "anyOf": [ + { + "$ref": "#/definitions/PromptsCapability" + }, + { + "type": "null" + } + ] + }, + "resources": { + "anyOf": [ + { + "$ref": "#/definitions/ResourcesCapability" + }, + { + "type": "null" + } + ] + }, + "tools": { + "anyOf": [ + { + "$ref": "#/definitions/ToolsCapability" + }, + { + "type": "null" + } + ] + } + } + }, + "ServerResult": { + "anyOf": [ + { + "$ref": "#/definitions/InitializeResult" + }, + { + "$ref": "#/definitions/CompleteResult" + }, + { + "$ref": "#/definitions/GetPromptResult" + }, + { + "$ref": "#/definitions/ListPromptsResult" + }, + { + "$ref": "#/definitions/ListResourcesResult" + }, + { + "$ref": "#/definitions/ListResourceTemplatesResult" + }, + { + "$ref": "#/definitions/ReadResourceResult" + }, + { + "$ref": "#/definitions/CallToolResult" + }, + { + "$ref": "#/definitions/ListToolsResult" + }, + { + "$ref": "#/definitions/CreateElicitationResult" + }, + { + "$ref": "#/definitions/EmptyObject" + } + ] + }, + "StringFormat": { + "description": "String format types allowed by the MCP specification.", + "oneOf": [ + { + "description": "Email address format", + "type": "string", + "const": "email" + }, + { + "description": "URI format", + "type": "string", + "const": "uri" + }, + { + "description": "Date format (YYYY-MM-DD)", + "type": "string", + "const": "date" + }, + { + "description": "Date-time format (ISO 8601)", + "type": "string", + "const": "date-time" + } + ] + }, + "StringSchema": { + "description": "Schema definition for string properties.\n\nCompliant with MCP 2025-06-18 specification for elicitation schemas.\nSupports only the fields allowed by the MCP spec:\n- format limited to: \"email\", \"uri\", \"date\", \"date-time\"", + "type": "object", + "properties": { + "description": { + "description": "Human-readable description", + "type": [ + "string", + "null" + ] + }, + "format": { + "description": "String format - limited to: \"email\", \"uri\", \"date\", \"date-time\"", + "anyOf": [ + { + "$ref": "#/definitions/StringFormat" + }, + { + "type": "null" + } + ] + }, + "maxLength": { + "description": "Maximum string length", + "type": [ + "integer", + "null" + ], + "format": "uint32", + "minimum": 0 + }, + "minLength": { + "description": "Minimum string length", + "type": [ + "integer", + "null" + ], + "format": "uint32", + "minimum": 0 + }, + "title": { + "description": "Optional title for the schema", + "type": [ + "string", + "null" + ] + }, + "type": { + "description": "Type discriminator", + "allOf": [ + { + "$ref": "#/definitions/StringTypeConst" + } + ] + } + }, + "required": [ + "type" + ] + }, + "StringTypeConst": { + "type": "string", + "format": "const", + "const": "string" + }, + "Tool": { + "description": "A tool that can be used by a model.", + "type": "object", + "properties": { + "annotations": { + "description": "Optional additional tool information.", + "anyOf": [ + { + "$ref": "#/definitions/ToolAnnotations" + }, + { + "type": "null" + } + ] + }, + "description": { + "description": "A description of what the tool does", + "type": [ + "string", + "null" + ] + }, + "icons": { + "description": "Optional list of icons for the tool", + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/Icon" + } + }, + "inputSchema": { + "description": "A JSON Schema object defining the expected parameters for the tool", + "type": "object", + "additionalProperties": true + }, + "name": { + "description": "The name of the tool", + "type": "string" + }, + "outputSchema": { + "description": "An optional JSON Schema object defining the structure of the tool's output", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "title": { + "description": "A human-readable title for the tool", + "type": [ + "string", + "null" + ] + } + }, + "required": [ + "name", + "inputSchema" + ] + }, + "ToolAnnotations": { + "description": "Additional properties describing a Tool to clients.\n\nNOTE: all properties in ToolAnnotations are **hints**.\nThey are not guaranteed to provide a faithful description of\ntool behavior (including descriptive properties like `title`).\n\nClients should never make tool use decisions based on ToolAnnotations\nreceived from untrusted servers.", + "type": "object", + "properties": { + "destructiveHint": { + "description": "If true, the tool may perform destructive updates to its environment.\nIf false, the tool performs only additive updates.\n\n(This property is meaningful only when `readOnlyHint == false`)\n\nDefault: true\nA human-readable description of the tool's purpose.", + "type": [ + "boolean", + "null" + ] + }, + "idempotentHint": { + "description": "If true, calling the tool repeatedly with the same arguments\nwill have no additional effect on the its environment.\n\n(This property is meaningful only when `readOnlyHint == false`)\n\nDefault: false.", + "type": [ + "boolean", + "null" + ] + }, + "openWorldHint": { + "description": "If true, this tool may interact with an \"open world\" of external\nentities. If false, the tool's domain of interaction is closed.\nFor example, the world of a web search tool is open, whereas that\nof a memory tool is not.\n\nDefault: true", + "type": [ + "boolean", + "null" + ] + }, + "readOnlyHint": { + "description": "If true, the tool does not modify its environment.\n\nDefault: false", + "type": [ + "boolean", + "null" + ] + }, + "title": { + "description": "A human-readable title for the tool.", + "type": [ + "string", + "null" + ] + } + } + }, + "ToolListChangedNotificationMethod": { + "type": "string", + "format": "const", + "const": "notifications/tools/list_changed" + }, + "ToolsCapability": { + "type": "object", + "properties": { + "listChanged": { + "type": [ + "boolean", + "null" + ] + } + } + } + } +} \ No newline at end of file diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_message_schema/server_json_rpc_message_schema_current.json b/code-rs/third_party/rmcp-0.8.3/tests/test_message_schema/server_json_rpc_message_schema_current.json new file mode 100644 index 00000000000..663a68941cb --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_message_schema/server_json_rpc_message_schema_current.json @@ -0,0 +1,2411 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "JsonRpcMessage", + "description": "Represents any JSON-RPC message that can be sent or received.\n\nThis enum covers all possible message types in the JSON-RPC protocol:\nindividual requests/responses, notifications, and errors.\nIt serves as the top-level message container for MCP communication.", + "anyOf": [ + { + "description": "A single request expecting a response", + "allOf": [ + { + "$ref": "#/definitions/JsonRpcRequest" + } + ] + }, + { + "description": "A response to a previous request", + "allOf": [ + { + "$ref": "#/definitions/JsonRpcResponse" + } + ] + }, + { + "description": "A one-way notification (no response expected)", + "allOf": [ + { + "$ref": "#/definitions/JsonRpcNotification" + } + ] + }, + { + "description": "An error response", + "allOf": [ + { + "$ref": "#/definitions/JsonRpcError" + } + ] + } + ], + "definitions": { + "Annotated": { + "type": "object", + "properties": { + "annotations": { + "anyOf": [ + { + "$ref": "#/definitions/Annotations" + }, + { + "type": "null" + } + ] + } + }, + "oneOf": [ + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "text" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawTextContent" + } + ], + "required": [ + "type" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "image" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawImageContent" + } + ], + "required": [ + "type" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "resource" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawEmbeddedResource" + } + ], + "required": [ + "type" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "audio" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawAudioContent" + } + ], + "required": [ + "type" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "resource_link" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawResource" + } + ], + "required": [ + "type" + ] + } + ] + }, + "Annotated2": { + "type": "object", + "properties": { + "_meta": { + "description": "Optional protocol-level metadata for this content block", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "annotations": { + "anyOf": [ + { + "$ref": "#/definitions/Annotations" + }, + { + "type": "null" + } + ] + }, + "resource": { + "$ref": "#/definitions/ResourceContents" + } + }, + "required": [ + "resource" + ] + }, + "Annotated3": { + "description": "Represents a resource in the extension with metadata", + "type": "object", + "properties": { + "annotations": { + "anyOf": [ + { + "$ref": "#/definitions/Annotations" + }, + { + "type": "null" + } + ] + }, + "description": { + "description": "Optional description of the resource", + "type": [ + "string", + "null" + ] + }, + "icons": { + "description": "Optional list of icons for the resource", + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/Icon" + } + }, + "mimeType": { + "description": "MIME type of the resource content (\"text\" or \"blob\")", + "type": [ + "string", + "null" + ] + }, + "name": { + "description": "Name of the resource", + "type": "string" + }, + "size": { + "description": "The size of the raw resource content, in bytes (i.e., before base64 encoding or any tokenization), if known.\n\nThis can be used by Hosts to display file sizes and estimate context window us", + "type": [ + "integer", + "null" + ], + "format": "uint32", + "minimum": 0 + }, + "title": { + "description": "Human-readable title of the resource", + "type": [ + "string", + "null" + ] + }, + "uri": { + "description": "URI representing the resource location (e.g., \"file:///path/to/file\" or \"str:///content\")", + "type": "string" + } + }, + "required": [ + "uri", + "name" + ] + }, + "Annotated4": { + "type": "object", + "properties": { + "annotations": { + "anyOf": [ + { + "$ref": "#/definitions/Annotations" + }, + { + "type": "null" + } + ] + }, + "description": { + "type": [ + "string", + "null" + ] + }, + "mimeType": { + "type": [ + "string", + "null" + ] + }, + "name": { + "type": "string" + }, + "title": { + "type": [ + "string", + "null" + ] + }, + "uriTemplate": { + "type": "string" + } + }, + "required": [ + "uriTemplate", + "name" + ] + }, + "Annotations": { + "type": "object", + "properties": { + "audience": { + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/Role" + } + }, + "lastModified": { + "type": [ + "string", + "null" + ], + "format": "date-time" + }, + "priority": { + "type": [ + "number", + "null" + ], + "format": "float" + } + } + }, + "BooleanSchema": { + "description": "Schema definition for boolean properties.", + "type": "object", + "properties": { + "default": { + "description": "Default value", + "type": [ + "boolean", + "null" + ] + }, + "description": { + "description": "Human-readable description", + "type": [ + "string", + "null" + ] + }, + "title": { + "description": "Optional title for the schema", + "type": [ + "string", + "null" + ] + }, + "type": { + "description": "Type discriminator", + "allOf": [ + { + "$ref": "#/definitions/BooleanTypeConst" + } + ] + } + }, + "required": [ + "type" + ] + }, + "BooleanTypeConst": { + "type": "string", + "format": "const", + "const": "boolean" + }, + "CallToolResult": { + "description": "The result of a tool call operation.\n\nContains the content returned by the tool execution and an optional\nflag indicating whether the operation resulted in an error.", + "type": "object", + "properties": { + "_meta": { + "description": "Optional protocol-level metadata for this result", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "content": { + "description": "The content returned by the tool (text, images, etc.)", + "type": "array", + "items": { + "$ref": "#/definitions/Annotated" + } + }, + "isError": { + "description": "Whether this result represents an error condition", + "type": [ + "boolean", + "null" + ] + }, + "structuredContent": { + "description": "An optional JSON object that represents the structured result of the tool call" + } + }, + "required": [ + "content" + ] + }, + "CancelledNotificationMethod": { + "type": "string", + "format": "const", + "const": "notifications/cancelled" + }, + "CancelledNotificationParam": { + "type": "object", + "properties": { + "reason": { + "type": [ + "string", + "null" + ] + }, + "requestId": { + "$ref": "#/definitions/NumberOrString" + } + }, + "required": [ + "requestId" + ] + }, + "CompleteResult": { + "type": "object", + "properties": { + "completion": { + "$ref": "#/definitions/CompletionInfo" + } + }, + "required": [ + "completion" + ] + }, + "CompletionInfo": { + "type": "object", + "properties": { + "hasMore": { + "type": [ + "boolean", + "null" + ] + }, + "total": { + "type": [ + "integer", + "null" + ], + "format": "uint32", + "minimum": 0 + }, + "values": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "required": [ + "values" + ] + }, + "ContextInclusion": { + "description": "Specifies how much context should be included in sampling requests.\n\nThis allows clients to control what additional context information\nshould be provided to the LLM when processing sampling requests.", + "oneOf": [ + { + "description": "Include context from all connected MCP servers", + "type": "string", + "const": "allServers" + }, + { + "description": "Include no additional context", + "type": "string", + "const": "none" + }, + { + "description": "Include context only from the requesting server", + "type": "string", + "const": "thisServer" + } + ] + }, + "CreateElicitationRequestParam": { + "description": "Parameters for creating an elicitation request to gather user input.\n\nThis structure contains everything needed to request interactive input from a user:\n- A human-readable message explaining what information is needed\n- A type-safe schema defining the expected structure of the response\n\n# Example\n\n```rust\nuse rmcp::model::*;\n\nlet params = CreateElicitationRequestParam {\n message: \"Please provide your email\".to_string(),\n requested_schema: ElicitationSchema::builder()\n .required_email(\"email\")\n .build()\n .unwrap(),\n};\n```", + "type": "object", + "properties": { + "message": { + "description": "Human-readable message explaining what input is needed from the user.\nThis should be clear and provide sufficient context for the user to understand\nwhat information they need to provide.", + "type": "string" + }, + "requestedSchema": { + "description": "Type-safe schema defining the expected structure and validation rules for the user's response.\nThis enforces the MCP 2025-06-18 specification that elicitation schemas must be objects\nwith primitive-typed properties.", + "allOf": [ + { + "$ref": "#/definitions/ElicitationSchema" + } + ] + } + }, + "required": [ + "message", + "requestedSchema" + ] + }, + "CreateElicitationResult": { + "description": "The result returned by a client in response to an elicitation request.\n\nContains the user's decision (accept/decline/cancel) and optionally their input data\nif they chose to accept the request.", + "type": "object", + "properties": { + "action": { + "description": "The user's decision on how to handle the elicitation request", + "allOf": [ + { + "$ref": "#/definitions/ElicitationAction" + } + ] + }, + "content": { + "description": "The actual data provided by the user, if they accepted the request.\nMust conform to the JSON schema specified in the original request.\nOnly present when action is Accept." + } + }, + "required": [ + "action" + ] + }, + "CreateMessageRequestMethod": { + "type": "string", + "format": "const", + "const": "sampling/createMessage" + }, + "CreateMessageRequestParam": { + "description": "Parameters for creating a message through LLM sampling.\n\nThis structure contains all the necessary information for a client to\ngenerate an LLM response, including conversation history, model preferences,\nand generation parameters.", + "type": "object", + "properties": { + "includeContext": { + "description": "How much context to include from MCP servers", + "anyOf": [ + { + "$ref": "#/definitions/ContextInclusion" + }, + { + "type": "null" + } + ] + }, + "maxTokens": { + "description": "Maximum number of tokens to generate", + "type": "integer", + "format": "uint32", + "minimum": 0 + }, + "messages": { + "description": "The conversation history and current messages", + "type": "array", + "items": { + "$ref": "#/definitions/SamplingMessage" + } + }, + "metadata": { + "description": "Additional metadata for the request" + }, + "modelPreferences": { + "description": "Preferences for model selection and behavior", + "anyOf": [ + { + "$ref": "#/definitions/ModelPreferences" + }, + { + "type": "null" + } + ] + }, + "stopSequences": { + "description": "Sequences that should stop generation", + "type": [ + "array", + "null" + ], + "items": { + "type": "string" + } + }, + "systemPrompt": { + "description": "System prompt to guide the model's behavior", + "type": [ + "string", + "null" + ] + }, + "temperature": { + "description": "Temperature for controlling randomness (0.0 to 1.0)", + "type": [ + "number", + "null" + ], + "format": "float" + } + }, + "required": [ + "messages", + "maxTokens" + ] + }, + "ElicitationAction": { + "description": "Represents the possible actions a user can take in response to an elicitation request.\n\nWhen a server requests user input through elicitation, the user can:\n- Accept: Provide the requested information and continue\n- Decline: Refuse to provide the information but continue the operation\n- Cancel: Stop the entire operation", + "oneOf": [ + { + "description": "User accepts the request and provides the requested information", + "type": "string", + "const": "accept" + }, + { + "description": "User declines to provide the information but allows the operation to continue", + "type": "string", + "const": "decline" + }, + { + "description": "User cancels the entire operation", + "type": "string", + "const": "cancel" + } + ] + }, + "ElicitationCreateRequestMethod": { + "type": "string", + "format": "const", + "const": "elicitation/create" + }, + "ElicitationSchema": { + "description": "Type-safe elicitation schema for requesting structured user input.\n\nThis enforces the MCP 2025-06-18 specification that elicitation schemas\nmust be objects with primitive-typed properties.\n\n# Example\n\n```rust\nuse rmcp::model::*;\n\nlet schema = ElicitationSchema::builder()\n .required_email(\"email\")\n .required_integer(\"age\", 0, 150)\n .optional_bool(\"newsletter\", false)\n .build();\n```", + "type": "object", + "properties": { + "description": { + "description": "Optional description of what this schema represents", + "type": [ + "string", + "null" + ] + }, + "properties": { + "description": "Property definitions (must be primitive types)", + "type": "object", + "additionalProperties": { + "$ref": "#/definitions/PrimitiveSchema" + } + }, + "required": { + "description": "List of required property names", + "type": [ + "array", + "null" + ], + "items": { + "type": "string" + } + }, + "title": { + "description": "Optional title for the schema", + "type": [ + "string", + "null" + ] + }, + "type": { + "description": "Always \"object\" for elicitation schemas", + "allOf": [ + { + "$ref": "#/definitions/ObjectTypeConst" + } + ] + } + }, + "required": [ + "type", + "properties" + ] + }, + "EmptyObject": { + "description": "This is commonly used for representing empty objects in MCP messages.\n\nwithout returning any specific data.", + "type": "object" + }, + "EnumSchema": { + "description": "Schema definition for enum properties.\n\nCompliant with MCP 2025-06-18 specification for elicitation schemas.\nEnums must have string type and can optionally include human-readable names.", + "type": "object", + "properties": { + "description": { + "description": "Human-readable description", + "type": [ + "string", + "null" + ] + }, + "enum": { + "description": "Allowed enum values (string values only per MCP spec)", + "type": "array", + "items": { + "type": "string" + } + }, + "enumNames": { + "description": "Optional human-readable names for each enum value", + "type": [ + "array", + "null" + ], + "items": { + "type": "string" + } + }, + "title": { + "description": "Optional title for the schema", + "type": [ + "string", + "null" + ] + }, + "type": { + "description": "Type discriminator (always \"string\" for enums)", + "allOf": [ + { + "$ref": "#/definitions/StringTypeConst" + } + ] + } + }, + "required": [ + "type", + "enum" + ] + }, + "ErrorCode": { + "description": "Standard JSON-RPC error codes used throughout the MCP protocol.\n\nThese codes follow the JSON-RPC 2.0 specification and provide\nstandardized error reporting across all MCP implementations.", + "type": "integer", + "format": "int32" + }, + "ErrorData": { + "description": "Error information for JSON-RPC error responses.\n\nThis structure follows the JSON-RPC 2.0 specification for error reporting,\nproviding a standardized way to communicate errors between clients and servers.", + "type": "object", + "properties": { + "code": { + "description": "The error type that occurred (using standard JSON-RPC error codes)", + "allOf": [ + { + "$ref": "#/definitions/ErrorCode" + } + ] + }, + "data": { + "description": "Additional information about the error. The value of this member is defined by the\nsender (e.g. detailed error information, nested errors etc.)." + }, + "message": { + "description": "A short description of the error. The message SHOULD be limited to a concise single sentence.", + "type": "string" + } + }, + "required": [ + "code", + "message" + ] + }, + "GetPromptResult": { + "type": "object", + "properties": { + "description": { + "type": [ + "string", + "null" + ] + }, + "messages": { + "type": "array", + "items": { + "$ref": "#/definitions/PromptMessage" + } + } + }, + "required": [ + "messages" + ] + }, + "Icon": { + "description": "A URL pointing to an icon resource or a base64-encoded data URI.\n\nClients that support rendering icons MUST support at least the following MIME types:\n- image/png - PNG images (safe, universal compatibility)\n- image/jpeg (and image/jpg) - JPEG images (safe, universal compatibility)\n\nClients that support rendering icons SHOULD also support:\n- image/svg+xml - SVG images (scalable but requires security precautions)\n- image/webp - WebP images (modern, efficient format)", + "type": "object", + "properties": { + "mimeType": { + "description": "Optional override if the server's MIME type is missing or generic", + "type": [ + "string", + "null" + ] + }, + "sizes": { + "description": "Size specification, each string should be in WxH format (e.g., `\\\"48x48\\\"`, `\\\"96x96\\\"`) or `\\\"any\\\"` for scalable formats like SVG", + "type": [ + "array", + "null" + ], + "items": { + "type": "string" + } + }, + "src": { + "description": "A standard URI pointing to an icon resource", + "type": "string" + } + }, + "required": [ + "src" + ] + }, + "Implementation": { + "type": "object", + "properties": { + "icons": { + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/Icon" + } + }, + "name": { + "type": "string" + }, + "title": { + "type": [ + "string", + "null" + ] + }, + "version": { + "type": "string" + }, + "websiteUrl": { + "type": [ + "string", + "null" + ] + } + }, + "required": [ + "name", + "version" + ] + }, + "InitializeResult": { + "description": "The server's response to an initialization request.\n\nContains the server's protocol version, capabilities, and implementation\ninformation, along with optional instructions for the client.", + "type": "object", + "properties": { + "capabilities": { + "description": "The capabilities this server provides (tools, resources, prompts, etc.)", + "allOf": [ + { + "$ref": "#/definitions/ServerCapabilities" + } + ] + }, + "instructions": { + "description": "Optional human-readable instructions about using this server", + "type": [ + "string", + "null" + ] + }, + "protocolVersion": { + "description": "The MCP protocol version this server supports", + "allOf": [ + { + "$ref": "#/definitions/ProtocolVersion" + } + ] + }, + "serverInfo": { + "description": "Information about the server implementation", + "allOf": [ + { + "$ref": "#/definitions/Implementation" + } + ] + } + }, + "required": [ + "protocolVersion", + "capabilities", + "serverInfo" + ] + }, + "IntegerSchema": { + "description": "Schema definition for integer properties.\n\nCompliant with MCP 2025-06-18 specification for elicitation schemas.\nSupports only the fields allowed by the MCP spec.", + "type": "object", + "properties": { + "description": { + "description": "Human-readable description", + "type": [ + "string", + "null" + ] + }, + "maximum": { + "description": "Maximum value (inclusive)", + "type": [ + "integer", + "null" + ], + "format": "int64" + }, + "minimum": { + "description": "Minimum value (inclusive)", + "type": [ + "integer", + "null" + ], + "format": "int64" + }, + "title": { + "description": "Optional title for the schema", + "type": [ + "string", + "null" + ] + }, + "type": { + "description": "Type discriminator", + "allOf": [ + { + "$ref": "#/definitions/IntegerTypeConst" + } + ] + } + }, + "required": [ + "type" + ] + }, + "IntegerTypeConst": { + "type": "string", + "format": "const", + "const": "integer" + }, + "JsonRpcError": { + "type": "object", + "properties": { + "error": { + "$ref": "#/definitions/ErrorData" + }, + "id": { + "$ref": "#/definitions/NumberOrString" + }, + "jsonrpc": { + "$ref": "#/definitions/JsonRpcVersion2_0" + } + }, + "required": [ + "jsonrpc", + "id", + "error" + ] + }, + "JsonRpcNotification": { + "type": "object", + "properties": { + "jsonrpc": { + "$ref": "#/definitions/JsonRpcVersion2_0" + } + }, + "anyOf": [ + { + "$ref": "#/definitions/Notification" + }, + { + "$ref": "#/definitions/Notification2" + }, + { + "$ref": "#/definitions/Notification3" + }, + { + "$ref": "#/definitions/Notification4" + }, + { + "$ref": "#/definitions/NotificationNoParam" + }, + { + "$ref": "#/definitions/NotificationNoParam2" + }, + { + "$ref": "#/definitions/NotificationNoParam3" + } + ], + "required": [ + "jsonrpc" + ] + }, + "JsonRpcRequest": { + "type": "object", + "properties": { + "id": { + "$ref": "#/definitions/NumberOrString" + }, + "jsonrpc": { + "$ref": "#/definitions/JsonRpcVersion2_0" + } + }, + "anyOf": [ + { + "$ref": "#/definitions/RequestNoParam" + }, + { + "$ref": "#/definitions/Request" + }, + { + "$ref": "#/definitions/RequestNoParam2" + }, + { + "$ref": "#/definitions/Request2" + } + ], + "required": [ + "jsonrpc", + "id" + ] + }, + "JsonRpcResponse": { + "type": "object", + "properties": { + "id": { + "$ref": "#/definitions/NumberOrString" + }, + "jsonrpc": { + "$ref": "#/definitions/JsonRpcVersion2_0" + }, + "result": { + "$ref": "#/definitions/ServerResult" + } + }, + "required": [ + "jsonrpc", + "id", + "result" + ] + }, + "JsonRpcVersion2_0": { + "type": "string", + "format": "const", + "const": "2.0" + }, + "ListPromptsResult": { + "type": "object", + "properties": { + "nextCursor": { + "type": [ + "string", + "null" + ] + }, + "prompts": { + "type": "array", + "items": { + "$ref": "#/definitions/Prompt" + } + } + }, + "required": [ + "prompts" + ] + }, + "ListResourceTemplatesResult": { + "type": "object", + "properties": { + "nextCursor": { + "type": [ + "string", + "null" + ] + }, + "resourceTemplates": { + "type": "array", + "items": { + "$ref": "#/definitions/Annotated4" + } + } + }, + "required": [ + "resourceTemplates" + ] + }, + "ListResourcesResult": { + "type": "object", + "properties": { + "nextCursor": { + "type": [ + "string", + "null" + ] + }, + "resources": { + "type": "array", + "items": { + "$ref": "#/definitions/Annotated3" + } + } + }, + "required": [ + "resources" + ] + }, + "ListRootsRequestMethod": { + "type": "string", + "format": "const", + "const": "roots/list" + }, + "ListToolsResult": { + "type": "object", + "properties": { + "nextCursor": { + "type": [ + "string", + "null" + ] + }, + "tools": { + "type": "array", + "items": { + "$ref": "#/definitions/Tool" + } + } + }, + "required": [ + "tools" + ] + }, + "LoggingLevel": { + "description": "Logging levels supported by the MCP protocol", + "type": "string", + "enum": [ + "debug", + "info", + "notice", + "warning", + "error", + "critical", + "alert", + "emergency" + ] + }, + "LoggingMessageNotificationMethod": { + "type": "string", + "format": "const", + "const": "notifications/message" + }, + "LoggingMessageNotificationParam": { + "description": "Parameters for a logging message notification", + "type": "object", + "properties": { + "data": { + "description": "The actual log data" + }, + "level": { + "description": "The severity level of this log message", + "allOf": [ + { + "$ref": "#/definitions/LoggingLevel" + } + ] + }, + "logger": { + "description": "Optional logger name that generated this message", + "type": [ + "string", + "null" + ] + } + }, + "required": [ + "level", + "data" + ] + }, + "ModelHint": { + "description": "A hint suggesting a preferred model name or family.\n\nModel hints are advisory suggestions that help clients choose appropriate\nmodels. They can be specific model names or general families like \"claude\" or \"gpt\".", + "type": "object", + "properties": { + "name": { + "description": "The suggested model name or family identifier", + "type": [ + "string", + "null" + ] + } + } + }, + "ModelPreferences": { + "description": "Preferences for model selection and behavior in sampling requests.\n\nThis allows servers to express their preferences for which model to use\nand how to balance different priorities when the client has multiple\nmodel options available.", + "type": "object", + "properties": { + "costPriority": { + "description": "Priority for cost optimization (0.0 to 1.0, higher = prefer cheaper models)", + "type": [ + "number", + "null" + ], + "format": "float" + }, + "hints": { + "description": "Specific model names or families to prefer (e.g., \"claude\", \"gpt\")", + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/ModelHint" + } + }, + "intelligencePriority": { + "description": "Priority for intelligence/capability (0.0 to 1.0, higher = prefer more capable models)", + "type": [ + "number", + "null" + ], + "format": "float" + }, + "speedPriority": { + "description": "Priority for speed/latency (0.0 to 1.0, higher = prefer faster models)", + "type": [ + "number", + "null" + ], + "format": "float" + } + } + }, + "Notification": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/CancelledNotificationMethod" + }, + "params": { + "$ref": "#/definitions/CancelledNotificationParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "Notification2": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/ProgressNotificationMethod" + }, + "params": { + "$ref": "#/definitions/ProgressNotificationParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "Notification3": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/LoggingMessageNotificationMethod" + }, + "params": { + "$ref": "#/definitions/LoggingMessageNotificationParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "Notification4": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/ResourceUpdatedNotificationMethod" + }, + "params": { + "$ref": "#/definitions/ResourceUpdatedNotificationParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "NotificationNoParam": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/ResourceListChangedNotificationMethod" + } + }, + "required": [ + "method" + ] + }, + "NotificationNoParam2": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/ToolListChangedNotificationMethod" + } + }, + "required": [ + "method" + ] + }, + "NotificationNoParam3": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/PromptListChangedNotificationMethod" + } + }, + "required": [ + "method" + ] + }, + "NumberOrString": { + "oneOf": [ + { + "type": "number" + }, + { + "type": "string" + } + ] + }, + "NumberSchema": { + "description": "Schema definition for number properties (floating-point).\n\nCompliant with MCP 2025-06-18 specification for elicitation schemas.\nSupports only the fields allowed by the MCP spec.", + "type": "object", + "properties": { + "description": { + "description": "Human-readable description", + "type": [ + "string", + "null" + ] + }, + "maximum": { + "description": "Maximum value (inclusive)", + "type": [ + "number", + "null" + ], + "format": "double" + }, + "minimum": { + "description": "Minimum value (inclusive)", + "type": [ + "number", + "null" + ], + "format": "double" + }, + "title": { + "description": "Optional title for the schema", + "type": [ + "string", + "null" + ] + }, + "type": { + "description": "Type discriminator", + "allOf": [ + { + "$ref": "#/definitions/NumberTypeConst" + } + ] + } + }, + "required": [ + "type" + ] + }, + "NumberTypeConst": { + "type": "string", + "format": "const", + "const": "number" + }, + "ObjectTypeConst": { + "type": "string", + "format": "const", + "const": "object" + }, + "PingRequestMethod": { + "type": "string", + "format": "const", + "const": "ping" + }, + "PrimitiveSchema": { + "description": "Primitive schema definition for elicitation properties.\n\nAccording to MCP 2025-06-18 specification, elicitation schemas must have\nproperties of primitive types only (string, number, integer, boolean, enum).", + "anyOf": [ + { + "description": "String property (with optional enum constraint)", + "allOf": [ + { + "$ref": "#/definitions/StringSchema" + } + ] + }, + { + "description": "Number property (with optional enum constraint)", + "allOf": [ + { + "$ref": "#/definitions/NumberSchema" + } + ] + }, + { + "description": "Integer property (with optional enum constraint)", + "allOf": [ + { + "$ref": "#/definitions/IntegerSchema" + } + ] + }, + { + "description": "Boolean property", + "allOf": [ + { + "$ref": "#/definitions/BooleanSchema" + } + ] + }, + { + "description": "Enum property (explicit enum schema)", + "allOf": [ + { + "$ref": "#/definitions/EnumSchema" + } + ] + } + ] + }, + "ProgressNotificationMethod": { + "type": "string", + "format": "const", + "const": "notifications/progress" + }, + "ProgressNotificationParam": { + "type": "object", + "properties": { + "message": { + "description": "An optional message describing the current progress.", + "type": [ + "string", + "null" + ] + }, + "progress": { + "description": "The progress thus far. This should increase every time progress is made, even if the total is unknown.", + "type": "number", + "format": "double" + }, + "progressToken": { + "$ref": "#/definitions/ProgressToken" + }, + "total": { + "description": "Total number of items to process (or total progress required), if known", + "type": [ + "number", + "null" + ], + "format": "double" + } + }, + "required": [ + "progressToken", + "progress" + ] + }, + "ProgressToken": { + "description": "A token used to track the progress of long-running operations.\n\nProgress tokens allow clients and servers to associate progress notifications\nwith specific requests, enabling real-time updates on operation status.", + "allOf": [ + { + "$ref": "#/definitions/NumberOrString" + } + ] + }, + "Prompt": { + "description": "A prompt that can be used to generate text from a model", + "type": "object", + "properties": { + "arguments": { + "description": "Optional arguments that can be passed to customize the prompt", + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/PromptArgument" + } + }, + "description": { + "description": "Optional description of what the prompt does", + "type": [ + "string", + "null" + ] + }, + "icons": { + "description": "Optional list of icons for the prompt", + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/Icon" + } + }, + "name": { + "description": "The name of the prompt", + "type": "string" + }, + "title": { + "type": [ + "string", + "null" + ] + } + }, + "required": [ + "name" + ] + }, + "PromptArgument": { + "description": "Represents a prompt argument that can be passed to customize the prompt", + "type": "object", + "properties": { + "description": { + "description": "A description of what the argument is used for", + "type": [ + "string", + "null" + ] + }, + "name": { + "description": "The name of the argument", + "type": "string" + }, + "required": { + "description": "Whether this argument is required", + "type": [ + "boolean", + "null" + ] + }, + "title": { + "description": "A human-readable title for the argument", + "type": [ + "string", + "null" + ] + } + }, + "required": [ + "name" + ] + }, + "PromptListChangedNotificationMethod": { + "type": "string", + "format": "const", + "const": "notifications/prompts/list_changed" + }, + "PromptMessage": { + "description": "A message in a prompt conversation", + "type": "object", + "properties": { + "content": { + "description": "The content of the message", + "allOf": [ + { + "$ref": "#/definitions/PromptMessageContent" + } + ] + }, + "role": { + "description": "The role of the message sender", + "allOf": [ + { + "$ref": "#/definitions/PromptMessageRole" + } + ] + } + }, + "required": [ + "role", + "content" + ] + }, + "PromptMessageContent": { + "description": "Content types that can be included in prompt messages", + "oneOf": [ + { + "description": "Plain text content", + "type": "object", + "properties": { + "text": { + "type": "string" + }, + "type": { + "type": "string", + "const": "text" + } + }, + "required": [ + "type", + "text" + ] + }, + { + "description": "Image content with base64-encoded data", + "type": "object", + "properties": { + "_meta": { + "description": "Optional protocol-level metadata for this content block", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "annotations": { + "anyOf": [ + { + "$ref": "#/definitions/Annotations" + }, + { + "type": "null" + } + ] + }, + "data": { + "description": "The base64-encoded image", + "type": "string" + }, + "mimeType": { + "type": "string" + }, + "type": { + "type": "string", + "const": "image" + } + }, + "required": [ + "type", + "data", + "mimeType" + ] + }, + { + "description": "Embedded server-side resource", + "type": "object", + "properties": { + "resource": { + "$ref": "#/definitions/Annotated2" + }, + "type": { + "type": "string", + "const": "resource" + } + }, + "required": [ + "type", + "resource" + ] + }, + { + "description": "A link to a resource that can be fetched separately", + "type": "object", + "properties": { + "annotations": { + "anyOf": [ + { + "$ref": "#/definitions/Annotations" + }, + { + "type": "null" + } + ] + }, + "description": { + "description": "Optional description of the resource", + "type": [ + "string", + "null" + ] + }, + "icons": { + "description": "Optional list of icons for the resource", + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/Icon" + } + }, + "mimeType": { + "description": "MIME type of the resource content (\"text\" or \"blob\")", + "type": [ + "string", + "null" + ] + }, + "name": { + "description": "Name of the resource", + "type": "string" + }, + "size": { + "description": "The size of the raw resource content, in bytes (i.e., before base64 encoding or any tokenization), if known.\n\nThis can be used by Hosts to display file sizes and estimate context window us", + "type": [ + "integer", + "null" + ], + "format": "uint32", + "minimum": 0 + }, + "title": { + "description": "Human-readable title of the resource", + "type": [ + "string", + "null" + ] + }, + "type": { + "type": "string", + "const": "resource_link" + }, + "uri": { + "description": "URI representing the resource location (e.g., \"file:///path/to/file\" or \"str:///content\")", + "type": "string" + } + }, + "required": [ + "type", + "uri", + "name" + ] + } + ] + }, + "PromptMessageRole": { + "description": "Represents the role of a message sender in a prompt conversation", + "type": "string", + "enum": [ + "user", + "assistant" + ] + }, + "PromptsCapability": { + "type": "object", + "properties": { + "listChanged": { + "type": [ + "boolean", + "null" + ] + } + } + }, + "ProtocolVersion": { + "description": "Represents the MCP protocol version used for communication.\n\nThis ensures compatibility between clients and servers by specifying\nwhich version of the Model Context Protocol is being used.", + "type": "string" + }, + "RawAudioContent": { + "type": "object", + "properties": { + "data": { + "type": "string" + }, + "mimeType": { + "type": "string" + } + }, + "required": [ + "data", + "mimeType" + ] + }, + "RawEmbeddedResource": { + "type": "object", + "properties": { + "_meta": { + "description": "Optional protocol-level metadata for this content block", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "resource": { + "$ref": "#/definitions/ResourceContents" + } + }, + "required": [ + "resource" + ] + }, + "RawImageContent": { + "type": "object", + "properties": { + "_meta": { + "description": "Optional protocol-level metadata for this content block", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "data": { + "description": "The base64-encoded image", + "type": "string" + }, + "mimeType": { + "type": "string" + } + }, + "required": [ + "data", + "mimeType" + ] + }, + "RawResource": { + "description": "Represents a resource in the extension with metadata", + "type": "object", + "properties": { + "description": { + "description": "Optional description of the resource", + "type": [ + "string", + "null" + ] + }, + "icons": { + "description": "Optional list of icons for the resource", + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/Icon" + } + }, + "mimeType": { + "description": "MIME type of the resource content (\"text\" or \"blob\")", + "type": [ + "string", + "null" + ] + }, + "name": { + "description": "Name of the resource", + "type": "string" + }, + "size": { + "description": "The size of the raw resource content, in bytes (i.e., before base64 encoding or any tokenization), if known.\n\nThis can be used by Hosts to display file sizes and estimate context window us", + "type": [ + "integer", + "null" + ], + "format": "uint32", + "minimum": 0 + }, + "title": { + "description": "Human-readable title of the resource", + "type": [ + "string", + "null" + ] + }, + "uri": { + "description": "URI representing the resource location (e.g., \"file:///path/to/file\" or \"str:///content\")", + "type": "string" + } + }, + "required": [ + "uri", + "name" + ] + }, + "RawTextContent": { + "type": "object", + "properties": { + "_meta": { + "description": "Optional protocol-level metadata for this content block", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "text": { + "type": "string" + } + }, + "required": [ + "text" + ] + }, + "ReadResourceResult": { + "description": "Result containing the contents of a read resource", + "type": "object", + "properties": { + "contents": { + "description": "The actual content of the resource", + "type": "array", + "items": { + "$ref": "#/definitions/ResourceContents" + } + } + }, + "required": [ + "contents" + ] + }, + "Request": { + "description": "Represents a JSON-RPC request with method, parameters, and extensions.\n\nThis is the core structure for all MCP requests, containing:\n- `method`: The name of the method being called\n- `params`: The parameters for the method\n- `extensions`: Additional context data (similar to HTTP headers)", + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/CreateMessageRequestMethod" + }, + "params": { + "$ref": "#/definitions/CreateMessageRequestParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "Request2": { + "description": "Represents a JSON-RPC request with method, parameters, and extensions.\n\nThis is the core structure for all MCP requests, containing:\n- `method`: The name of the method being called\n- `params`: The parameters for the method\n- `extensions`: Additional context data (similar to HTTP headers)", + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/ElicitationCreateRequestMethod" + }, + "params": { + "$ref": "#/definitions/CreateElicitationRequestParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "RequestNoParam": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/PingRequestMethod" + } + }, + "required": [ + "method" + ] + }, + "RequestNoParam2": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/ListRootsRequestMethod" + } + }, + "required": [ + "method" + ] + }, + "ResourceContents": { + "anyOf": [ + { + "type": "object", + "properties": { + "_meta": { + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "mimeType": { + "type": [ + "string", + "null" + ] + }, + "text": { + "type": "string" + }, + "uri": { + "type": "string" + } + }, + "required": [ + "uri", + "text" + ] + }, + { + "type": "object", + "properties": { + "_meta": { + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "blob": { + "type": "string" + }, + "mimeType": { + "type": [ + "string", + "null" + ] + }, + "uri": { + "type": "string" + } + }, + "required": [ + "uri", + "blob" + ] + } + ] + }, + "ResourceListChangedNotificationMethod": { + "type": "string", + "format": "const", + "const": "notifications/resources/list_changed" + }, + "ResourceUpdatedNotificationMethod": { + "type": "string", + "format": "const", + "const": "notifications/resources/updated" + }, + "ResourceUpdatedNotificationParam": { + "description": "Parameters for a resource update notification", + "type": "object", + "properties": { + "uri": { + "description": "The URI of the resource that was updated", + "type": "string" + } + }, + "required": [ + "uri" + ] + }, + "ResourcesCapability": { + "type": "object", + "properties": { + "listChanged": { + "type": [ + "boolean", + "null" + ] + }, + "subscribe": { + "type": [ + "boolean", + "null" + ] + } + } + }, + "Role": { + "description": "Represents the role of a participant in a conversation or message exchange.\n\nUsed in sampling and chat contexts to distinguish between different\ntypes of message senders in the conversation flow.", + "oneOf": [ + { + "description": "A human user or client making a request", + "type": "string", + "const": "user" + }, + { + "description": "An AI assistant or server providing a response", + "type": "string", + "const": "assistant" + } + ] + }, + "SamplingMessage": { + "description": "A message in a sampling conversation, containing a role and content.\n\nThis represents a single message in a conversation flow, used primarily\nin LLM sampling requests where the conversation history is important\nfor generating appropriate responses.", + "type": "object", + "properties": { + "content": { + "description": "The actual content of the message (text, image, etc.)", + "allOf": [ + { + "$ref": "#/definitions/Annotated" + } + ] + }, + "role": { + "description": "The role of the message sender (User or Assistant)", + "allOf": [ + { + "$ref": "#/definitions/Role" + } + ] + } + }, + "required": [ + "role", + "content" + ] + }, + "ServerCapabilities": { + "title": "Builder", + "description": "```rust\n# use rmcp::model::ServerCapabilities;\nlet cap = ServerCapabilities::builder()\n .enable_logging()\n .enable_experimental()\n .enable_prompts()\n .enable_resources()\n .enable_tools()\n .enable_tool_list_changed()\n .build();\n```", + "type": "object", + "properties": { + "completions": { + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "experimental": { + "type": [ + "object", + "null" + ], + "additionalProperties": { + "type": "object", + "additionalProperties": true + } + }, + "logging": { + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "prompts": { + "anyOf": [ + { + "$ref": "#/definitions/PromptsCapability" + }, + { + "type": "null" + } + ] + }, + "resources": { + "anyOf": [ + { + "$ref": "#/definitions/ResourcesCapability" + }, + { + "type": "null" + } + ] + }, + "tools": { + "anyOf": [ + { + "$ref": "#/definitions/ToolsCapability" + }, + { + "type": "null" + } + ] + } + } + }, + "ServerResult": { + "anyOf": [ + { + "$ref": "#/definitions/InitializeResult" + }, + { + "$ref": "#/definitions/CompleteResult" + }, + { + "$ref": "#/definitions/GetPromptResult" + }, + { + "$ref": "#/definitions/ListPromptsResult" + }, + { + "$ref": "#/definitions/ListResourcesResult" + }, + { + "$ref": "#/definitions/ListResourceTemplatesResult" + }, + { + "$ref": "#/definitions/ReadResourceResult" + }, + { + "$ref": "#/definitions/CallToolResult" + }, + { + "$ref": "#/definitions/ListToolsResult" + }, + { + "$ref": "#/definitions/CreateElicitationResult" + }, + { + "$ref": "#/definitions/EmptyObject" + } + ] + }, + "StringFormat": { + "description": "String format types allowed by the MCP specification.", + "oneOf": [ + { + "description": "Email address format", + "type": "string", + "const": "email" + }, + { + "description": "URI format", + "type": "string", + "const": "uri" + }, + { + "description": "Date format (YYYY-MM-DD)", + "type": "string", + "const": "date" + }, + { + "description": "Date-time format (ISO 8601)", + "type": "string", + "const": "date-time" + } + ] + }, + "StringSchema": { + "description": "Schema definition for string properties.\n\nCompliant with MCP 2025-06-18 specification for elicitation schemas.\nSupports only the fields allowed by the MCP spec:\n- format limited to: \"email\", \"uri\", \"date\", \"date-time\"", + "type": "object", + "properties": { + "description": { + "description": "Human-readable description", + "type": [ + "string", + "null" + ] + }, + "format": { + "description": "String format - limited to: \"email\", \"uri\", \"date\", \"date-time\"", + "anyOf": [ + { + "$ref": "#/definitions/StringFormat" + }, + { + "type": "null" + } + ] + }, + "maxLength": { + "description": "Maximum string length", + "type": [ + "integer", + "null" + ], + "format": "uint32", + "minimum": 0 + }, + "minLength": { + "description": "Minimum string length", + "type": [ + "integer", + "null" + ], + "format": "uint32", + "minimum": 0 + }, + "title": { + "description": "Optional title for the schema", + "type": [ + "string", + "null" + ] + }, + "type": { + "description": "Type discriminator", + "allOf": [ + { + "$ref": "#/definitions/StringTypeConst" + } + ] + } + }, + "required": [ + "type" + ] + }, + "StringTypeConst": { + "type": "string", + "format": "const", + "const": "string" + }, + "Tool": { + "description": "A tool that can be used by a model.", + "type": "object", + "properties": { + "annotations": { + "description": "Optional additional tool information.", + "anyOf": [ + { + "$ref": "#/definitions/ToolAnnotations" + }, + { + "type": "null" + } + ] + }, + "description": { + "description": "A description of what the tool does", + "type": [ + "string", + "null" + ] + }, + "icons": { + "description": "Optional list of icons for the tool", + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/Icon" + } + }, + "inputSchema": { + "description": "A JSON Schema object defining the expected parameters for the tool", + "type": "object", + "additionalProperties": true + }, + "name": { + "description": "The name of the tool", + "type": "string" + }, + "outputSchema": { + "description": "An optional JSON Schema object defining the structure of the tool's output", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "title": { + "description": "A human-readable title for the tool", + "type": [ + "string", + "null" + ] + } + }, + "required": [ + "name", + "inputSchema" + ] + }, + "ToolAnnotations": { + "description": "Additional properties describing a Tool to clients.\n\nNOTE: all properties in ToolAnnotations are **hints**.\nThey are not guaranteed to provide a faithful description of\ntool behavior (including descriptive properties like `title`).\n\nClients should never make tool use decisions based on ToolAnnotations\nreceived from untrusted servers.", + "type": "object", + "properties": { + "destructiveHint": { + "description": "If true, the tool may perform destructive updates to its environment.\nIf false, the tool performs only additive updates.\n\n(This property is meaningful only when `readOnlyHint == false`)\n\nDefault: true\nA human-readable description of the tool's purpose.", + "type": [ + "boolean", + "null" + ] + }, + "idempotentHint": { + "description": "If true, calling the tool repeatedly with the same arguments\nwill have no additional effect on the its environment.\n\n(This property is meaningful only when `readOnlyHint == false`)\n\nDefault: false.", + "type": [ + "boolean", + "null" + ] + }, + "openWorldHint": { + "description": "If true, this tool may interact with an \"open world\" of external\nentities. If false, the tool's domain of interaction is closed.\nFor example, the world of a web search tool is open, whereas that\nof a memory tool is not.\n\nDefault: true", + "type": [ + "boolean", + "null" + ] + }, + "readOnlyHint": { + "description": "If true, the tool does not modify its environment.\n\nDefault: false", + "type": [ + "boolean", + "null" + ] + }, + "title": { + "description": "A human-readable title for the tool.", + "type": [ + "string", + "null" + ] + } + } + }, + "ToolListChangedNotificationMethod": { + "type": "string", + "format": "const", + "const": "notifications/tools/list_changed" + }, + "ToolsCapability": { + "type": "object", + "properties": { + "listChanged": { + "type": [ + "boolean", + "null" + ] + } + } + } + } +} \ No newline at end of file diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_notification.rs b/code-rs/third_party/rmcp-0.8.3/tests/test_notification.rs new file mode 100644 index 00000000000..a46ac2fd568 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_notification.rs @@ -0,0 +1,95 @@ +use std::sync::Arc; + +use rmcp::{ + ClientHandler, ServerHandler, ServiceExt, + model::{ + ResourceUpdatedNotificationParam, ServerCapabilities, ServerInfo, SubscribeRequestParam, + }, +}; +use tokio::sync::Notify; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +pub struct Server {} + +impl ServerHandler for Server { + fn get_info(&self) -> ServerInfo { + ServerInfo { + capabilities: ServerCapabilities::builder() + .enable_resources() + .enable_resources_subscribe() + .enable_resources_list_changed() + .build(), + ..Default::default() + } + } + + async fn subscribe( + &self, + request: rmcp::model::SubscribeRequestParam, + context: rmcp::service::RequestContext, + ) -> Result<(), rmcp::ErrorData> { + let uri = request.uri; + let peer = context.peer; + + tokio::spawn(async move { + let span = tracing::info_span!("subscribe", uri = %uri); + let _enter = span.enter(); + + if let Err(e) = peer + .notify_resource_updated(ResourceUpdatedNotificationParam { uri: uri.clone() }) + .await + { + panic!("Failed to send notification: {}", e); + } + }); + + Ok(()) + } +} + +pub struct Client { + receive_signal: Arc, +} + +impl ClientHandler for Client { + async fn on_resource_updated( + &self, + params: rmcp::model::ResourceUpdatedNotificationParam, + _context: rmcp::service::NotificationContext, + ) { + let uri = params.uri; + tracing::info!("Resource updated: {}", uri); + self.receive_signal.notify_one(); + } +} + +#[tokio::test] +async fn test_server_notification() -> anyhow::Result<()> { + let _ = tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .try_init(); + let (server_transport, client_transport) = tokio::io::duplex(4096); + tokio::spawn(async move { + let server = Server {}.serve(server_transport).await?; + server.waiting().await?; + anyhow::Ok(()) + }); + let receive_signal = Arc::new(Notify::new()); + let client = Client { + receive_signal: receive_signal.clone(), + } + .serve(client_transport) + .await?; + client + .subscribe(SubscribeRequestParam { + uri: "test://test-resource".to_owned(), + }) + .await?; + receive_signal.notified().await; + client.cancel().await?; + Ok(()) +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_progress_subscriber.rs b/code-rs/third_party/rmcp-0.8.3/tests/test_progress_subscriber.rs new file mode 100644 index 00000000000..531b1692507 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_progress_subscriber.rs @@ -0,0 +1,132 @@ +use futures::StreamExt; +use rmcp::{ + ClientHandler, Peer, RoleServer, ServerHandler, ServiceExt, + handler::{client::progress::ProgressDispatcher, server::tool::ToolRouter}, + model::{CallToolRequestParam, ClientRequest, Meta, ProgressNotificationParam, Request}, + service::PeerRequestOptions, + tool, tool_handler, tool_router, +}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +pub struct MyClient { + progress_handler: ProgressDispatcher, +} + +impl MyClient { + pub fn new() -> Self { + Self { + progress_handler: ProgressDispatcher::new(), + } + } +} + +impl Default for MyClient { + fn default() -> Self { + Self::new() + } +} + +impl ClientHandler for MyClient { + async fn on_progress( + &self, + params: rmcp::model::ProgressNotificationParam, + _context: rmcp::service::NotificationContext, + ) { + tracing::info!("Received progress notification: {:?}", params); + self.progress_handler.handle_notification(params).await; + } +} + +pub struct MyServer { + tool_router: ToolRouter, +} + +impl MyServer { + pub fn new() -> Self { + Self { + tool_router: Self::tool_router(), + } + } +} + +impl Default for MyServer { + fn default() -> Self { + Self::new() + } +} + +#[tool_router] +impl MyServer { + #[tool] + pub async fn some_progress( + meta: Meta, + client: Peer, + ) -> Result<(), rmcp::ErrorData> { + let progress_token = meta + .get_progress_token() + .ok_or(rmcp::ErrorData::invalid_params( + "Progress token is required for this tool", + None, + ))?; + for step in 0..10 { + let _ = client + .notify_progress(ProgressNotificationParam { + progress_token: progress_token.clone(), + progress: (step as f64), + total: Some(10.0), + message: Some("Some message".into()), + }) + .await; + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + } + Ok(()) + } +} + +#[tool_handler] +impl ServerHandler for MyServer {} + +#[tokio::test] +async fn test_progress_subscriber() -> anyhow::Result<()> { + let _ = tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .try_init(); + let client = MyClient::new(); + + let server = MyServer::new(); + let (transport_server, transport_client) = tokio::io::duplex(4096); + tokio::spawn(async move { + let service = server.serve(transport_server).await?; + service.waiting().await?; + anyhow::Ok(()) + }); + let client_service = client.serve(transport_client).await?; + let handle = client_service + .send_cancellable_request( + ClientRequest::CallToolRequest(Request::new(CallToolRequestParam { + name: "some_progress".into(), + arguments: None, + })), + PeerRequestOptions::no_options(), + ) + .await?; + let mut progress_subscriber = client_service + .service() + .progress_handler + .subscribe(handle.progress_token.clone()) + .await; + tokio::spawn(async move { + while let Some(notification) = progress_subscriber.next().await { + tracing::info!("Progress notification: {:?}", notification); + } + }); + let _response = handle.await_response().await?; + + // Simulate some delay to allow the async task to complete + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + Ok(()) +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_prompt_handler.rs b/code-rs/third_party/rmcp-0.8.3/tests/test_prompt_handler.rs new file mode 100644 index 00000000000..86f204347f1 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_prompt_handler.rs @@ -0,0 +1,161 @@ +//cargo test --test test_prompt_handler --features "client server" +// Tests for verifying that the #[prompt_handler] macro correctly generates +// the ServerHandler trait implementation methods. +#![allow(dead_code)] + +use rmcp::{ + RoleServer, ServerHandler, + handler::server::router::prompt::PromptRouter, + model::{GetPromptRequestParam, GetPromptResult, ListPromptsResult, PaginatedRequestParam}, + prompt_handler, + service::RequestContext, +}; + +#[derive(Debug, Clone)] +pub struct TestPromptServer { + prompt_router: PromptRouter, +} + +impl Default for TestPromptServer { + fn default() -> Self { + Self::new() + } +} + +impl TestPromptServer { + pub fn new() -> Self { + Self { + prompt_router: PromptRouter::new(), + } + } +} + +#[prompt_handler] +impl ServerHandler for TestPromptServer {} + +#[derive(Debug, Clone)] +pub struct CustomRouterServer { + custom_router: PromptRouter, +} + +impl Default for CustomRouterServer { + fn default() -> Self { + Self::new() + } +} + +impl CustomRouterServer { + pub fn new() -> Self { + Self { + custom_router: PromptRouter::new(), + } + } + + pub fn get_custom_router(&self) -> &PromptRouter { + &self.custom_router + } +} + +#[prompt_handler(router = self.custom_router)] +impl ServerHandler for CustomRouterServer {} + +#[derive(Debug, Clone)] +pub struct GenericPromptServer { + prompt_router: PromptRouter, + _marker: std::marker::PhantomData, +} + +impl Default for GenericPromptServer { + fn default() -> Self { + Self::new() + } +} + +impl GenericPromptServer { + pub fn new() -> Self { + Self { + prompt_router: PromptRouter::new(), + _marker: std::marker::PhantomData, + } + } +} + +#[prompt_handler] +impl ServerHandler for GenericPromptServer {} + +#[test] +fn test_prompt_handler_basic() { + let server = TestPromptServer::new(); + + // Test that the server implements ServerHandler + fn assert_server_handler(_: &T) {} + assert_server_handler(&server); + + // Test that the prompt router is accessible + assert_eq!(server.prompt_router.list_all().len(), 0); +} + +#[test] +fn test_prompt_handler_custom_router() { + let server = CustomRouterServer::new(); + + // Test that the server implements ServerHandler + fn assert_server_handler(_: &T) {} + assert_server_handler(&server); + + // Test that the custom router is used + assert_eq!(server.custom_router.list_all().len(), 0); +} + +#[test] +fn test_prompt_handler_with_generics() { + let server = GenericPromptServer::::new(); + + // Test that generic server implements ServerHandler + fn assert_server_handler(_: &T) {} + assert_server_handler(&server); + + // Test with a different generic type + let server2 = GenericPromptServer::::new(); + assert_server_handler(&server2); +} + +#[test] +fn test_prompt_handler_trait_implementation() { + // This test verifies that the prompt_handler macro generates proper ServerHandler implementation + // The actual method signatures are tested through the ServerHandler trait bound + fn compile_time_check() {} + + compile_time_check::(); + compile_time_check::(); + compile_time_check::>(); +} + +// Test that the macro works with different server configurations +mod nested { + use super::*; + + #[derive(Debug, Clone)] + pub struct NestedServer { + prompt_router: PromptRouter, + } + + impl NestedServer { + pub fn new() -> Self { + Self { + prompt_router: PromptRouter::new(), + } + } + } + + #[prompt_handler] + impl ServerHandler for NestedServer {} + + #[test] + fn test_nested_prompt_handler() { + let server = NestedServer::new(); + // Verify it implements ServerHandler + fn assert_server_handler(_: &T) {} + assert_server_handler(&server); + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_prompt_macro_annotations.rs b/code-rs/third_party/rmcp-0.8.3/tests/test_prompt_macro_annotations.rs new file mode 100644 index 00000000000..f313927f5b5 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_prompt_macro_annotations.rs @@ -0,0 +1,291 @@ +//cargo test --test test_prompt_macro_annotations --features "client server" +#![allow(dead_code)] + +use rmcp::{ + ServerHandler, + handler::server::wrapper::Parameters, + model::{GetPromptResult, Prompt, PromptMessage, PromptMessageRole}, + prompt, +}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone)] +struct TestServer; + +impl ServerHandler for TestServer {} + +#[derive(Serialize, Deserialize, JsonSchema)] +struct TestArgs { + /// The input text to process + input: String, + /// Optional configuration + #[serde(skip_serializing_if = "Option::is_none")] + config: Option, +} + +#[derive(Serialize, Deserialize, JsonSchema)] +struct ComplexArgs { + /// Required field + required_field: String, + /// Optional string field + #[schemars(description = "An optional string parameter")] + optional_string: Option, + /// Optional number field + optional_number: Option, + /// Array field + items: Vec, +} + +// Test basic prompt attribute generation +#[prompt] +async fn basic_prompt(_server: &TestServer) -> Vec { + vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Basic response", + )] +} + +// Test prompt with custom name +#[prompt(name = "custom_name")] +async fn named_prompt(_server: &TestServer) -> Vec { + vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Named response", + )] +} + +// Test prompt with custom description +#[prompt(description = "This is a custom description")] +async fn described_prompt(_server: &TestServer) -> Vec { + vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Described response", + )] +} + +// Test prompt with both name and description +#[prompt(name = "full_custom", description = "Fully customized prompt")] +async fn fully_custom_prompt(_server: &TestServer) -> Vec { + vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Fully custom response", + )] +} + +// Test prompt with doc comments +/// This is a doc comment description +/// that spans multiple lines +#[prompt] +async fn doc_comment_prompt(_server: &TestServer) -> Vec { + vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Doc comment response", + )] +} + +// Test prompt with doc comments and explicit description (explicit wins) +/// This is a doc comment +#[prompt(description = "This overrides the doc comment")] +async fn override_doc_prompt(_server: &TestServer) -> Vec { + vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Override response", + )] +} + +// Test prompt with arguments +#[prompt] +async fn args_prompt(_server: &TestServer, _args: Parameters) -> Vec { + vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Args response", + )] +} + +// Test prompt with complex arguments +#[prompt] +async fn complex_args_prompt( + _server: &TestServer, + _args: Parameters, +) -> GetPromptResult { + GetPromptResult { + description: Some("Complex args result".to_string()), + messages: vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Complex response", + )], + } +} + +// Test sync prompt +#[prompt] +fn sync_prompt(_server: &TestServer) -> Vec { + vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Sync response", + )] +} + +#[test] +fn test_basic_prompt_attr() { + let attr = basic_prompt_prompt_attr(); + assert_eq!(attr.name, "basic_prompt"); + assert_eq!(attr.description, None); + assert!(attr.arguments.is_none()); +} + +#[test] +fn test_named_prompt_attr() { + let attr = named_prompt_prompt_attr(); + assert_eq!(attr.name, "custom_name"); + assert_eq!(attr.description, None); + assert!(attr.arguments.is_none()); +} + +#[test] +fn test_described_prompt_attr() { + let attr = described_prompt_prompt_attr(); + assert_eq!(attr.name, "described_prompt"); + assert_eq!( + attr.description.as_deref(), + Some("This is a custom description") + ); + assert!(attr.arguments.is_none()); +} + +#[test] +fn test_fully_custom_prompt_attr() { + let attr = fully_custom_prompt_prompt_attr(); + assert_eq!(attr.name, "full_custom"); + assert_eq!(attr.description.as_deref(), Some("Fully customized prompt")); + assert!(attr.arguments.is_none()); +} + +#[test] +fn test_doc_comment_prompt_attr() { + let attr = doc_comment_prompt_prompt_attr(); + assert_eq!(attr.name, "doc_comment_prompt"); + assert!(attr.description.is_some()); + let desc = attr.description.unwrap(); + assert!(desc.contains("This is a doc comment description")); + assert!(desc.contains("that spans multiple lines")); +} + +#[test] +fn test_override_doc_prompt_attr() { + let attr = override_doc_prompt_prompt_attr(); + assert_eq!(attr.name, "override_doc_prompt"); + assert_eq!( + attr.description.as_deref(), + Some("This overrides the doc comment") + ); +} + +#[test] +fn test_args_prompt_attr() { + let attr = args_prompt_prompt_attr(); + assert_eq!(attr.name, "args_prompt"); + + let args = attr.arguments.as_ref().unwrap(); + assert_eq!(args.len(), 2); + + // Check input field + let input_arg = args.iter().find(|a| a.name == "input").unwrap(); + assert_eq!(input_arg.required, Some(true)); + assert_eq!( + input_arg.description.as_deref(), + Some("The input text to process") + ); + + // Check config field + let config_arg = args.iter().find(|a| a.name == "config").unwrap(); + assert_eq!(config_arg.required, Some(false)); + assert_eq!( + config_arg.description.as_deref(), + Some("Optional configuration") + ); +} + +#[test] +fn test_complex_args_prompt_attr() { + let attr = complex_args_prompt_prompt_attr(); + assert_eq!(attr.name, "complex_args_prompt"); + + let args = attr.arguments.as_ref().unwrap(); + assert_eq!(args.len(), 4); + + // Check required_field + let required_arg = args.iter().find(|a| a.name == "required_field").unwrap(); + assert_eq!(required_arg.required, Some(true)); + assert_eq!(required_arg.description.as_deref(), Some("Required field")); + + // Check optional_string + let optional_string_arg = args.iter().find(|a| a.name == "optional_string").unwrap(); + assert_eq!(optional_string_arg.required, Some(false)); + assert_eq!( + optional_string_arg.description.as_deref(), + Some("An optional string parameter") + ); + + // Check optional_number + let optional_number_arg = args.iter().find(|a| a.name == "optional_number").unwrap(); + assert_eq!(optional_number_arg.required, Some(false)); + assert_eq!( + optional_number_arg.description.as_deref(), + Some("Optional number field") + ); + + // Check items + let items_arg = args.iter().find(|a| a.name == "items").unwrap(); + assert_eq!(items_arg.required, Some(true)); + assert_eq!(items_arg.description.as_deref(), Some("Array field")); +} + +#[test] +fn test_sync_prompt_attr() { + let attr = sync_prompt_prompt_attr(); + assert_eq!(attr.name, "sync_prompt"); + assert!(attr.arguments.is_none()); +} + +#[test] +fn test_prompt_attr_function_type() { + // Test that the generated function returns the correct type + fn assert_prompt_attr_fn(_: impl Fn() -> Prompt) {} + + assert_prompt_attr_fn(basic_prompt_prompt_attr); + assert_prompt_attr_fn(named_prompt_prompt_attr); + assert_prompt_attr_fn(described_prompt_prompt_attr); + assert_prompt_attr_fn(fully_custom_prompt_prompt_attr); + assert_prompt_attr_fn(doc_comment_prompt_prompt_attr); + assert_prompt_attr_fn(override_doc_prompt_prompt_attr); + assert_prompt_attr_fn(args_prompt_prompt_attr); + assert_prompt_attr_fn(complex_args_prompt_prompt_attr); + assert_prompt_attr_fn(sync_prompt_prompt_attr); +} + +// Test generic prompts +#[derive(Debug, Clone)] +struct GenericServer { + _marker: std::marker::PhantomData, +} + +impl ServerHandler for GenericServer {} + +#[prompt] +async fn generic_prompt( + _server: &GenericServer, +) -> Vec { + vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Generic response", + )] +} + +#[test] +fn test_generic_prompt_attr() { + let attr = generic_prompt_prompt_attr(); + assert_eq!(attr.name, "generic_prompt"); + assert!(attr.arguments.is_none()); +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_prompt_macros.rs b/code-rs/third_party/rmcp-0.8.3/tests/test_prompt_macros.rs new file mode 100644 index 00000000000..5d5ece8cba2 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_prompt_macros.rs @@ -0,0 +1,383 @@ +//cargo test --test test_prompt_macros --features "client server" +#![allow(dead_code)] +use std::sync::Arc; + +use rmcp::{ + ClientHandler, RoleServer, ServerHandler, ServiceExt, + handler::server::{router::prompt::PromptRouter, wrapper::Parameters}, + model::{ + ClientInfo, GetPromptRequestParam, GetPromptResult, ListPromptsResult, + PaginatedRequestParam, PromptMessage, PromptMessageRole, + }, + prompt, prompt_handler, prompt_router, + service::RequestContext, +}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, JsonSchema)] +pub struct CodeReviewRequest { + pub file_path: String, + pub language: String, +} + +#[prompt_handler(router = self.prompt_router)] +impl ServerHandler for Server {} + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct Server { + prompt_router: PromptRouter, +} + +impl Default for Server { + fn default() -> Self { + Self::new() + } +} + +#[prompt_router] +impl Server { + pub fn new() -> Self { + Self { + prompt_router: Self::prompt_router(), + } + } + + /// This prompt is used to review code for best practices. + #[prompt( + name = "code-review", + description = "Review code for best practices and issues." + )] + pub async fn code_review(&self, params: Parameters) -> Vec { + vec![ + PromptMessage::new_text( + PromptMessageRole::User, + format!( + "Please review the {} code in: {}", + params.0.language, params.0.file_path + ), + ), + PromptMessage::new_text( + PromptMessageRole::Assistant, + "I'll review this code for best practices and potential issues.".to_string(), + ), + ] + } + + #[prompt] + async fn empty_param(&self) -> Vec { + vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "This is a prompt with no parameters.".to_string(), + )] + } +} + +// define generic service trait +pub trait DataService: Send + Sync + 'static { + fn get_context(&self) -> String; +} + +// mock service for test +#[derive(Clone)] +struct MockDataService; +impl DataService for MockDataService { + fn get_context(&self) -> String { + "mock context data".to_string() + } +} + +// define generic server +#[derive(Debug, Clone)] +pub struct GenericServer { + data_service: Arc, + prompt_router: PromptRouter, +} + +#[prompt_router] +impl GenericServer { + pub fn new(data_service: DS) -> Self { + Self { + data_service: Arc::new(data_service), + prompt_router: Self::prompt_router(), + } + } + + #[prompt(description = "Get contextual help from the service")] + async fn get_help(&self) -> GetPromptResult { + let context = self.data_service.get_context(); + GetPromptResult { + description: Some("Contextual help based on service data".to_string()), + messages: vec![ + PromptMessage::new_text( + PromptMessageRole::User, + "I need help with the current context.".to_string(), + ), + PromptMessage::new_text( + PromptMessageRole::Assistant, + format!( + "Based on the context '{}', here's how I can help...", + context + ), + ), + ], + } + } +} + +#[prompt_handler] +impl ServerHandler for GenericServer {} + +#[tokio::test] +async fn test_prompt_macros() { + let server = Server::new(); + let _attr = Server::code_review_prompt_attr(); + let _code_review_prompt_attr_fn = Server::code_review_prompt_attr; + let _code_review_fn = Server::code_review; + let result = server + .code_review(Parameters(CodeReviewRequest { + file_path: "/src/main.rs".into(), + language: "rust".into(), + })) + .await; + assert_eq!(result.len(), 2); + assert_eq!(result[0].role, PromptMessageRole::User); + assert_eq!(result[1].role, PromptMessageRole::Assistant); +} + +#[tokio::test] +async fn test_prompt_macros_with_empty_param() { + let _attr = Server::empty_param_prompt_attr(); + println!("{_attr:?}"); + assert!( + _attr.arguments.is_none(), + "Empty param prompt should have no arguments" + ); +} + +#[tokio::test] +async fn test_prompt_macros_with_generics() { + let mock_service = MockDataService; + let server = GenericServer::new(mock_service); + let _attr = GenericServer::::get_help_prompt_attr(); + let _get_help_call_fn = GenericServer::::get_help; + let _get_help_fn = GenericServer::::get_help; + let result = server.get_help().await; + assert!(result.description.is_some()); + assert_eq!(result.messages.len(), 2); + match &result.messages[1].content { + rmcp::model::PromptMessageContent::Text { text } => { + assert!(text.contains("mock context data")); + } + _ => panic!("Expected text content"), + } +} + +#[tokio::test] +async fn test_prompt_macros_with_optional_param() { + let _attr = Server::code_review_prompt_attr(); + let arguments = _attr.arguments.as_ref().unwrap(); + + // Check that we have the expected number of arguments + assert_eq!(arguments.len(), 2); + + // Verify file_path is required + let file_path_arg = arguments.iter().find(|a| a.name == "file_path").unwrap(); + assert_eq!(file_path_arg.required, Some(true)); + + // Verify language is required + let language_arg = arguments.iter().find(|a| a.name == "language").unwrap(); + assert_eq!(language_arg.required, Some(true)); +} + +impl CodeReviewRequest {} + +// Struct defined for testing optional field schema generation +#[derive(Debug, Deserialize, Serialize, JsonSchema)] +pub struct OptionalFieldTestSchema { + #[schemars(description = "An optional description field")] + pub description: Option, +} + +// Struct defined for testing optional i64 field schema generation and null handling +#[derive(Debug, Deserialize, Serialize, JsonSchema)] +pub struct OptionalI64TestSchema { + #[schemars(description = "An optional i64 field")] + pub count: Option, + pub mandatory_field: String, // Added to ensure non-empty object schema +} + +// Dummy struct to host the test prompt method +#[derive(Debug, Clone)] +pub struct OptionalSchemaTester { + prompt_router: PromptRouter, +} + +impl Default for OptionalSchemaTester { + fn default() -> Self { + Self::new() + } +} + +impl OptionalSchemaTester { + pub fn new() -> Self { + Self { + prompt_router: Self::prompt_router(), + } + } +} + +#[prompt_router] +impl OptionalSchemaTester { + // Dummy prompt function using the test schema as an aggregated parameter + #[prompt(description = "A prompt to test optional schema generation")] + async fn test_optional(&self, _req: Parameters) -> Vec { + vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Testing optional fields".to_string(), + )] + } + + // Prompt function to test optional i64 handling + #[prompt(description = "A prompt to test optional i64 schema generation")] + async fn test_optional_i64( + &self, + Parameters(req): Parameters, + ) -> GetPromptResult { + let message = match req.count { + Some(c) => format!("Received count: {}", c), + None => "Received null count".to_string(), + }; + + GetPromptResult { + description: Some("Test result for optional i64".to_string()), + messages: vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + message, + )], + } + } +} + +#[prompt_handler] +// Implement ServerHandler to route prompt calls for OptionalSchemaTester +impl ServerHandler for OptionalSchemaTester {} + +#[test] +fn test_optional_field_schema_generation_via_macro() { + // tests https://github.com/modelcontextprotocol/rust-sdk/issues/135 + + // Get the attributes generated by the #[prompt] macro helper + let prompt_attr = OptionalSchemaTester::test_optional_prompt_attr(); + + // Print the actual generated schema for debugging + println!( + "Actual arguments generated by macro: {:#?}", + prompt_attr.arguments + ); + + // Verify the schema generated for the aggregated OptionalFieldTestSchema + let arguments = prompt_attr.arguments.expect("Should have arguments"); + + // Check that we have an argument for the optional description field + let description_arg = arguments + .iter() + .find(|arg| arg.name == "description") + .expect("Should have description argument"); + + // Assert that optional fields are marked as not required + assert_eq!( + description_arg.required, + Some(false), + "Optional fields should be marked as not required" + ); + + // Check the description is correct + assert_eq!( + description_arg.description.as_deref(), + Some("An optional description field") + ); +} + +// Define a dummy client handler +#[derive(Debug, Clone, Default)] +struct DummyClientHandler {} + +impl ClientHandler for DummyClientHandler { + fn get_info(&self) -> ClientInfo { + ClientInfo::default() + } +} + +#[tokio::test] +async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + + // Server setup + let server = OptionalSchemaTester::new(); + let server_handle = tokio::spawn(async move { + server.serve(server_transport).await?.waiting().await?; + anyhow::Ok(()) + }); + + // Create a simple client handler that just forwards prompt calls + let client_handler = DummyClientHandler::default(); + let client = client_handler.serve(client_transport).await?; + + // Test null case + let result = client + .get_prompt(GetPromptRequestParam { + name: "test_optional_i64".into(), + arguments: Some( + serde_json::json!({ + "count": null, + "mandatory_field": "test_null" + }) + .as_object() + .unwrap() + .clone(), + ), + }) + .await?; + + let result_text = match &result.messages.first().unwrap().content { + rmcp::model::PromptMessageContent::Text { text } => text.as_str(), + _ => panic!("Expected text content"), + }; + + assert_eq!( + result_text, "Received null count", + "Null case should return expected message" + ); + + // Test Some case + let some_result = client + .get_prompt(GetPromptRequestParam { + name: "test_optional_i64".into(), + arguments: Some( + serde_json::json!({ + "count": 42, + "mandatory_field": "test_some" + }) + .as_object() + .unwrap() + .clone(), + ), + }) + .await?; + + let some_result_text = match &some_result.messages.first().unwrap().content { + rmcp::model::PromptMessageContent::Text { text } => text.as_str(), + _ => panic!("Expected text content"), + }; + + assert_eq!( + some_result_text, "Received count: 42", + "Some case should return expected message" + ); + + client.cancel().await?; + server_handle.await??; + Ok(()) +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_prompt_routers.rs b/code-rs/third_party/rmcp-0.8.3/tests/test_prompt_routers.rs new file mode 100644 index 00000000000..6dc223b39e6 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_prompt_routers.rs @@ -0,0 +1,105 @@ +use std::collections::HashMap; + +use futures::future::BoxFuture; +use rmcp::{ + ServerHandler, + handler::server::wrapper::Parameters, + model::{GetPromptResult, PromptMessage, PromptMessageRole}, +}; + +#[derive(Debug, Default)] +pub struct TestHandler { + pub _marker: std::marker::PhantomData, +} + +impl ServerHandler for TestHandler {} + +#[derive(Debug, schemars::JsonSchema, serde::Deserialize, serde::Serialize)] +pub struct Request { + pub fields: HashMap, +} + +#[derive(Debug, schemars::JsonSchema, serde::Deserialize, serde::Serialize)] +pub struct Sum { + pub a: i32, + pub b: i32, +} + +#[rmcp::prompt_router(router = "test_router")] +impl TestHandler { + #[rmcp::prompt] + async fn async_method( + &self, + Parameters(Request { fields }): Parameters, + ) -> Vec { + drop(fields); + vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Async method response", + )] + } + + #[rmcp::prompt] + fn sync_method( + &self, + Parameters(Request { fields }): Parameters, + ) -> Vec { + drop(fields); + vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Sync method response", + )] + } +} + +#[rmcp::prompt] +async fn async_function(Parameters(Request { fields }): Parameters) -> Vec { + drop(fields); + vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Async function response", + )] +} + +#[rmcp::prompt] +fn async_function2(_callee: &TestHandler) -> BoxFuture<'_, GetPromptResult> { + Box::pin(async move { + GetPromptResult { + description: Some("Async function 2".to_string()), + messages: vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Async function 2 response", + )], + } + }) +} + +#[test] +fn test_prompt_router() { + let test_prompt_router = TestHandler::<()>::test_router() + .with_route(rmcp::handler::server::router::prompt::PromptRoute::new_dyn( + async_function_prompt_attr(), + |mut context| { + Box::pin(async move { + use rmcp::handler::server::{ + common::FromContextPart, prompt::IntoGetPromptResult, + }; + let params = Parameters::::from_context_part(&mut context)?; + let result = async_function(params).await; + result.into_get_prompt_result() + }) + }, + )) + .with_route(rmcp::handler::server::router::prompt::PromptRoute::new_dyn( + async_function2_prompt_attr(), + |context| { + Box::pin(async move { + use rmcp::handler::server::prompt::IntoGetPromptResult; + let result = async_function2(context.server).await; + result.into_get_prompt_result() + }) + }, + )); + let prompts = test_prompt_router.list_all(); + assert_eq!(prompts.len(), 4); +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_resource_link.rs b/code-rs/third_party/rmcp-0.8.3/tests/test_resource_link.rs new file mode 100644 index 00000000000..685a645c4c0 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_resource_link.rs @@ -0,0 +1,86 @@ +use rmcp::model::{CallToolResult, Content, RawResource}; + +#[test] +fn test_resource_link_in_tool_result() { + // Test creating a tool result with resource links + let resource = RawResource::new("file:///test/file.txt", "test.txt"); + + // Create a tool result with a resource link + let result = CallToolResult::success(vec![ + Content::text("Found a file"), + Content::resource_link(resource), + ]); + + // Serialize to JSON to verify format + let json = serde_json::to_string_pretty(&result).unwrap(); + println!("Tool result with resource link:\n{}", json); + + // Verify JSON contains expected structure + assert!( + json.contains("\"type\":\"resource_link\"") || json.contains("\"type\": \"resource_link\"") + ); + assert!( + json.contains("\"uri\":\"file:///test/file.txt\"") + || json.contains("\"uri\": \"file:///test/file.txt\"") + ); + assert!(json.contains("\"name\":\"test.txt\"") || json.contains("\"name\": \"test.txt\"")); + + // Test deserialization + let deserialized: CallToolResult = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.content.len(), 2); + + // Check the text content + assert!(deserialized.content[0].as_text().is_some()); + + // Check the resource link + let resource_link = deserialized.content[1] + .as_resource_link() + .expect("Expected resource link in content[1]"); + assert_eq!(resource_link.uri, "file:///test/file.txt"); + assert_eq!(resource_link.name, "test.txt"); +} + +#[test] +fn test_resource_link_with_full_metadata() { + let mut resource = RawResource::new("https://example.com/data.json", "API Data"); + resource.description = Some("JSON data from external API".to_string()); + resource.mime_type = Some("application/json".to_string()); + resource.size = Some(1024); + + let result = CallToolResult::success(vec![Content::resource_link(resource)]); + + let json = serde_json::to_string(&result).unwrap(); + let deserialized: CallToolResult = serde_json::from_str(&json).unwrap(); + + let resource_link = deserialized.content[0] + .as_resource_link() + .expect("Expected resource link"); + assert_eq!(resource_link.uri, "https://example.com/data.json"); + assert_eq!(resource_link.name, "API Data"); + assert_eq!( + resource_link.description, + Some("JSON data from external API".to_string()) + ); + assert_eq!( + resource_link.mime_type, + Some("application/json".to_string()) + ); + assert_eq!(resource_link.size, Some(1024)); +} + +#[test] +fn test_mixed_content_types() { + // Test that resource links can be mixed with other content types + let resource = RawResource::new("file:///doc.pdf", "Document"); + + let result = CallToolResult::success(vec![ + Content::text("Processing complete"), + Content::resource_link(resource), + Content::embedded_text("memo://result", "Analysis results here"), + ]); + + assert_eq!(result.content.len(), 3); + assert!(result.content[0].as_text().is_some()); + assert!(result.content[1].as_resource_link().is_some()); + assert!(result.content[2].as_resource().is_some()); +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_resource_link_integration.rs b/code-rs/third_party/rmcp-0.8.3/tests/test_resource_link_integration.rs new file mode 100644 index 00000000000..ab663525853 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_resource_link_integration.rs @@ -0,0 +1,131 @@ +/// Integration tests for resource_link support in both tools and prompts +use rmcp::model::{ + AnnotateAble, CallToolResult, Content, PromptMessage, PromptMessageContent, PromptMessageRole, + RawResource, Resource, +}; + +#[test] +fn test_tool_and_prompt_resource_link_compatibility() { + // Create a resource that can be used in both tools and prompts + let resource = RawResource::new("file:///shared/data.json", "Shared Data"); + let resource_annotated: Resource = resource.clone().no_annotation(); + + // Test 1: Tool returning a resource link + let tool_result = CallToolResult::success(vec![ + Content::text("Found shared data"), + Content::resource_link(resource.clone()), + ]); + + let tool_json = serde_json::to_string(&tool_result).unwrap(); + assert!(tool_json.contains("\"type\":\"resource_link\"")); + + // Test 2: Prompt returning a resource link + let prompt_message = + PromptMessage::new_resource_link(PromptMessageRole::Assistant, resource_annotated.clone()); + + let prompt_json = serde_json::to_string(&prompt_message).unwrap(); + assert!(prompt_json.contains("\"type\":\"resource_link\"")); + + // Test 3: Verify both serialize to the same resource link structure + let tool_content = &tool_result.content[1]; + let prompt_content = &prompt_message.content; + + // Extract just the resource link parts + let tool_resource_json = serde_json::to_value(tool_content).unwrap(); + let prompt_resource_json = serde_json::to_value(prompt_content).unwrap(); + + // Both should have the same structure + assert_eq!( + tool_resource_json.get("type").unwrap(), + prompt_resource_json.get("type").unwrap() + ); + assert_eq!( + tool_resource_json.get("uri").unwrap(), + prompt_resource_json.get("uri").unwrap() + ); + assert_eq!( + tool_resource_json.get("name").unwrap(), + prompt_resource_json.get("name").unwrap() + ); +} + +#[test] +fn test_resource_link_roundtrip() { + // Test that resource links can be serialized and deserialized correctly + // in both tool results and prompt messages + + let mut resource = RawResource::new("https://api.example.com/resource", "API Resource"); + resource.description = Some("External API resource".to_string()); + resource.mime_type = Some("application/json".to_string()); + resource.size = Some(2048); + + // Test with tool result + let tool_result = CallToolResult::success(vec![Content::resource_link(resource.clone())]); + + let tool_json = serde_json::to_string(&tool_result).unwrap(); + let tool_deserialized: CallToolResult = serde_json::from_str(&tool_json).unwrap(); + + if let Some(resource_link) = tool_deserialized.content[0].as_resource_link() { + assert_eq!(resource_link.uri, "https://api.example.com/resource"); + assert_eq!(resource_link.name, "API Resource"); + assert_eq!( + resource_link.description, + Some("External API resource".to_string()) + ); + assert_eq!( + resource_link.mime_type, + Some("application/json".to_string()) + ); + assert_eq!(resource_link.size, Some(2048)); + } else { + panic!("Expected resource link in tool result"); + } + + // Test with prompt message + let prompt_message = PromptMessage { + role: PromptMessageRole::User, + content: PromptMessageContent::resource_link(resource.no_annotation()), + }; + + let prompt_json = serde_json::to_string(&prompt_message).unwrap(); + let prompt_deserialized: PromptMessage = serde_json::from_str(&prompt_json).unwrap(); + + if let PromptMessageContent::ResourceLink { link } = prompt_deserialized.content { + assert_eq!(link.uri, "https://api.example.com/resource"); + assert_eq!(link.name, "API Resource"); + assert_eq!(link.description, Some("External API resource".to_string())); + assert_eq!(link.mime_type, Some("application/json".to_string())); + assert_eq!(link.size, Some(2048)); + } else { + panic!("Expected resource link in prompt message"); + } +} + +#[test] +fn test_mixed_content_in_prompts_and_tools() { + // Test that resource links can be mixed with other content types + // in both prompts and tools + + let resource1 = RawResource::new("file:///doc1.md", "Document 1"); + let resource2 = RawResource::new("file:///doc2.md", "Document 2"); + + // Tool with mixed content + let tool_result = CallToolResult::success(vec![ + Content::text("Processing complete. Found documents:"), + Content::resource_link(resource1.clone()), + Content::resource_link(resource2.clone()), + Content::embedded_text("summary://result", "Both documents processed successfully"), + ]); + + assert_eq!(tool_result.content.len(), 4); + assert!(tool_result.content[0].as_text().is_some()); + assert!(tool_result.content[1].as_resource_link().is_some()); + assert!(tool_result.content[2].as_resource_link().is_some()); + assert!(tool_result.content[3].as_resource().is_some()); + + // Verify serialization includes all types + let json = serde_json::to_string(&tool_result).unwrap(); + assert!(json.contains("\"type\":\"text\"")); + assert!(json.contains("\"type\":\"resource_link\"")); + assert!(json.contains("\"type\":\"resource\"")); +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_sampling.rs b/code-rs/third_party/rmcp-0.8.3/tests/test_sampling.rs new file mode 100644 index 00000000000..b760796f7d3 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_sampling.rs @@ -0,0 +1,321 @@ +//cargo test --test test_sampling --features "client server" + +mod common; + +use anyhow::Result; +use common::handlers::{TestClientHandler, TestServer}; +use rmcp::{ + ServiceExt, + model::*, + service::{RequestContext, Service}, +}; +use tokio_util::sync::CancellationToken; + +#[tokio::test] +async fn test_basic_sampling_message_creation() -> Result<()> { + // Test basic sampling message structure + let message = SamplingMessage { + role: Role::User, + content: Content::text("What is the capital of France?"), + }; + + // Verify serialization/deserialization + let json = serde_json::to_string(&message)?; + let deserialized: SamplingMessage = serde_json::from_str(&json)?; + assert_eq!(message, deserialized); + assert_eq!(message.role, Role::User); + + Ok(()) +} + +#[tokio::test] +async fn test_sampling_request_params() -> Result<()> { + // Test sampling request parameters structure + let params = CreateMessageRequestParam { + messages: vec![SamplingMessage { + role: Role::User, + content: Content::text("Hello, world!"), + }], + model_preferences: Some(ModelPreferences { + hints: Some(vec![ModelHint { + name: Some("claude".to_string()), + }]), + cost_priority: Some(0.5), + speed_priority: Some(0.8), + intelligence_priority: Some(0.7), + }), + system_prompt: Some("You are a helpful assistant.".to_string()), + temperature: Some(0.7), + max_tokens: 100, + stop_sequences: Some(vec!["STOP".to_string()]), + include_context: Some(ContextInclusion::None), + metadata: Some(serde_json::json!({"test": "value"})), + }; + + // Verify serialization/deserialization + let json = serde_json::to_string(¶ms)?; + let deserialized: CreateMessageRequestParam = serde_json::from_str(&json)?; + assert_eq!(params, deserialized); + + // Verify specific fields + assert_eq!(params.messages.len(), 1); + assert_eq!(params.max_tokens, 100); + assert_eq!(params.temperature, Some(0.7)); + + Ok(()) +} + +#[tokio::test] +async fn test_sampling_result_structure() -> Result<()> { + // Test sampling result structure + let result = CreateMessageResult { + message: SamplingMessage { + role: Role::Assistant, + content: Content::text("The capital of France is Paris."), + }, + model: "test-model".to_string(), + stop_reason: Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()), + }; + + // Verify serialization/deserialization + let json = serde_json::to_string(&result)?; + let deserialized: CreateMessageResult = serde_json::from_str(&json)?; + assert_eq!(result, deserialized); + + // Verify specific fields + assert_eq!(result.message.role, Role::Assistant); + assert_eq!(result.model, "test-model"); + assert_eq!( + result.stop_reason, + Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()) + ); + + Ok(()) +} + +#[tokio::test] +async fn test_sampling_context_inclusion_enum() -> Result<()> { + // Test context inclusion enum values + let test_cases = vec![ + (ContextInclusion::None, "none"), + (ContextInclusion::ThisServer, "thisServer"), + (ContextInclusion::AllServers, "allServers"), + ]; + + for (context, expected_json) in test_cases { + let json = serde_json::to_string(&context)?; + assert_eq!(json, format!("\"{}\"", expected_json)); + + let deserialized: ContextInclusion = serde_json::from_str(&json)?; + assert_eq!(context, deserialized); + } + + Ok(()) +} + +#[tokio::test] +async fn test_sampling_integration_with_test_handlers() -> Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + + // Start server + let server_handle = tokio::spawn(async move { + let server = TestServer::new().serve(server_transport).await?; + server.waiting().await?; + anyhow::Ok(()) + }); + + // Start client that honors sampling requests + let handler = TestClientHandler::new(true, true); + let client = handler.clone().serve(client_transport).await?; + + // Wait for initialization + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Test sampling with context inclusion + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { + method: Default::default(), + params: CreateMessageRequestParam { + messages: vec![SamplingMessage { + role: Role::User, + content: Content::text("What is the capital of France?"), + }], + include_context: Some(ContextInclusion::ThisServer), + model_preferences: Some(ModelPreferences { + hints: Some(vec![ModelHint { + name: Some("test-model".to_string()), + }]), + cost_priority: Some(0.5), + speed_priority: Some(0.8), + intelligence_priority: Some(0.7), + }), + system_prompt: Some("You are a helpful assistant.".to_string()), + temperature: Some(0.7), + max_tokens: 100, + stop_sequences: None, + metadata: None, + }, + extensions: Default::default(), + }); + + let result = handler + .handle_request( + request.clone(), + RequestContext { + peer: client.peer().clone(), + ct: CancellationToken::new(), + id: NumberOrString::Number(1), + meta: Default::default(), + extensions: Default::default(), + }, + ) + .await?; + + // Verify the response + if let ClientResult::CreateMessageResult(result) = result { + assert_eq!(result.message.role, Role::Assistant); + assert_eq!(result.model, "test-model"); + assert_eq!( + result.stop_reason, + Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()) + ); + + let response_text = result.message.content.as_text().unwrap().text.as_str(); + assert!( + response_text.contains("test context"), + "Response should include context for ThisServer inclusion" + ); + } else { + panic!("Expected CreateMessageResult"); + } + + client.cancel().await?; + server_handle.await??; + Ok(()) +} + +#[tokio::test] +async fn test_sampling_no_context_inclusion() -> Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + + // Start server + let server_handle = tokio::spawn(async move { + let server = TestServer::new().serve(server_transport).await?; + server.waiting().await?; + anyhow::Ok(()) + }); + + // Start client that honors sampling requests + let handler = TestClientHandler::new(true, true); + let client = handler.clone().serve(client_transport).await?; + + // Wait for initialization + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Test sampling without context inclusion + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { + method: Default::default(), + params: CreateMessageRequestParam { + messages: vec![SamplingMessage { + role: Role::User, + content: Content::text("Hello"), + }], + include_context: Some(ContextInclusion::None), + model_preferences: None, + system_prompt: None, + temperature: None, + max_tokens: 50, + stop_sequences: None, + metadata: None, + }, + extensions: Default::default(), + }); + + let result = handler + .handle_request( + request.clone(), + RequestContext { + peer: client.peer().clone(), + ct: CancellationToken::new(), + id: NumberOrString::Number(2), + meta: Default::default(), + extensions: Default::default(), + }, + ) + .await?; + + // Verify the response + if let ClientResult::CreateMessageResult(result) = result { + assert_eq!(result.message.role, Role::Assistant); + assert_eq!(result.model, "test-model"); + + let response_text = result.message.content.as_text().unwrap().text.as_str(); + assert!( + !response_text.contains("test context"), + "Response should not include context for None inclusion" + ); + } else { + panic!("Expected CreateMessageResult"); + } + + client.cancel().await?; + server_handle.await??; + Ok(()) +} + +#[tokio::test] +async fn test_sampling_error_invalid_message_sequence() -> Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + + // Start server + let server_handle = tokio::spawn(async move { + let server = TestServer::new().serve(server_transport).await?; + server.waiting().await?; + anyhow::Ok(()) + }); + + // Start client + let handler = TestClientHandler::new(true, true); + let client = handler.clone().serve(client_transport).await?; + + // Wait for initialization + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Test sampling with no user messages (should fail) + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { + method: Default::default(), + params: CreateMessageRequestParam { + messages: vec![SamplingMessage { + role: Role::Assistant, + content: Content::text("I'm an assistant message without a user message"), + }], + include_context: Some(ContextInclusion::None), + model_preferences: None, + system_prompt: None, + temperature: None, + max_tokens: 50, + stop_sequences: None, + metadata: None, + }, + extensions: Default::default(), + }); + + let result = handler + .handle_request( + request.clone(), + RequestContext { + peer: client.peer().clone(), + ct: CancellationToken::new(), + id: NumberOrString::Number(3), + meta: Default::default(), + extensions: Default::default(), + }, + ) + .await; + + // This should result in an error + assert!(result.is_err()); + + client.cancel().await?; + server_handle.await??; + Ok(()) +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_structured_output.rs b/code-rs/third_party/rmcp-0.8.3/tests/test_structured_output.rs new file mode 100644 index 00000000000..cb9a11b9fb3 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_structured_output.rs @@ -0,0 +1,282 @@ +//cargo test --test test_structured_output --features "client server macros" +use rmcp::{ + Json, ServerHandler, + handler::server::{router::tool::ToolRouter, tool::IntoCallToolResult, wrapper::Parameters}, + model::{CallToolResult, Content, Tool}, + tool, tool_handler, tool_router, +}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use serde_json::{Value, json}; + +#[derive(Serialize, Deserialize, JsonSchema)] +pub struct CalculationRequest { + pub a: i32, + pub b: i32, +} + +#[derive(Serialize, Deserialize, JsonSchema)] +pub struct CalculationResult { + pub sum: i32, + pub product: i32, +} + +#[derive(Serialize, Deserialize, JsonSchema)] +pub struct UserInfo { + pub name: String, + pub age: u32, +} + +#[tool_handler(router = self.tool_router)] +impl ServerHandler for TestServer {} + +#[derive(Debug, Clone)] +pub struct TestServer { + tool_router: ToolRouter, +} + +impl Default for TestServer { + fn default() -> Self { + Self::new() + } +} + +#[tool_router(router = tool_router)] +impl TestServer { + pub fn new() -> Self { + Self { + tool_router: Self::tool_router(), + } + } + + /// Tool that returns structured output + #[tool(name = "calculate", description = "Perform calculations")] + pub async fn calculate( + &self, + params: Parameters, + ) -> Result, String> { + Ok(Json(CalculationResult { + sum: params.0.a + params.0.b, + product: params.0.a * params.0.b, + })) + } + + /// Tool that returns regular string output + #[tool(name = "get-greeting", description = "Get a greeting")] + pub async fn get_greeting(&self, name: Parameters) -> String { + format!("Hello, {}!", name.0) + } + + /// Tool that returns structured user info + #[tool(name = "get-user", description = "Get user info")] + pub async fn get_user(&self, user_id: Parameters) -> Result, String> { + if user_id.0 == "123" { + Ok(Json(UserInfo { + name: "Alice".to_string(), + age: 30, + })) + } else { + Err("User not found".to_string()) + } + } +} + +#[tokio::test] +async fn test_tool_with_output_schema() { + let server = TestServer::new(); + let tools = server.tool_router.list_all(); + + // Find the calculate tool + let calculate_tool = tools.iter().find(|t| t.name == "calculate").unwrap(); + + // Verify it has an output schema + assert!(calculate_tool.output_schema.is_some()); + + let schema = calculate_tool.output_schema.as_ref().unwrap(); + + // Check that the schema contains expected fields + let schema_str = serde_json::to_string(schema).unwrap(); + assert!(schema_str.contains("sum")); + assert!(schema_str.contains("product")); +} + +#[tokio::test] +async fn test_tool_without_output_schema() { + let server = TestServer::new(); + let tools = server.tool_router.list_all(); + + // Find the get-greeting tool + let greeting_tool = tools.iter().find(|t| t.name == "get-greeting").unwrap(); + + // Verify it doesn't have an output schema (returns String) + assert!(greeting_tool.output_schema.is_none()); +} + +#[tokio::test] +async fn test_structured_content_in_call_result() { + // Test creating a CallToolResult with structured content + let structured_data = json!({ + "sum": 7, + "product": 12 + }); + + let result = CallToolResult::structured(structured_data.clone()); + + assert!(!result.content.is_empty()); + assert!(result.structured_content.is_some()); + + let contents = result.content; + + assert_eq!(contents.len(), 1); + + let content_text = contents.first().unwrap().as_text(); + + assert!(content_text.is_some()); + + let content_value: Value = serde_json::from_str(&content_text.unwrap().text).unwrap(); + + assert_eq!(content_value, structured_data); + assert_eq!(result.structured_content.unwrap(), structured_data); + assert_eq!(result.is_error, Some(false)); +} + +#[tokio::test] +async fn test_structured_error_in_call_result() { + // Test creating a CallToolResult with structured error + let error_data = json!({ + "error_code": "NOT_FOUND", + "message": "User not found" + }); + + let result = CallToolResult::structured_error(error_data.clone()); + + assert!(!result.content.is_empty()); + assert!(result.structured_content.is_some()); + + let contents = result.content; + + assert_eq!(contents.len(), 1); + + let content_text = contents.first().unwrap().as_text(); + + assert!(content_text.is_some()); + + let content_value: Value = serde_json::from_str(&content_text.unwrap().text).unwrap(); + + assert_eq!(content_value, error_data); + assert_eq!(result.structured_content.unwrap(), error_data); + assert_eq!(result.is_error, Some(true)); +} + +#[tokio::test] +async fn test_mutual_exclusivity_validation() { + #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] + pub struct Response { + message: String, + } + let response = Response { + message: "Hello".into(), + }; + // Test that content and structured_content can both be passed separately + let content_result = CallToolResult::success(vec![Content::json(response.clone()).unwrap()]); + let structured_result = CallToolResult::structured(json!({"message": "Hello"})); + + // Verify the validation + content_result + .into_typed::() + .expect("Failed to extract content"); + structured_result + .into_typed::() + .expect("Failed to extract content"); + + // Try to create a result with both fields + let json_with_both = json!({ + "content": [{"type": "text", "text": "Hello"}], + "structuredContent": {"message": "Hello"} + }); + + // The deserialization itself should not fail + let deserialized: Result = serde_json::from_value(json_with_both); + assert!(deserialized.is_ok()); +} + +#[tokio::test] +async fn test_structured_return_conversion() { + // Test that Json converts to CallToolResult with structured_content + let calc_result = CalculationResult { + sum: 7, + product: 12, + }; + + let structured = Json(calc_result); + let result: Result = + rmcp::handler::server::tool::IntoCallToolResult::into_call_tool_result(structured); + + assert!(result.is_ok()); + let call_result = result.unwrap(); + + // Tools which return structured content should also return a serialized version as + // Content::text for backwards compatibility. + assert!(!call_result.content.is_empty()); + assert!(call_result.structured_content.is_some()); + + let contents = call_result.content; + + assert_eq!(contents.len(), 1); + + let content_text = contents.first().unwrap().as_text(); + + assert!(content_text.is_some()); + + let content_value: Value = serde_json::from_str(&content_text.unwrap().text).unwrap(); + let structured_value = call_result.structured_content.unwrap(); + + assert_eq!(content_value, structured_value); + + assert_eq!(structured_value["sum"], 7); + assert_eq!(structured_value["product"], 12); +} + +#[tokio::test] +async fn test_tool_serialization_with_output_schema() { + let server = TestServer::new(); + let tools = server.tool_router.list_all(); + + let calculate_tool = tools.iter().find(|t| t.name == "calculate").unwrap(); + + // Serialize the tool + let serialized = serde_json::to_value(calculate_tool).unwrap(); + + // Check that outputSchema is included + assert!(serialized["outputSchema"].is_object()); + + // Deserialize back + let deserialized: Tool = serde_json::from_value(serialized).unwrap(); + assert!(deserialized.output_schema.is_some()); +} + +#[tokio::test] +async fn test_output_schema_requires_structured_content() { + // Test that tools with output_schema must use structured_content + let server = TestServer::new(); + let tools = server.tool_router.list_all(); + + // The calculate tool should have output_schema + let calculate_tool = tools.iter().find(|t| t.name == "calculate").unwrap(); + assert!(calculate_tool.output_schema.is_some()); + + // Directly call the tool and verify its result structure + let params = Parameters(CalculationRequest { a: 5, b: 3 }); + let result = server.calculate(params).await.unwrap(); + + // Convert the Json to CallToolResult + let call_result: Result = + IntoCallToolResult::into_call_tool_result(result); + + assert!(call_result.is_ok()); + let call_result = call_result.unwrap(); + + // Verify it has structured_content and content + assert!(call_result.structured_content.is_some()); + assert!(!call_result.content.is_empty()); +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_tool_builder_methods.rs b/code-rs/third_party/rmcp-0.8.3/tests/test_tool_builder_methods.rs new file mode 100644 index 00000000000..f93c0546263 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_tool_builder_methods.rs @@ -0,0 +1,62 @@ +//cargo test --test test_tool_builder_methods --features "client server macros" +use rmcp::model::{JsonObject, Tool}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, JsonSchema)] +pub struct InputData { + pub name: String, + pub age: u32, +} + +#[derive(Serialize, Deserialize, JsonSchema)] +pub struct OutputData { + pub greeting: String, + pub is_adult: bool, +} + +#[test] +fn test_with_output_schema() { + let tool = Tool::new("test", "Test tool", JsonObject::new()).with_output_schema::(); + + assert!(tool.output_schema.is_some()); + + // Verify the schema contains expected fields + let schema_str = serde_json::to_string(tool.output_schema.as_ref().unwrap()).unwrap(); + assert!(schema_str.contains("greeting")); + assert!(schema_str.contains("is_adult")); +} + +#[test] +fn test_with_input_schema() { + let tool = Tool::new("test", "Test tool", JsonObject::new()).with_input_schema::(); + + // Verify the schema contains expected fields + let schema_str = serde_json::to_string(&tool.input_schema).unwrap(); + assert!(schema_str.contains("name")); + assert!(schema_str.contains("age")); +} + +#[test] +fn test_chained_builder_methods() { + let tool = Tool::new("test", "Test tool", JsonObject::new()) + .with_input_schema::() + .with_output_schema::() + .annotate(rmcp::model::ToolAnnotations::new().read_only(true)); + + assert!(tool.output_schema.is_some()); + assert!(tool.annotations.is_some()); + assert_eq!( + tool.annotations.as_ref().unwrap().read_only_hint, + Some(true) + ); + + // Verify both schemas are set correctly + let input_schema_str = serde_json::to_string(&tool.input_schema).unwrap(); + assert!(input_schema_str.contains("name")); + assert!(input_schema_str.contains("age")); + + let output_schema_str = serde_json::to_string(tool.output_schema.as_ref().unwrap()).unwrap(); + assert!(output_schema_str.contains("greeting")); + assert!(output_schema_str.contains("is_adult")); +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_tool_handler.rs b/code-rs/third_party/rmcp-0.8.3/tests/test_tool_handler.rs new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_tool_handler.rs @@ -0,0 +1 @@ + diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_tool_macro_annotations.rs b/code-rs/third_party/rmcp-0.8.3/tests/test_tool_macro_annotations.rs new file mode 100644 index 00000000000..e945a10fef7 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_tool_macro_annotations.rs @@ -0,0 +1,46 @@ +#[cfg(test)] +mod tests { + use rmcp::{ServerHandler, handler::server::router::tool::ToolRouter, tool, tool_handler}; + + #[derive(Debug, Clone, Default)] + pub struct AnnotatedServer { + tool_router: ToolRouter, + } + + impl AnnotatedServer { + // Tool with inline comments for documentation + /// Direct annotation test tool + /// This is used to test tool annotations + #[tool( + name = "direct-annotated-tool", + annotations(title = "Annotated Tool", read_only_hint = true) + )] + pub async fn direct_annotated_tool(&self, input: String) -> String { + format!("Direct: {}", input) + } + } + #[tool_handler] + impl ServerHandler for AnnotatedServer {} + + #[test] + fn test_direct_tool_attributes() { + // Get the tool definition + let tool = AnnotatedServer::direct_annotated_tool_tool_attr(); + + // Verify basic properties + assert_eq!(tool.name, "direct-annotated-tool"); + + // Verify description is extracted from doc comments + assert!(tool.description.is_some()); + assert!( + tool.description + .as_ref() + .unwrap() + .contains("Direct annotation test tool") + ); + + let annotations = tool.annotations.unwrap(); + assert_eq!(annotations.title.as_ref().unwrap(), "Annotated Tool"); + assert_eq!(annotations.read_only_hint, Some(true)); + } +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_tool_macros.rs b/code-rs/third_party/rmcp-0.8.3/tests/test_tool_macros.rs new file mode 100644 index 00000000000..db5242b3366 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_tool_macros.rs @@ -0,0 +1,369 @@ +//! Test tool macros, including documentation for generated fns. + +//cargo test --test test_tool_macros --features "client server" +// Enforce that all generated code has sufficient docs to pass missing_docs lint +#![deny(missing_docs)] +#![allow(dead_code)] +use std::sync::Arc; + +use rmcp::{ + ClientHandler, ServerHandler, ServiceExt, + handler::server::{router::tool::ToolRouter, wrapper::Parameters}, + model::{CallToolRequestParam, ClientInfo}, + tool, tool_handler, tool_router, +}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +/// Parameters for weather tool. +#[derive(Serialize, Deserialize, JsonSchema)] +pub struct GetWeatherRequest { + /// City of interest. + pub city: String, + /// Date of interest. + pub date: String, +} + +#[tool_handler(router = self.tool_router)] +impl ServerHandler for Server {} + +/// Trivial stateless server. +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct Server { + tool_router: ToolRouter, +} + +impl Server { + /// Create weather server. + pub fn new() -> Self { + Self { + tool_router: Self::tool_router(), + } + } +} + +impl Default for Server { + fn default() -> Self { + Self::new() + } +} + +#[tool_router(router = tool_router)] +impl Server { + /// This tool is used to get the weather of a city. + #[tool(name = "get-weather", description = "Get the weather of a city.")] + pub async fn get_weather(&self, city: Parameters) -> String { + drop(city); + "rain".to_string() + } + + #[tool] + async fn empty_param(&self) {} +} + +/// Generic service trait. +pub trait DataService: Send + Sync + 'static { + /// Get data from service. + fn get_data(&self) -> String; +} + +// mock service for test +#[derive(Clone)] +struct MockDataService; +impl DataService for MockDataService { + fn get_data(&self) -> String { + "mock data".to_string() + } +} + +/// Generic server. +#[derive(Debug, Clone)] +pub struct GenericServer { + data_service: Arc, + tool_router: ToolRouter, +} + +#[tool_router] +impl GenericServer { + /// Create data server instance. + pub fn new(data_service: DS) -> Self { + Self { + data_service: Arc::new(data_service), + tool_router: Self::tool_router(), + } + } + + #[tool(description = "Get data from the service")] + async fn get_data(&self) -> String { + self.data_service.get_data() + } +} + +#[tool_handler] +impl ServerHandler for GenericServer {} + +#[tokio::test] +async fn test_tool_macros() { + let server = Server::new(); + let _attr = Server::get_weather_tool_attr(); + let _get_weather_tool_attr_fn = Server::get_weather_tool_attr; + let _get_weather_fn = Server::get_weather; + server + .get_weather(Parameters(GetWeatherRequest { + city: "Harbin".into(), + date: "Yesterday".into(), + })) + .await; +} + +#[tokio::test] +async fn test_tool_macros_with_empty_param() { + let _attr = Server::empty_param_tool_attr(); + println!("{_attr:?}"); + assert_eq!( + _attr.input_schema.get("type"), + Some(&serde_json::Value::String("object".to_string())) + ); + assert_eq!( + _attr.input_schema.get("properties"), + Some(&serde_json::Value::Object(serde_json::Map::new())) + ); +} + +#[tokio::test] +async fn test_tool_macros_with_generics() { + let mock_service = MockDataService; + let server = GenericServer::new(mock_service); + let _attr = GenericServer::::get_data_tool_attr(); + let _get_data_call_fn = GenericServer::::get_data; + let _get_data_fn = GenericServer::::get_data; + assert_eq!(server.get_data().await, "mock data"); +} + +#[tokio::test] +async fn test_tool_macros_with_optional_param() { + let _attr = Server::get_weather_tool_attr(); + // println!("{_attr:?}"); + let attr_type = _attr + .input_schema + .get("properties") + .unwrap() + .get("city") + .unwrap() + .get("type") + .unwrap(); + println!("_attr.input_schema: {:?}", attr_type); + assert_eq!(attr_type.as_str().unwrap(), "string"); +} + +impl GetWeatherRequest {} + +/// Struct defined for testing optional field schema generation. +#[derive(Debug, Deserialize, Serialize, JsonSchema)] +pub struct OptionalFieldTestSchema { + /// Field description. + #[schemars(description = "An optional description field")] + pub description: Option, +} + +/// Struct defined for testing optional i64 field schema generation and null handling. +#[derive(Debug, Deserialize, Serialize, JsonSchema)] +pub struct OptionalI64TestSchema { + /// Optional count field. + #[schemars(description = "An optional i64 field")] + pub count: Option, + + /// Added to ensure non-empty object schema. + pub mandatory_field: String, +} + +/// Dummy struct to host the test tool method. +#[derive(Debug, Clone)] +pub struct OptionalSchemaTester { + tool_router: ToolRouter, +} + +impl Default for OptionalSchemaTester { + fn default() -> Self { + Self::new() + } +} + +impl OptionalSchemaTester { + /// Create instance of optional schema tester service. + pub fn new() -> Self { + Self { + tool_router: Self::tool_router(), + } + } +} + +#[tool_router] +impl OptionalSchemaTester { + // Dummy tool function using the test schema as an aggregated parameter + #[tool(description = "A tool to test optional schema generation")] + async fn test_optional(&self, _req: Parameters) { + // Implementation doesn't matter for schema testing + // Return type changed to () to satisfy IntoCallToolResult + } + + // Tool function to test optional i64 handling + #[tool(description = "A tool to test optional i64 schema generation")] + async fn test_optional_i64( + &self, + Parameters(req): Parameters, + ) -> String { + match req.count { + Some(c) => format!("Received count: {}", c), + None => "Received null count".to_string(), + } + } +} +#[tool_handler] +// Implement ServerHandler to route tool calls for OptionalSchemaTester +impl ServerHandler for OptionalSchemaTester {} + +#[test] +fn test_optional_field_schema_generation_via_macro() { + // tests https://github.com/modelcontextprotocol/rust-sdk/issues/135 + + // Get the attributes generated by the #[tool] macro helper + let tool_attr = OptionalSchemaTester::test_optional_tool_attr(); + + // Print the actual generated schema for debugging + println!( + "Actual input schema generated by macro: {:#?}", + tool_attr.input_schema + ); + + // Verify the schema generated for the aggregated OptionalFieldTestSchema + // by the macro infrastructure (which should now use OpenAPI 3 settings) + let input_schema_map = &*tool_attr.input_schema; // Dereference Arc + + // Check the schema for the 'description' property within the input schema + let properties = input_schema_map + .get("properties") + .expect("Schema should have properties") + .as_object() + .unwrap(); + let description_schema = properties + .get("description") + .expect("Properties should include description") + .as_object() + .unwrap(); + + // Assert that the format is now `type: "string", nullable: true` + assert_eq!( + description_schema.get("type").map(|v| v.as_str().unwrap()), + Some("string"), + "Schema for Option generated by macro should be type: \"string\"" + ); + assert_eq!( + description_schema + .get("nullable") + .map(|v| v.as_bool().unwrap()), + Some(true), + "Schema for Option generated by macro should have nullable: true" + ); + // We still check the description is correct + assert_eq!( + description_schema + .get("description") + .map(|v| v.as_str().unwrap()), + Some("An optional description field") + ); + + // Ensure the old 'type: [T, null]' format is NOT used + let type_value = description_schema.get("type").unwrap(); + assert!( + !type_value.is_array(), + "Schema type should not be an array [T, null]" + ); +} + +// Define a dummy client handler +#[derive(Debug, Clone, Default)] +struct DummyClientHandler {} + +impl ClientHandler for DummyClientHandler { + fn get_info(&self) -> ClientInfo { + ClientInfo::default() + } +} + +#[tokio::test] +async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + + // Server setup + let server = OptionalSchemaTester::new(); + let server_handle = tokio::spawn(async move { + server.serve(server_transport).await?.waiting().await?; + anyhow::Ok(()) + }); + + // Create a simple client handler that just forwards tool calls + let client_handler = DummyClientHandler::default(); + let client = client_handler.serve(client_transport).await?; + + // Test null case + let result = client + .call_tool(CallToolRequestParam { + name: "test_optional_i64".into(), + arguments: Some( + serde_json::json!({ + "count": null, + "mandatory_field": "test_null" + }) + .as_object() + .unwrap() + .clone(), + ), + }) + .await?; + + let result_text = result + .content + .first() + .and_then(|content| content.raw.as_text()) + .map(|text| text.text.as_str()) + .expect("Expected text content"); + + assert_eq!( + result_text, "Received null count", + "Null case should return expected message" + ); + + // Test Some case + let some_result = client + .call_tool(CallToolRequestParam { + name: "test_optional_i64".into(), + arguments: Some( + serde_json::json!({ + "count": 42, + "mandatory_field": "test_some" + }) + .as_object() + .unwrap() + .clone(), + ), + }) + .await?; + + let some_result_text = some_result + .content + .first() + .and_then(|content| content.raw.as_text()) + .map(|text| text.text.as_str()) + .expect("Expected text content"); + + assert_eq!( + some_result_text, "Received count: 42", + "Some case should return expected message" + ); + + client.cancel().await?; + server_handle.await??; + Ok(()) +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_tool_result_meta.rs b/code-rs/third_party/rmcp-0.8.3/tests/test_tool_result_meta.rs new file mode 100644 index 00000000000..78e1809efc9 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_tool_result_meta.rs @@ -0,0 +1,45 @@ +use rmcp::model::{CallToolResult, Content, Meta}; +use serde_json::{Value, json}; + +#[test] +fn serialize_tool_result_with_meta() { + let content = vec![Content::text("ok")]; + let mut meta = Meta::new(); + meta.insert("foo".to_string(), json!("bar")); + let result = CallToolResult { + content, + structured_content: None, + is_error: Some(false), + meta: Some(meta), + }; + let v = serde_json::to_value(&result).unwrap(); + let expected = json!({ + "content": [{"type":"text","text":"ok"}], + "isError": false, + "_meta": {"foo":"bar"} + }); + assert_eq!(v, expected); +} + +#[test] +fn deserialize_tool_result_with_meta() { + let raw: Value = json!({ + "content": [{"type":"text","text":"hello"}], + "isError": true, + "_meta": {"a": 1, "b": "two"} + }); + let result: CallToolResult = serde_json::from_value(raw).unwrap(); + assert_eq!(result.is_error, Some(true)); + assert_eq!(result.content.len(), 1); + let meta = result.meta.expect("meta should exist"); + assert_eq!(meta.get("a").unwrap(), &json!(1)); + assert_eq!(meta.get("b").unwrap(), &json!("two")); +} + +#[test] +fn serialize_tool_result_without_meta_omits_field() { + let result = CallToolResult::success(vec![Content::text("no meta")]); + let v = serde_json::to_value(&result).unwrap(); + // Ensure _meta is omitted + assert!(v.get("_meta").is_none()); +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_tool_routers.rs b/code-rs/third_party/rmcp-0.8.3/tests/test_tool_routers.rs new file mode 100644 index 00000000000..442c70ea1fb --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_tool_routers.rs @@ -0,0 +1,68 @@ +use std::collections::HashMap; + +use futures::future::BoxFuture; +use rmcp::{ + ServerHandler, + handler::server::{router::tool::ToolRouter, tool::CallToolHandler, wrapper::Parameters}, +}; + +#[derive(Debug, Default)] +pub struct TestHandler { + pub _marker: std::marker::PhantomData, +} + +impl ServerHandler for TestHandler {} +#[derive(Debug, schemars::JsonSchema, serde::Deserialize, serde::Serialize)] +pub struct Request { + pub fields: HashMap, +} + +#[derive(Debug, schemars::JsonSchema, serde::Deserialize, serde::Serialize)] +pub struct Sum { + pub a: i32, + pub b: i32, +} + +#[rmcp::tool_router(router = test_router_1)] +impl TestHandler { + #[rmcp::tool] + async fn async_method(&self, Parameters(Request { fields }): Parameters) { + drop(fields) + } +} + +#[rmcp::tool_router(router = test_router_2)] +impl TestHandler { + #[rmcp::tool] + fn sync_method(&self, Parameters(Request { fields }): Parameters) { + drop(fields) + } +} + +#[rmcp::tool] +async fn async_function(Parameters(Request { fields }): Parameters) { + drop(fields) +} + +#[rmcp::tool] +fn async_function2(_callee: &TestHandler) -> BoxFuture<'_, ()> { + Box::pin(async move {}) +} + +#[test] +fn test_tool_router() { + let test_tool_router: ToolRouter> = ToolRouter::>::new() + .with_route((async_function_tool_attr(), async_function)) + .with_route((async_function2_tool_attr(), async_function2)) + + TestHandler::<()>::test_router_1() + + TestHandler::<()>::test_router_2(); + let tools = test_tool_router.list_all(); + assert_eq!(tools.len(), 4); + assert_handler(TestHandler::<()>::async_method); +} + +fn assert_handler(_handler: H) +where + H: CallToolHandler, +{ +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_with_js.rs b/code-rs/third_party/rmcp-0.8.3/tests/test_with_js.rs new file mode 100644 index 00000000000..3f2761cd312 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_with_js.rs @@ -0,0 +1,164 @@ +use rmcp::{ + ServiceExt, + service::QuitReason, + transport::{ + ConfigureCommandExt, SseServer, StreamableHttpClientTransport, StreamableHttpServerConfig, + TokioChildProcess, + streamable_http_server::{ + session::local::LocalSessionManager, tower::StreamableHttpService, + }, + }, +}; +use tokio_util::sync::CancellationToken; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; +mod common; +use common::calculator::Calculator; + +const SSE_BIND_ADDRESS: &str = "127.0.0.1:8000"; +const STREAMABLE_HTTP_BIND_ADDRESS: &str = "127.0.0.1:8001"; +const STREAMABLE_HTTP_JS_BIND_ADDRESS: &str = "127.0.0.1:8002"; + +#[tokio::test] +async fn test_with_js_client() -> anyhow::Result<()> { + let _ = tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .try_init(); + tokio::process::Command::new("npm") + .arg("install") + .current_dir("tests/test_with_js") + .spawn()? + .wait() + .await?; + + let ct = SseServer::serve(SSE_BIND_ADDRESS.parse()?) + .await? + .with_service(Calculator::default); + + let exit_status = tokio::process::Command::new("node") + .arg("tests/test_with_js/client.js") + .spawn()? + .wait() + .await?; + assert!(exit_status.success()); + ct.cancel(); + Ok(()) +} + +#[tokio::test] +async fn test_with_js_server() -> anyhow::Result<()> { + let _ = tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .try_init(); + tokio::process::Command::new("npm") + .arg("install") + .current_dir("tests/test_with_js") + .spawn()? + .wait() + .await?; + let transport = + TokioChildProcess::new(tokio::process::Command::new("node").configure(|cmd| { + cmd.arg("tests/test_with_js/server.js"); + }))?; + + let client = ().serve(transport).await?; + let resources = client.list_all_resources().await?; + tracing::info!("{:#?}", resources); + let tools = client.list_all_tools().await?; + tracing::info!("{:#?}", tools); + + client.cancel().await?; + Ok(()) +} + +#[tokio::test] +async fn test_with_js_streamable_http_client() -> anyhow::Result<()> { + let _ = tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .try_init(); + tokio::process::Command::new("npm") + .arg("install") + .current_dir("tests/test_with_js") + .spawn()? + .wait() + .await?; + + let service: StreamableHttpService = + StreamableHttpService::new( + || Ok(Calculator::new()), + Default::default(), + StreamableHttpServerConfig { + stateful_mode: true, + sse_keep_alive: None, + }, + ); + let router = axum::Router::new().nest_service("/mcp", service); + let tcp_listener = tokio::net::TcpListener::bind(STREAMABLE_HTTP_BIND_ADDRESS).await?; + let ct = CancellationToken::new(); + let handle = tokio::spawn({ + let ct = ct.clone(); + async move { + let _ = axum::serve(tcp_listener, router) + .with_graceful_shutdown(async move { ct.cancelled_owned().await }) + .await; + } + }); + let exit_status = tokio::process::Command::new("node") + .arg("tests/test_with_js/streamable_client.js") + .spawn()? + .wait() + .await?; + assert!(exit_status.success()); + ct.cancel(); + handle.await?; + Ok(()) +} + +#[tokio::test] +async fn test_with_js_streamable_http_server() -> anyhow::Result<()> { + let _ = tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .try_init(); + tokio::process::Command::new("npm") + .arg("install") + .current_dir("tests/test_with_js") + .spawn()? + .wait() + .await?; + + let transport = StreamableHttpClientTransport::from_uri(format!( + "http://{STREAMABLE_HTTP_JS_BIND_ADDRESS}/mcp" + )); + + let mut server = tokio::process::Command::new("node") + .arg("tests/test_with_js/streamable_server.js") + .spawn()?; + + // waiting for server up + tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; + + let client = ().serve(transport).await?; + let resources = client.list_all_resources().await?; + tracing::info!("{:#?}", resources); + let tools = client.list_all_tools().await?; + tracing::info!("{:#?}", tools); + let quit_reason = client.cancel().await?; + server.kill().await?; + assert!(matches!(quit_reason, QuitReason::Cancelled)); + Ok(()) +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_with_js/.gitignore b/code-rs/third_party/rmcp-0.8.3/tests/test_with_js/.gitignore new file mode 100644 index 00000000000..572406bfdde --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_with_js/.gitignore @@ -0,0 +1,2 @@ +/node_modules +package-lock.json \ No newline at end of file diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_with_js/client.js b/code-rs/third_party/rmcp-0.8.3/tests/test_with_js/client.js new file mode 100644 index 00000000000..17b189a743a --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_with_js/client.js @@ -0,0 +1,29 @@ +import { Client } from "@modelcontextprotocol/sdk/client/index.js"; +import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js"; + +const transport = new SSEClientTransport(new URL(`http://127.0.0.1:8000/sse`)); + +const client = new Client( + { + name: "example-client", + version: "1.0.0" + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {} + } + } +); +await client.connect(transport); +const tools = await client.listTools(); +console.log(tools); +const resources = await client.listResources(); +console.log(resources); +const templates = await client.listResourceTemplates(); +console.log(templates); +const prompts = await client.listPrompts(); +console.log(prompts); +await client.close(); +await transport.close(); \ No newline at end of file diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_with_js/package.json b/code-rs/third_party/rmcp-0.8.3/tests/test_with_js/package.json new file mode 100644 index 00000000000..d8ecfbfb9b6 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_with_js/package.json @@ -0,0 +1,17 @@ +{ + "dependencies": { + "@modelcontextprotocol/sdk": "^1.10", + "eventsource-parser": "^3.0.1", + "express": "^5.1.0" + }, + "type": "module", + "name": "test_with_ts", + "version": "1.0.0", + "main": "index.js", + "scripts": { + "test": "echo \"Error: no test specified\" && exit 1" + }, + "author": "", + "license": "ISC", + "description": "" +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_with_js/server.js b/code-rs/third_party/rmcp-0.8.3/tests/test_with_js/server.js new file mode 100644 index 00000000000..d959224db78 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_with_js/server.js @@ -0,0 +1,35 @@ +import { McpServer, ResourceTemplate } from "@modelcontextprotocol/sdk/server/mcp.js"; +import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; +import { z } from "zod"; + +const server = new McpServer({ + name: "Demo", + version: "1.0.0" +}); + +server.resource( + "greeting", + new ResourceTemplate("greeting://{name}", { list: undefined }), + async (uri, { name }) => ({ + contents: [{ + uri: uri.href, + text: `Hello, ${name}` + }] + }) +); + +server.tool( + "add", + { a: z.number(), b: z.number() }, + async ({ a, b }) => ({ + "content": [ + { + "type": "text", + "text": `${a + b}` + } + ] + }) +); + +const transport = new StdioServerTransport(); +await server.connect(transport); \ No newline at end of file diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_with_js/streamable_client.js b/code-rs/third_party/rmcp-0.8.3/tests/test_with_js/streamable_client.js new file mode 100644 index 00000000000..99826131241 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_with_js/streamable_client.js @@ -0,0 +1,28 @@ +import { Client } from "@modelcontextprotocol/sdk/client/index.js"; +import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; + +const transport = new StreamableHTTPClientTransport(new URL(`http://127.0.0.1:8001/mcp/`)); + +const client = new Client( + { + name: "example-client", + version: "1.0.0" + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {} + } + } +); +await client.connect(transport); +const tools = await client.listTools(); +console.log(tools); +const resources = await client.listResources(); +console.log(resources); +const templates = await client.listResourceTemplates(); +console.log(templates); +const prompts = await client.listPrompts(); +console.log(prompts); +await client.close(); diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_with_js/streamable_server.js b/code-rs/third_party/rmcp-0.8.3/tests/test_with_js/streamable_server.js new file mode 100644 index 00000000000..3f87ccf0134 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_with_js/streamable_server.js @@ -0,0 +1,105 @@ +import { McpServer, ResourceTemplate } from "@modelcontextprotocol/sdk/server/mcp.js"; +import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; +import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js" +import { randomUUID } from "node:crypto" +import { z } from "zod"; +import express from "express" + +const app = express(); +app.use(express.json()); + +// Map to store transports by session ID +const transports = {}; + +// Handle POST requests for client-to-server communication +app.post('/mcp', async (req, res) => { + // Check for existing session ID + const sessionId = req.headers['mcp-session-id']; + let transport; + + if (sessionId && transports[sessionId]) { + // Reuse existing transport + transport = transports[sessionId]; + } else if (!sessionId && isInitializeRequest(req.body)) { + // New initialization request + transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID().toString(), + onsessioninitialized: (sessionId) => { + // Store the transport by session ID + transports[sessionId] = transport; + } + }); + + // Clean up transport when closed + transport.onclose = () => { + if (transport.sessionId) { + delete transports[transport.sessionId]; + } + }; + const server = new McpServer({ + name: "example-server", + version: "1.0.0" + }); + + server.resource( + "greeting", + new ResourceTemplate("greeting://{name}", { list: undefined }), + async (uri, { name }) => ({ + contents: [{ + uri: uri.href, + text: `Hello, ${name}` + }] + }) + ); + + server.tool( + "add", + { a: z.number(), b: z.number() }, + async ({ a, b }) => ({ + "content": [ + { + "type": "text", + "text": `${a + b}` + } + ] + }) + ); + + // Connect to the MCP server + await server.connect(transport); + } else { + // Invalid request + res.status(400).json({ + jsonrpc: '2.0', + error: { + code: -32000, + message: 'Bad Request: No valid session ID provided', + }, + id: null, + }); + return; + } + + // Handle the request + await transport.handleRequest(req, res, req.body); +}); + +// Reusable handler for GET and DELETE requests +const handleSessionRequest = async (req, res) => { + const sessionId = req.headers['mcp-session-id']; + if (!sessionId || !transports[sessionId]) { + res.status(400).send('Invalid or missing session ID'); + return; + } + + const transport = transports[sessionId]; + await transport.handleRequest(req, res); +}; + +// Handle GET requests for server-to-client notifications via SSE +app.get('/mcp', handleSessionRequest); + +// Handle DELETE requests for session termination +app.delete('/mcp', handleSessionRequest); +console.log("Listening on port 8002"); +app.listen(8002); \ No newline at end of file diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_with_python.rs b/code-rs/third_party/rmcp-0.8.3/tests/test_with_python.rs new file mode 100644 index 00000000000..8971a32734e --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_with_python.rs @@ -0,0 +1,155 @@ +use std::process::Stdio; + +use axum::Router; +use rmcp::{ + ServiceExt, + transport::{ConfigureCommandExt, SseServer, TokioChildProcess, sse_server::SseServerConfig}, +}; +use tokio::{io::AsyncReadExt, time::timeout}; +use tokio_util::sync::CancellationToken; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; +mod common; +use common::calculator::Calculator; + +async fn init() -> anyhow::Result<()> { + let _ = tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .try_init(); + tokio::process::Command::new("uv") + .args(["sync"]) + .current_dir("tests/test_with_python") + .spawn()? + .wait() + .await?; + Ok(()) +} + +#[tokio::test] +async fn test_with_python_client() -> anyhow::Result<()> { + init().await?; + + const BIND_ADDRESS: &str = "127.0.0.1:8000"; + + let ct = SseServer::serve(BIND_ADDRESS.parse()?) + .await? + .with_service(Calculator::default); + + let status = tokio::process::Command::new("uv") + .arg("run") + .arg("client.py") + .arg(format!("http://{BIND_ADDRESS}/sse")) + .current_dir("tests/test_with_python") + .spawn()? + .wait() + .await?; + assert!(status.success()); + ct.cancel(); + Ok(()) +} + +/// Test the SSE server in a nested Axum router. +#[tokio::test] +async fn test_nested_with_python_client() -> anyhow::Result<()> { + init().await?; + + const BIND_ADDRESS: &str = "127.0.0.1:8001"; + + // Create an SSE router + let sse_config = SseServerConfig { + bind: BIND_ADDRESS.parse()?, + sse_path: "/sse".to_string(), + post_path: "/message".to_string(), + ct: CancellationToken::new(), + sse_keep_alive: None, + }; + + let listener = tokio::net::TcpListener::bind(&sse_config.bind).await?; + + let (sse_server, sse_router) = SseServer::new(sse_config); + let ct = sse_server.with_service(Calculator::default); + + let main_router = Router::new().nest("/nested", sse_router); + + let server_ct = ct.clone(); + let server = axum::serve(listener, main_router).with_graceful_shutdown(async move { + server_ct.cancelled().await; + tracing::info!("sse server cancelled"); + }); + + tokio::spawn(async move { + let _ = server.await; + tracing::info!("sse server shutting down"); + }); + + // Spawn the process with timeout, as failure to access the '/message' URL + // causes the client to never exit. + let status = timeout( + tokio::time::Duration::from_secs(5), + tokio::process::Command::new("uv") + .arg("run") + .arg("client.py") + .arg(format!("http://{BIND_ADDRESS}/nested/sse")) + .current_dir("tests/test_with_python") + .spawn()? + .wait(), + ) + .await?; + assert!(status?.success()); + ct.cancel(); + Ok(()) +} + +#[tokio::test] +async fn test_with_python_server() -> anyhow::Result<()> { + init().await?; + + let transport = TokioChildProcess::new(tokio::process::Command::new("uv").configure(|cmd| { + cmd.arg("run") + .arg("server.py") + .current_dir("tests/test_with_python"); + }))?; + + let client = ().serve(transport).await?; + let resources = client.list_all_resources().await?; + tracing::info!("{:#?}", resources); + let tools = client.list_all_tools().await?; + tracing::info!("{:#?}", tools); + client.cancel().await?; + Ok(()) +} + +#[tokio::test] +async fn test_with_python_server_stderr() -> anyhow::Result<()> { + init().await?; + + let (transport, stderr) = + TokioChildProcess::builder(tokio::process::Command::new("uv").configure(|cmd| { + cmd.arg("run") + .arg("server.py") + .current_dir("tests/test_with_python"); + })) + .stderr(Stdio::piped()) + .spawn()?; + + let mut stderr = stderr.expect("stderr must be piped"); + + let stderr_task = tokio::spawn(async move { + let mut buffer = String::new(); + stderr.read_to_string(&mut buffer).await?; + Ok::<_, std::io::Error>(buffer) + }); + + let client = ().serve(transport).await?; + let _ = client.list_all_resources().await?; + let _ = client.list_all_tools().await?; + client.cancel().await?; + + let stderr_output = stderr_task.await??; + assert!(stderr_output.contains("server starting up...")); + + Ok(()) +} diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_with_python/.gitignore b/code-rs/third_party/rmcp-0.8.3/tests/test_with_python/.gitignore new file mode 100644 index 00000000000..2eb00089091 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_with_python/.gitignore @@ -0,0 +1,22 @@ +# Lock files +*.lock + +# Python build artifacts +*.egg-info/ +build/ +dist/ +__pycache__/ +*.py[cod] +*$py.class + +# Virtual environments +venv/ +env/ +.venv/ +.env/ + +# IDE +.vscode/ +.idea/ +*.swp +*.swo diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_with_python/client.py b/code-rs/third_party/rmcp-0.8.3/tests/test_with_python/client.py new file mode 100644 index 00000000000..83d2c35b040 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_with_python/client.py @@ -0,0 +1,28 @@ +from mcp import ClientSession, StdioServerParameters, types +from mcp.client.sse import sse_client +import sys + +async def run(): + url = sys.argv[1] + async with sse_client(url) as (read, write): + async with ClientSession( + read, write + ) as session: + # Initialize the connection + await session.initialize() + + # List available prompts + prompts = await session.list_prompts() + print(prompts) + # List available resources + resources = await session.list_resources() + print(resources) + + # List available tools + tools = await session.list_tools() + print(tools) + +if __name__ == "__main__": + import asyncio + + asyncio.run(run()) diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_with_python/pyproject.toml b/code-rs/third_party/rmcp-0.8.3/tests/test_with_python/pyproject.toml new file mode 100644 index 00000000000..9c8ce6ae9f1 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_with_python/pyproject.toml @@ -0,0 +1,14 @@ +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "test_with_python" +version = "0.1.0" +description = "Test Python client for RMCP" +dependencies = [ + "fastmcp", +] + +[tool.setuptools] +py-modules = ["client", "server"] diff --git a/code-rs/third_party/rmcp-0.8.3/tests/test_with_python/server.py b/code-rs/third_party/rmcp-0.8.3/tests/test_with_python/server.py new file mode 100644 index 00000000000..1c8ee6986d9 --- /dev/null +++ b/code-rs/third_party/rmcp-0.8.3/tests/test_with_python/server.py @@ -0,0 +1,25 @@ +from fastmcp import FastMCP + +import sys + +mcp = FastMCP("Demo") + +print("server starting up...", file=sys.stderr) + + +@mcp.tool() +def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b + + +# Add a dynamic greeting resource +@mcp.resource("greeting://{name}") +def get_greeting(name: str) -> str: + """Get a personalized greeting""" + return f"Hello, {name}!" + + + +if __name__ == "__main__": + mcp.run() \ No newline at end of file