Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 122 additions & 5 deletions crates/rmcp/src/transport/streamable_http_client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
use std::{borrow::Cow, collections::HashMap, sync::Arc, time::Duration};
use std::{
borrow::Cow,
collections::{HashMap, HashSet},
sync::Arc,
time::Duration,
};

use futures::{Stream, StreamExt, future::BoxFuture, stream::BoxStream};
use http::{HeaderName, HeaderValue};
Expand All @@ -12,8 +17,8 @@ use super::common::client_side_sse::{ExponentialBackoff, SseRetryPolicy, SseStre
use crate::{
RoleClient,
model::{
ClientJsonRpcMessage, ClientNotification, InitializedNotification, ServerJsonRpcMessage,
ServerResult,
ClientJsonRpcMessage, ClientNotification, ErrorData, InitializedNotification, RequestId,
ServerJsonRpcMessage, ServerResult,
},
transport::{
common::client_side_sse::SseAutoReconnectStream,
Expand Down Expand Up @@ -298,6 +303,79 @@ impl<C: StreamableHttpClient> StreamableHttpClientWorker<C> {
}

impl<C: StreamableHttpClient> StreamableHttpClientWorker<C> {
fn client_request_id(message: &ClientJsonRpcMessage) -> Option<RequestId> {
match message {
ClientJsonRpcMessage::Request(request) => Some(request.id.clone()),
_ => None,
}
}

fn server_response_id(message: &ServerJsonRpcMessage) -> Option<&RequestId> {
match message {
ServerJsonRpcMessage::Response(response) => Some(&response.id),
ServerJsonRpcMessage::Error(error) => error.id.as_ref(),
_ => None,
}
}

fn mark_stream_response_pending(
pending_stream_response_ids: &mut HashSet<RequestId>,
request_id: Option<RequestId>,
) {
if let Some(request_id) = request_id {
pending_stream_response_ids.insert(request_id);
}
}

fn clear_stream_response_pending(
pending_stream_response_ids: &mut HashSet<RequestId>,
message: &ServerJsonRpcMessage,
) {
if let Some(id) = Self::server_response_id(message) {
pending_stream_response_ids.remove(id);
}
}

async fn drain_queued_stream_messages(
sse_worker_rx: &mut tokio::sync::mpsc::Receiver<ServerJsonRpcMessage>,
context: &mut super::worker::WorkerContext<Self>,
pending_stream_response_ids: &mut HashSet<RequestId>,
) -> Result<(), WorkerQuitReason<StreamableHttpError<C::Error>>> {
loop {
match sse_worker_rx.try_recv() {
Ok(message) => {
Self::clear_stream_response_pending(pending_stream_response_ids, &message);
context.send_to_handler(message).await?;
}
Err(tokio::sync::mpsc::error::TryRecvError::Empty) => return Ok(()),
Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => return Ok(()),
}
}
}

async fn fail_pending_stream_responses(
context: &mut super::worker::WorkerContext<Self>,
pending_stream_response_ids: &mut HashSet<RequestId>,
) -> Result<(), WorkerQuitReason<StreamableHttpError<C::Error>>> {
if pending_stream_response_ids.is_empty() {
return Ok(());
}

let pending_ids = pending_stream_response_ids.drain().collect::<Vec<_>>();
for id in pending_ids {
context
.send_to_handler(ServerJsonRpcMessage::error(
ErrorData::internal_error(
"streamable HTTP session was re-initialized before the response arrived",
None,
),
Some(id),
))
.await?;
}
Ok(())
}

/// Convert a raw SSE stream into a JSON-RPC message stream without
/// reconnection logic.
fn raw_sse_to_jsonrpc(
Expand Down Expand Up @@ -557,6 +635,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
StreamResult(Result<(), StreamableHttpError<E>>),
}
let mut streams = tokio::task::JoinSet::new();
let mut pending_stream_response_ids = HashSet::new();
if let Some(session_id) = &session_id {
let client = self.client.clone();
let uri = config.uri.clone();
Expand Down Expand Up @@ -646,6 +725,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
match event {
Event::ClientMessage(send_request) => {
let WorkerSendRequest { message, responder } = send_request;
let request_id = Self::client_request_id(&message);
// Pass a clone to the first attempt so `message` is retained for a
// potential re-init retry. `post_message` takes ownership and the
// trait cannot be changed, so the clone is unavoidable.
Expand Down Expand Up @@ -679,9 +759,26 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
.await
{
Ok((new_session_id, new_protocol_headers)) => {
// Old streams hold the stale session ID; abort them
// so the new standalone SSE stream takes over.
// Old streams hold the stale session ID. Stop them first
// so no late stale-session messages can arrive after the
// pending requests below are completed.
streams.abort_all();
while streams.join_next().await.is_some() {}

// Forward any already queued response messages and fail
// the remaining accepted requests so callers do not wait
// forever for responses that can no longer arrive.
Self::drain_queued_stream_messages(
&mut sse_worker_rx,
&mut context,
&mut pending_stream_response_ids,
)
.await?;
Self::fail_pending_stream_responses(
&mut context,
&mut pending_stream_response_ids,
)
.await?;

session_id = new_session_id;
protocol_headers = new_protocol_headers;
Expand Down Expand Up @@ -765,6 +862,10 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
match retry_response {
Err(e) => Err(e),
Ok(StreamableHttpPostResponse::Accepted) => {
Self::mark_stream_response_pending(
&mut pending_stream_response_ids,
request_id,
);
tracing::trace!(
"client message accepted after re-init"
);
Expand All @@ -775,6 +876,10 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
Ok(())
}
Ok(StreamableHttpPostResponse::Sse(stream, ..)) => {
Self::mark_stream_response_pending(
&mut pending_stream_response_ids,
request_id,
);
streams.spawn(Self::execute_sse_stream(
Self::raw_sse_to_jsonrpc(stream),
sse_worker_tx.clone(),
Expand All @@ -792,6 +897,10 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
}
Err(e) => Err(e),
Ok(StreamableHttpPostResponse::Accepted) => {
Self::mark_stream_response_pending(
&mut pending_stream_response_ids,
request_id,
);
tracing::trace!("client message accepted");
Ok(())
}
Expand All @@ -800,6 +909,10 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
Ok(())
}
Ok(StreamableHttpPostResponse::Sse(stream, ..)) => {
Self::mark_stream_response_pending(
&mut pending_stream_response_ids,
request_id,
);
streams.spawn(Self::execute_sse_stream(
Self::raw_sse_to_jsonrpc(stream),
sse_worker_tx.clone(),
Expand All @@ -813,6 +926,10 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
let _ = responder.send(send_result);
}
Event::ServerMessage(json_rpc_message) => {
Self::clear_stream_response_pending(
&mut pending_stream_response_ids,
&json_rpc_message,
);
// send the message to the handler
if let Err(e) = context.send_to_handler(json_rpc_message).await {
break 'main_loop Err(e);
Expand Down
Loading