From 4f03bf9f8f9d76dc49a98c000bac21faef9e1f5e Mon Sep 17 00:00:00 2001 From: jif-oai Date: Tue, 16 Jun 2026 18:29:06 +0200 Subject: [PATCH 1/5] feat: allow custom HTTP clients for OAuth --- crates/rmcp/src/transport/auth.rs | 606 ++++++++++++++++++++++++------ 1 file changed, 481 insertions(+), 125 deletions(-) diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs index 37abfc37..48bb4554 100644 --- a/crates/rmcp/src/transport/auth.rs +++ b/crates/rmcp/src/transport/auth.rs @@ -1,5 +1,7 @@ use std::{ collections::HashMap, + future::Future, + pin::Pin, sync::Arc, time::{Duration, SystemTime, UNIX_EPOCH}, }; @@ -7,13 +9,13 @@ use std::{ use async_trait::async_trait; use oauth2::{ AsyncHttpClient, AuthType, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, - EmptyExtraTokenFields, ExtraTokenFields, HttpClientError, HttpRequest, HttpResponse, - PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError, Scope, - StandardTokenResponse, TokenResponse, TokenUrl, basic::BasicTokenType, + EmptyExtraTokenFields, ExtraTokenFields, HttpRequest, HttpResponse, PkceCodeChallenge, + PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError, Scope, StandardTokenResponse, + TokenResponse, TokenUrl, basic::BasicTokenType, }; use reqwest::{ - Client as HttpClient, IntoUrl, StatusCode, Url, - header::{AUTHORIZATION, WWW_AUTHENTICATE}, + Client as ReqwestClient, IntoUrl, StatusCode, Url, + header::{AUTHORIZATION, CONTENT_TYPE, WWW_AUTHENTICATE}, }; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -23,39 +25,151 @@ use tracing::{debug, warn}; use crate::transport::common::http_header::HEADER_MCP_PROTOCOL_VERSION; -/// Owned wrapper around [`reqwest::Client`] that implements [`AsyncHttpClient`] for oauth2. -struct OAuthReqwestClient(HttpClient); +const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(30); -impl<'c> AsyncHttpClient<'c> for OAuthReqwestClient { - type Error = HttpClientError; +/// Redirect handling requested for an outbound OAuth HTTP operation. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +#[non_exhaustive] +pub enum OAuthHttpRedirectPolicy { + /// Follow redirects using the client's normal limits. + #[default] + Follow, + /// Return the redirect response without following its location. + Stop, +} - type Future = std::pin::Pin< - Box> + Send + Sync + 'c>, - >; +/// A complete outbound HTTP operation requested by the OAuth implementation. +#[derive(Debug)] +#[non_exhaustive] +pub struct OAuthHttpRequest { + /// HTTP request with an absolute URI and buffered body. + pub request: HttpRequest, + /// Redirect behavior required by the OAuth operation. + pub redirect_policy: OAuthHttpRedirectPolicy, + /// Maximum duration for the operation, or no SDK-specified timeout. + pub timeout: Option, +} - fn call(&'c self, request: HttpRequest) -> Self::Future { +impl OAuthHttpRequest { + fn discovery(request: HttpRequest) -> Self { + Self { + request, + redirect_policy: OAuthHttpRedirectPolicy::Follow, + timeout: Some(DEFAULT_HTTP_TIMEOUT), + } + } + + fn credentials(request: HttpRequest) -> Self { + Self { + request, + redirect_policy: OAuthHttpRedirectPolicy::Stop, + timeout: Some(DEFAULT_HTTP_TIMEOUT), + } + } +} + +/// Error returned by a custom OAuth HTTP client. +#[derive(Debug, Error)] +#[error("{message}")] +pub struct OAuthHttpClientError { + message: String, +} + +impl OAuthHttpClientError { + /// Create an error from a transport-provided message. + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + } + } +} + +/// Future returned by [`OAuthHttpClient::execute`]. +pub type OAuthHttpClientFuture<'a> = + Pin> + Send + 'a>>; + +/// Executes every outbound HTTP request made by the OAuth state machine. +/// +/// Implementations may route requests through a remote execution environment. +/// They must honor the request's redirect policy and return the raw response +/// status, headers, and body. +pub trait OAuthHttpClient: Send + Sync { + /// Execute one OAuth HTTP operation. + fn execute(&self, request: OAuthHttpRequest) -> OAuthHttpClientFuture<'_>; +} + +struct ReqwestOAuthHttpClient { + follow_redirects: ReqwestClient, + stop_redirects: ReqwestClient, +} + +impl ReqwestOAuthHttpClient { + fn new(follow_redirects: ReqwestClient) -> Result { + let stop_redirects = ReqwestClient::builder() + .timeout(DEFAULT_HTTP_TIMEOUT) + .redirect(reqwest::redirect::Policy::none()) + .build() + .map_err(|error| AuthError::InternalError(error.to_string()))?; + Ok(Self { + follow_redirects, + stop_redirects, + }) + } +} + +impl OAuthHttpClient for ReqwestOAuthHttpClient { + fn execute(&self, request: OAuthHttpRequest) -> OAuthHttpClientFuture<'_> { Box::pin(async move { - let response = self - .0 - .execute(request.try_into().map_err(Box::new)?) + let OAuthHttpRequest { + request, + redirect_policy, + timeout, + } = request; + let client = match redirect_policy { + OAuthHttpRedirectPolicy::Follow => &self.follow_redirects, + OAuthHttpRedirectPolicy::Stop => &self.stop_redirects, + }; + let mut request = reqwest::Request::try_from(request) + .map_err(|error| OAuthHttpClientError::new(error.to_string()))?; + *request.timeout_mut() = timeout; + let response = client + .execute(request) .await - .map_err(Box::new)?; + .map_err(|error| OAuthHttpClientError::new(error.to_string()))?; let mut builder = oauth2::http::Response::builder() .status(response.status()) .version(response.version()); - - for (name, value) in response.headers().iter() { + for (name, value) in response.headers() { builder = builder.header(name, value); } - builder - .body(response.bytes().await.map_err(Box::new)?.to_vec()) - .map_err(HttpClientError::Http) + .body( + response + .bytes() + .await + .map_err(|error| OAuthHttpClientError::new(error.to_string()))? + .to_vec(), + ) + .map_err(|error| OAuthHttpClientError::new(error.to_string())) }) } } +struct OAuth2HttpClient<'a>(&'a dyn OAuthHttpClient); + +impl<'c> AsyncHttpClient<'c> for OAuth2HttpClient<'_> { + type Error = OAuthHttpClientError; + + type Future = std::pin::Pin< + Box> + Send + 'c>, + >; + + fn call(&'c self, request: HttpRequest) -> Self::Future { + self.0.execute(OAuthHttpRequest::credentials(request)) + } +} + const DEFAULT_EXCHANGE_URL: &str = "http://localhost"; /// Default OIDC Dynamic Client Registration `application_type` (SEP-837) @@ -639,7 +753,7 @@ impl Default for ScopeUpgradeConfig { /// oauth2 auth manager pub struct AuthorizationManager { - http_client: HttpClient, + http_client: Arc, metadata: Option, oauth_client: Option, credential_store: Arc, @@ -732,11 +846,23 @@ impl AuthorizationManager { /// create new auth manager with base url pub async fn new(base_url: U) -> Result { - let base_url = base_url.into_url()?; - let http_client = HttpClient::builder() - .timeout(Duration::from_secs(30)) + let http_client = ReqwestClient::builder() + .timeout(DEFAULT_HTTP_TIMEOUT) .build() .map_err(|e| AuthError::InternalError(e.to_string()))?; + Self::new_with_oauth_http_client( + base_url, + Arc::new(ReqwestOAuthHttpClient::new(http_client)?), + ) + .await + } + + /// Create an auth manager with a client used for every OAuth HTTP operation. + pub async fn new_with_oauth_http_client( + base_url: U, + http_client: Arc, + ) -> Result { + let base_url = base_url.into_url()?; let manager = Self { http_client, @@ -804,11 +930,20 @@ impl AuthorizationManager { Ok(false) } - pub fn with_client(&mut self, http_client: HttpClient) -> Result<(), AuthError> { - self.http_client = http_client; + pub fn with_client(&mut self, http_client: ReqwestClient) -> Result<(), AuthError> { + self.http_client = Arc::new(ReqwestOAuthHttpClient::new(http_client)?); Ok(()) } + /// Replace the HTTP client used by every OAuth network operation. + /// + /// This includes protected-resource and authorization-server discovery, + /// dynamic client registration, code exchange, refresh, and client + /// credentials. The ordinary MCP transport is configured separately. + pub fn with_oauth_http_client(&mut self, http_client: Arc) { + self.http_client = http_client; + } + /// discover oauth2 metadata (per SEP-985: Protected Resource Metadata first, then direct OAuth) pub async fn discover_metadata(&self) -> Result { if let Some(metadata) = self.discover_oauth_server_via_resource_metadata().await? { @@ -957,11 +1092,18 @@ impl AuthorizationManager { application_type: application_type.clone(), }; + let request = oauth2::http::Request::builder() + .method("POST") + .uri(registration_url) + .header(CONTENT_TYPE, "application/json") + .body( + serde_json::to_vec(®istration_request) + .map_err(|error| AuthError::RegistrationFailed(error.to_string()))?, + ) + .map_err(|error| AuthError::RegistrationFailed(error.to_string()))?; let response = match self .http_client - .post(registration_url) - .json(®istration_request) - .send() + .execute(OAuthHttpRequest::discovery(request)) .await { Ok(response) => response, @@ -975,10 +1117,7 @@ impl AuthorizationManager { if !response.status().is_success() { let status = response.status(); - let error_text = match response.text().await { - Ok(text) => text, - Err(_) => "cannot get error details".to_string(), - }; + let error_text = String::from_utf8_lossy(response.body()); return Err(AuthError::RegistrationFailed(format!( "HTTP {}: {}", @@ -986,16 +1125,17 @@ impl AuthorizationManager { ))); } - debug!("registration response: {:?}", response); - let reg_response = match response.json::().await { - Ok(response) => response, - Err(e) => { - return Err(AuthError::RegistrationFailed(format!( - "analyze response error: {}", - e - ))); - } - }; + debug!("registration response status: {:?}", response.status()); + let reg_response = + match serde_json::from_slice::(response.body()) { + Ok(response) => response, + Err(e) => { + return Err(AuthError::RegistrationFailed(format!( + "analyze response error: {}", + e + ))); + } + }; let config = OAuthClientConfig { client_id: reg_response.client_id, @@ -1287,10 +1427,6 @@ impl AuthorizationManager { // Reconstruct the PKCE verifier let pkce_verifier = stored_state.into_pkce_verifier(); - let http_client = reqwest::ClientBuilder::new() - .redirect(reqwest::redirect::Policy::none()) - .build() - .map_err(|e| AuthError::InternalError(e.to_string()))?; debug!("client_id: {:?}", oauth_client.client_id()); // exchange token @@ -1298,7 +1434,7 @@ impl AuthorizationManager { .exchange_code(AuthorizationCode::new(code.to_string())) .set_pkce_verifier(pkce_verifier) .add_extra_param("resource", self.base_url.to_string()) - .request_async(&OAuthReqwestClient(http_client)) + .request_async(&OAuth2HttpClient(self.http_client.as_ref())) .await { Ok(token) => token, @@ -1432,7 +1568,7 @@ impl AuthorizationManager { refresh_request = refresh_request.add_scope(Scope::new(scope)); } let token_result = refresh_request - .request_async(&OAuthReqwestClient(self.http_client.clone())) + .request_async(&OAuth2HttpClient(self.http_client.as_ref())) .await .map_err(|e| AuthError::TokenRefreshFailed(e.to_string()))?; @@ -1539,13 +1675,7 @@ impl AuthorizationManager { discovery_url: &Url, ) -> Result, AuthError> { debug!("discovery url: {:?}", discovery_url); - let response = match self - .http_client - .get(discovery_url.clone()) - .header(HEADER_MCP_PROTOCOL_VERSION, "2024-11-05") - .send() - .await - { + let response = match self.discovery_get(discovery_url).await { Ok(r) => r, Err(e) => { debug!("discovery request failed: {}", e); @@ -1558,8 +1688,7 @@ impl AuthorizationManager { return Ok(None); } - let body = response.text().await?; - match serde_json::from_str::(&body) { + match serde_json::from_slice::(response.body()) { Ok(metadata) => Ok(Some(metadata)), Err(err) => { debug!("Failed to parse metadata for {}: {}", discovery_url, err); @@ -1659,13 +1788,7 @@ impl AuthorizationManager { /// Extract the resource metadata url from the WWW-Authenticate header value. /// https://www.rfc-editor.org/rfc/rfc9728.html#name-use-of-www-authenticate-for async fn fetch_resource_metadata_url(&self, url: &Url) -> Result, AuthError> { - let response = match self - .http_client - .get(url.clone()) - .header(HEADER_MCP_PROTOCOL_VERSION, "2024-11-05") - .send() - .await - { + let response = match self.discovery_get(url).await { Ok(r) => r, Err(e) => { debug!("resource metadata probe failed: {}", e); @@ -1712,13 +1835,7 @@ impl AuthorizationManager { "resource metadata discovery url: {:?}", resource_metadata_url ); - let response = match self - .http_client - .get(resource_metadata_url.clone()) - .header(HEADER_MCP_PROTOCOL_VERSION, "2024-11-05") - .send() - .await - { + let response = match self.discovery_get(resource_metadata_url).await { Ok(r) => r, Err(e) => { debug!("resource metadata request failed: {}", e); @@ -1734,7 +1851,7 @@ impl AuthorizationManager { return Ok(None); } - let metadata = match response.json::().await { + let metadata = match serde_json::from_slice::(response.body()) { Ok(metadata) => metadata, Err(e) => { debug!("failed to parse resource metadata as JSON: {}", e); @@ -1744,6 +1861,18 @@ impl AuthorizationManager { Ok(Some(metadata)) } + async fn discovery_get(&self, url: &Url) -> Result { + let request = oauth2::http::Request::builder() + .method("GET") + .uri(url.as_str()) + .header(HEADER_MCP_PROTOCOL_VERSION, "2024-11-05") + .body(Vec::new()) + .map_err(|error| OAuthHttpClientError::new(error.to_string()))?; + self.http_client + .execute(OAuthHttpRequest::discovery(request)) + .await + } + /// extract parameters from WWW-Authenticate header (resource_metadata and scope) fn extract_www_authenticate_params(header: &str, base_url: &Url) -> WWWAuthenticateParams { let mut params = WWWAuthenticateParams::default(); @@ -2016,13 +2145,8 @@ impl AuthorizationManager { request = request.add_extra_param("resource", resource); } - let http_client = reqwest::ClientBuilder::new() - .redirect(reqwest::redirect::Policy::none()) - .build() - .map_err(|e| AuthError::InternalError(e.to_string()))?; - let token_result = match request - .request_async(&OAuthReqwestClient(http_client)) + .request_async(&OAuth2HttpClient(self.http_client.as_ref())) .await { Ok(token) => token, @@ -2135,28 +2259,25 @@ impl AuthorizationManager { } let body_str = serializer.finish(); - let http_client = reqwest::ClientBuilder::new() - .redirect(reqwest::redirect::Policy::none()) - .build() - .map_err(|e| AuthError::InternalError(e.to_string()))?; - - let response = http_client - .post(token_endpoint_url.as_str()) - .header("content-type", "application/x-www-form-urlencoded") - .body(body_str) - .send() + let request = oauth2::http::Request::builder() + .method("POST") + .uri(token_endpoint_url.as_str()) + .header(CONTENT_TYPE, "application/x-www-form-urlencoded") + .body(body_str.into_bytes()) + .map_err(|error| AuthError::ClientCredentialsError(error.to_string()))?; + let response = self + .http_client + .execute(OAuthHttpRequest::credentials(request)) .await .map_err(|e| { AuthError::ClientCredentialsError(format!("Token exchange request failed: {e}")) })?; let status = response.status(); - let body = response.bytes().await.map_err(|e| { - AuthError::ClientCredentialsError(format!("Failed to read token response: {e}")) - })?; + let body = response.body(); if !status.is_success() { - let msg = if let Ok(v) = serde_json::from_slice::(&body) { + let msg = if let Ok(v) = serde_json::from_slice::(body) { let error = v.get("error").and_then(|e| e.as_str()).unwrap_or("unknown"); let desc = v .get("error_description") @@ -2169,7 +2290,7 @@ impl AuthorizationManager { return Err(AuthError::ClientCredentialsError(msg)); } - let token_result = serde_json::from_slice::(&body).map_err(|e| { + let token_result = serde_json::from_slice::(body).map_err(|e| { AuthError::ClientCredentialsError(format!("Failed to parse token response: {e}")) })?; @@ -2415,12 +2536,12 @@ impl AuthorizationSession { /// http client extension, automatically add authorization header pub struct AuthorizedHttpClient { auth_manager: Arc, - inner_client: HttpClient, + inner_client: ReqwestClient, } impl AuthorizedHttpClient { /// create new authorized http client - pub fn new(auth_manager: Arc, client: Option) -> Self { + pub fn new(auth_manager: Arc, client: Option) -> Self { let inner_client = client.unwrap_or_default(); Self { auth_manager, @@ -2467,10 +2588,32 @@ pub enum OAuthState { } impl OAuthState { + fn oauth_http_client(&self) -> Arc { + match self { + OAuthState::Unauthorized(manager) | OAuthState::Authorized(manager) => { + Arc::clone(&manager.http_client) + } + OAuthState::Session(session) => Arc::clone(&session.auth_manager.http_client), + OAuthState::AuthorizedHttpClient(client) => { + Arc::clone(&client.auth_manager.http_client) + } + } + } + + async fn placeholder(&self) -> Result { + Ok(OAuthState::Unauthorized( + AuthorizationManager::new_with_oauth_http_client( + DEFAULT_EXCHANGE_URL, + self.oauth_http_client(), + ) + .await?, + )) + } + /// Create new OAuth state machine pub async fn new( base_url: U, - client: Option, + client: Option, ) -> Result { let mut manager = AuthorizationManager::new(base_url).await?; if let Some(client) = client { @@ -2480,6 +2623,16 @@ impl OAuthState { Ok(OAuthState::Unauthorized(manager)) } + /// Create an OAuth state machine that routes all OAuth HTTP operations + /// through the supplied client. + pub async fn new_with_oauth_http_client( + base_url: U, + client: Arc, + ) -> Result { + let manager = AuthorizationManager::new_with_oauth_http_client(base_url, client).await?; + Ok(OAuthState::Unauthorized(manager)) + } + /// Get client_id and OAuth credentials pub async fn get_credentials(&self) -> Result { // return client_id and credentials @@ -2500,10 +2653,12 @@ impl OAuthState { credentials: OAuthTokenResponse, ) -> Result<(), AuthError> { if let OAuthState::Unauthorized(manager) = self { - let mut manager = std::mem::replace( - manager, - AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?, - ); + let replacement = AuthorizationManager::new_with_oauth_http_client( + DEFAULT_EXCHANGE_URL, + Arc::clone(&manager.http_client), + ) + .await?; + let mut manager = std::mem::replace(manager, replacement); let granted_scopes: Vec = credentials .scopes() @@ -2553,10 +2708,8 @@ impl OAuthState { client_name: Option<&str>, client_metadata_url: Option<&str>, ) -> Result<(), AuthError> { - if let OAuthState::Unauthorized(mut manager) = std::mem::replace( - self, - OAuthState::Unauthorized(AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?), - ) { + let placeholder = self.placeholder().await?; + if let OAuthState::Unauthorized(mut manager) = std::mem::replace(self, placeholder) { debug!("start discovery"); let metadata = manager.discover_metadata().await?; manager.metadata = Some(metadata); @@ -2588,10 +2741,8 @@ impl OAuthState { /// complete authorization pub async fn complete_authorization(&mut self) -> Result<(), AuthError> { - if let OAuthState::Session(session) = std::mem::replace( - self, - OAuthState::Unauthorized(AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?), - ) { + let placeholder = self.placeholder().await?; + if let OAuthState::Session(session) = std::mem::replace(self, placeholder) { *self = OAuthState::Authorized(session.auth_manager); Ok(()) } else { @@ -2600,10 +2751,8 @@ impl OAuthState { } /// covert to authorized http client pub async fn to_authorized_http_client(&mut self) -> Result<(), AuthError> { - if let OAuthState::Authorized(manager) = std::mem::replace( - self, - OAuthState::Authorized(AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?), - ) { + let placeholder = self.placeholder().await?; + if let OAuthState::Authorized(manager) = std::mem::replace(self, placeholder) { *self = OAuthState::AuthorizedHttpClient(AuthorizedHttpClient::new( Arc::new(manager), None, @@ -2622,8 +2771,7 @@ impl OAuthState { required_scope: &str, redirect_uri: &str, ) -> Result { - let placeholder = - OAuthState::Authorized(AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?); + let placeholder = self.placeholder().await?; let old = std::mem::replace(self, placeholder); let OAuthState::Authorized(manager) = old else { *self = old; @@ -2755,10 +2903,8 @@ impl OAuthState { &mut self, config: ClientCredentialsConfig, ) -> Result<(), AuthError> { - let OAuthState::Unauthorized(mut manager) = std::mem::replace( - self, - OAuthState::Unauthorized(AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?), - ) else { + let placeholder = self.placeholder().await?; + let OAuthState::Unauthorized(mut manager) = std::mem::replace(self, placeholder) else { return Err(AuthError::InternalError( "Client credentials flow requires Unauthorized state".to_string(), )); @@ -2784,18 +2930,228 @@ impl OAuthState { #[cfg(test)] mod tests { - use std::{collections::HashMap, sync::Arc}; + use std::{ + collections::{HashMap, VecDeque}, + sync::{Arc, Mutex as StdMutex}, + }; - use oauth2::{AuthType, CsrfToken, PkceCodeVerifier}; + use oauth2::{AuthType, CsrfToken, HttpResponse, PkceCodeVerifier}; use url::Url; use super::{ AuthError, AuthorizationCallback, AuthorizationManager, AuthorizationMetadata, - InMemoryStateStore, OAuthClientConfig, ScopeUpgradeConfig, StateStore, - StoredAuthorizationState, is_https_url, + InMemoryStateStore, OAuthClientConfig, OAuthHttpClient, OAuthHttpClientError, + OAuthHttpClientFuture, OAuthHttpRedirectPolicy, OAuthHttpRequest, ScopeUpgradeConfig, + StateStore, StoredAuthorizationState, is_https_url, }; use crate::transport::auth::VendorExtraTokenFields; + #[derive(Clone, Debug, PartialEq, Eq)] + struct RecordedOAuthRequest { + method: String, + uri: String, + redirect_policy: OAuthHttpRedirectPolicy, + body: Vec, + } + + #[derive(Clone, Default)] + struct RecordingOAuthHttpClient { + requests: Arc>>, + responses: Arc>>, + } + + impl RecordingOAuthHttpClient { + fn with_responses(responses: Vec) -> Self { + Self { + responses: Arc::new(StdMutex::new(responses.into())), + ..Default::default() + } + } + + fn requests(&self) -> Vec { + self.requests.lock().unwrap().clone() + } + } + + impl OAuthHttpClient for RecordingOAuthHttpClient { + fn execute(&self, request: OAuthHttpRequest) -> OAuthHttpClientFuture<'_> { + self.requests.lock().unwrap().push(RecordedOAuthRequest { + method: request.request.method().to_string(), + uri: request.request.uri().to_string(), + redirect_policy: request.redirect_policy, + body: request.request.body().clone(), + }); + let response = self.responses.lock().unwrap().pop_front(); + Box::pin(async move { + response.ok_or_else(|| OAuthHttpClientError::new("missing fake response")) + }) + } + } + + fn http_response(status: u16, body: serde_json::Value) -> HttpResponse { + oauth2::http::Response::builder() + .status(status) + .body(serde_json::to_vec(&body).unwrap()) + .unwrap() + } + + #[tokio::test] + async fn custom_http_client_handles_protected_resource_discovery() { + let challenge = oauth2::http::Response::builder() + .status(401) + .header( + "www-authenticate", + r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource""#, + ) + .body(Vec::new()) + .unwrap(); + let client = RecordingOAuthHttpClient::with_responses(vec![ + challenge, + http_response( + 200, + serde_json::json!({ + "authorization_servers": ["https://auth.example.com"] + }), + ), + http_response( + 200, + serde_json::json!({ + "authorization_endpoint": "https://auth.example.com/authorize", + "token_endpoint": "https://auth.example.com/token" + }), + ), + ]); + let manager = AuthorizationManager::new_with_oauth_http_client( + "https://mcp.example.com/mcp", + Arc::new(client.clone()), + ) + .await + .unwrap(); + + let metadata = manager.discover_metadata().await.unwrap(); + + assert_eq!(metadata.token_endpoint, "https://auth.example.com/token"); + assert_eq!( + client.requests(), + vec![ + RecordedOAuthRequest { + method: "GET".to_string(), + uri: "https://mcp.example.com/mcp".to_string(), + redirect_policy: OAuthHttpRedirectPolicy::Follow, + body: Vec::new(), + }, + RecordedOAuthRequest { + method: "GET".to_string(), + uri: "https://mcp.example.com/.well-known/oauth-protected-resource".to_string(), + redirect_policy: OAuthHttpRedirectPolicy::Follow, + body: Vec::new(), + }, + RecordedOAuthRequest { + method: "GET".to_string(), + uri: "https://auth.example.com/.well-known/oauth-authorization-server" + .to_string(), + redirect_policy: OAuthHttpRedirectPolicy::Follow, + body: Vec::new(), + }, + ] + ); + } + + #[tokio::test] + async fn custom_http_client_handles_registration_exchange_and_refresh() { + let client = RecordingOAuthHttpClient::with_responses(vec![ + http_response( + 201, + serde_json::json!({ + "client_id": "test-client", + "redirect_uris": ["http://localhost/callback"] + }), + ), + http_response( + 200, + serde_json::json!({ + "access_token": "access-1", + "token_type": "bearer", + "refresh_token": "refresh-1", + "expires_in": 3600 + }), + ), + http_response( + 200, + serde_json::json!({ + "access_token": "access-2", + "token_type": "bearer", + "refresh_token": "refresh-2", + "expires_in": 3600 + }), + ), + ]); + let mut manager = AuthorizationManager::new_with_oauth_http_client( + "https://mcp.example.com/mcp", + Arc::new(client.clone()), + ) + .await + .unwrap(); + manager.set_metadata(AuthorizationMetadata { + authorization_endpoint: "https://auth.example.com/authorize".to_string(), + token_endpoint: "https://auth.example.com/token".to_string(), + registration_endpoint: Some("https://auth.example.com/register".to_string()), + response_types_supported: Some(vec!["code".to_string()]), + ..Default::default() + }); + manager + .register_client( + "Codex", + "http://localhost/callback", + &["profile", "offline_access"], + ) + .await + .unwrap(); + let authorization_url = manager + .get_authorization_url(&["profile", "offline_access"]) + .await + .unwrap(); + let state = Url::parse(&authorization_url) + .unwrap() + .query_pairs() + .find(|(name, _)| name == "state") + .unwrap() + .1 + .into_owned(); + + manager + .exchange_code_for_token("authorization-code", &state) + .await + .unwrap(); + manager.refresh_token().await.unwrap(); + + let requests = client.requests(); + let registration: serde_json::Value = serde_json::from_slice(&requests[0].body).unwrap(); + assert_eq!(registration["scope"], "profile offline_access"); + assert_eq!( + requests + .iter() + .map(|request| request.uri.as_str()) + .collect::>(), + vec![ + "https://auth.example.com/register", + "https://auth.example.com/token", + "https://auth.example.com/token", + ] + ); + assert_eq!( + requests + .iter() + .map(|request| request.redirect_policy) + .collect::>(), + vec![ + OAuthHttpRedirectPolicy::Follow, + OAuthHttpRedirectPolicy::Stop, + OAuthHttpRedirectPolicy::Stop, + ] + ); + } + // -- url helpers -- #[test] From e740a759c3e47b21a0d951afc57b3d73086c825d Mon Sep 17 00:00:00 2001 From: jif-oai Date: Tue, 16 Jun 2026 19:47:01 +0200 Subject: [PATCH 2/5] fix(auth): preserve configured client for refresh --- crates/rmcp/src/transport/auth.rs | 163 +++++++++++++++++++++++++----- 1 file changed, 138 insertions(+), 25 deletions(-) diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs index 48bb4554..b13a2c79 100644 --- a/crates/rmcp/src/transport/auth.rs +++ b/crates/rmcp/src/transport/auth.rs @@ -52,17 +52,21 @@ pub struct OAuthHttpRequest { impl OAuthHttpRequest { fn discovery(request: HttpRequest) -> Self { - Self { - request, - redirect_policy: OAuthHttpRedirectPolicy::Follow, - timeout: Some(DEFAULT_HTTP_TIMEOUT), - } + Self::with_redirect_policy(request, OAuthHttpRedirectPolicy::Follow) } + #[cfg(feature = "auth-client-credentials-jwt")] fn credentials(request: HttpRequest) -> Self { + Self::with_redirect_policy(request, OAuthHttpRedirectPolicy::Stop) + } + + fn with_redirect_policy( + request: HttpRequest, + redirect_policy: OAuthHttpRedirectPolicy, + ) -> Self { Self { request, - redirect_policy: OAuthHttpRedirectPolicy::Stop, + redirect_policy, timeout: Some(DEFAULT_HTTP_TIMEOUT), } } @@ -156,7 +160,19 @@ impl OAuthHttpClient for ReqwestOAuthHttpClient { } } -struct OAuth2HttpClient<'a>(&'a dyn OAuthHttpClient); +struct OAuth2HttpClient<'a> { + client: &'a dyn OAuthHttpClient, + redirect_policy: OAuthHttpRedirectPolicy, +} + +impl<'a> OAuth2HttpClient<'a> { + fn new(client: &'a dyn OAuthHttpClient, redirect_policy: OAuthHttpRedirectPolicy) -> Self { + Self { + client, + redirect_policy, + } + } +} impl<'c> AsyncHttpClient<'c> for OAuth2HttpClient<'_> { type Error = OAuthHttpClientError; @@ -166,7 +182,10 @@ impl<'c> AsyncHttpClient<'c> for OAuth2HttpClient<'_> { >; fn call(&'c self, request: HttpRequest) -> Self::Future { - self.0.execute(OAuthHttpRequest::credentials(request)) + self.client.execute(OAuthHttpRequest::with_redirect_policy( + request, + self.redirect_policy, + )) } } @@ -754,6 +773,8 @@ impl Default for ScopeUpgradeConfig { /// oauth2 auth manager pub struct AuthorizationManager { http_client: Arc, + // Preserve legacy reqwest refresh behavior without weakening custom clients. + refresh_redirect_policy: OAuthHttpRedirectPolicy, metadata: Option, oauth_client: Option, credential_store: Arc, @@ -850,9 +871,10 @@ impl AuthorizationManager { .timeout(DEFAULT_HTTP_TIMEOUT) .build() .map_err(|e| AuthError::InternalError(e.to_string()))?; - Self::new_with_oauth_http_client( + Self::new_inner( base_url, Arc::new(ReqwestOAuthHttpClient::new(http_client)?), + OAuthHttpRedirectPolicy::Follow, ) .await } @@ -861,11 +883,20 @@ impl AuthorizationManager { pub async fn new_with_oauth_http_client( base_url: U, http_client: Arc, + ) -> Result { + Self::new_inner(base_url, http_client, OAuthHttpRedirectPolicy::Stop).await + } + + async fn new_inner( + base_url: U, + http_client: Arc, + refresh_redirect_policy: OAuthHttpRedirectPolicy, ) -> Result { let base_url = base_url.into_url()?; let manager = Self { http_client, + refresh_redirect_policy, metadata: None, oauth_client: None, credential_store: Arc::new(InMemoryCredentialStore::new()), @@ -932,6 +963,7 @@ impl AuthorizationManager { pub fn with_client(&mut self, http_client: ReqwestClient) -> Result<(), AuthError> { self.http_client = Arc::new(ReqwestOAuthHttpClient::new(http_client)?); + self.refresh_redirect_policy = OAuthHttpRedirectPolicy::Follow; Ok(()) } @@ -942,6 +974,7 @@ impl AuthorizationManager { /// credentials. The ordinary MCP transport is configured separately. pub fn with_oauth_http_client(&mut self, http_client: Arc) { self.http_client = http_client; + self.refresh_redirect_policy = OAuthHttpRedirectPolicy::Stop; } /// discover oauth2 metadata (per SEP-985: Protected Resource Metadata first, then direct OAuth) @@ -1434,7 +1467,10 @@ impl AuthorizationManager { .exchange_code(AuthorizationCode::new(code.to_string())) .set_pkce_verifier(pkce_verifier) .add_extra_param("resource", self.base_url.to_string()) - .request_async(&OAuth2HttpClient(self.http_client.as_ref())) + .request_async(&OAuth2HttpClient::new( + self.http_client.as_ref(), + OAuthHttpRedirectPolicy::Stop, + )) .await { Ok(token) => token, @@ -1568,7 +1604,10 @@ impl AuthorizationManager { refresh_request = refresh_request.add_scope(Scope::new(scope)); } let token_result = refresh_request - .request_async(&OAuth2HttpClient(self.http_client.as_ref())) + .request_async(&OAuth2HttpClient::new( + self.http_client.as_ref(), + self.refresh_redirect_policy, + )) .await .map_err(|e| AuthError::TokenRefreshFailed(e.to_string()))?; @@ -2146,7 +2185,10 @@ impl AuthorizationManager { } let token_result = match request - .request_async(&OAuth2HttpClient(self.http_client.as_ref())) + .request_async(&OAuth2HttpClient::new( + self.http_client.as_ref(), + OAuthHttpRedirectPolicy::Stop, + )) .await { Ok(token) => token, @@ -2588,23 +2630,25 @@ pub enum OAuthState { } impl OAuthState { - fn oauth_http_client(&self) -> Arc { - match self { - OAuthState::Unauthorized(manager) | OAuthState::Authorized(manager) => { - Arc::clone(&manager.http_client) - } - OAuthState::Session(session) => Arc::clone(&session.auth_manager.http_client), - OAuthState::AuthorizedHttpClient(client) => { - Arc::clone(&client.auth_manager.http_client) - } - } + fn oauth_http_client_config(&self) -> (Arc, OAuthHttpRedirectPolicy) { + let manager = match self { + OAuthState::Unauthorized(manager) | OAuthState::Authorized(manager) => manager, + OAuthState::Session(session) => &session.auth_manager, + OAuthState::AuthorizedHttpClient(client) => &client.auth_manager, + }; + ( + Arc::clone(&manager.http_client), + manager.refresh_redirect_policy, + ) } async fn placeholder(&self) -> Result { + let (http_client, refresh_redirect_policy) = self.oauth_http_client_config(); Ok(OAuthState::Unauthorized( - AuthorizationManager::new_with_oauth_http_client( + AuthorizationManager::new_inner( DEFAULT_EXCHANGE_URL, - self.oauth_http_client(), + http_client, + refresh_redirect_policy, ) .await?, )) @@ -2653,9 +2697,10 @@ impl OAuthState { credentials: OAuthTokenResponse, ) -> Result<(), AuthError> { if let OAuthState::Unauthorized(manager) = self { - let replacement = AuthorizationManager::new_with_oauth_http_client( + let replacement = AuthorizationManager::new_inner( DEFAULT_EXCHANGE_URL, Arc::clone(&manager.http_client), + manager.refresh_redirect_policy, ) .await?; let mut manager = std::mem::replace(manager, replacement); @@ -4776,6 +4821,74 @@ mod tests { ); } + #[tokio::test] + async fn refresh_token_uses_client_configured_by_with_client() { + use axum::{Router, body::Body, http::Response, routing::post}; + + let received_header = Arc::new(std::sync::Mutex::new(None)); + let received_header_clone = Arc::clone(&received_header); + let app = Router::new().route( + "/token", + post(move |headers: axum::http::HeaderMap| { + let received_header = Arc::clone(&received_header_clone); + async move { + *received_header.lock().unwrap() = headers + .get("x-custom-client") + .and_then(|value| value.to_str().ok()) + .map(str::to_string); + Response::builder() + .status(200) + .header("content-type", "application/json") + .body(Body::from( + r#"{"access_token":"new-token","token_type":"Bearer","expires_in":3600}"#, + )) + .unwrap() + } + }), + ); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { axum::serve(listener, app).await.unwrap() }); + + let mut manager = manager_with_metadata(Some(AuthorizationMetadata { + authorization_endpoint: format!("http://{addr}/authorize"), + token_endpoint: format!("http://{addr}/token"), + ..Default::default() + })) + .await; + let mut default_headers = reqwest::header::HeaderMap::new(); + default_headers.insert("x-custom-client", "configured".parse().unwrap()); + manager + .with_client( + reqwest::Client::builder() + .default_headers(default_headers) + .build() + .unwrap(), + ) + .unwrap(); + manager.configure_client(test_client_config()).unwrap(); + manager + .credential_store + .save(StoredCredentials { + client_id: "my-client".to_string(), + token_response: Some(make_token_response_with_refresh( + "old-token", + "my-refresh-token", + )), + granted_scopes: vec![], + token_received_at: Some(AuthorizationManager::now_epoch_secs()), + }) + .await + .unwrap(); + + manager.refresh_token().await.unwrap(); + + assert_eq!( + received_header.lock().unwrap().as_deref(), + Some("configured") + ); + } + async fn start_token_server() -> (String, Arc>>) { use axum::{Router, body::Body, http::Response, routing::post}; let captured: Arc>> = Arc::new(std::sync::Mutex::new(None)); From aeb7e0d11e9efc101aab42e8f90a220c7c86e925 Mon Sep 17 00:00:00 2001 From: jif-oai Date: Tue, 16 Jun 2026 20:25:32 +0200 Subject: [PATCH 3/5] fix(auth): harden OAuth HTTP adapter --- crates/rmcp/src/transport/auth.rs | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs index b13a2c79..f5166fec 100644 --- a/crates/rmcp/src/transport/auth.rs +++ b/crates/rmcp/src/transport/auth.rs @@ -7,6 +7,7 @@ use std::{ }; use async_trait::async_trait; +use futures::StreamExt; use oauth2::{ AsyncHttpClient, AuthType, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields, ExtraTokenFields, HttpRequest, HttpResponse, PkceCodeChallenge, @@ -26,6 +27,7 @@ use tracing::{debug, warn}; use crate::transport::common::http_header::HEADER_MCP_PROTOCOL_VERSION; const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(30); +const MAX_OAUTH_HTTP_RESPONSE_BODY_BYTES: usize = 1024 * 1024; /// Redirect handling requested for an outbound OAuth HTTP operation. #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] @@ -39,14 +41,14 @@ pub enum OAuthHttpRedirectPolicy { } /// A complete outbound HTTP operation requested by the OAuth implementation. -#[derive(Debug)] #[non_exhaustive] pub struct OAuthHttpRequest { /// HTTP request with an absolute URI and buffered body. pub request: HttpRequest, /// Redirect behavior required by the OAuth operation. pub redirect_policy: OAuthHttpRedirectPolicy, - /// Maximum duration for the operation, or no SDK-specified timeout. + /// Suggested maximum duration for the operation, or no SDK-specified timeout. + /// Implementations with their own timeout policy may retain it instead. pub timeout: Option, } @@ -127,15 +129,14 @@ impl OAuthHttpClient for ReqwestOAuthHttpClient { let OAuthHttpRequest { request, redirect_policy, - timeout, + .. } = request; let client = match redirect_policy { OAuthHttpRedirectPolicy::Follow => &self.follow_redirects, OAuthHttpRedirectPolicy::Stop => &self.stop_redirects, }; - let mut request = reqwest::Request::try_from(request) + let request = reqwest::Request::try_from(request) .map_err(|error| OAuthHttpClientError::new(error.to_string()))?; - *request.timeout_mut() = timeout; let response = client .execute(request) .await @@ -147,14 +148,19 @@ impl OAuthHttpClient for ReqwestOAuthHttpClient { for (name, value) in response.headers() { builder = builder.header(name, value); } + let mut body = Vec::new(); + let mut body_stream = response.bytes_stream(); + while let Some(chunk) = body_stream.next().await { + let chunk = chunk.map_err(|error| OAuthHttpClientError::new(error.to_string()))?; + if chunk.len() > MAX_OAUTH_HTTP_RESPONSE_BODY_BYTES - body.len() { + return Err(OAuthHttpClientError::new(format!( + "OAuth HTTP response body exceeds {MAX_OAUTH_HTTP_RESPONSE_BODY_BYTES} bytes" + ))); + } + body.extend_from_slice(&chunk); + } builder - .body( - response - .bytes() - .await - .map_err(|error| OAuthHttpClientError::new(error.to_string()))? - .to_vec(), - ) + .body(body) .map_err(|error| OAuthHttpClientError::new(error.to_string())) }) } From af170d2ea6cb2e5c9cbd859de6b0a335b9c0f63a Mon Sep 17 00:00:00 2001 From: jif-oai Date: Tue, 16 Jun 2026 20:40:17 +0200 Subject: [PATCH 4/5] refactor(auth): simplify OAuth HTTP plumbing --- crates/rmcp/src/transport/auth.rs | 78 +++++++++++-------------------- 1 file changed, 27 insertions(+), 51 deletions(-) diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs index f5166fec..3fcacf91 100644 --- a/crates/rmcp/src/transport/auth.rs +++ b/crates/rmcp/src/transport/auth.rs @@ -53,19 +53,7 @@ pub struct OAuthHttpRequest { } impl OAuthHttpRequest { - fn discovery(request: HttpRequest) -> Self { - Self::with_redirect_policy(request, OAuthHttpRedirectPolicy::Follow) - } - - #[cfg(feature = "auth-client-credentials-jwt")] - fn credentials(request: HttpRequest) -> Self { - Self::with_redirect_policy(request, OAuthHttpRedirectPolicy::Stop) - } - - fn with_redirect_policy( - request: HttpRequest, - redirect_policy: OAuthHttpRedirectPolicy, - ) -> Self { + fn new(request: HttpRequest, redirect_policy: OAuthHttpRedirectPolicy) -> Self { Self { request, redirect_policy, @@ -171,15 +159,6 @@ struct OAuth2HttpClient<'a> { redirect_policy: OAuthHttpRedirectPolicy, } -impl<'a> OAuth2HttpClient<'a> { - fn new(client: &'a dyn OAuthHttpClient, redirect_policy: OAuthHttpRedirectPolicy) -> Self { - Self { - client, - redirect_policy, - } - } -} - impl<'c> AsyncHttpClient<'c> for OAuth2HttpClient<'_> { type Error = OAuthHttpClientError; @@ -188,10 +167,8 @@ impl<'c> AsyncHttpClient<'c> for OAuth2HttpClient<'_> { >; fn call(&'c self, request: HttpRequest) -> Self::Future { - self.client.execute(OAuthHttpRequest::with_redirect_policy( - request, - self.redirect_policy, - )) + self.client + .execute(OAuthHttpRequest::new(request, self.redirect_policy)) } } @@ -973,16 +950,6 @@ impl AuthorizationManager { Ok(()) } - /// Replace the HTTP client used by every OAuth network operation. - /// - /// This includes protected-resource and authorization-server discovery, - /// dynamic client registration, code exchange, refresh, and client - /// credentials. The ordinary MCP transport is configured separately. - pub fn with_oauth_http_client(&mut self, http_client: Arc) { - self.http_client = http_client; - self.refresh_redirect_policy = OAuthHttpRedirectPolicy::Stop; - } - /// discover oauth2 metadata (per SEP-985: Protected Resource Metadata first, then direct OAuth) pub async fn discover_metadata(&self) -> Result { if let Some(metadata) = self.discover_oauth_server_via_resource_metadata().await? { @@ -1142,7 +1109,10 @@ impl AuthorizationManager { .map_err(|error| AuthError::RegistrationFailed(error.to_string()))?; let response = match self .http_client - .execute(OAuthHttpRequest::discovery(request)) + .execute(OAuthHttpRequest::new( + request, + OAuthHttpRedirectPolicy::Follow, + )) .await { Ok(response) => response, @@ -1473,10 +1443,10 @@ impl AuthorizationManager { .exchange_code(AuthorizationCode::new(code.to_string())) .set_pkce_verifier(pkce_verifier) .add_extra_param("resource", self.base_url.to_string()) - .request_async(&OAuth2HttpClient::new( - self.http_client.as_ref(), - OAuthHttpRedirectPolicy::Stop, - )) + .request_async(&OAuth2HttpClient { + client: self.http_client.as_ref(), + redirect_policy: OAuthHttpRedirectPolicy::Stop, + }) .await { Ok(token) => token, @@ -1610,10 +1580,10 @@ impl AuthorizationManager { refresh_request = refresh_request.add_scope(Scope::new(scope)); } let token_result = refresh_request - .request_async(&OAuth2HttpClient::new( - self.http_client.as_ref(), - self.refresh_redirect_policy, - )) + .request_async(&OAuth2HttpClient { + client: self.http_client.as_ref(), + redirect_policy: self.refresh_redirect_policy, + }) .await .map_err(|e| AuthError::TokenRefreshFailed(e.to_string()))?; @@ -1914,7 +1884,10 @@ impl AuthorizationManager { .body(Vec::new()) .map_err(|error| OAuthHttpClientError::new(error.to_string()))?; self.http_client - .execute(OAuthHttpRequest::discovery(request)) + .execute(OAuthHttpRequest::new( + request, + OAuthHttpRedirectPolicy::Follow, + )) .await } @@ -2191,10 +2164,10 @@ impl AuthorizationManager { } let token_result = match request - .request_async(&OAuth2HttpClient::new( - self.http_client.as_ref(), - OAuthHttpRedirectPolicy::Stop, - )) + .request_async(&OAuth2HttpClient { + client: self.http_client.as_ref(), + redirect_policy: OAuthHttpRedirectPolicy::Stop, + }) .await { Ok(token) => token, @@ -2315,7 +2288,10 @@ impl AuthorizationManager { .map_err(|error| AuthError::ClientCredentialsError(error.to_string()))?; let response = self .http_client - .execute(OAuthHttpRequest::credentials(request)) + .execute(OAuthHttpRequest::new( + request, + OAuthHttpRedirectPolicy::Stop, + )) .await .map_err(|e| { AuthError::ClientCredentialsError(format!("Token exchange request failed: {e}")) From 7b26799e131648fe9d78db8d9a9146c68dce480f Mon Sep 17 00:00:00 2001 From: jif-oai Date: Wed, 17 Jun 2026 10:38:48 +0200 Subject: [PATCH 5/5] fix(auth): stop refresh token redirects by default --- crates/rmcp/src/transport/auth.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs index 3fcacf91..91dcff87 100644 --- a/crates/rmcp/src/transport/auth.rs +++ b/crates/rmcp/src/transport/auth.rs @@ -857,7 +857,7 @@ impl AuthorizationManager { Self::new_inner( base_url, Arc::new(ReqwestOAuthHttpClient::new(http_client)?), - OAuthHttpRedirectPolicy::Follow, + OAuthHttpRedirectPolicy::Stop, ) .await }