From fc2d932dbcd8b885768d1eafc611a5df9777f4fb Mon Sep 17 00:00:00 2001 From: Dale Seo <5466341+DaleSeo@users.noreply.github.com> Date: Wed, 17 Jun 2026 15:36:25 -0400 Subject: [PATCH] fix: align progress timeout token --- crates/rmcp/src/service.rs | 40 ++++++++++--------- .../tests/test_request_timeout_progress.rs | 32 ++++++++++----- 2 files changed, 42 insertions(+), 30 deletions(-) diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index 9bc0fc97..70045d11 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -348,7 +348,6 @@ pub struct RequestHandle { pub peer: Peer, pub id: RequestId, pub progress_token: ProgressToken, - progress_timeout_watchers: ProgressTimeoutWatchers, progress_reset_rx: Option>, } @@ -416,8 +415,10 @@ impl RequestHandle { max_total_timeout: Option, reset_timeout_on_progress: bool, ) -> Result { - let mut idle_sleep = timeout.map(tokio::time::sleep).map(Box::pin); - let mut max_total_sleep = max_total_timeout.map(tokio::time::sleep).map(Box::pin); + let mut idle_sleep = + timeout.map(|timeout| (timeout, Box::pin(tokio::time::sleep(timeout)))); + let mut max_total_sleep = + max_total_timeout.map(|timeout| (timeout, Box::pin(tokio::time::sleep(timeout)))); loop { tokio::select! { @@ -427,32 +428,34 @@ impl RequestHandle { return response.map_err(|_e| ServiceError::TransportClosed)?; } _ = async { - if let Some(sleep) = idle_sleep.as_mut() { + if let Some((_, sleep)) = idle_sleep.as_mut() { sleep.as_mut().await; } }, if idle_sleep.is_some() => { - let timeout = timeout.expect("idle timeout exists when idle sleep exists"); - self.send_timeout_cancel_notification(Self::REQUEST_TIMEOUT_REASON).await; - return Err(ServiceError::Timeout { timeout }); + if let Some((timeout, _)) = idle_sleep.as_ref() { + self.send_timeout_cancel_notification(Self::REQUEST_TIMEOUT_REASON).await; + return Err(ServiceError::Timeout { timeout: *timeout }); + } } _ = async { - if let Some(sleep) = max_total_sleep.as_mut() { + if let Some((_, sleep)) = max_total_sleep.as_mut() { sleep.as_mut().await; } }, if max_total_sleep.is_some() => { - let timeout = max_total_timeout.expect("max total timeout exists when max total sleep exists"); - self.send_timeout_cancel_notification(Self::REQUEST_MAX_TOTAL_TIMEOUT_REASON).await; - return Err(ServiceError::Timeout { timeout }); + if let Some((timeout, _)) = max_total_sleep.as_ref() { + self.send_timeout_cancel_notification(Self::REQUEST_MAX_TOTAL_TIMEOUT_REASON).await; + return Err(ServiceError::Timeout { timeout: *timeout }); + } } progress = async { match self.progress_reset_rx.as_mut() { Some(rx) => rx.recv().await, None => None, } - }, if reset_timeout_on_progress && timeout.is_some() && self.progress_reset_rx.is_some() => { + }, if reset_timeout_on_progress && idle_sleep.is_some() && self.progress_reset_rx.is_some() => { if progress.is_some() { - if let (Some(timeout), Some(sleep)) = (timeout, idle_sleep.as_mut()) { - sleep.as_mut().reset(tokio::time::Instant::now() + timeout); + if let Some((timeout, sleep)) = idle_sleep.as_mut() { + sleep.as_mut().reset(tokio::time::Instant::now() + *timeout); } } } @@ -463,7 +466,7 @@ impl RequestHandle { /// Cancel this request pub async fn cancel(self, reason: Option) -> Result<(), ServiceError> { Self::cleanup_progress_timeout_watcher( - &self.progress_timeout_watchers, + &self.peer.progress_timeout_watchers, &self.progress_token, self.progress_reset_rx.is_some(), ) @@ -617,12 +620,12 @@ impl Peer { ) -> Result, ServiceError> { let id = self.request_id_provider.next_request_id(); let progress_token = self.progress_token_provider.next_progress_token(); - request - .get_meta_mut() - .set_progress_token(progress_token.clone()); if let Some(meta) = options.meta.clone() { request.get_meta_mut().extend(meta); } + request + .get_meta_mut() + .set_progress_token(progress_token.clone()); let (responder, receiver) = tokio::sync::oneshot::channel(); let progress_reset_rx = if options.reset_timeout_on_progress && options.timeout.is_some() { let (sender, receiver) = mpsc::channel(1); @@ -658,7 +661,6 @@ impl Peer { progress_token, options, peer: self.clone(), - progress_timeout_watchers: self.progress_timeout_watchers.clone(), progress_reset_rx, }) } diff --git a/crates/rmcp/tests/test_request_timeout_progress.rs b/crates/rmcp/tests/test_request_timeout_progress.rs index ff3f5369..af62a466 100644 --- a/crates/rmcp/tests/test_request_timeout_progress.rs +++ b/crates/rmcp/tests/test_request_timeout_progress.rs @@ -10,7 +10,10 @@ use std::{ use rmcp::{ ClientHandler, Peer, RoleServer, ServiceError, ServiceExt, - model::{CallToolRequestParams, ClientRequest, Meta, ProgressNotificationParam, Request}, + model::{ + CallToolRequestParams, ClientRequest, Meta, NumberOrString, ProgressNotificationParam, + ProgressToken, Request, + }, service::PeerRequestOptions, tool, tool_router, }; @@ -32,12 +35,6 @@ impl ClientHandler for ProgressCountingClient { struct ProgressTimeoutServer; -impl ProgressTimeoutServer { - fn new() -> Self { - Self - } -} - #[tool_router(server_handler)] impl ProgressTimeoutServer { #[tool] @@ -83,9 +80,7 @@ impl ProgressTimeoutServer { tokio::time::sleep(Duration::from_millis(50)).await; let _ = client .notify_progress(ProgressNotificationParam { - progress_token: rmcp::model::ProgressToken( - rmcp::model::NumberOrString::Number(999_999), - ), + progress_token: ProgressToken(NumberOrString::Number(999_999)), progress: step as f64, total: Some(4.0), message: Some("unrelated".into()), @@ -99,7 +94,7 @@ impl ProgressTimeoutServer { async fn start_pair() -> anyhow::Result> { - let server = ProgressTimeoutServer::new(); + let server = ProgressTimeoutServer; let client = ProgressCountingClient::default(); let (transport_server, transport_client) = tokio::io::duplex(4096); @@ -172,6 +167,21 @@ async fn matching_progress_resets_timeout_when_enabled() -> anyhow::Result<()> { Ok(()) } +#[tokio::test] +async fn generated_progress_token_overrides_option_meta_token() -> anyhow::Result<()> { + let client = start_pair().await?; + let mut options = + PeerRequestOptions::with_timeout(Duration::from_millis(75)).reset_timeout_on_progress(); + options.meta = Some(Meta::with_progress_token(ProgressToken( + NumberOrString::Number(999_999), + ))); + + let result = call_tool_with_options(&client, "delayed_with_progress", options).await; + + assert!(result.is_ok()); + Ok(()) +} + #[tokio::test] async fn max_total_timeout_wins_over_progress_reset() -> anyhow::Result<()> { let client = start_pair().await?;