From 2b9b31df8582087ea3d8606859765789cc625bf4 Mon Sep 17 00:00:00 2001 From: King Star Date: Fri, 19 Jun 2026 20:31:35 +0800 Subject: [PATCH] fix: fail orphaned streamable HTTP responses on reinit --- .../src/transport/streamable_http_client.rs | 127 +++++++++- .../test_streamable_http_stale_session.rs | 223 +++++++++++++++++- 2 files changed, 343 insertions(+), 7 deletions(-) diff --git a/crates/rmcp/src/transport/streamable_http_client.rs b/crates/rmcp/src/transport/streamable_http_client.rs index a2c1a7b19..0c9859d69 100644 --- a/crates/rmcp/src/transport/streamable_http_client.rs +++ b/crates/rmcp/src/transport/streamable_http_client.rs @@ -1,4 +1,9 @@ -use std::{borrow::Cow, collections::HashMap, sync::Arc, time::Duration}; +use std::{ + borrow::Cow, + collections::{HashMap, HashSet}, + sync::Arc, + time::Duration, +}; use futures::{Stream, StreamExt, future::BoxFuture, stream::BoxStream}; use http::{HeaderName, HeaderValue}; @@ -12,8 +17,8 @@ use super::common::client_side_sse::{ExponentialBackoff, SseRetryPolicy, SseStre use crate::{ RoleClient, model::{ - ClientJsonRpcMessage, ClientNotification, InitializedNotification, ServerJsonRpcMessage, - ServerResult, + ClientJsonRpcMessage, ClientNotification, ErrorData, InitializedNotification, RequestId, + ServerJsonRpcMessage, ServerResult, }, transport::{ common::client_side_sse::SseAutoReconnectStream, @@ -298,6 +303,79 @@ impl StreamableHttpClientWorker { } impl StreamableHttpClientWorker { + fn client_request_id(message: &ClientJsonRpcMessage) -> Option { + match message { + ClientJsonRpcMessage::Request(request) => Some(request.id.clone()), + _ => None, + } + } + + fn server_response_id(message: &ServerJsonRpcMessage) -> Option<&RequestId> { + match message { + ServerJsonRpcMessage::Response(response) => Some(&response.id), + ServerJsonRpcMessage::Error(error) => error.id.as_ref(), + _ => None, + } + } + + fn mark_stream_response_pending( + pending_stream_response_ids: &mut HashSet, + request_id: Option, + ) { + if let Some(request_id) = request_id { + pending_stream_response_ids.insert(request_id); + } + } + + fn clear_stream_response_pending( + pending_stream_response_ids: &mut HashSet, + message: &ServerJsonRpcMessage, + ) { + if let Some(id) = Self::server_response_id(message) { + pending_stream_response_ids.remove(id); + } + } + + async fn drain_queued_stream_messages( + sse_worker_rx: &mut tokio::sync::mpsc::Receiver, + context: &mut super::worker::WorkerContext, + pending_stream_response_ids: &mut HashSet, + ) -> Result<(), WorkerQuitReason>> { + loop { + match sse_worker_rx.try_recv() { + Ok(message) => { + Self::clear_stream_response_pending(pending_stream_response_ids, &message); + context.send_to_handler(message).await?; + } + Err(tokio::sync::mpsc::error::TryRecvError::Empty) => return Ok(()), + Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => return Ok(()), + } + } + } + + async fn fail_pending_stream_responses( + context: &mut super::worker::WorkerContext, + pending_stream_response_ids: &mut HashSet, + ) -> Result<(), WorkerQuitReason>> { + if pending_stream_response_ids.is_empty() { + return Ok(()); + } + + let pending_ids = pending_stream_response_ids.drain().collect::>(); + for id in pending_ids { + context + .send_to_handler(ServerJsonRpcMessage::error( + ErrorData::internal_error( + "streamable HTTP session was re-initialized before the response arrived", + None, + ), + Some(id), + )) + .await?; + } + Ok(()) + } + /// Convert a raw SSE stream into a JSON-RPC message stream without /// reconnection logic. fn raw_sse_to_jsonrpc( @@ -557,6 +635,7 @@ impl Worker for StreamableHttpClientWorker { StreamResult(Result<(), StreamableHttpError>), } let mut streams = tokio::task::JoinSet::new(); + let mut pending_stream_response_ids = HashSet::new(); if let Some(session_id) = &session_id { let client = self.client.clone(); let uri = config.uri.clone(); @@ -646,6 +725,7 @@ impl Worker for StreamableHttpClientWorker { match event { Event::ClientMessage(send_request) => { let WorkerSendRequest { message, responder } = send_request; + let request_id = Self::client_request_id(&message); // Pass a clone to the first attempt so `message` is retained for a // potential re-init retry. `post_message` takes ownership and the // trait cannot be changed, so the clone is unavoidable. @@ -679,9 +759,26 @@ impl Worker for StreamableHttpClientWorker { .await { Ok((new_session_id, new_protocol_headers)) => { - // Old streams hold the stale session ID; abort them - // so the new standalone SSE stream takes over. + // Old streams hold the stale session ID. Stop them first + // so no late stale-session messages can arrive after the + // pending requests below are completed. streams.abort_all(); + while streams.join_next().await.is_some() {} + + // Forward any already queued response messages and fail + // the remaining accepted requests so callers do not wait + // forever for responses that can no longer arrive. + Self::drain_queued_stream_messages( + &mut sse_worker_rx, + &mut context, + &mut pending_stream_response_ids, + ) + .await?; + Self::fail_pending_stream_responses( + &mut context, + &mut pending_stream_response_ids, + ) + .await?; session_id = new_session_id; protocol_headers = new_protocol_headers; @@ -765,6 +862,10 @@ impl Worker for StreamableHttpClientWorker { match retry_response { Err(e) => Err(e), Ok(StreamableHttpPostResponse::Accepted) => { + Self::mark_stream_response_pending( + &mut pending_stream_response_ids, + request_id, + ); tracing::trace!( "client message accepted after re-init" ); @@ -775,6 +876,10 @@ impl Worker for StreamableHttpClientWorker { Ok(()) } Ok(StreamableHttpPostResponse::Sse(stream, ..)) => { + Self::mark_stream_response_pending( + &mut pending_stream_response_ids, + request_id, + ); streams.spawn(Self::execute_sse_stream( Self::raw_sse_to_jsonrpc(stream), sse_worker_tx.clone(), @@ -792,6 +897,10 @@ impl Worker for StreamableHttpClientWorker { } Err(e) => Err(e), Ok(StreamableHttpPostResponse::Accepted) => { + Self::mark_stream_response_pending( + &mut pending_stream_response_ids, + request_id, + ); tracing::trace!("client message accepted"); Ok(()) } @@ -800,6 +909,10 @@ impl Worker for StreamableHttpClientWorker { Ok(()) } Ok(StreamableHttpPostResponse::Sse(stream, ..)) => { + Self::mark_stream_response_pending( + &mut pending_stream_response_ids, + request_id, + ); streams.spawn(Self::execute_sse_stream( Self::raw_sse_to_jsonrpc(stream), sse_worker_tx.clone(), @@ -813,6 +926,10 @@ impl Worker for StreamableHttpClientWorker { let _ = responder.send(send_result); } Event::ServerMessage(json_rpc_message) => { + Self::clear_stream_response_pending( + &mut pending_stream_response_ids, + &json_rpc_message, + ); // send the message to the handler if let Err(e) = context.send_to_handler(json_rpc_message).await { break 'main_loop Err(e); diff --git a/crates/rmcp/tests/test_streamable_http_stale_session.rs b/crates/rmcp/tests/test_streamable_http_stale_session.rs index d96d83c73..9aa0309ad 100644 --- a/crates/rmcp/tests/test_streamable_http_stale_session.rs +++ b/crates/rmcp/tests/test_streamable_http_stale_session.rs @@ -5,15 +5,25 @@ not(feature = "local") ))] -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::{HashMap, VecDeque}, + sync::Arc, +}; +use futures::stream; +use http::{HeaderName, HeaderValue}; use rmcp::{ ServiceError, ServiceExt, - model::{ClientJsonRpcMessage, ClientRequest, PingRequest, RequestId}, + model::{ + CallToolRequestParams, ClientInfo, ClientJsonRpcMessage, ClientRequest, ErrorCode, + ErrorData, InitializeResult, PingRequest, RequestId, ServerCapabilities, + ServerJsonRpcMessage, ServerResult, + }, transport::{ StreamableHttpClientTransport, streamable_http_client::{ StreamableHttpClient, StreamableHttpClientTransportConfig, StreamableHttpError, + StreamableHttpPostResponse, }, streamable_http_server::{ StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager, @@ -25,6 +35,215 @@ use tokio_util::sync::CancellationToken; mod common; use common::calculator::Calculator; +#[derive(Debug, thiserror::Error)] +#[error("mock streamable http client error")] +struct MockClientError; + +#[derive(Clone)] +struct ReinitDropsAcceptedResponseClient { + state: Arc>, + stale_stream_cancelled: CancellationToken, + initial_request_accepted: Arc, + final_retry_accepted: Arc, +} + +struct MockState { + session_counter: usize, + posts: VecDeque, +} + +enum MockPost { + Initialize, + Initialized, + Accepted, + SessionExpired, +} + +impl ReinitDropsAcceptedResponseClient { + fn new() -> Self { + Self { + state: Arc::new(tokio::sync::Mutex::new(MockState { + session_counter: 0, + posts: VecDeque::from([ + MockPost::Initialize, + MockPost::Initialized, + MockPost::Accepted, + MockPost::SessionExpired, + MockPost::Initialize, + MockPost::Initialized, + MockPost::Accepted, + ]), + })), + stale_stream_cancelled: CancellationToken::new(), + initial_request_accepted: Arc::new(tokio::sync::Semaphore::new(0)), + final_retry_accepted: Arc::new(tokio::sync::Semaphore::new(0)), + } + } +} + +impl StreamableHttpClient for ReinitDropsAcceptedResponseClient { + type Error = MockClientError; + + async fn post_message( + &self, + _uri: Arc, + message: ClientJsonRpcMessage, + _session_id: Option>, + _auth_header: Option, + _custom_headers: HashMap, + ) -> Result> { + let mut state = self.state.lock().await; + match state + .posts + .pop_front() + .expect("unexpected mock post_message call") + { + MockPost::Initialize => { + state.session_counter += 1; + let id = match message { + ClientJsonRpcMessage::Request(request) => request.id, + other => panic!("expected initialize request, got {other:?}"), + }; + Ok(StreamableHttpPostResponse::Json( + ServerJsonRpcMessage::response( + ServerResult::InitializeResult(InitializeResult::new( + ServerCapabilities::builder().enable_tools().build(), + )), + id, + ), + Some(format!("session-{}", state.session_counter)), + )) + } + MockPost::Initialized => { + assert!( + matches!(message, ClientJsonRpcMessage::Notification(_)), + "expected initialized notification, got {message:?}" + ); + Ok(StreamableHttpPostResponse::Accepted) + } + MockPost::Accepted => { + if state.posts.is_empty() { + self.final_retry_accepted.add_permits(1); + } else { + self.initial_request_accepted.add_permits(1); + } + Ok(StreamableHttpPostResponse::Accepted) + } + MockPost::SessionExpired => Err(StreamableHttpError::SessionExpired), + } + } + + async fn delete_session( + &self, + _uri: Arc, + _session_id: Arc, + _auth_header: Option, + _custom_headers: HashMap, + ) -> Result<(), StreamableHttpError> { + Ok(()) + } + + async fn get_stream( + &self, + _uri: Arc, + session_id: Arc, + _last_event_id: Option, + _auth_header: Option, + _custom_headers: HashMap, + ) -> Result< + futures::stream::BoxStream<'static, Result>, + StreamableHttpError, + > { + if session_id.as_ref() == "session-1" { + let cancel = self.stale_stream_cancelled.clone(); + Ok(Box::pin(stream::once(async move { + cancel.cancelled_owned().await; + Ok(sse_stream::Sse { + event: None, + data: Some( + serde_json::to_string(&ServerJsonRpcMessage::error( + ErrorData::new( + ErrorCode::INTERNAL_ERROR, + "stale stream should not deliver after re-init", + None, + ), + Some(RequestId::Number(2)), + )) + .expect("serialize stale error"), + ), + id: None, + retry: None, + }) + }))) + } else { + Ok(Box::pin(stream::pending())) + } + } +} + +#[tokio::test] +async fn test_reinitialization_completes_accepted_sse_request_instead_of_hanging() +-> anyhow::Result<()> { + let mock_client = ReinitDropsAcceptedResponseClient::new(); + let initial_request_accepted = mock_client.initial_request_accepted.clone(); + let final_retry_accepted = mock_client.final_retry_accepted.clone(); + let transport = StreamableHttpClientTransport::with_client( + mock_client, + StreamableHttpClientTransportConfig::with_uri("mock://mcp"), + ); + let mut client = ClientInfo::default().serve(transport).await?; + + let peer = client.peer().clone(); + let pending_call = tokio::spawn(async move { + peer.call_tool(CallToolRequestParams::new("slow_tool")) + .await + }); + + let _initial_permit = tokio::time::timeout( + std::time::Duration::from_secs(1), + initial_request_accepted.acquire(), + ) + .await + .expect("initial accepted request should be observed") + .expect("initial accepted request semaphore should stay open"); + + let reinit_trigger = { + let peer = client.peer().clone(); + tokio::spawn(async move { peer.list_tools(None).await }) + }; + + let _retry_permit = tokio::time::timeout( + std::time::Duration::from_secs(1), + final_retry_accepted.acquire(), + ) + .await + .expect("re-initialization retry should be accepted") + .expect("re-initialization retry semaphore should stay open"); + + let err = tokio::time::timeout(std::time::Duration::from_millis(100), pending_call) + .await + .expect("accepted SSE-backed request should complete instead of hanging")? + .expect_err( + "accepted request should fail after re-initialization drops its response stream", + ); + + match err { + ServiceError::McpError(error) => { + assert_eq!(error.code, ErrorCode::INTERNAL_ERROR); + assert!( + error.message.contains("session"), + "expected session-related error, got: {error}" + ); + } + other => panic!("expected McpError for orphaned request, got: {other:?}"), + } + + reinit_trigger.abort(); + let _ = client.close().await; + + Ok(()) +} + #[tokio::test] async fn test_stale_session_id_returns_status_aware_error() -> anyhow::Result<()> { let ct = CancellationToken::new();