diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs index de2cf5e9..8d1833f5 100644 --- a/crates/rmcp/src/transport/auth.rs +++ b/crates/rmcp/src/transport/auth.rs @@ -2,7 +2,7 @@ use std::{collections::HashMap, sync::Arc, time::Duration}; use async_trait::async_trait; use oauth2::{ - AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields, + AuthType, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError, Scope, StandardTokenResponse, TokenResponse, TokenUrl, basic::{BasicClient, BasicTokenType}, @@ -482,6 +482,23 @@ impl AuthorizationManager { client_builder = client_builder.set_client_secret(ClientSecret::new(secret)); } + let uses_secret_post = metadata + .additional_fields + .get("token_endpoint_auth_methods_supported") + .and_then(|v| v.as_array()) + .map(|arr| { + let has_basic = arr + .iter() + .any(|m| m.as_str() == Some("client_secret_basic")); + let has_post = arr.iter().any(|m| m.as_str() == Some("client_secret_post")); + has_post && !has_basic + }) + .unwrap_or(false); + + if uses_secret_post { + client_builder = client_builder.set_auth_type(AuthType::RequestBody); + } + self.oauth_client = Some(client_builder); Ok(()) } @@ -1451,14 +1468,15 @@ impl OAuthState { #[cfg(test)] mod tests { + use std::collections::HashMap; use std::sync::Arc; - use oauth2::{CsrfToken, PkceCodeVerifier}; + use oauth2::{AuthType, CsrfToken, PkceCodeVerifier}; use url::Url; use super::{ - AuthError, AuthorizationManager, InMemoryStateStore, StateStore, StoredAuthorizationState, - is_https_url, + AuthError, AuthorizationManager, AuthorizationMetadata, InMemoryStateStore, + OAuthClientConfig, StateStore, StoredAuthorizationState, is_https_url, }; // SEP-991: URL-based Client IDs @@ -1876,4 +1894,122 @@ mod tests { let mut manager = AuthorizationManager::new("http://localhost").await.unwrap(); manager.set_state_store(TrackingStateStore::default()); } + + /// Helper: create an AuthorizationManager with minimal metadata so + /// `configure_client` can be exercised without a live server. + async fn manager_with_metadata( + metadata_override: Option, + ) -> AuthorizationManager { + let mut mgr = AuthorizationManager::new("http://localhost").await.unwrap(); + mgr.set_metadata(metadata_override.unwrap_or(AuthorizationMetadata { + authorization_endpoint: "http://localhost/authorize".to_string(), + token_endpoint: "http://localhost/token".to_string(), + ..Default::default() + })); + mgr + } + + fn test_client_config() -> OAuthClientConfig { + OAuthClientConfig { + client_id: "my-client".to_string(), + client_secret: Some("my-secret".to_string()), + scopes: vec![], + redirect_uri: "http://localhost/callback".to_string(), + } + } + + #[tokio::test] + async fn test_configure_client_uses_client_secret_post_from_metadata() { + let mut additional_fields = HashMap::new(); + additional_fields.insert( + "token_endpoint_auth_methods_supported".to_string(), + serde_json::json!(["client_secret_post"]), + ); + let meta = AuthorizationMetadata { + authorization_endpoint: "http://localhost/authorize".to_string(), + token_endpoint: "http://localhost/token".to_string(), + additional_fields, + ..Default::default() + }; + let mut mgr = manager_with_metadata(Some(meta)).await; + mgr.configure_client(test_client_config()).unwrap(); + assert!(matches!( + mgr.oauth_client.as_ref().unwrap().auth_type(), + AuthType::RequestBody + )); + } + + #[tokio::test] + async fn test_configure_client_defaults_to_basic_auth() { + let mut mgr = manager_with_metadata(None).await; + mgr.configure_client(test_client_config()).unwrap(); + assert!(matches!( + mgr.oauth_client.as_ref().unwrap().auth_type(), + AuthType::BasicAuth + )); + } + + #[tokio::test] + async fn test_configure_client_with_explicit_basic_in_metadata() { + let mut additional_fields = HashMap::new(); + additional_fields.insert( + "token_endpoint_auth_methods_supported".to_string(), + serde_json::json!(["client_secret_basic"]), + ); + let meta = AuthorizationMetadata { + authorization_endpoint: "http://localhost/authorize".to_string(), + token_endpoint: "http://localhost/token".to_string(), + additional_fields, + ..Default::default() + }; + let mut mgr = manager_with_metadata(Some(meta)).await; + mgr.configure_client(test_client_config()).unwrap(); + assert!(matches!( + mgr.oauth_client.as_ref().unwrap().auth_type(), + AuthType::BasicAuth + )); + } + + #[tokio::test] + async fn test_configure_client_ignores_unsupported_auth_methods_in_metadata() { + let mut additional_fields = HashMap::new(); + additional_fields.insert( + "token_endpoint_auth_methods_supported".to_string(), + serde_json::json!(["private_key_jwt"]), + ); + let meta = AuthorizationMetadata { + authorization_endpoint: "http://localhost/authorize".to_string(), + token_endpoint: "http://localhost/token".to_string(), + additional_fields, + ..Default::default() + }; + let mut mgr = manager_with_metadata(Some(meta)).await; + // Unsupported method should fall through to default (basic auth) + mgr.configure_client(test_client_config()).unwrap(); + assert!(matches!( + mgr.oauth_client.as_ref().unwrap().auth_type(), + AuthType::BasicAuth + )); + } + + #[tokio::test] + async fn test_configure_client_prefers_basic_when_both_methods_supported() { + let mut additional_fields = HashMap::new(); + additional_fields.insert( + "token_endpoint_auth_methods_supported".to_string(), + serde_json::json!(["client_secret_post", "client_secret_basic"]), + ); + let meta = AuthorizationMetadata { + authorization_endpoint: "http://localhost/authorize".to_string(), + token_endpoint: "http://localhost/token".to_string(), + additional_fields, + ..Default::default() + }; + let mut mgr = manager_with_metadata(Some(meta)).await; + mgr.configure_client(test_client_config()).unwrap(); + assert!(matches!( + mgr.oauth_client.as_ref().unwrap().auth_type(), + AuthType::BasicAuth + )); + } }