Skip to content

Commit 83f28af

Browse files
committed
feat: add cross-platform SSE listener support
1 parent 34c5c23 commit 83f28af

1 file changed

Lines changed: 163 additions & 32 deletions

File tree

src/eval.rs

Lines changed: 163 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ use serde::{Deserialize, Serialize};
2828
use serde_json::{json, Value};
2929
use strip_ansi_escapes::strip;
3030
use tokio::io::{AsyncBufReadExt, BufReader};
31+
#[cfg(not(unix))]
32+
use tokio::net::TcpListener;
33+
#[cfg(unix)]
3134
use tokio::net::UnixListener;
3235
use tokio::process::Command;
3336
use tokio::sync::mpsc;
@@ -71,7 +74,7 @@ struct EvalRunnerProcess {
7174
rx: mpsc::UnboundedReceiver<EvalEvent>,
7275
sse_task: tokio::task::JoinHandle<()>,
7376
sse_connected: Arc<AtomicBool>,
74-
_socket_cleanup_guard: SocketCleanupGuard,
77+
_socket_cleanup_guard: Option<SocketCleanupGuard>,
7578
}
7679

7780
struct EvalProcessOutput {
@@ -210,6 +213,20 @@ impl Drop for SocketCleanupGuard {
210213
}
211214
}
212215

216+
enum SseListener {
217+
#[cfg(unix)]
218+
Unix(UnixListener),
219+
#[cfg(not(unix))]
220+
Tcp(TcpListener),
221+
}
222+
223+
struct BoundSseListener {
224+
listener: SseListener,
225+
env_key: &'static str,
226+
env_value: String,
227+
socket_cleanup_guard: Option<SocketCleanupGuard>,
228+
}
229+
213230
#[derive(Debug, Copy, Clone, Eq, PartialEq, ValueEnum)]
214231
pub enum EvalLanguage {
215232
#[value(alias = "js")]
@@ -660,32 +677,60 @@ async fn spawn_eval_runner(
660677
let (js_runner, py_runner) = prepare_eval_runners()?;
661678
let force_esm = matches!(js_mode, JsMode::ForceEsm);
662679

663-
let (listener, socket_path, socket_cleanup_guard) = bind_sse_listener()?;
680+
let BoundSseListener {
681+
listener,
682+
env_key,
683+
env_value,
684+
socket_cleanup_guard,
685+
} = bind_sse_listener()?;
664686
let (tx, rx) = mpsc::unbounded_channel();
665687
let sse_connected = Arc::new(AtomicBool::new(false));
666688

667689
let tx_sse = tx.clone();
668690
let sse_connected_for_task = Arc::clone(&sse_connected);
669691
let sse_task = tokio::spawn(async move {
670-
match listener.accept().await {
671-
Ok((stream, _)) => {
672-
sse_connected_for_task.store(true, Ordering::Relaxed);
673-
if let Err(err) = read_sse_stream(stream, tx_sse.clone()).await {
692+
match listener {
693+
#[cfg(unix)]
694+
SseListener::Unix(listener) => match listener.accept().await {
695+
Ok((stream, _)) => {
696+
sse_connected_for_task.store(true, Ordering::Relaxed);
697+
if let Err(err) = read_sse_stream(stream, tx_sse.clone()).await {
698+
let _ = tx_sse.send(EvalEvent::Error {
699+
message: format!("SSE stream error: {err}"),
700+
stack: None,
701+
status: None,
702+
});
703+
}
704+
}
705+
Err(err) => {
674706
let _ = tx_sse.send(EvalEvent::Error {
675-
message: format!("SSE stream error: {err}"),
707+
message: format!("Failed to accept SSE connection: {err}"),
676708
stack: None,
677709
status: None,
678710
});
679711
}
680-
}
681-
Err(err) => {
682-
let _ = tx_sse.send(EvalEvent::Error {
683-
message: format!("Failed to accept SSE connection: {err}"),
684-
stack: None,
685-
status: None,
686-
});
687-
}
688-
};
712+
},
713+
#[cfg(not(unix))]
714+
SseListener::Tcp(listener) => match listener.accept().await {
715+
Ok((stream, _)) => {
716+
sse_connected_for_task.store(true, Ordering::Relaxed);
717+
if let Err(err) = read_sse_stream(stream, tx_sse.clone()).await {
718+
let _ = tx_sse.send(EvalEvent::Error {
719+
message: format!("SSE stream error: {err}"),
720+
stack: None,
721+
status: None,
722+
});
723+
}
724+
}
725+
Err(err) => {
726+
let _ = tx_sse.send(EvalEvent::Error {
727+
message: format!("Failed to accept SSE connection: {err}"),
728+
stack: None,
729+
status: None,
730+
});
731+
}
732+
},
733+
}
689734
});
690735

691736
let (mut cmd, runner_kind) = match language {
@@ -753,10 +798,7 @@ async fn spawn_eval_runner(
753798
serde_json::to_string(&options.extra_args).context("failed to serialize extra args")?;
754799
cmd.env("BT_EVAL_EXTRA_ARGS_JSON", serialized);
755800
}
756-
cmd.env(
757-
"BT_EVAL_SSE_SOCK",
758-
socket_path.to_string_lossy().to_string(),
759-
);
801+
cmd.env(env_key, env_value);
760802
cmd.stdout(Stdio::piped());
761803
cmd.stderr(Stdio::piped());
762804

@@ -2304,7 +2346,11 @@ fn prepare_js_runner_in_cwd() -> Result<PathBuf> {
23042346

23052347
fn runner_bin_name(runner_command: &Path) -> Option<String> {
23062348
let name = runner_command.file_name()?.to_str()?.to_ascii_lowercase();
2307-
Some(name.strip_suffix(".cmd").unwrap_or(&name).to_string())
2349+
let trimmed = name
2350+
.strip_suffix(".cmd")
2351+
.or_else(|| name.strip_suffix(".exe"))
2352+
.unwrap_or(&name);
2353+
Some(trimmed.to_string())
23082354
}
23092355

23102356
fn runner_kind_for_bin(runner_command: &Path) -> RunnerKind {
@@ -2342,9 +2388,15 @@ fn is_ts_node_runner(runner_command: &Path) -> bool {
23422388

23432389
fn find_python_binary() -> Option<PathBuf> {
23442390
if let Some(venv_root) = std::env::var_os("VIRTUAL_ENV") {
2345-
let candidate = PathBuf::from(venv_root).join("bin").join("python");
2346-
if candidate.is_file() {
2347-
return Some(candidate);
2391+
let venv_root = PathBuf::from(venv_root);
2392+
let unix = venv_root.join("bin").join("python");
2393+
if unix.is_file() {
2394+
return Some(unix);
2395+
}
2396+
2397+
let windows = venv_root.join("Scripts").join("python.exe");
2398+
if windows.is_file() {
2399+
return Some(windows);
23482400
}
23492401
}
23502402
find_binary_in_path(&["python3", "python"])
@@ -2358,9 +2410,10 @@ fn find_node_module_bin(binary: &str, start: &Path) -> Option<PathBuf> {
23582410
return Some(base);
23592411
}
23602412
if cfg!(windows) {
2361-
let cmd = base.with_extension("cmd");
2362-
if cmd.is_file() {
2363-
return Some(cmd);
2413+
for candidate in with_windows_extensions(&base) {
2414+
if candidate.is_file() {
2415+
return Some(candidate);
2416+
}
23642417
}
23652418
}
23662419
current = dir.parent();
@@ -2377,16 +2430,27 @@ fn find_binary_in_path(candidates: &[&str]) -> Option<PathBuf> {
23772430
return Some(path);
23782431
}
23792432
if cfg!(windows) {
2380-
let cmd = path.with_extension("cmd");
2381-
if cmd.is_file() {
2382-
return Some(cmd);
2433+
for candidate_path in with_windows_extensions(&path) {
2434+
if candidate_path.is_file() {
2435+
return Some(candidate_path);
2436+
}
23832437
}
23842438
}
23852439
}
23862440
}
23872441
None
23882442
}
23892443

2444+
#[cfg(windows)]
2445+
fn with_windows_extensions(path: &Path) -> [PathBuf; 2] {
2446+
[path.with_extension("exe"), path.with_extension("cmd")]
2447+
}
2448+
2449+
#[cfg(not(windows))]
2450+
fn with_windows_extensions(_path: &Path) -> [PathBuf; 0] {
2451+
[]
2452+
}
2453+
23902454
fn build_sse_socket_path() -> Result<PathBuf> {
23912455
let pid = std::process::id();
23922456
let serial = SSE_SOCKET_COUNTER.fetch_add(1, Ordering::Relaxed);
@@ -2397,14 +2461,22 @@ fn build_sse_socket_path() -> Result<PathBuf> {
23972461
Ok(std::env::temp_dir().join(format!("bt-eval-{pid}-{now}-{serial}.sock")))
23982462
}
23992463

2400-
fn bind_sse_listener() -> Result<(UnixListener, PathBuf, SocketCleanupGuard)> {
2464+
#[cfg(unix)]
2465+
fn bind_sse_listener() -> Result<BoundSseListener> {
24012466
let mut last_bind_err: Option<std::io::Error> = None;
24022467
for _ in 0..SSE_SOCKET_BIND_MAX_ATTEMPTS {
24032468
let socket_path = build_sse_socket_path()?;
24042469
let socket_cleanup_guard = SocketCleanupGuard::new(socket_path.clone());
24052470
let _ = std::fs::remove_file(&socket_path);
24062471
match UnixListener::bind(&socket_path) {
2407-
Ok(listener) => return Ok((listener, socket_path, socket_cleanup_guard)),
2472+
Ok(listener) => {
2473+
return Ok(BoundSseListener {
2474+
listener: SseListener::Unix(listener),
2475+
env_key: "BT_EVAL_SSE_SOCK",
2476+
env_value: socket_path.to_string_lossy().to_string(),
2477+
socket_cleanup_guard: Some(socket_cleanup_guard),
2478+
})
2479+
}
24082480
Err(err)
24092481
if matches!(
24102482
err.kind(),
@@ -2430,6 +2502,32 @@ fn bind_sse_listener() -> Result<(UnixListener, PathBuf, SocketCleanupGuard)> {
24302502
))
24312503
}
24322504

2505+
#[cfg(not(unix))]
2506+
fn bind_sse_listener() -> Result<BoundSseListener> {
2507+
bind_sse_tcp_listener()
2508+
}
2509+
2510+
#[cfg(not(unix))]
2511+
fn bind_sse_tcp_listener() -> Result<BoundSseListener> {
2512+
let std_listener = std::net::TcpListener::bind((std::net::Ipv4Addr::LOCALHOST, 0))
2513+
.context("failed to bind SSE TCP listener")?;
2514+
let addr = std_listener
2515+
.local_addr()
2516+
.context("failed to resolve SSE TCP listener address")?;
2517+
std_listener
2518+
.set_nonblocking(true)
2519+
.context("failed to set SSE TCP listener non-blocking mode")?;
2520+
let listener =
2521+
TcpListener::from_std(std_listener).context("failed to create tokio SSE TCP listener")?;
2522+
2523+
Ok(BoundSseListener {
2524+
listener: SseListener::Tcp(listener),
2525+
env_key: "BT_EVAL_SSE_ADDR",
2526+
env_value: format!("127.0.0.1:{}", addr.port()),
2527+
socket_cleanup_guard: None,
2528+
})
2529+
}
2530+
24332531
fn eval_runner_cache_dir() -> PathBuf {
24342532
let root = std::env::var_os("XDG_CACHE_HOME")
24352533
.map(PathBuf::from)
@@ -3953,10 +4051,43 @@ mod tests {
39534051
#[test]
39544052
fn runner_kind_for_bin_detects_bun() {
39554053
assert_eq!(runner_kind_for_bin(Path::new("bun")), RunnerKind::Bun);
4054+
assert_eq!(runner_kind_for_bin(Path::new("bun.exe")), RunnerKind::Bun);
39564055
assert_eq!(runner_kind_for_bin(Path::new("bunx")), RunnerKind::Bun);
4056+
assert_eq!(runner_kind_for_bin(Path::new("bunx.cmd")), RunnerKind::Bun);
39574057
assert_eq!(runner_kind_for_bin(Path::new("deno")), RunnerKind::Other);
39584058
}
39594059

4060+
#[test]
4061+
fn find_python_binary_uses_virtual_env_scripts_python_exe() {
4062+
let _guard = env_test_lock()
4063+
.lock()
4064+
.unwrap_or_else(|poisoned| poisoned.into_inner());
4065+
let dir = make_temp_dir("venv-scripts-python");
4066+
let scripts_dir = dir.join("Scripts");
4067+
fs::create_dir_all(&scripts_dir).expect("create scripts directory");
4068+
let scripts_python = scripts_dir.join("python.exe");
4069+
fs::write(&scripts_python, "").expect("write scripts python");
4070+
4071+
let previous = set_env_var("VIRTUAL_ENV", &dir.to_string_lossy());
4072+
let found = find_python_binary();
4073+
restore_env_var("VIRTUAL_ENV", previous);
4074+
4075+
assert_eq!(found, Some(scripts_python));
4076+
let _ = fs::remove_dir_all(&dir);
4077+
}
4078+
4079+
#[cfg(not(unix))]
4080+
#[test]
4081+
fn bind_sse_tcp_listener_sets_sse_addr_env() {
4082+
let bound = bind_sse_tcp_listener().expect("bind tcp sse listener");
4083+
assert_eq!(bound.env_key, "BT_EVAL_SSE_ADDR");
4084+
assert!(
4085+
bound.env_value.starts_with("127.0.0.1:"),
4086+
"expected loopback host in SSE address env"
4087+
);
4088+
assert!(bound.socket_cleanup_guard.is_none());
4089+
}
4090+
39604091
#[test]
39614092
fn set_node_heap_size_env_sets_default_when_absent() {
39624093
let _guard = env_test_lock()

0 commit comments

Comments
 (0)