From c0d1bb6a198de8fce29ae0e01377528cd9f1f231 Mon Sep 17 00:00:00 2001 From: acrognale-oai Date: Fri, 22 May 2026 12:21:14 -0400 Subject: [PATCH] feat(exec-server): forward runtime install lifecycle remotely --- codex-rs/exec-server/src/client.rs | 265 +++++++++++++++++- codex-rs/exec-server/src/environment.rs | 14 +- codex-rs/exec-server/src/protocol.rs | 2 + codex-rs/exec-server/src/rpc.rs | 52 +++- codex-rs/exec-server/src/server/handler.rs | 106 ++++++- .../exec-server/src/server/handler/tests.rs | 7 + codex-rs/exec-server/src/server/processor.rs | 14 +- codex-rs/exec-server/src/server/registry.rs | 13 +- 8 files changed, 446 insertions(+), 27 deletions(-) diff --git a/codex-rs/exec-server/src/client.rs b/codex-rs/exec-server/src/client.rs index cbe6a0fc7971..f2ef5b6c7694 100644 --- a/codex-rs/exec-server/src/client.rs +++ b/codex-rs/exec-server/src/client.rs @@ -8,7 +8,9 @@ use std::time::Duration; use arc_swap::ArcSwap; use codex_app_server_protocol::JSONRPCNotification; +use codex_app_server_protocol::RuntimeInstallCancelResponse; use codex_app_server_protocol::RuntimeInstallParams; +use codex_app_server_protocol::RuntimeInstallProgressNotification; use codex_app_server_protocol::RuntimeInstallResponse; use codex_utils_absolute_path::AbsolutePathBuf; use futures::FutureExt; @@ -18,6 +20,7 @@ use tokio::sync::Mutex; use tokio::sync::Semaphore; use tokio::sync::mpsc; use tokio::sync::watch; +use tokio_util::sync::CancellationToken; use tokio::time::timeout; use tracing::debug; @@ -72,7 +75,9 @@ use crate::protocol::INITIALIZED_METHOD; use crate::protocol::InitializeParams; use crate::protocol::InitializeResponse; use crate::protocol::ProcessOutputChunk; +use crate::protocol::RUNTIME_INSTALL_CANCEL_METHOD; use crate::protocol::RUNTIME_INSTALL_METHOD; +use crate::protocol::RUNTIME_INSTALL_PROGRESS_METHOD; use crate::protocol::ReadParams; use crate::protocol::ReadResponse; use crate::protocol::TerminateParams; @@ -82,6 +87,7 @@ use crate::protocol::WriteResponse; use crate::rpc::RpcCallError; use crate::rpc::RpcClient; use crate::rpc::RpcClientEvent; +use crate::rpc::RpcPendingResponse; pub(crate) mod http_client; @@ -178,6 +184,8 @@ struct Inner { http_body_stream_failures: ArcSwap>, http_body_streams_write_lock: Mutex<()>, http_body_stream_next_id: AtomicU64, + runtime_install_progress: + StdMutex>>, session_id: std::sync::RwLock>, codex_home: std::sync::RwLock>, reader_task: tokio::task::JoinHandle<()>, @@ -452,6 +460,47 @@ impl ExecServerClient { self.call(RUNTIME_INSTALL_METHOD, ¶ms).await } + pub(crate) async fn runtime_install_with_progress( + &self, + params: RuntimeInstallParams, + progress: mpsc::UnboundedSender, + cancellation: CancellationToken, + ) -> Result { + let _progress_guard = self.route_runtime_install_progress(progress)?; + let install = self.start_call(RUNTIME_INSTALL_METHOD, ¶ms).await?; + let install = self.finish_call(install); + tokio::pin!(install); + tokio::select! { + response = &mut install => response, + _ = cancellation.cancelled() => { + let _: RuntimeInstallCancelResponse = self + .call(RUNTIME_INSTALL_CANCEL_METHOD, &serde_json::json!({})) + .await?; + install.await + } + } + } + + fn route_runtime_install_progress( + &self, + progress: mpsc::UnboundedSender, + ) -> Result { + let mut active = self + .inner + .runtime_install_progress + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + if active.is_some() { + return Err(ExecServerError::Protocol( + "runtime install progress receiver is already active".to_string(), + )); + } + *active = Some(progress); + Ok(RuntimeInstallProgressGuard { + inner: Arc::clone(&self.inner), + }) + } + pub(crate) async fn register_session( &self, process_id: &ProcessId, @@ -537,6 +586,7 @@ impl ExecServerClient { http_body_stream_failures: ArcSwap::from_pointee(HashMap::new()), http_body_streams_write_lock: Mutex::new(()), http_body_stream_next_id: AtomicU64::new(1), + runtime_install_progress: StdMutex::new(None), session_id: std::sync::RwLock::new(None), codex_home: std::sync::RwLock::new(None), reader_task, @@ -560,6 +610,18 @@ impl ExecServerClient { where P: serde::Serialize, T: serde::de::DeserializeOwned, + { + let response = self.start_call(method, params).await?; + self.finish_call(response).await + } + + async fn start_call

( + &self, + method: &str, + params: &P, + ) -> Result + where + P: serde::Serialize, { // Reject new work before allocating a JSON-RPC request id. MCP tool // calls, process writes, and fs operations all pass through here, so @@ -568,7 +630,17 @@ impl ExecServerClient { return Err(error); } - match self.inner.client.call(method, params).await { + match self.inner.client.start_call(method, params).await { + Ok(response) => Ok(response), + Err(error) => Err(ExecServerError::from(error)), + } + } + + async fn finish_call(&self, response: RpcPendingResponse) -> Result + where + T: serde::de::DeserializeOwned, + { + match response.response().await { Ok(response) => Ok(response), Err(error) => { let error = ExecServerError::from(error); @@ -868,6 +940,20 @@ async fn fail_all_sessions(inner: &Arc, message: String) { } } +struct RuntimeInstallProgressGuard { + inner: Arc, +} + +impl Drop for RuntimeInstallProgressGuard { + fn drop(&mut self) { + self.inner + .runtime_install_progress + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .take(); + } +} + /// Fails all in-flight work that depends on the shared JSON-RPC transport. async fn fail_all_in_flight_work(inner: &Arc, message: String) { fail_all_sessions(inner, message.clone()).await; @@ -929,6 +1015,18 @@ async fn handle_server_notification( .handle_http_body_delta_notification(notification.params) .await?; } + RUNTIME_INSTALL_PROGRESS_METHOD => { + let progress: RuntimeInstallProgressNotification = + serde_json::from_value(notification.params.unwrap_or(Value::Null))?; + let sender = inner + .runtime_install_progress + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .clone(); + if let Some(sender) = sender { + let _ = sender.send(progress); + } + } other => { debug!("ignoring unknown exec-server notification: {other}"); } @@ -938,9 +1036,17 @@ async fn handle_server_notification( #[cfg(test)] mod tests { + use codex_app_server_protocol::JSONRPCError; + use codex_app_server_protocol::JSONRPCErrorError; use codex_app_server_protocol::JSONRPCMessage; use codex_app_server_protocol::JSONRPCNotification; use codex_app_server_protocol::JSONRPCResponse; + use codex_app_server_protocol::RuntimeInstallCancelResponse; + use codex_app_server_protocol::RuntimeInstallCancelStatus; + use codex_app_server_protocol::RuntimeInstallManifestParams; + use codex_app_server_protocol::RuntimeInstallParams; + use codex_app_server_protocol::RuntimeInstallProgressNotification; + use codex_app_server_protocol::RuntimeInstallProgressPhase; use codex_utils_absolute_path::AbsolutePathBuf; use futures::SinkExt; use futures::StreamExt; @@ -967,9 +1073,11 @@ mod tests { use tokio_tungstenite::WebSocketStream; use tokio_tungstenite::accept_async; use tokio_tungstenite::tungstenite::Message; + use tokio_util::sync::CancellationToken; use super::ExecServerClient; use super::ExecServerClientConnectOptions; + use super::ExecServerError; use super::LazyRemoteExecServerClient; use crate::ProcessId; #[cfg(not(windows))] @@ -990,6 +1098,9 @@ mod tests { use crate::protocol::INITIALIZED_METHOD; use crate::protocol::InitializeResponse; use crate::protocol::ProcessOutputChunk; + use crate::protocol::RUNTIME_INSTALL_CANCEL_METHOD; + use crate::protocol::RUNTIME_INSTALL_METHOD; + use crate::protocol::RUNTIME_INSTALL_PROGRESS_METHOD; async fn read_jsonrpc_line(lines: &mut tokio::io::Lines>) -> JSONRPCMessage where @@ -1014,6 +1125,158 @@ mod tests { .expect("json-rpc line should write"); } + #[tokio::test] + async fn runtime_install_progress_and_cancel_are_forwarded_over_remote_transport() { + let (client_stdin, server_reader) = duplex(1 << 20); + let (mut server_writer, client_stdout) = duplex(1 << 20); + let server = tokio::spawn(async move { + let mut lines = BufReader::new(server_reader).lines(); + let initialize = read_jsonrpc_line(&mut lines).await; + let initialize_request = match initialize { + JSONRPCMessage::Request(request) if request.method == INITIALIZE_METHOD => request, + other => panic!("expected initialize request, got {other:?}"), + }; + write_jsonrpc_line( + &mut server_writer, + JSONRPCMessage::Response(JSONRPCResponse { + id: initialize_request.id, + result: serde_json::to_value(InitializeResponse { + session_id: "runtime-session".to_string(), + codex_home: AbsolutePathBuf::try_from( + std::env::current_dir().expect("current dir"), + ) + .expect("absolute current dir"), + }) + .expect("initialize response should serialize"), + }), + ) + .await; + let initialized = read_jsonrpc_line(&mut lines).await; + assert!(matches!( + initialized, + JSONRPCMessage::Notification(JSONRPCNotification { method, .. }) + if method == INITIALIZED_METHOD + )); + + let install = read_jsonrpc_line(&mut lines).await; + let install_request_id = match install { + JSONRPCMessage::Request(request) if request.method == RUNTIME_INSTALL_METHOD => { + request.id + } + other => panic!("expected runtime install request, got {other:?}"), + }; + write_jsonrpc_line( + &mut server_writer, + JSONRPCMessage::Notification(JSONRPCNotification { + method: RUNTIME_INSTALL_PROGRESS_METHOD.to_string(), + params: Some( + serde_json::to_value(RuntimeInstallProgressNotification { + bundle_version: Some("runtime-test".to_string()), + downloaded_bytes: Some(64), + phase: RuntimeInstallProgressPhase::Downloading, + total_bytes: Some(128), + }) + .expect("progress should serialize"), + ), + }), + ) + .await; + + let cancel = read_jsonrpc_line(&mut lines).await; + let cancel_request_id = match cancel { + JSONRPCMessage::Request(request) + if request.method == RUNTIME_INSTALL_CANCEL_METHOD => + { + request.id + } + other => panic!("expected runtime install cancel request, got {other:?}"), + }; + write_jsonrpc_line( + &mut server_writer, + JSONRPCMessage::Response(JSONRPCResponse { + id: cancel_request_id, + result: serde_json::to_value(RuntimeInstallCancelResponse { + status: RuntimeInstallCancelStatus::Canceled, + }) + .expect("cancel response should serialize"), + }), + ) + .await; + write_jsonrpc_line( + &mut server_writer, + JSONRPCMessage::Error(JSONRPCError { + id: install_request_id, + error: JSONRPCErrorError { + code: -32603, + data: None, + message: "runtime install canceled".to_string(), + }, + }), + ) + .await; + }); + + let client = ExecServerClient::connect( + JsonRpcConnection::from_stdio( + client_stdout, + client_stdin, + "runtime-progress-test".to_string(), + ), + ExecServerClientConnectOptions::default(), + ) + .await + .expect("client should initialize"); + let cancellation = CancellationToken::new(); + let cancellation_for_install = cancellation.clone(); + let (progress_tx, mut progress_rx) = mpsc::unbounded_channel(); + let request = tokio::spawn(async move { + client + .runtime_install_with_progress( + RuntimeInstallParams { + environment_id: None, + manifest: Box::new(RuntimeInstallManifestParams { + archive_name: None, + archive_sha256: "0".repeat(64), + archive_size_bytes: None, + archive_url: "https://example.test/runtime.zip".to_string(), + bundle_format_version: None, + bundle_version: Some("runtime-test".to_string()), + format: Some("zip".to_string()), + runtime_root_directory_name: None, + }), + release: "test".to_string(), + }, + progress_tx, + cancellation_for_install, + ) + .await + }); + + let progress = timeout(Duration::from_secs(1), progress_rx.recv()) + .await + .expect("progress should arrive before timeout") + .expect("progress stream should remain open"); + assert_eq!( + progress, + RuntimeInstallProgressNotification { + bundle_version: Some("runtime-test".to_string()), + downloaded_bytes: Some(64), + phase: RuntimeInstallProgressPhase::Downloading, + total_bytes: Some(128), + } + ); + cancellation.cancel(); + let error = request + .await + .expect("request task should join") + .expect_err("canceled runtime install should fail"); + assert!(matches!( + error, + ExecServerError::Server { message, .. } if message == "runtime install canceled" + )); + server.await.expect("server task should finish"); + } + async fn accept_websocket(listener: &TcpListener) -> WebSocketStream { let (stream, _) = listener.accept().await.expect("listener should accept"); accept_async(stream) diff --git a/codex-rs/exec-server/src/environment.rs b/codex-rs/exec-server/src/environment.rs index 15e5a13bae2a..3ae7bc3e0b4c 100644 --- a/codex-rs/exec-server/src/environment.rs +++ b/codex-rs/exec-server/src/environment.rs @@ -358,11 +358,15 @@ impl RuntimeInstaller { }, RuntimeInstaller::Remote(client) => { let client = client.get().await.map_err(exec_server_error_to_jsonrpc)?; - tokio::select! { - _ = cancellation.cancelled() => Err(internal_error("runtime install canceled")), - response = client.runtime_install(params) => { - response.map_err(exec_server_error_to_jsonrpc) - } + match progress { + Some(progress) => client + .runtime_install_with_progress(params, progress, cancellation) + .await + .map_err(exec_server_error_to_jsonrpc), + None => client + .runtime_install(params) + .await + .map_err(exec_server_error_to_jsonrpc), } } } diff --git a/codex-rs/exec-server/src/protocol.rs b/codex-rs/exec-server/src/protocol.rs index 5ecd6a34d672..116b47698ccc 100644 --- a/codex-rs/exec-server/src/protocol.rs +++ b/codex-rs/exec-server/src/protocol.rs @@ -31,6 +31,8 @@ pub const HTTP_REQUEST_METHOD: &str = "http/request"; /// JSON-RPC notification method for streamed executor HTTP response bodies. pub const HTTP_REQUEST_BODY_DELTA_METHOD: &str = "http/request/bodyDelta"; pub const RUNTIME_INSTALL_METHOD: &str = "runtime/install"; +pub const RUNTIME_INSTALL_CANCEL_METHOD: &str = "runtime/install/cancel"; +pub const RUNTIME_INSTALL_PROGRESS_METHOD: &str = "runtime/install/progress"; #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(transparent)] diff --git a/codex-rs/exec-server/src/rpc.rs b/codex-rs/exec-server/src/rpc.rs index 4195bc18335d..0556e3161442 100644 --- a/codex-rs/exec-server/src/rpc.rs +++ b/codex-rs/exec-server/src/rpc.rs @@ -49,6 +49,22 @@ pub(crate) enum RpcClientEvent { Disconnected { reason: Option }, } +pub(crate) struct RpcPendingResponse { + response_rx: oneshot::Receiver>, +} + +impl RpcPendingResponse { + pub(crate) async fn response(self) -> Result + where + T: DeserializeOwned, + { + let result: Result = + self.response_rx.await.map_err(|_| RpcCallError::Closed)?; + let response = result?; + serde_json::from_value(response).map_err(RpcCallError::Json) + } +} + #[derive(Debug, Clone, PartialEq)] pub(crate) enum RpcServerOutboundMessage { Response { @@ -83,6 +99,17 @@ impl RpcNotificationSender { .map_err(|_| internal_error("RPC connection closed while sending response")) } + pub(crate) async fn error( + &self, + request_id: RequestId, + error: JSONRPCErrorError, + ) -> Result<(), JSONRPCErrorError> { + self.outgoing_tx + .send(RpcServerOutboundMessage::Error { request_id, error }) + .await + .map_err(|_| internal_error("RPC connection closed while sending error")) + } + pub(crate) async fn notify( &self, method: &str, @@ -320,6 +347,17 @@ impl RpcClient { where P: Serialize, T: DeserializeOwned, + { + self.start_call(method, params).await?.response().await + } + + pub(crate) async fn start_call

( + &self, + method: &str, + params: &P, + ) -> Result + where + P: Serialize, { let request_id = RequestId::Integer(self.next_request_id.fetch_add(1, Ordering::SeqCst)); let (response_tx, response_rx) = oneshot::channel(); @@ -356,19 +394,7 @@ impl RpcClient { return Err(RpcCallError::Closed); } - // Do not race in-flight requests directly against the transport-close - // watch value. The connection reader receives JSON-RPC messages and - // the terminal disconnect event on one ordered queue, then drains any - // still-pending requests. Awaiting this receiver preserves that order: - // responses already read before EOF still win, and truly pending calls - // are failed once the reader observes the disconnect. - let result: Result = - response_rx.await.map_err(|_| RpcCallError::Closed)?; - let response = match result { - Ok(response) => response, - Err(error) => return Err(error), - }; - serde_json::from_value(response).map_err(RpcCallError::Json) + Ok(RpcPendingResponse { response_rx }) } #[cfg(test)] diff --git a/codex-rs/exec-server/src/server/handler.rs b/codex-rs/exec-server/src/server/handler.rs index 5dd6ea537a13..8a93eed2cd62 100644 --- a/codex-rs/exec-server/src/server/handler.rs +++ b/codex-rs/exec-server/src/server/handler.rs @@ -5,8 +5,9 @@ use std::sync::atomic::Ordering; use codex_app_server_protocol::JSONRPCErrorError; use codex_app_server_protocol::RequestId; +use codex_app_server_protocol::RuntimeInstallCancelResponse; +use codex_app_server_protocol::RuntimeInstallCancelStatus; use codex_app_server_protocol::RuntimeInstallParams; -use codex_app_server_protocol::RuntimeInstallResponse; use serde_json::to_value; use std::collections::HashSet; use tokio::sync::Mutex; @@ -35,6 +36,7 @@ use crate::protocol::FsWriteFileResponse; use crate::protocol::HttpRequestParams; use crate::protocol::InitializeParams; use crate::protocol::InitializeResponse; +use crate::protocol::RUNTIME_INSTALL_PROGRESS_METHOD; use crate::protocol::ReadParams; use crate::protocol::ReadResponse; use crate::protocol::TerminateParams; @@ -51,6 +53,7 @@ use crate::server::session_registry::SessionRegistry; pub(crate) struct ExecServerHandler { session_registry: Arc, + active_runtime_install: Arc>>, notifications: RpcNotificationSender, session: StdMutex>, active_body_stream_ids: Mutex>, @@ -64,11 +67,13 @@ pub(crate) struct ExecServerHandler { impl ExecServerHandler { pub(crate) fn new( session_registry: Arc, + active_runtime_install: Arc>>, notifications: RpcNotificationSender, runtime_paths: ExecServerRuntimePaths, ) -> Self { Self { session_registry, + active_runtime_install, notifications, session: StdMutex::new(None), active_body_stream_ids: Mutex::new(HashSet::new()), @@ -270,11 +275,91 @@ impl ExecServerHandler { } pub(crate) async fn runtime_install( - &self, + self: &Arc, + request_id: RequestId, params: RuntimeInstallParams, - ) -> Result { + ) -> Result<(), JSONRPCErrorError> { + self.require_initialized_for("runtime")?; + let (cancellation, active_install) = self.begin_runtime_install()?; + let (progress_tx, mut progress_rx) = tokio::sync::mpsc::unbounded_channel(); + let notifications = self.notifications.clone(); + self.background_tasks.spawn(async move { + let install = crate::runtime_install::install_runtime_with_progress( + params, + progress_tx, + cancellation, + ); + let progress_forwarder = async { + while let Some(progress) = progress_rx.recv().await { + if notifications + .notify(RUNTIME_INSTALL_PROGRESS_METHOD, &progress) + .await + .is_err() + { + break; + } + } + }; + let (result, ()) = tokio::join!(install, progress_forwarder); + let send_result = match result { + Ok(response) => match to_value(response) { + Ok(result) => notifications.response(request_id, result).await, + Err(error) => { + notifications + .error(request_id, internal_error(error.to_string())) + .await + } + }, + Err(error) => notifications.error(request_id, error).await, + }; + if let Err(error) = send_result { + tracing::warn!("failed to send runtime install result: {error:?}"); + } + drop(active_install); + }); + Ok(()) + } + + pub(crate) async fn cancel_runtime_install( + &self, + ) -> Result { self.require_initialized_for("runtime")?; - crate::runtime_install::install_runtime(params).await + let status = { + let active_install = self + .active_runtime_install + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + match active_install.as_ref() { + Some(cancellation) => { + cancellation.cancel(); + RuntimeInstallCancelStatus::Canceled + } + None => RuntimeInstallCancelStatus::NotFound, + } + }; + Ok(RuntimeInstallCancelResponse { status }) + } + + fn begin_runtime_install( + &self, + ) -> Result<(CancellationToken, ActiveRuntimeInstallGuard), JSONRPCErrorError> { + let cancellation = CancellationToken::new(); + let mut active_install = self + .active_runtime_install + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + if active_install.is_some() { + return Err(invalid_request( + "runtime install is already in progress".to_string(), + )); + } + *active_install = Some(cancellation.clone()); + Ok(( + cancellation, + ActiveRuntimeInstallGuard { + active_runtime_install: Arc::clone(&self.active_runtime_install), + }, + )) } fn require_initialized_for( @@ -356,5 +441,18 @@ impl ExecServerHandler { } } +struct ActiveRuntimeInstallGuard { + active_runtime_install: Arc>>, +} + +impl Drop for ActiveRuntimeInstallGuard { + fn drop(&mut self) { + self.active_runtime_install + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .take(); + } +} + #[cfg(test)] mod tests; diff --git a/codex-rs/exec-server/src/server/handler/tests.rs b/codex-rs/exec-server/src/server/handler/tests.rs index 6b632fe8909b..581bc0d8a20d 100644 --- a/codex-rs/exec-server/src/server/handler/tests.rs +++ b/codex-rs/exec-server/src/server/handler/tests.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; use std::sync::Arc; +use std::sync::Mutex as StdMutex; use std::time::Duration; use pretty_assertions::assert_eq; @@ -80,6 +81,7 @@ async fn initialized_handler() -> Arc { let registry = SessionRegistry::new(); let handler = Arc::new(ExecServerHandler::new( registry, + Arc::new(StdMutex::new(None)), RpcNotificationSender::new(outgoing_tx), test_runtime_paths(), )); @@ -158,6 +160,7 @@ async fn long_poll_read_fails_after_session_resume() { let registry = SessionRegistry::new(); let first_handler = Arc::new(ExecServerHandler::new( Arc::clone(®istry), + Arc::new(StdMutex::new(None)), RpcNotificationSender::new(first_tx), test_runtime_paths(), )); @@ -198,6 +201,7 @@ async fn long_poll_read_fails_after_session_resume() { let (second_tx, _second_rx) = mpsc::channel(16); let second_handler = Arc::new(ExecServerHandler::new( registry, + Arc::new(StdMutex::new(None)), RpcNotificationSender::new(second_tx), test_runtime_paths(), )); @@ -231,6 +235,7 @@ async fn active_session_resume_is_rejected() { let registry = SessionRegistry::new(); let first_handler = Arc::new(ExecServerHandler::new( Arc::clone(®istry), + Arc::new(StdMutex::new(None)), RpcNotificationSender::new(first_tx), test_runtime_paths(), )); @@ -245,6 +250,7 @@ async fn active_session_resume_is_rejected() { let (second_tx, _second_rx) = mpsc::channel(16); let second_handler = Arc::new(ExecServerHandler::new( registry, + Arc::new(StdMutex::new(None)), RpcNotificationSender::new(second_tx), test_runtime_paths(), )); @@ -273,6 +279,7 @@ async fn output_and_exit_are_retained_after_notification_receiver_closes() { let (outgoing_tx, outgoing_rx) = mpsc::channel(16); let handler = Arc::new(ExecServerHandler::new( SessionRegistry::new(), + Arc::new(StdMutex::new(None)), RpcNotificationSender::new(outgoing_tx), test_runtime_paths(), )); diff --git a/codex-rs/exec-server/src/server/processor.rs b/codex-rs/exec-server/src/server/processor.rs index 6fc0723f0c1e..d4e50ae61909 100644 --- a/codex-rs/exec-server/src/server/processor.rs +++ b/codex-rs/exec-server/src/server/processor.rs @@ -1,6 +1,8 @@ use std::sync::Arc; +use std::sync::Mutex as StdMutex; use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; use tracing::debug; use tracing::warn; @@ -20,6 +22,7 @@ use crate::server::session_registry::SessionRegistry; #[derive(Clone)] pub(crate) struct ConnectionProcessor { session_registry: Arc, + active_runtime_install: Arc>>, runtime_paths: ExecServerRuntimePaths, } @@ -27,6 +30,7 @@ impl ConnectionProcessor { pub(crate) fn new(runtime_paths: ExecServerRuntimePaths) -> Self { Self { session_registry: SessionRegistry::new(), + active_runtime_install: Arc::new(StdMutex::new(None)), runtime_paths, } } @@ -35,6 +39,7 @@ impl ConnectionProcessor { run_connection( connection, Arc::clone(&self.session_registry), + Arc::clone(&self.active_runtime_install), self.runtime_paths.clone(), ) .await; @@ -44,6 +49,7 @@ impl ConnectionProcessor { async fn run_connection( connection: JsonRpcConnection, session_registry: Arc, + active_runtime_install: Arc>>, runtime_paths: ExecServerRuntimePaths, ) { let router = Arc::new(build_router()); @@ -59,6 +65,7 @@ async fn run_connection( let notifications = RpcNotificationSender::new(outgoing_tx.clone()); let handler = Arc::new(ExecServerHandler::new( session_registry, + active_runtime_install, notifications, runtime_paths, )); @@ -322,7 +329,12 @@ mod tests { let (server_writer, client_reader) = duplex(1 << 20); let connection = JsonRpcConnection::from_stdio(server_reader, server_writer, label.to_string()); - let task = tokio::spawn(run_connection(connection, registry, test_runtime_paths())); + let task = tokio::spawn(run_connection( + connection, + registry, + Arc::new(std::sync::Mutex::new(None)), + test_runtime_paths(), + )); (client_writer, BufReader::new(client_reader).lines(), task) } diff --git a/codex-rs/exec-server/src/server/registry.rs b/codex-rs/exec-server/src/server/registry.rs index eafeffbcb00d..5755f2e547c7 100644 --- a/codex-rs/exec-server/src/server/registry.rs +++ b/codex-rs/exec-server/src/server/registry.rs @@ -24,6 +24,7 @@ use crate::protocol::HttpRequestParams; use crate::protocol::INITIALIZE_METHOD; use crate::protocol::INITIALIZED_METHOD; use crate::protocol::InitializeParams; +use crate::protocol::RUNTIME_INSTALL_CANCEL_METHOD; use crate::protocol::RUNTIME_INSTALL_METHOD; use crate::protocol::ReadParams; use crate::protocol::TerminateParams; @@ -116,10 +117,16 @@ pub(crate) fn build_router() -> RpcRouter { handler.fs_copy(params).await }, ); - router.request( + router.request_with_id( RUNTIME_INSTALL_METHOD, - |handler: Arc, params: RuntimeInstallParams| async move { - handler.runtime_install(params).await + |handler: Arc, request_id, params: RuntimeInstallParams| async move { + handler.runtime_install(request_id, params).await + }, + ); + router.request( + RUNTIME_INSTALL_CANCEL_METHOD, + |handler: Arc, _params: serde_json::Value| async move { + handler.cancel_runtime_install().await }, ); router