Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 264 additions & 1 deletion codex-rs/exec-server/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ use std::time::Duration;

use arc_swap::ArcSwap;
use codex_app_server_protocol::JSONRPCNotification;
use codex_app_server_protocol::RuntimeInstallCancelResponse;
use codex_app_server_protocol::RuntimeInstallParams;
use codex_app_server_protocol::RuntimeInstallProgressNotification;
use codex_app_server_protocol::RuntimeInstallResponse;
use codex_utils_absolute_path::AbsolutePathBuf;
use futures::FutureExt;
Expand All @@ -18,6 +20,7 @@ use tokio::sync::Mutex;
use tokio::sync::Semaphore;
use tokio::sync::mpsc;
use tokio::sync::watch;
use tokio_util::sync::CancellationToken;

use tokio::time::timeout;
use tracing::debug;
Expand Down Expand Up @@ -72,7 +75,9 @@ use crate::protocol::INITIALIZED_METHOD;
use crate::protocol::InitializeParams;
use crate::protocol::InitializeResponse;
use crate::protocol::ProcessOutputChunk;
use crate::protocol::RUNTIME_INSTALL_CANCEL_METHOD;
use crate::protocol::RUNTIME_INSTALL_METHOD;
use crate::protocol::RUNTIME_INSTALL_PROGRESS_METHOD;
use crate::protocol::ReadParams;
use crate::protocol::ReadResponse;
use crate::protocol::TerminateParams;
Expand All @@ -82,6 +87,7 @@ use crate::protocol::WriteResponse;
use crate::rpc::RpcCallError;
use crate::rpc::RpcClient;
use crate::rpc::RpcClientEvent;
use crate::rpc::RpcPendingResponse;

pub(crate) mod http_client;

Expand Down Expand Up @@ -178,6 +184,8 @@ struct Inner {
http_body_stream_failures: ArcSwap<HashMap<String, String>>,
http_body_streams_write_lock: Mutex<()>,
http_body_stream_next_id: AtomicU64,
runtime_install_progress:
StdMutex<Option<mpsc::UnboundedSender<RuntimeInstallProgressNotification>>>,
session_id: std::sync::RwLock<Option<String>>,
codex_home: std::sync::RwLock<Option<AbsolutePathBuf>>,
reader_task: tokio::task::JoinHandle<()>,
Expand Down Expand Up @@ -452,6 +460,47 @@ impl ExecServerClient {
self.call(RUNTIME_INSTALL_METHOD, &params).await
}

pub(crate) async fn runtime_install_with_progress(
&self,
params: RuntimeInstallParams,
progress: mpsc::UnboundedSender<RuntimeInstallProgressNotification>,
cancellation: CancellationToken,
) -> Result<RuntimeInstallResponse, ExecServerError> {
let _progress_guard = self.route_runtime_install_progress(progress)?;
let install = self.start_call(RUNTIME_INSTALL_METHOD, &params).await?;
let install = self.finish_call(install);
tokio::pin!(install);
tokio::select! {
response = &mut install => response,
_ = cancellation.cancelled() => {
let _: RuntimeInstallCancelResponse = self
.call(RUNTIME_INSTALL_CANCEL_METHOD, &serde_json::json!({}))
.await?;
install.await
}
}
}

fn route_runtime_install_progress(
&self,
progress: mpsc::UnboundedSender<RuntimeInstallProgressNotification>,
) -> Result<RuntimeInstallProgressGuard, ExecServerError> {
let mut active = self
.inner
.runtime_install_progress
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
if active.is_some() {
return Err(ExecServerError::Protocol(
"runtime install progress receiver is already active".to_string(),
));
}
*active = Some(progress);
Ok(RuntimeInstallProgressGuard {
inner: Arc::clone(&self.inner),
})
}

pub(crate) async fn register_session(
&self,
process_id: &ProcessId,
Expand Down Expand Up @@ -537,6 +586,7 @@ impl ExecServerClient {
http_body_stream_failures: ArcSwap::from_pointee(HashMap::new()),
http_body_streams_write_lock: Mutex::new(()),
http_body_stream_next_id: AtomicU64::new(1),
runtime_install_progress: StdMutex::new(None),
session_id: std::sync::RwLock::new(None),
codex_home: std::sync::RwLock::new(None),
reader_task,
Expand All @@ -560,6 +610,18 @@ impl ExecServerClient {
where
P: serde::Serialize,
T: serde::de::DeserializeOwned,
{
let response = self.start_call(method, params).await?;
self.finish_call(response).await
}

async fn start_call<P>(
&self,
method: &str,
params: &P,
) -> Result<RpcPendingResponse, ExecServerError>
where
P: serde::Serialize,
{
// Reject new work before allocating a JSON-RPC request id. MCP tool
// calls, process writes, and fs operations all pass through here, so
Expand All @@ -568,7 +630,17 @@ impl ExecServerClient {
return Err(error);
}

match self.inner.client.call(method, params).await {
match self.inner.client.start_call(method, params).await {
Ok(response) => Ok(response),
Err(error) => Err(ExecServerError::from(error)),
}
}

async fn finish_call<T>(&self, response: RpcPendingResponse) -> Result<T, ExecServerError>
where
T: serde::de::DeserializeOwned,
{
match response.response().await {
Ok(response) => Ok(response),
Err(error) => {
let error = ExecServerError::from(error);
Expand Down Expand Up @@ -868,6 +940,20 @@ async fn fail_all_sessions(inner: &Arc<Inner>, message: String) {
}
}

struct RuntimeInstallProgressGuard {
inner: Arc<Inner>,
}

impl Drop for RuntimeInstallProgressGuard {
fn drop(&mut self) {
self.inner
.runtime_install_progress
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.take();
}
}

/// Fails all in-flight work that depends on the shared JSON-RPC transport.
async fn fail_all_in_flight_work(inner: &Arc<Inner>, message: String) {
fail_all_sessions(inner, message.clone()).await;
Expand Down Expand Up @@ -929,6 +1015,18 @@ async fn handle_server_notification(
.handle_http_body_delta_notification(notification.params)
.await?;
}
RUNTIME_INSTALL_PROGRESS_METHOD => {
let progress: RuntimeInstallProgressNotification =
serde_json::from_value(notification.params.unwrap_or(Value::Null))?;
let sender = inner
.runtime_install_progress
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.clone();
if let Some(sender) = sender {
let _ = sender.send(progress);
}
}
other => {
debug!("ignoring unknown exec-server notification: {other}");
}
Expand All @@ -938,9 +1036,17 @@ async fn handle_server_notification(

#[cfg(test)]
mod tests {
use codex_app_server_protocol::JSONRPCError;
use codex_app_server_protocol::JSONRPCErrorError;
use codex_app_server_protocol::JSONRPCMessage;
use codex_app_server_protocol::JSONRPCNotification;
use codex_app_server_protocol::JSONRPCResponse;
use codex_app_server_protocol::RuntimeInstallCancelResponse;
use codex_app_server_protocol::RuntimeInstallCancelStatus;
use codex_app_server_protocol::RuntimeInstallManifestParams;
use codex_app_server_protocol::RuntimeInstallParams;
use codex_app_server_protocol::RuntimeInstallProgressNotification;
use codex_app_server_protocol::RuntimeInstallProgressPhase;
use codex_utils_absolute_path::AbsolutePathBuf;
use futures::SinkExt;
use futures::StreamExt;
Expand All @@ -967,9 +1073,11 @@ mod tests {
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::accept_async;
use tokio_tungstenite::tungstenite::Message;
use tokio_util::sync::CancellationToken;

use super::ExecServerClient;
use super::ExecServerClientConnectOptions;
use super::ExecServerError;
use super::LazyRemoteExecServerClient;
use crate::ProcessId;
#[cfg(not(windows))]
Expand All @@ -990,6 +1098,9 @@ mod tests {
use crate::protocol::INITIALIZED_METHOD;
use crate::protocol::InitializeResponse;
use crate::protocol::ProcessOutputChunk;
use crate::protocol::RUNTIME_INSTALL_CANCEL_METHOD;
use crate::protocol::RUNTIME_INSTALL_METHOD;
use crate::protocol::RUNTIME_INSTALL_PROGRESS_METHOD;

async fn read_jsonrpc_line<R>(lines: &mut tokio::io::Lines<BufReader<R>>) -> JSONRPCMessage
where
Expand All @@ -1014,6 +1125,158 @@ mod tests {
.expect("json-rpc line should write");
}

#[tokio::test]
async fn runtime_install_progress_and_cancel_are_forwarded_over_remote_transport() {
let (client_stdin, server_reader) = duplex(1 << 20);
let (mut server_writer, client_stdout) = duplex(1 << 20);
let server = tokio::spawn(async move {
let mut lines = BufReader::new(server_reader).lines();
let initialize = read_jsonrpc_line(&mut lines).await;
let initialize_request = match initialize {
JSONRPCMessage::Request(request) if request.method == INITIALIZE_METHOD => request,
other => panic!("expected initialize request, got {other:?}"),
};
write_jsonrpc_line(
&mut server_writer,
JSONRPCMessage::Response(JSONRPCResponse {
id: initialize_request.id,
result: serde_json::to_value(InitializeResponse {
session_id: "runtime-session".to_string(),
codex_home: AbsolutePathBuf::try_from(
std::env::current_dir().expect("current dir"),
)
.expect("absolute current dir"),
})
.expect("initialize response should serialize"),
}),
)
.await;
let initialized = read_jsonrpc_line(&mut lines).await;
assert!(matches!(
initialized,
JSONRPCMessage::Notification(JSONRPCNotification { method, .. })
if method == INITIALIZED_METHOD
));

let install = read_jsonrpc_line(&mut lines).await;
let install_request_id = match install {
JSONRPCMessage::Request(request) if request.method == RUNTIME_INSTALL_METHOD => {
request.id
}
other => panic!("expected runtime install request, got {other:?}"),
};
write_jsonrpc_line(
&mut server_writer,
JSONRPCMessage::Notification(JSONRPCNotification {
method: RUNTIME_INSTALL_PROGRESS_METHOD.to_string(),
params: Some(
serde_json::to_value(RuntimeInstallProgressNotification {
bundle_version: Some("runtime-test".to_string()),
downloaded_bytes: Some(64),
phase: RuntimeInstallProgressPhase::Downloading,
total_bytes: Some(128),
})
.expect("progress should serialize"),
),
}),
)
.await;

let cancel = read_jsonrpc_line(&mut lines).await;
let cancel_request_id = match cancel {
JSONRPCMessage::Request(request)
if request.method == RUNTIME_INSTALL_CANCEL_METHOD =>
{
request.id
}
other => panic!("expected runtime install cancel request, got {other:?}"),
};
write_jsonrpc_line(
&mut server_writer,
JSONRPCMessage::Response(JSONRPCResponse {
id: cancel_request_id,
result: serde_json::to_value(RuntimeInstallCancelResponse {
status: RuntimeInstallCancelStatus::Canceled,
})
.expect("cancel response should serialize"),
}),
)
.await;
write_jsonrpc_line(
&mut server_writer,
JSONRPCMessage::Error(JSONRPCError {
id: install_request_id,
error: JSONRPCErrorError {
code: -32603,
data: None,
message: "runtime install canceled".to_string(),
},
}),
)
.await;
});

let client = ExecServerClient::connect(
JsonRpcConnection::from_stdio(
client_stdout,
client_stdin,
"runtime-progress-test".to_string(),
),
ExecServerClientConnectOptions::default(),
)
.await
.expect("client should initialize");
let cancellation = CancellationToken::new();
let cancellation_for_install = cancellation.clone();
let (progress_tx, mut progress_rx) = mpsc::unbounded_channel();
let request = tokio::spawn(async move {
client
.runtime_install_with_progress(
RuntimeInstallParams {
environment_id: None,
manifest: Box::new(RuntimeInstallManifestParams {
archive_name: None,
archive_sha256: "0".repeat(64),
archive_size_bytes: None,
archive_url: "https://example.test/runtime.zip".to_string(),
bundle_format_version: None,
bundle_version: Some("runtime-test".to_string()),
format: Some("zip".to_string()),
runtime_root_directory_name: None,
}),
release: "test".to_string(),
},
progress_tx,
cancellation_for_install,
)
.await
});

let progress = timeout(Duration::from_secs(1), progress_rx.recv())
.await
.expect("progress should arrive before timeout")
.expect("progress stream should remain open");
assert_eq!(
progress,
RuntimeInstallProgressNotification {
bundle_version: Some("runtime-test".to_string()),
downloaded_bytes: Some(64),
phase: RuntimeInstallProgressPhase::Downloading,
total_bytes: Some(128),
}
);
cancellation.cancel();
let error = request
.await
.expect("request task should join")
.expect_err("canceled runtime install should fail");
assert!(matches!(
error,
ExecServerError::Server { message, .. } if message == "runtime install canceled"
));
server.await.expect("server task should finish");
}

async fn accept_websocket(listener: &TcpListener) -> WebSocketStream<TcpStream> {
let (stream, _) = listener.accept().await.expect("listener should accept");
accept_async(stream)
Expand Down
14 changes: 9 additions & 5 deletions codex-rs/exec-server/src/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,11 +358,15 @@ impl RuntimeInstaller {
},
RuntimeInstaller::Remote(client) => {
let client = client.get().await.map_err(exec_server_error_to_jsonrpc)?;
tokio::select! {
_ = cancellation.cancelled() => Err(internal_error("runtime install canceled")),
response = client.runtime_install(params) => {
response.map_err(exec_server_error_to_jsonrpc)
}
match progress {
Some(progress) => client
.runtime_install_with_progress(params, progress, cancellation)
.await
.map_err(exec_server_error_to_jsonrpc),
None => client
.runtime_install(params)
.await
.map_err(exec_server_error_to_jsonrpc),
}
}
}
Expand Down
Loading
Loading