Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 123 additions & 4 deletions crates/rmcp/src/transport/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{collections::HashMap, sync::Arc, time::Duration};

use async_trait::async_trait;
use oauth2::{
AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields,
AuthType, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields,
PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError, Scope,
StandardTokenResponse, TokenResponse, TokenUrl,
basic::{BasicClient, BasicTokenType},
Expand Down Expand Up @@ -482,6 +482,21 @@ impl AuthorizationManager {
client_builder = client_builder.set_client_secret(ClientSecret::new(secret));
}

let uses_secret_post = metadata
.additional_fields
.get("token_endpoint_auth_methods_supported")
.and_then(|v| v.as_array())
.map(|arr| {
let has_basic = arr.iter().any(|m| m.as_str() == Some("client_secret_basic"));
let has_post = arr.iter().any(|m| m.as_str() == Some("client_secret_post"));
has_post && !has_basic
})
.unwrap_or(false);

if uses_secret_post {
client_builder = client_builder.set_auth_type(AuthType::RequestBody);
}

self.oauth_client = Some(client_builder);
Ok(())
}
Expand Down Expand Up @@ -1451,14 +1466,15 @@ impl OAuthState {

#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::sync::Arc;

use oauth2::{CsrfToken, PkceCodeVerifier};
use oauth2::{AuthType, CsrfToken, PkceCodeVerifier};
use url::Url;

use super::{
AuthError, AuthorizationManager, InMemoryStateStore, StateStore, StoredAuthorizationState,
is_https_url,
AuthError, AuthorizationManager, AuthorizationMetadata, InMemoryStateStore,
OAuthClientConfig, StateStore, StoredAuthorizationState, is_https_url,
};

// SEP-991: URL-based Client IDs
Expand Down Expand Up @@ -1876,4 +1892,107 @@ mod tests {
let mut manager = AuthorizationManager::new("http://localhost").await.unwrap();
manager.set_state_store(TrackingStateStore::default());
}

/// Helper: create an AuthorizationManager with minimal metadata so
/// `configure_client` can be exercised without a live server.
async fn manager_with_metadata(
metadata_override: Option<AuthorizationMetadata>,
) -> AuthorizationManager {
let mut mgr = AuthorizationManager::new("http://localhost").await.unwrap();
mgr.set_metadata(metadata_override.unwrap_or(AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
..Default::default()
}));
mgr
}

fn test_client_config() -> OAuthClientConfig {
OAuthClientConfig {
client_id: "my-client".to_string(),
client_secret: Some("my-secret".to_string()),
scopes: vec![],
redirect_uri: "http://localhost/callback".to_string(),
}
}

#[tokio::test]
async fn test_configure_client_uses_client_secret_post_from_metadata() {
let mut additional_fields = HashMap::new();
additional_fields.insert(
"token_endpoint_auth_methods_supported".to_string(),
serde_json::json!(["client_secret_post"]),
);
let meta = AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
additional_fields,
..Default::default()
};
let mut mgr = manager_with_metadata(Some(meta)).await;
mgr.configure_client(test_client_config()).unwrap();
assert!(matches!(mgr.oauth_client.as_ref().unwrap().auth_type(), AuthType::RequestBody));
}

#[tokio::test]
async fn test_configure_client_defaults_to_basic_auth() {
let mut mgr = manager_with_metadata(None).await;
mgr.configure_client(test_client_config()).unwrap();
assert!(matches!(mgr.oauth_client.as_ref().unwrap().auth_type(), AuthType::BasicAuth));
}

#[tokio::test]
async fn test_configure_client_with_explicit_basic_in_metadata() {
let mut additional_fields = HashMap::new();
additional_fields.insert(
"token_endpoint_auth_methods_supported".to_string(),
serde_json::json!(["client_secret_basic"]),
);
let meta = AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
additional_fields,
..Default::default()
};
let mut mgr = manager_with_metadata(Some(meta)).await;
mgr.configure_client(test_client_config()).unwrap();
assert!(matches!(mgr.oauth_client.as_ref().unwrap().auth_type(), AuthType::BasicAuth));
}

#[tokio::test]
async fn test_configure_client_ignores_unsupported_auth_methods_in_metadata() {
let mut additional_fields = HashMap::new();
additional_fields.insert(
"token_endpoint_auth_methods_supported".to_string(),
serde_json::json!(["private_key_jwt"]),
);
let meta = AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
additional_fields,
..Default::default()
};
let mut mgr = manager_with_metadata(Some(meta)).await;
// Unsupported method should fall through to default (basic auth)
mgr.configure_client(test_client_config()).unwrap();
assert!(matches!(mgr.oauth_client.as_ref().unwrap().auth_type(), AuthType::BasicAuth));
}

#[tokio::test]
async fn test_configure_client_prefers_basic_when_both_methods_supported() {
let mut additional_fields = HashMap::new();
additional_fields.insert(
"token_endpoint_auth_methods_supported".to_string(),
serde_json::json!(["client_secret_post", "client_secret_basic"]),
);
let meta = AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
additional_fields,
..Default::default()
};
let mut mgr = manager_with_metadata(Some(meta)).await;
mgr.configure_client(test_client_config()).unwrap();
assert!(matches!(mgr.oauth_client.as_ref().unwrap().auth_type(), AuthType::BasicAuth));
}
}