diff --git a/devolutions-agent/Cargo.toml b/devolutions-agent/Cargo.toml index 1496a1ed4..27ce4e7fd 100644 --- a/devolutions-agent/Cargo.toml +++ b/devolutions-agent/Cargo.toml @@ -41,6 +41,7 @@ sha2 = "0.10" serde_json = "1" serde = { version = "1", features = ["derive"] } tap = "1.0" +tempfile = "3" tokio-tungstenite = { version = "0.26", features = ["rustls-tls-native-roots"] } tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "tls12", "ring"] } tracing = "0.1" @@ -98,8 +99,5 @@ features = [ [target.'cfg(windows)'.build-dependencies] embed-resource = "3.0" -[dev-dependencies] -tempfile = "3" - [target.'cfg(windows)'.dev-dependencies] expect-test = "1.5" diff --git a/devolutions-agent/src/config.rs b/devolutions-agent/src/config.rs index eb1ff549a..9cde999b5 100644 --- a/devolutions-agent/src/config.rs +++ b/devolutions-agent/src/config.rs @@ -525,7 +525,10 @@ pub mod dto { pub server_spki_sha256: Option, } - #[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize)] + /// PowerShell Universal Event Hub compatibility configuration. + /// + /// Defaults to disabled. + #[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize, Default)] #[serde(rename_all = "PascalCase")] pub struct PsuEventHubConf { /// Enable PowerShell Universal Event Hub compatibility. @@ -544,17 +547,6 @@ pub mod dto { pub powershell: PsuPowerShellConf, } - #[allow(clippy::derivable_impls)] // Just to be explicit about default disabled behavior. - impl Default for PsuEventHubConf { - fn default() -> Self { - Self { - enabled: false, - connections: Vec::new(), - powershell: PsuPowerShellConf::default(), - } - } - } - #[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "PascalCase")] pub struct PsuEventHubConnectionConf { diff --git a/devolutions-agent/src/psu_event_hub/executor.rs b/devolutions-agent/src/psu_event_hub/executor.rs index 8769a42d1..20865ead5 100644 --- a/devolutions-agent/src/psu_event_hub/executor.rs +++ b/devolutions-agent/src/psu_event_hub/executor.rs @@ -3,9 +3,10 @@ use std::sync::Arc; use anyhow::Context as _; use camino::Utf8PathBuf; use serde_json::Value; +use tokio::task::JoinSet; use uuid::Uuid; -use crate::config::dto::{PsuEventHubConnectionConf, PsuPowerShellConf}; +use crate::config::dto::PsuEventHubConnectionConf; use crate::psu_event_hub::models::WebsocketEventResponse; use crate::psu_event_hub::powershell_worker::PowerShellWorker; use crate::psu_event_hub::result_store::ResultStore; @@ -19,16 +20,21 @@ pub(super) struct EventHubExecutor { } impl EventHubExecutor { - pub(super) fn new(connection: &PsuEventHubConnectionConf, power_shell: PsuPowerShellConf) -> Self { + pub(super) fn new(connection: &PsuEventHubConnectionConf, worker: Arc) -> Self { Self { hub: connection.hub.clone(), script_path: connection.script_path.as_ref().map(normalize_script_path), - worker: Arc::new(PowerShellWorker::new(power_shell)), + worker, result_store: ResultStore::default(), } } - pub(super) fn handle_invocation(&self, target: &str, arguments: &[Value]) -> anyhow::Result> { + pub(super) fn handle_invocation( + &self, + target: &str, + arguments: &[Value], + execution_tasks: &mut JoinSet<()>, + ) -> anyhow::Result> { if target == "GetResult" { let execution_id = required_string_argument(arguments, 0, "event id")?; let result = self.result_store.take(execution_id); @@ -39,27 +45,27 @@ impl EventHubExecutor { if target == self.hub { let data = required_string_argument(arguments, 0, "event data")?.to_owned(); - let execution_id = self.execute_script(data, true); + let execution_id = self.execute_script(data, true, execution_tasks); return Ok(Some(Value::String(execution_id))); } if target == format!("{}Void", self.hub) { let data = required_string_argument(arguments, 0, "event data")?.to_owned(); - self.execute_script(data, false); + self.execute_script(data, false, execution_tasks); return Ok(None); } if target == format!("{}Module", self.hub) { let command = required_string_argument(arguments, 0, "command")?.to_owned(); let data = required_string_argument(arguments, 1, "event data")?.to_owned(); - let execution_id = self.execute_command(command, data, true); + let execution_id = self.execute_command(command, data, true, execution_tasks); return Ok(Some(Value::String(execution_id))); } if target == format!("{}ModuleVoid", self.hub) { let command = required_string_argument(arguments, 0, "command")?.to_owned(); let data = required_string_argument(arguments, 1, "event data")?.to_owned(); - self.execute_command(command, data, false); + self.execute_command(command, data, false, execution_tasks); return Ok(None); } @@ -67,13 +73,19 @@ impl EventHubExecutor { Ok(None) } - fn execute_command(&self, command: String, data: String, return_result: bool) -> String { + fn execute_command( + &self, + command: String, + data: String, + return_result: bool, + execution_tasks: &mut JoinSet<()>, + ) -> String { let execution_id = Uuid::new_v4().to_string(); let worker = Arc::clone(&self.worker); let result_store = self.result_store.clone(); let stored_execution_id = execution_id.clone(); - tokio::spawn(async move { + execution_tasks.spawn(async move { match worker.execute_command(command, data, return_result).await { Ok(response) if return_result => result_store.insert(stored_execution_id, response), Ok(_) => {} @@ -90,7 +102,7 @@ impl EventHubExecutor { execution_id } - fn execute_script(&self, data: String, return_result: bool) -> String { + fn execute_script(&self, data: String, return_result: bool, execution_tasks: &mut JoinSet<()>) -> String { let execution_id = Uuid::new_v4().to_string(); let Some(script_path) = self.script_path.clone() else { if return_result { @@ -106,7 +118,7 @@ impl EventHubExecutor { let result_store = self.result_store.clone(); let stored_execution_id = execution_id.clone(); - tokio::spawn(async move { + execution_tasks.spawn(async move { match worker.execute_script(script_path, data, return_result).await { Ok(response) if return_result => result_store.insert(stored_execution_id, response), Ok(_) => {} diff --git a/devolutions-agent/src/psu_event_hub/mod.rs b/devolutions-agent/src/psu_event_hub/mod.rs index 2f3123612..eee44b74f 100644 --- a/devolutions-agent/src/psu_event_hub/mod.rs +++ b/devolutions-agent/src/psu_event_hub/mod.rs @@ -5,11 +5,14 @@ mod powershell_worker; mod result_store; mod signalr; +use std::sync::Arc; + +use anyhow::Context as _; use async_trait::async_trait; use devolutions_gateway_task::{ShutdownSignal, Task}; use tokio::task::JoinSet; -use crate::config::ConfHandle; +use crate::config::{ConfHandle, dto}; use crate::psu_event_hub::executor::EventHubExecutor; use crate::psu_event_hub::powershell_worker::PowerShellWorker; @@ -45,7 +48,9 @@ impl Task for PsuEventHubTask { let mut join_set = JoinSet::new(); - let secret_resolver = PowerShellWorker::new(psu_conf.powershell.clone()); + let worker = Arc::new( + PowerShellWorker::new(psu_conf.powershell.clone()).context("failed to initialize PSU PowerShell worker")?, + ); for mut connection in psu_conf.connections { if connection.hub.trim().is_empty() { @@ -53,8 +58,17 @@ impl Task for PsuEventHubTask { continue; } + if let Err(error) = validate_connection(&connection) { + error!( + hub = %connection.hub, + error = format!("{error:#}"), + "Skipping PSU Event Hub connection because configuration is invalid" + ); + continue; + } + if let Some(app_token) = connection.app_token.as_deref() { - match secret_resolver.resolve_app_token(app_token).await { + match worker.resolve_app_token(app_token).await { Ok(resolved) => connection.app_token = Some(resolved), Err(error) => { error!( @@ -67,7 +81,7 @@ impl Task for PsuEventHubTask { } } - let executor = EventHubExecutor::new(&connection, psu_conf.powershell.clone()); + let executor = EventHubExecutor::new(&connection, Arc::clone(&worker)); let connection_shutdown_signal = shutdown_signal.clone(); join_set @@ -85,3 +99,35 @@ impl Task for PsuEventHubTask { Ok(()) } } + +fn validate_connection(connection: &dto::PsuEventHubConnectionConf) -> anyhow::Result<()> { + if connection.use_default_credentials && connection.app_token.is_none() { + anyhow::bail!( + "PSU Event Hub UseDefaultCredentials is configured for hub {}, but Windows default credentials are not implemented", + connection.hub + ); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use url::Url; + + use super::*; + + #[test] + fn default_credentials_without_app_token_are_rejected() { + let connection = dto::PsuEventHubConnectionConf { + hub: "Hub".to_owned(), + url: Url::parse("http://localhost:5000").expect("parse URL"), + app_token: None, + use_default_credentials: true, + script_path: None, + description: None, + }; + + assert!(validate_connection(&connection).is_err()); + } +} diff --git a/devolutions-agent/src/psu_event_hub/models.rs b/devolutions-agent/src/psu_event_hub/models.rs index dddd6106a..9d2989bb4 100644 --- a/devolutions-agent/src/psu_event_hub/models.rs +++ b/devolutions-agent/src/psu_event_hub/models.rs @@ -37,6 +37,16 @@ impl WebsocketEventResponse { terminating_error: Some(message.into()), } } + + pub(super) fn timeout(message: impl Into) -> Self { + Self { + data: None, + job_outputs: Vec::new(), + complete: true, + timeout: true, + terminating_error: Some(message.into()), + } + } } impl Default for WebsocketEventResponse { diff --git a/devolutions-agent/src/psu_event_hub/powershell_worker.rs b/devolutions-agent/src/psu_event_hub/powershell_worker.rs index 41ae0928a..d036c5d9b 100644 --- a/devolutions-agent/src/psu_event_hub/powershell_worker.rs +++ b/devolutions-agent/src/psu_event_hub/powershell_worker.rs @@ -1,13 +1,13 @@ use std::ffi::OsString; use std::process::Stdio; use std::sync::Arc; +use std::time::Duration; use anyhow::{Context as _, bail}; use camino::{Utf8Path, Utf8PathBuf}; use serde::Serialize; use tokio::process::Command; use tokio::sync::Semaphore; -use uuid::Uuid; use crate::config::dto::PsuPowerShellConf; use crate::psu_event_hub::models::WebsocketEventResponse; @@ -166,19 +166,29 @@ try { $response | ConvertTo-Json -Compress -Depth 16 "#; +const POWERSHELL_EXECUTION_TIMEOUT: Duration = Duration::from_secs(30 * 60); + #[derive(Debug, Clone)] pub(super) struct PowerShellWorker { conf: PsuPowerShellConf, permits: Arc, + worker_script: Arc, + execution_timeout: Duration, } impl PowerShellWorker { - pub(super) fn new(conf: PsuPowerShellConf) -> Self { + pub(super) fn new(conf: PsuPowerShellConf) -> anyhow::Result { + Self::with_execution_timeout(conf, POWERSHELL_EXECUTION_TIMEOUT) + } + + fn with_execution_timeout(conf: PsuPowerShellConf, execution_timeout: Duration) -> anyhow::Result { let worker_limit = effective_worker_limit(&conf); - Self { + Ok(Self { conf, permits: Arc::new(Semaphore::new(worker_limit)), - } + worker_script: Arc::new(WorkerScriptFile::new()?), + execution_timeout, + }) } pub(super) async fn resolve_app_token(&self, app_token: &str) -> anyhow::Result { @@ -223,25 +233,9 @@ impl PowerShellWorker { .acquire() .await .context("PSU PowerShell worker pool is closed")?; - let temp_dir = Utf8PathBuf::from_path_buf(std::env::temp_dir()) - .map_err(|path| anyhow::anyhow!("non-UTF-8 temp path: {path:?}"))?; - let request_path = temp_dir.join(format!("devolutions-agent-psu-{}.json", Uuid::new_v4())); - let script_path = temp_dir.join(format!("devolutions-agent-psu-{}.ps1", Uuid::new_v4())); + let request_file = TempRequestFile::write(&request).await?; - let request_json = serde_json::to_vec(&request).context("failed to serialize PSU worker request")?; - tokio::fs::write(&request_path, request_json) - .await - .with_context(|| format!("failed to write PSU worker request at {request_path}"))?; - tokio::fs::write(&script_path, WORKER_SCRIPT) - .await - .with_context(|| format!("failed to write PSU worker script at {script_path}"))?; - - let output = self.invoke_worker(&script_path, &request_path).await; - - remove_temp_file(&request_path).await; - remove_temp_file(&script_path).await; - - output + self.invoke_worker(self.worker_script.path(), request_file.path()).await } async fn invoke_worker( @@ -267,13 +261,23 @@ impl PowerShellWorker { if let Some(virtual_environment) = &self.conf.virtual_environment { command.env("PSMODULE_VENV_PATH", virtual_environment); } - - let output = command.output().await.with_context(|| { - format!( - "failed to start PowerShell worker using {}", - executable.to_string_lossy() - ) - })?; + command.kill_on_drop(true); + + let output = match tokio::time::timeout(self.execution_timeout, command.output()).await { + Ok(output) => output.with_context(|| { + format!( + "failed to start PowerShell worker using {}", + executable.to_string_lossy() + ) + })?, + Err(_) => { + warn!( + timeout_secs = self.execution_timeout.as_secs(), + "PowerShell worker timed out" + ); + return Ok(WebsocketEventResponse::timeout("PowerShell worker timed out.")); + } + }; if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); @@ -288,6 +292,69 @@ impl PowerShellWorker { } } +#[derive(Debug)] +struct WorkerScriptFile { + path: Utf8PathBuf, + _temp_path: tempfile::TempPath, +} + +impl WorkerScriptFile { + fn new() -> anyhow::Result { + let temp_path = tempfile::Builder::new() + .prefix("devolutions-agent-psu-worker-") + .suffix(".ps1") + .tempfile_in(temp_dir()?.as_std_path()) + .context("failed to create temporary PSU worker script")? + .into_temp_path(); + let path = Utf8PathBuf::from_path_buf(temp_path.to_path_buf()) + .map_err(|path| anyhow::anyhow!("non-UTF-8 PSU worker script path: {path:?}"))?; + + std::fs::write(&path, WORKER_SCRIPT).with_context(|| format!("failed to write PSU worker script at {path}"))?; + + Ok(Self { + path, + _temp_path: temp_path, + }) + } + + fn path(&self) -> &Utf8Path { + &self.path + } +} + +#[derive(Debug)] +struct TempRequestFile { + path: Utf8PathBuf, + _temp_path: tempfile::TempPath, +} + +impl TempRequestFile { + async fn write(request: &WorkerRequest) -> anyhow::Result { + let request_json = serde_json::to_vec(request).context("failed to serialize PSU worker request")?; + let temp_path = tempfile::Builder::new() + .prefix("devolutions-agent-psu-") + .suffix(".json") + .tempfile_in(temp_dir()?.as_std_path()) + .context("failed to create temporary PSU worker request")? + .into_temp_path(); + let path = Utf8PathBuf::from_path_buf(temp_path.to_path_buf()) + .map_err(|path| anyhow::anyhow!("non-UTF-8 PSU worker request path: {path:?}"))?; + + tokio::fs::write(&path, request_json) + .await + .with_context(|| format!("failed to write PSU worker request at {path}"))?; + + Ok(Self { + path, + _temp_path: temp_path, + }) + } + + fn path(&self) -> &Utf8Path { + &self.path + } +} + #[derive(Serialize)] #[serde(rename_all = "camelCase")] struct WorkerRequest { @@ -377,10 +444,8 @@ fn resolve_powershell_executable(conf: &PsuPowerShellConf) -> OsString { } } -async fn remove_temp_file(path: &Utf8Path) { - if let Err(error) = tokio::fs::remove_file(path).await { - debug!(%path, %error, "Failed to remove temporary PSU worker file"); - } +fn temp_dir() -> anyhow::Result { + Utf8PathBuf::from_path_buf(std::env::temp_dir()).map_err(|path| anyhow::anyhow!("non-UTF-8 temp path: {path:?}")) } #[cfg(test)] @@ -422,9 +487,24 @@ mod tests { "#; + const HASHTABLE_SECONDS: &str = r#" + + + System.Collections.Hashtable + System.Object + + + + Seconds + 10 + + + +"#; + #[tokio::test] async fn command_execution_returns_clixml_result() { - let worker = PowerShellWorker::new(PsuPowerShellConf::default()); + let worker = PowerShellWorker::new(PsuPowerShellConf::default()).expect("create worker"); let response = worker .execute_command("Get-Variable".to_owned(), HASHTABLE_PS_VERSION_TABLE.to_owned(), true) .await @@ -437,7 +517,7 @@ mod tests { #[tokio::test] async fn command_execution_captures_error_stream() { - let worker = PowerShellWorker::new(PsuPowerShellConf::default()); + let worker = PowerShellWorker::new(PsuPowerShellConf::default()).expect("create worker"); let response = worker .execute_command("Write-Error".to_owned(), HASHTABLE_MESSAGE.to_owned(), true) .await @@ -454,12 +534,27 @@ mod tests { ); } + #[tokio::test] + async fn command_execution_times_out() { + let worker = PowerShellWorker::with_execution_timeout(PsuPowerShellConf::default(), Duration::from_millis(1)) + .expect("create worker"); + let response = worker + .execute_command("Start-Sleep".to_owned(), HASHTABLE_SECONDS.to_owned(), true) + .await + .expect("execute command"); + + assert!(response.complete); + assert!(response.timeout); + assert!(response.terminating_error.is_some()); + } + #[tokio::test] async fn literal_app_token_does_not_require_secret_resolution() { let worker = PowerShellWorker::new(PsuPowerShellConf { executable_path: Some(Utf8PathBuf::from("missing-pwsh")), ..PsuPowerShellConf::default() - }); + }) + .expect("create worker"); let token = worker.resolve_app_token("literal-token").await.expect("resolve token"); diff --git a/devolutions-agent/src/psu_event_hub/result_store.rs b/devolutions-agent/src/psu_event_hub/result_store.rs index b86b3d9bc..d567a3b80 100644 --- a/devolutions-agent/src/psu_event_hub/result_store.rs +++ b/devolutions-agent/src/psu_event_hub/result_store.rs @@ -1,32 +1,113 @@ -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; use std::sync::Arc; +use std::time::{Duration, Instant}; use parking_lot::Mutex; use crate::psu_event_hub::models::WebsocketEventResponse; -#[derive(Debug, Clone, Default)] +const RESULT_TTL: Duration = Duration::from_secs(15 * 60); +const MAX_RESULTS: usize = 1024; + +#[derive(Debug, Clone)] pub(super) struct ResultStore { - inner: Arc>>, + inner: Arc>, + ttl: Duration, + max_results: usize, } impl ResultStore { + fn new(ttl: Duration, max_results: usize) -> Self { + Self { + inner: Arc::new(Mutex::new(ResultStoreInner::default())), + ttl, + max_results, + } + } + pub(super) fn insert(&self, execution_id: String, response: WebsocketEventResponse) { - self.inner.lock().insert(execution_id, response); + let mut inner = self.inner.lock(); + let now = Instant::now(); + + inner.remove_expired(now, self.ttl); + inner.results.insert( + execution_id.clone(), + StoredResult { + inserted_at: now, + response, + }, + ); + inner.order.push_back(execution_id); + inner.enforce_limit(self.max_results); } pub(super) fn take(&self, execution_id: &str) -> WebsocketEventResponse { - self.inner - .lock() + let mut inner = self.inner.lock(); + inner.remove_expired(Instant::now(), self.ttl); + inner + .results .remove(execution_id) + .map(|stored| stored.response) .unwrap_or_else(WebsocketEventResponse::pending) } } +impl Default for ResultStore { + fn default() -> Self { + Self::new(RESULT_TTL, MAX_RESULTS) + } +} + +#[derive(Debug, Default)] +struct ResultStoreInner { + results: HashMap, + order: VecDeque, +} + +impl ResultStoreInner { + fn remove_expired(&mut self, now: Instant, ttl: Duration) { + while let Some(execution_id) = self.order.front() { + let Some(result) = self.results.get(execution_id) else { + self.order.pop_front(); + continue; + }; + + if now.duration_since(result.inserted_at) < ttl { + break; + } + + let execution_id = self.order.pop_front().expect("front exists"); + self.results.remove(&execution_id); + } + } + + fn enforce_limit(&mut self, max_results: usize) { + while self.results.len() > max_results { + let Some(execution_id) = self.order.pop_front() else { + break; + }; + + self.results.remove(&execution_id); + } + } +} + +#[derive(Debug)] +struct StoredResult { + inserted_at: Instant, + response: WebsocketEventResponse, +} + #[cfg(test)] mod tests { use super::*; + impl ResultStore { + fn test_with_limits(ttl: Duration, max_results: usize) -> Self { + Self::new(ttl, max_results) + } + } + #[test] fn take_removes_result_after_first_read() { let store = ResultStore::default(); @@ -41,4 +122,40 @@ mod tests { assert!(store.take("execution-id").complete); assert!(!store.take("execution-id").complete); } + + #[test] + fn insert_evicts_oldest_result_when_limit_is_reached() { + let store = ResultStore::test_with_limits(Duration::from_secs(60), 1); + store.insert( + "first".to_owned(), + WebsocketEventResponse { + complete: true, + ..WebsocketEventResponse::default() + }, + ); + store.insert( + "second".to_owned(), + WebsocketEventResponse { + complete: true, + ..WebsocketEventResponse::default() + }, + ); + + assert!(!store.take("first").complete); + assert!(store.take("second").complete); + } + + #[test] + fn take_ignores_expired_results() { + let store = ResultStore::test_with_limits(Duration::ZERO, 10); + store.insert( + "execution-id".to_owned(), + WebsocketEventResponse { + complete: true, + ..WebsocketEventResponse::default() + }, + ); + + assert!(!store.take("execution-id").complete); + } } diff --git a/devolutions-agent/src/psu_event_hub/signalr.rs b/devolutions-agent/src/psu_event_hub/signalr.rs index f5279e17b..3bc8492ca 100644 --- a/devolutions-agent/src/psu_event_hub/signalr.rs +++ b/devolutions-agent/src/psu_event_hub/signalr.rs @@ -1,8 +1,11 @@ +use std::time::{Duration, Instant}; + use anyhow::{Context as _, bail}; use futures::{SinkExt as _, StreamExt as _}; use reqwest::header::{AUTHORIZATION, HeaderMap, HeaderValue}; use serde::Deserialize; use serde_json::{Value, json}; +use tokio::task::{JoinError, JoinSet}; use tokio_tungstenite::connect_async; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::tungstenite::client::IntoClientRequest as _; @@ -29,10 +32,28 @@ pub(super) async fn run_connection( executor: EventHubExecutor, mut shutdown_signal: devolutions_gateway_task::ShutdownSignal, ) -> anyhow::Result<()> { + use backoff::backoff::Backoff as _; + + const RETRY_INITIAL_INTERVAL: Duration = Duration::from_secs(1); + const RETRY_MAX_INTERVAL: Duration = Duration::from_secs(60); + const RETRY_MULTIPLIER: f64 = 2.0; + const CONNECTED_THRESHOLD: Duration = Duration::from_secs(30); + + let mut backoff = backoff::ExponentialBackoffBuilder::default() + .with_initial_interval(RETRY_INITIAL_INTERVAL) + .with_max_interval(RETRY_MAX_INTERVAL) + .with_multiplier(RETRY_MULTIPLIER) + .with_max_elapsed_time(None) + .build(); + let mut execution_tasks = JoinSet::new(); + loop { - match run_single_connection(&connection, &executor, &mut shutdown_signal).await { + let start = Instant::now(); + + match run_single_connection(&connection, &executor, &mut shutdown_signal, &mut execution_tasks).await { Ok(()) => { info!(hub = %connection.hub, "Stopping PSU Event Hub connection"); + execution_tasks.shutdown().await; return Ok(()); } Err(error) => { @@ -45,9 +66,24 @@ pub(super) async fn run_connection( } } - tokio::select! { - _ = shutdown_signal.wait() => return Ok(()), - _ = tokio::time::sleep(std::time::Duration::from_secs(30)) => {} + if start.elapsed() > CONNECTED_THRESHOLD { + backoff.reset(); + } + + let wait = match backoff.next_backoff() { + Some(wait) => wait, + None => { + warn!("PSU Event Hub reconnect backoff exhausted, resetting"); + backoff.reset(); + RETRY_INITIAL_INTERVAL + } + }; + + info!(hub = %connection.hub, ?wait, "Reconnecting PSU Event Hub after backoff"); + + if !wait_before_reconnect(&mut shutdown_signal, wait, &mut execution_tasks).await { + execution_tasks.shutdown().await; + return Ok(()); } } } @@ -56,14 +92,8 @@ async fn run_single_connection( connection: &PsuEventHubConnectionConf, executor: &EventHubExecutor, shutdown_signal: &mut devolutions_gateway_task::ShutdownSignal, + execution_tasks: &mut JoinSet<()>, ) -> anyhow::Result<()> { - if connection.use_default_credentials && connection.app_token.is_none() { - warn!( - hub = %connection.hub, - "PSU Event Hub UseDefaultCredentials is configured, but Windows default credentials are not implemented yet" - ); - } - let endpoint = endpoint_url(connection)?; let negotiate = negotiate_url(&endpoint)?; let headers = psu_headers(connection)?; @@ -117,30 +147,71 @@ async fn run_single_connection( tokio::select! { _ = shutdown_signal.wait() => { let _ = socket.close(None).await; - return Ok(()); + break Ok(()); } message = socket.next() => { let Some(message) = message else { - bail!("SignalR WebSocket closed"); + break Err(anyhow::anyhow!("SignalR WebSocket closed")); + }; + + let message = match message.context("failed to read SignalR WebSocket message") { + Ok(message) => message, + Err(error) => break Err(error), }; - match message.context("failed to read SignalR WebSocket message")? { - Message::Text(text) => handle_text_message(&mut socket, executor, &text).await?, + let message_result = match message { + Message::Text(text) => handle_text_message(&mut socket, executor, &text, execution_tasks).await, Message::Binary(bytes) => { - let text = String::from_utf8(bytes.to_vec()).context("SignalR binary message was not UTF-8")?; - handle_text_message(&mut socket, executor, &text).await?; + let text = match std::str::from_utf8(&bytes).context("SignalR binary message was not UTF-8") { + Ok(text) => text, + Err(error) => break Err(error), + }; + handle_text_message(&mut socket, executor, text, execution_tasks).await } - Message::Close(frame) => bail!("SignalR WebSocket closed: {frame:?}"), - Message::Ping(payload) => socket.send(Message::Pong(payload)).await?, - Message::Pong(_) => {} - Message::Frame(_) => {} + Message::Close(frame) => break Err(anyhow::anyhow!("SignalR WebSocket closed: {frame:?}")), + Message::Ping(payload) => socket + .send(Message::Pong(payload)) + .await + .context("failed to send SignalR pong"), + Message::Pong(_) | Message::Frame(_) => Ok(()), + }; + + if let Err(error) = message_result { + break Err(error); } } + Some(result) = execution_tasks.join_next(), if !execution_tasks.is_empty() => { + log_execution_task_result(result); + } + } + } +} + +async fn wait_before_reconnect( + shutdown_signal: &mut devolutions_gateway_task::ShutdownSignal, + wait: Duration, + execution_tasks: &mut JoinSet<()>, +) -> bool { + let sleep = tokio::time::sleep(wait); + tokio::pin!(sleep); + + loop { + tokio::select! { + _ = shutdown_signal.wait() => return false, + _ = &mut sleep => return true, + Some(result) = execution_tasks.join_next(), if !execution_tasks.is_empty() => { + log_execution_task_result(result); + } } } } -async fn handle_text_message(socket: &mut S, executor: &EventHubExecutor, text: &str) -> anyhow::Result<()> +async fn handle_text_message( + socket: &mut S, + executor: &EventHubExecutor, + text: &str, + execution_tasks: &mut JoinSet<()>, +) -> anyhow::Result<()> where S: futures::Sink + Unpin, { @@ -151,7 +222,7 @@ where match message_type { None => {} - Some(1) => handle_invocation(socket, executor, value).await?, + Some(1) => handle_invocation(socket, executor, value, execution_tasks).await?, Some(6) => {} Some(7) => bail!("SignalR server sent close message"), Some(message_type) => trace!(message_type, "Ignoring unsupported SignalR message"), @@ -161,7 +232,12 @@ where Ok(()) } -async fn handle_invocation(socket: &mut S, executor: &EventHubExecutor, value: Value) -> anyhow::Result<()> +async fn handle_invocation( + socket: &mut S, + executor: &EventHubExecutor, + value: Value, + execution_tasks: &mut JoinSet<()>, +) -> anyhow::Result<()> where S: futures::Sink + Unpin, { @@ -176,7 +252,7 @@ where .unwrap_or(&[]); let invocation_id = value.get("invocationId").and_then(Value::as_str); - let result = executor.handle_invocation(target, arguments)?; + let result = executor.handle_invocation(target, arguments, execution_tasks)?; if let Some(invocation_id) = invocation_id { let completion = if let Some(result) = result { json!({ @@ -200,6 +276,12 @@ where Ok(()) } +fn log_execution_task_result(result: Result<(), JoinError>) { + if let Err(error) = result { + error!(%error, "PSU Event Hub execution task panicked"); + } +} + fn endpoint_url(connection: &PsuEventHubConnectionConf) -> anyhow::Result { let endpoint = if connection.app_token.is_some() || connection.use_default_credentials { "autheventhub" @@ -257,20 +339,30 @@ fn redact_url(url: &Url) -> String { fn psu_headers(connection: &PsuEventHubConnectionConf) -> anyhow::Result { let mut headers = HeaderMap::new(); - headers.insert("PSUComputerName", HeaderValue::from_str(&computer_name())?); - headers.insert("PSUUserName", HeaderValue::from_str(&user_name())?); - headers.insert("PSUDomainName", HeaderValue::from_str(&domain_name())?); - headers.insert("PSUVersion", HeaderValue::from_static(env!("CARGO_PKG_VERSION"))); + let identity = psu_identity(); headers.insert( - "PSUDescription", - HeaderValue::from_str(connection.description.as_deref().unwrap_or_default())?, + "PSUComputerName", + psu_header_value("PSUComputerName", &computer_name())?, ); + headers.insert("PSUUserName", psu_header_value("PSUUserName", &identity.user_name)?); + headers.insert( + "PSUDomainName", + psu_header_value("PSUDomainName", &identity.domain_name)?, + ); + headers.insert("PSUVersion", HeaderValue::from_static(env!("CARGO_PKG_VERSION"))); + let description = sanitize_header_value(connection.description.as_deref().unwrap_or_default()); + headers.insert("PSUDescription", psu_header_value("PSUDescription", &description)?); if let Some(token) = &connection.app_token { - headers.insert(AUTHORIZATION, HeaderValue::from_str(&format!("Bearer {token}"))?); + let authorization = format!("Bearer {token}"); + headers.insert(AUTHORIZATION, psu_header_value("Authorization", &authorization)?); } Ok(headers) } +fn psu_header_value(name: &str, value: &str) -> anyhow::Result { + HeaderValue::from_str(value).with_context(|| format!("invalid PSU header value for {name}")) +} + fn apply_ws_headers(target: &mut WsHeaderMap, source: &HeaderMap) -> anyhow::Result<()> { for (name, value) in source { let name = WsHeaderName::from_bytes(name.as_str().as_bytes())?; @@ -287,12 +379,102 @@ fn computer_name() -> String { .unwrap_or_else(|| "localhost".to_owned()) } -fn user_name() -> String { - std::env::var("USERNAME") - .or_else(|_| std::env::var("USER")) - .unwrap_or_default() +#[derive(Debug, Clone, PartialEq, Eq)] +struct PsuIdentity { + user_name: String, + domain_name: String, +} + +fn psu_identity() -> PsuIdentity { + platform_psu_identity().unwrap_or_else(env_psu_identity) +} + +#[cfg(target_os = "windows")] +fn platform_psu_identity() -> Option { + use win_api_wrappers::identity::account::get_username; + use win_api_wrappers::raw::Win32::Security::Authentication::Identity::NameSamCompatible; + + let name = get_username(NameSamCompatible).ok()?.to_string_lossy(); + let identity = split_sam_compatible_name(&name); + if identity.user_name.is_empty() { + None + } else { + Some(identity) + } +} + +#[cfg(not(target_os = "windows"))] +fn platform_psu_identity() -> Option { + None +} + +fn split_sam_compatible_name(name: &str) -> PsuIdentity { + if let Some((domain_name, user_name)) = name.split_once('\\') { + PsuIdentity { + user_name: user_name.to_owned(), + domain_name: domain_name.to_owned(), + } + } else { + PsuIdentity { + user_name: name.to_owned(), + domain_name: env_domain_name(), + } + } +} + +fn env_psu_identity() -> PsuIdentity { + PsuIdentity { + user_name: std::env::var("USERNAME") + .or_else(|_| std::env::var("USER")) + .unwrap_or_default(), + domain_name: env_domain_name(), + } } -fn domain_name() -> String { +fn env_domain_name() -> String { std::env::var("USERDOMAIN").unwrap_or_default() } + +fn sanitize_header_value(value: &str) -> String { + value + .chars() + .map(|ch| if ch.is_ascii_control() && ch != '\t' { ' ' } else { ch }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_connection(description: Option) -> PsuEventHubConnectionConf { + PsuEventHubConnectionConf { + hub: "Hub".to_owned(), + url: Url::parse("http://localhost:5000").expect("parse URL"), + app_token: None, + use_default_credentials: false, + script_path: None, + description, + } + } + + #[test] + fn psu_headers_sanitize_description_control_characters() { + let headers = psu_headers(&test_connection(Some("line 1\r\nline 2".to_owned()))).expect("build headers"); + + assert_eq!( + headers["PSUDescription"].to_str().expect("description header"), + "line 1 line 2" + ); + } + + #[test] + fn sam_compatible_identity_splits_domain_and_user() { + assert_eq!( + split_sam_compatible_name("DOMAIN\\user"), + PsuIdentity { + user_name: "user".to_owned(), + domain_name: "DOMAIN".to_owned(), + } + ); + } +}