From 2080c3ff65fac9b1e980528dcdc0a9eb627903a0 Mon Sep 17 00:00:00 2001 From: Mathias Myrland Date: Tue, 9 Jun 2026 22:49:32 +0200 Subject: [PATCH 01/12] fix(security): re-apply WebSocket authz hardening - Add default-deny authorize_subscribe hook to MessageHandler; the default handle_subscribe now only subscribes topics the service explicitly authorizes (was: any attacker-supplied topic string). - Flip require_auth default to true in WebSocketServiceBuilder and the macro-generated service builder; anonymous connections now require an explicit opt-out. - Periodically re-validate connection credentials: the token is captured pre-upgrade and re-run through the AuthProvider every auth_revalidation_interval (default 30s, configurable). Failure closes the connection; success refreshes the cached user/permissions, so revoked or expired tokens are bounded to one interval on long-lived connections. - bidirectional-chat-server: depend on ras-rest-macro with default-features = false so the reqwest/client feature is not forced into the server build via proc-macro feature unification. Co-Authored-By: Claude Fable 5 --- .../src/server.rs | 5 +- .../Cargo.toml | 1 + .../src/handler.rs | 244 +++++++++++++++++- .../src/service.rs | 75 +++++- examples/bidirectional-chat/server/Cargo.toml | 2 +- 5 files changed, 313 insertions(+), 14 deletions(-) diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/src/server.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/src/server.rs index 631be47..b6ea1ea 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/src/server.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/src/server.rs @@ -436,11 +436,14 @@ pub fn generate_server_code( impl #builder_name { /// Create a new builder + /// + /// Authentication is required by default; call `.require_auth(false)` + /// to explicitly allow anonymous connections. pub fn new(service: T, auth_provider: A) -> Self { Self { service: std::sync::Arc::new(service), auth_provider: std::sync::Arc::new(auth_provider), - require_auth: false, + require_auth: true, } } diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/Cargo.toml b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/Cargo.toml index 92b568a..a6d454b 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/Cargo.toml +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/Cargo.toml @@ -36,3 +36,4 @@ dashmap = { workspace = true } [dev-dependencies] ras-jsonrpc-types = { path = "../../ras-jsonrpc-types", version = "0.1.1" } +tokio = { workspace = true, features = ["test-util"] } diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/handler.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/handler.rs index 13dd87a..c8a331e 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/handler.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/handler.rs @@ -4,9 +4,11 @@ use crate::{ConnectionContext, ServerError, ServerResult}; use async_trait::async_trait; use axum::extract::ws::{CloseFrame, Message, WebSocket}; use futures::stream::StreamExt; +use ras_auth_core::AuthProvider; use ras_jsonrpc_bidirectional_types::BidirectionalMessage; use ras_jsonrpc_types::{JsonRpcError, JsonRpcRequest, JsonRpcResponse, error_codes}; use std::sync::Arc; +use std::time::Duration; use tokio::sync::mpsc; use tracing::{debug, error, info, warn}; @@ -29,15 +31,37 @@ pub trait MessageHandler: Send + Sync + 'static { context: Arc, ) -> ServerResult>; + /// Decide whether this connection may subscribe to `topic`. + /// + /// Default-deny: services that broadcast over topics must override this + /// (or `handle_subscribe`) to allow the topics a connection is entitled + /// to. Errors propagate to the handler loop and close the connection. + async fn authorize_subscribe( + &self, + _topic: &str, + _context: &Arc, + ) -> ServerResult { + Ok(false) + } + /// Handle subscription requests async fn handle_subscribe( &self, topics: Vec, context: Arc, ) -> ServerResult<()> { - // Default implementation subscribes the connection to each requested topic. + // Default implementation subscribes the connection to each topic the + // service authorizes via `authorize_subscribe`; denied topics are + // skipped without closing the connection. for topic in topics { - context.subscribe(topic).await; + if self.authorize_subscribe(&topic, &context).await? { + context.subscribe(topic).await; + } else { + warn!( + "Denied subscription to topic '{}' for connection {}", + topic, context.id + ); + } } Ok(()) } @@ -154,6 +178,25 @@ impl WebSocketIo for AxumWebSocketIo { } } +/// Default interval between credential re-validations on long-lived connections. +pub const DEFAULT_AUTH_REVALIDATION_INTERVAL: Duration = Duration::from_secs(30); + +/// Periodic credential re-validation for a long-lived connection. +/// +/// The token is captured before the WebSocket upgrade and re-run through the +/// auth provider on every `interval` tick. Failure closes the connection; +/// success refreshes the cached user (so permission changes propagate). This +/// bounds the lifetime of revoked/expired credentials on an open socket to at +/// most one interval. +pub struct AuthRevalidation { + /// Provider used to re-run authentication + pub auth_provider: Arc, + /// Token captured at upgrade time + pub token: String, + /// How often to re-validate + pub interval: Duration, +} + /// WebSocket connection handler that manages the message flow pub struct WebSocketHandler { /// The message handler for processing requests @@ -163,6 +206,8 @@ pub struct WebSocketHandler { /// Channel for receiving messages to send to client message_rx: mpsc::Receiver, max_message_size: usize, + /// Optional periodic credential re-validation + auth_revalidation: Option, } impl WebSocketHandler { @@ -178,9 +223,16 @@ impl WebSocketHandler { context, message_rx, max_message_size, + auth_revalidation: None, } } + /// Enable periodic credential re-validation for this connection. + pub fn with_auth_revalidation(mut self, revalidation: AuthRevalidation) -> Self { + self.auth_revalidation = Some(revalidation); + self + } + /// Run the WebSocket handler loop pub async fn run(self, socket: WebSocket) -> ServerResult<()> { let mut socket = AxumWebSocketIo::new(socket); @@ -215,9 +267,51 @@ impl WebSocketHandler { error!("Failed to send connection established message: {}", e); } + let mut revalidation_timer = self.auth_revalidation.as_ref().map(|revalidation| { + let mut timer = tokio::time::interval_at( + tokio::time::Instant::now() + revalidation.interval, + revalidation.interval, + ); + timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + timer + }); + // Main message handling loop loop { tokio::select! { + // Re-validate credentials so revoked/expired tokens are + // bounded to at most one interval on a long-lived connection + _ = async { revalidation_timer.as_mut().expect("guarded by is_some").tick().await }, + if revalidation_timer.is_some() => + { + let revalidation = self + .auth_revalidation + .as_ref() + .expect("revalidation timer implies config"); + match revalidation + .auth_provider + .authenticate(revalidation.token.clone()) + .await + { + Ok(user) => { + // Refresh cached identity/permissions + self.context.set_user(user).await; + } + Err(e) => { + warn!( + "Closing connection {}: credential re-validation failed: {}", + self.context.id, e + ); + let _ = socket + .send(WebSocketIoMessage::Close(Some( + "credentials no longer valid".to_string(), + ))) + .await; + break; + } + } + } + // Handle incoming WebSocket messages msg = socket.recv() => { match msg { @@ -606,14 +700,53 @@ mod tests { } #[tokio::test] - async fn default_handle_subscribe_writes_to_context() { + async fn default_handle_subscribe_denies_all_topics() { let h = PassThrough; let c = ctx(); h.handle_subscribe(vec!["a".into(), "b".into()], c.clone()) .await .unwrap(); - assert!(c.is_subscribed_to("a").await); - assert!(c.is_subscribed_to("b").await); + assert!(!c.is_subscribed_to("a").await); + assert!(!c.is_subscribed_to("b").await); + } + + #[tokio::test] + async fn default_authorize_subscribe_denies() { + let h = PassThrough; + let c = ctx(); + assert!(!h.authorize_subscribe("any-topic", &c).await.unwrap()); + } + + struct AllowListHandler; + + #[async_trait] + impl MessageHandler for AllowListHandler { + async fn handle_request( + &self, + _request: JsonRpcRequest, + _context: Arc, + ) -> ServerResult> { + Ok(None) + } + + async fn authorize_subscribe( + &self, + topic: &str, + _context: &Arc, + ) -> ServerResult { + Ok(topic == "room:allowed") + } + } + + #[tokio::test] + async fn handle_subscribe_only_subscribes_authorized_topics() { + let h = AllowListHandler; + let c = ctx(); + h.handle_subscribe(vec!["room:allowed".into(), "room:denied".into()], c.clone()) + .await + .unwrap(); + assert!(c.is_subscribed_to("room:allowed").await); + assert!(!c.is_subscribed_to("room:denied").await); } #[tokio::test] @@ -890,6 +1023,107 @@ mod tests { ); } + fn auth_user(id: &str) -> ras_auth_core::AuthenticatedUser { + ras_auth_core::AuthenticatedUser { + user_id: id.to_string(), + permissions: std::collections::HashSet::new(), + metadata: None, + } + } + + /// Auth provider that replays a fixed sequence of results, then fails. + struct SequenceAuthProvider( + Mutex>>, + ); + + impl SequenceAuthProvider { + fn new( + results: impl IntoIterator< + Item = Result, + >, + ) -> Self { + Self(Mutex::new(results.into_iter().collect())) + } + } + + impl AuthProvider for SequenceAuthProvider { + fn authenticate(&self, _token: String) -> ras_auth_core::AuthFuture<'_> { + let result = self + .0 + .lock() + .expect("results lock") + .pop_front() + .unwrap_or(Err(ras_auth_core::AuthError::InvalidToken)); + Box::pin(async move { result }) + } + } + + #[tokio::test(start_paused = true)] + async fn revalidation_failure_closes_connection() { + let context = ctx(); + let (_tx, rx) = mpsc::channel(4); + let mut socket = InMemorySocket::pending(); + + WebSocketHandler::new(Arc::new(PassThrough), context, rx, 1024) + .with_auth_revalidation(AuthRevalidation { + auth_provider: Arc::new(SequenceAuthProvider::new([])), + token: "revoked-token".into(), + interval: Duration::from_secs(30), + }) + .run_with_io(&mut socket) + .await + .unwrap(); + + assert!(socket.outgoing.iter().any(|message| matches!( + message, + WebSocketIoMessage::Close(Some(reason)) if reason == "credentials no longer valid" + ))); + } + + #[tokio::test(start_paused = true)] + async fn revalidation_success_refreshes_cached_user() { + let context = ctx(); + context.set_user(auth_user("stale")).await; + let (_tx, rx) = mpsc::channel(4); + let mut socket = InMemorySocket::pending(); + + WebSocketHandler::new(Arc::new(PassThrough), context.clone(), rx, 1024) + .with_auth_revalidation(AuthRevalidation { + auth_provider: Arc::new(SequenceAuthProvider::new([Ok(auth_user("fresh"))])), + token: "valid-token".into(), + interval: Duration::from_secs(30), + }) + .run_with_io(&mut socket) + .await + .unwrap(); + + // First tick refreshed the cached user; the second (sequence + // exhausted) failed and closed the connection. + assert_eq!(context.get_user().await.expect("user").user_id, "fresh"); + assert!( + socket + .outgoing + .iter() + .any(|message| matches!(message, WebSocketIoMessage::Close(_))) + ); + } + + #[tokio::test] + async fn handler_without_revalidation_does_not_authenticate() { + // No auth provider involved at all: the loop must terminate on + // socket close without ticking a revalidation timer. + let mut socket = InMemorySocket::closing([]); + let (_tx, rx) = mpsc::channel(4); + + WebSocketHandler::new(Arc::new(PassThrough), ctx(), rx, 1024) + .run_with_io(&mut socket) + .await + .unwrap(); + + let messages = bidirectional_outgoing(&socket); + assert_eq!(messages.len(), 2); + } + fn bidirectional_outgoing(socket: &InMemorySocket) -> Vec { socket .outgoing diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/service.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/service.rs index a614b03..dec3e17 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/service.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/service.rs @@ -4,7 +4,7 @@ use crate::{ ConnectionContext, DefaultConnectionManager, MessageHandler, MessageRouter, ServerError, ServerResult, WebSocketHandler, WebSocketUpgrade, connection::ChannelMessageSender, - handler::{AxumWebSocketIo, WebSocketIo}, + handler::{AuthRevalidation, AxumWebSocketIo, DEFAULT_AUTH_REVALIDATION_INTERVAL, WebSocketIo}, }; use axum::{ extract::{State, ws::WebSocketUpgrade as AxumWebSocketUpgrade}, @@ -15,6 +15,7 @@ use bon::Builder; use ras_auth_core::AuthProvider; use ras_jsonrpc_bidirectional_types::{ConnectionId, ConnectionInfo, ConnectionManager}; use std::sync::Arc; +use std::time::Duration; use tokio::sync::mpsc; use tracing::{error, info}; @@ -53,6 +54,14 @@ pub trait WebSocketService: Clone + Send + Sync + 'static { DEFAULT_MAX_MESSAGE_SIZE } + /// How often to re-run authentication for a live connection. + /// + /// Bounds the lifetime of revoked/expired credentials on a long-lived + /// WebSocket to at most one interval. + fn auth_revalidation_interval(&self) -> Duration { + DEFAULT_AUTH_REVALIDATION_INTERVAL + } + /// Handle WebSocket upgrade async fn handle_upgrade( &self, @@ -60,6 +69,8 @@ pub trait WebSocketService: Clone + Send + Sync + 'static { headers: HeaderMap, ) -> Result { let ws_upgrade = WebSocketUpgrade::new(upgrade, headers); + // Captured pre-upgrade so the connection can periodically re-validate it. + let auth_token = ws_upgrade.extract_auth_token(); let service = self.clone(); ws_upgrade @@ -68,7 +79,7 @@ pub trait WebSocketService: Clone + Send + Sync + 'static { self.require_auth(), move |socket, user| { Box::pin(async move { - if let Err(e) = service.handle_connection(socket, user).await { + if let Err(e) = service.handle_connection(socket, user, auth_token).await { error!("WebSocket connection error: {}", e); } }) @@ -82,11 +93,12 @@ pub trait WebSocketService: Clone + Send + Sync + 'static { &self, socket: axum::extract::ws::WebSocket, user: Option, + auth_token: Option, ) -> impl std::future::Future> + Send { let service = self.clone(); async move { let mut socket = AxumWebSocketIo::new(socket); - run_connection_with_io(service, &mut socket, user).await + run_connection_with_io(service, &mut socket, user, auth_token).await } } @@ -103,7 +115,7 @@ pub trait WebSocketService: Clone + Send + Sync + 'static { S: WebSocketIo + ?Sized + 'a, { let service = self.clone(); - async move { run_connection_with_io(service, socket, user).await } + async move { run_connection_with_io(service, socket, user, None).await } } } @@ -111,6 +123,7 @@ async fn run_connection_with_io( service: Svc, socket: &mut S, user: Option, + auth_token: Option, ) -> ServerResult<()> where Svc: WebSocketService, @@ -139,13 +152,23 @@ where .await .map_err(ServerError::ConnectionError)?; - let handler = WebSocketHandler::new( + let mut handler = WebSocketHandler::new( service.handler(), context.clone(), message_rx, service.max_message_size(), ); + // Authenticated connections re-validate their token periodically so + // revocation/expiry takes effect without waiting for a disconnect. + if let Some(token) = auth_token { + handler = handler.with_auth_revalidation(AuthRevalidation { + auth_provider: service.auth_provider(), + token, + interval: service.auth_revalidation_interval(), + }); + } + let result = handler.run_with_io(socket).await; if let Err(e) = service @@ -168,8 +191,8 @@ pub struct WebSocketServiceBuilder { auth_provider: Arc, /// Connection manager connection_manager: Option>, - /// Whether authentication is required - #[builder(default = false)] + /// Whether authentication is required (secure default: required) + #[builder(default = true)] require_auth: bool, /// Maximum queued outbound messages per connection #[builder(default = DEFAULT_MESSAGE_CHANNEL_CAPACITY)] @@ -177,6 +200,9 @@ pub struct WebSocketServiceBuilder { /// Maximum accepted inbound WebSocket message size in bytes #[builder(default = DEFAULT_MAX_MESSAGE_SIZE)] max_message_size: usize, + /// Interval between credential re-validations for live connections + #[builder(default = DEFAULT_AUTH_REVALIDATION_INTERVAL)] + auth_revalidation_interval: Duration, } impl WebSocketServiceBuilder @@ -195,6 +221,7 @@ where require_auth: self.require_auth, message_channel_capacity: self.message_channel_capacity, max_message_size: self.max_message_size, + auth_revalidation_interval: self.auth_revalidation_interval, } } } @@ -214,6 +241,7 @@ where require_auth: self.require_auth, message_channel_capacity: self.message_channel_capacity, max_message_size: self.max_message_size, + auth_revalidation_interval: self.auth_revalidation_interval, } } } @@ -226,6 +254,7 @@ pub struct BuiltWebSocketService { require_auth: bool, message_channel_capacity: usize, max_message_size: usize, + auth_revalidation_interval: Duration, } impl Clone for BuiltWebSocketService { @@ -237,6 +266,7 @@ impl Clone for BuiltWebSocketService { require_auth: self.require_auth, message_channel_capacity: self.message_channel_capacity, max_message_size: self.max_message_size, + auth_revalidation_interval: self.auth_revalidation_interval, } } } @@ -274,6 +304,10 @@ where fn max_message_size(&self) -> usize { self.max_message_size } + + fn auth_revalidation_interval(&self) -> Duration { + self.auth_revalidation_interval + } } /// Convenience function to create a simple router-based service @@ -387,6 +421,33 @@ mod tests { assert_eq!(service.connection_manager().connection_count(), 0); } + #[tokio::test] + async fn builder_requires_auth_by_default() { + let builder = WebSocketServiceBuilder::builder() + .handler(Arc::new(MessageRouter::new())) + .auth_provider(Arc::new(MockAuthProvider)) + .build(); + let service = builder.build(); + + assert!(service.require_auth()); + assert_eq!( + service.auth_revalidation_interval(), + Duration::from_secs(30) + ); + } + + #[tokio::test] + async fn builder_revalidation_interval_is_configurable() { + let builder = WebSocketServiceBuilder::builder() + .handler(Arc::new(MessageRouter::new())) + .auth_provider(Arc::new(MockAuthProvider)) + .auth_revalidation_interval(Duration::from_secs(5)) + .build(); + let service = builder.build(); + + assert_eq!(service.auth_revalidation_interval(), Duration::from_secs(5)); + } + #[tokio::test] async fn test_service_with_auth_required() { let router = MessageRouter::new(); diff --git a/examples/bidirectional-chat/server/Cargo.toml b/examples/bidirectional-chat/server/Cargo.toml index a13df33..58b7217 100644 --- a/examples/bidirectional-chat/server/Cargo.toml +++ b/examples/bidirectional-chat/server/Cargo.toml @@ -17,7 +17,7 @@ ras-auth-core = { path = "../../../crates/core/ras-auth-core", version = "0.1.0" ras-jsonrpc-types = { path = "../../../crates/rpc/ras-jsonrpc-types", version = "0.1.1" } ras-jsonrpc-bidirectional-server = { path = "../../../crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server", version = "0.1.0" } ras-jsonrpc-bidirectional-types = { path = "../../../crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types", version = "0.1.0" } -ras-rest-macro = { path = "../../../crates/rest/ras-rest-macro", version = "0.2.1", features = ["server"] } +ras-rest-macro = { path = "../../../crates/rest/ras-rest-macro", version = "0.2.1", default-features = false, features = ["server"] } ras-rest-core = { path = "../../../crates/rest/ras-rest-core", version = "0.1.1" } ras-identity-core = { path = "../../../crates/core/ras-identity-core", version = "0.1.1" } ras-identity-local = { path = "../../../crates/identity/ras-identity-local", version = "0.2.0" } From f8ebb5cab239b34d87635562c6737c089223eed6 Mon Sep 17 00:00:00 2001 From: Mathias Myrland Date: Tue, 9 Jun 2026 22:53:54 +0200 Subject: [PATCH 02/12] fix(client): remove connect() busy-wait and pending-request leak - connect() no longer spins in a hot loop waiting for the server's ConnectionEstablished message. It now parks on a Notify signaled by the message handler, bounded by connection_timeout, and tears down the half-open connection on timeout instead of hanging forever. - The client only reports Connected (and only starts the heartbeat) after the handshake completes; previously concurrent calls could pass the state check with no connection id, and the heartbeat task could race the state write and exit immediately. - call() removes its pending-request entry on send failure, response timeout, and channel closure. Previously timed-out entries leaked until the map hit max_pending_requests, after which every call failed permanently with 'Too many pending requests'. Co-Authored-By: Claude Fable 5 --- .../Cargo.toml | 1 + .../src/client.rs | 161 ++++++++++++++++-- 2 files changed, 144 insertions(+), 18 deletions(-) diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/Cargo.toml b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/Cargo.toml index 62e1a68..61b0d39 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/Cargo.toml +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/Cargo.toml @@ -64,3 +64,4 @@ wasm = ["tokio", "web-sys", "wasm-bindgen", "wasm-bindgen-futures", "js-sys"] [dev-dependencies] tracing-subscriber = { workspace = true } chrono = { workspace = true } +tokio = { workspace = true, features = ["test-util"] } diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/client.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/client.rs index d1e2e03..ff0864b 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/client.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/client.rs @@ -42,6 +42,8 @@ pub struct Client { request_id_counter: Arc, shutdown_tx: Arc>>>, message_tx: Arc>>>, + /// Signaled when the server's ConnectionEstablished message arrives + connected_notify: Arc, } struct IncomingMessageContext<'a> { @@ -52,6 +54,7 @@ struct IncomingMessageContext<'a> { connection_event_handlers: &'a DashMap, connection_id: &'a RwLock>, message_tx: &'a RwLock>>, + connected_notify: &'a tokio::sync::Notify, } impl Client { @@ -80,6 +83,7 @@ impl Client { request_id_counter: Arc::new(AtomicU64::new(1)), shutdown_tx: Arc::new(RwLock::new(None)), message_tx: Arc::new(RwLock::new(None)), + connected_notify: Arc::new(tokio::sync::Notify::new()), }) } @@ -110,20 +114,38 @@ impl Client { // Start message handling task self.start_message_handler(message_rx, shutdown_rx).await?; - // Start heartbeat if configured - if let Some(interval) = self.config.heartbeat_interval { - self.start_heartbeat(interval).await; + // Wait for the server's ConnectionEstablished message before + // reporting the client as connected. The notify is signaled by the + // message handler; bound the wait so a silent server cannot hang us. + let handshake = async { + loop { + if self.connection_id.read().await.is_some() { + break; + } + self.connected_notify.notified().await; + } + }; + if tokio::time::timeout(self.config.connection_timeout, handshake) + .await + .is_err() + { + // Tear down the half-open connection + let _ = self.disconnect().await; + return Err(ClientError::timeout( + self.config.connection_timeout.as_secs(), + )); } *self.state.write().await = ClientState::Connected; - info!("Client connected to {}", self.config.url); - loop { - if self.connection_id.read().await.is_some() { - break; - } + // Start heartbeat once connected (its loop exits when state leaves + // Connected, so starting earlier would race it to an immediate stop) + if let Some(interval) = self.config.heartbeat_interval { + self.start_heartbeat(interval).await; } + info!("Client connected to {}", self.config.url); + Ok(()) } @@ -203,19 +225,30 @@ impl Client { return Err(ClientError::internal("Too many pending requests")); } - self.pending_requests.insert(request_id, pending); + self.pending_requests.insert(request_id.clone(), pending); - // Send the request + // Send the request; on failure, drop our pending entry so the map + // cannot fill up with waiters that will never be answered. let message = BidirectionalMessage::Request(request); - self.send_message(message).await?; - - // Wait for response with timeout - let response = tokio::time::timeout(self.config.request_timeout, response_rx) - .await - .map_err(|_| ClientError::timeout(self.config.request_timeout.as_secs()))? - .map_err(|_| ClientError::internal("Response channel closed"))?; + if let Err(e) = self.send_message(message).await { + self.pending_requests.remove(&request_id); + return Err(e); + } - Ok(response) + // Wait for response with timeout; every failure path removes our + // entry for the same reason (the success path is removed by the + // message handler when the response arrives). + match tokio::time::timeout(self.config.request_timeout, response_rx).await { + Ok(Ok(response)) => Ok(response), + Ok(Err(_)) => { + self.pending_requests.remove(&request_id); + Err(ClientError::internal("Response channel closed")) + } + Err(_) => { + self.pending_requests.remove(&request_id); + Err(ClientError::timeout(self.config.request_timeout.as_secs())) + } + } } /// Send a notification (fire-and-forget) @@ -358,6 +391,7 @@ impl Client { let connection_id = Arc::clone(&self.connection_id); let state = Arc::clone(&self.state); let message_tx_clone = Arc::clone(&self.message_tx); + let connected_notify = Arc::clone(&self.connected_notify); spawn_background(async move { let mut receive_interval = tokio::time::interval(Duration::from_millis(10)); @@ -397,6 +431,7 @@ impl Client { connection_event_handlers: &connection_event_handlers, connection_id: &connection_id, message_tx: &message_tx_clone, + connected_notify: &connected_notify, }; Self::handle_incoming_message( message, @@ -450,6 +485,10 @@ impl Client { connection_id: conn_id, } => { *context.connection_id.write().await = Some(conn_id); + // Wake a connect() call waiting on the handshake. notify_one + // stores a permit, so this works even if connect() has not + // started waiting yet. + context.connected_notify.notify_one(); Self::emit_connection_event_static( ConnectionEvent::Connected { connection_id: conn_id, @@ -744,6 +783,7 @@ mod tests { connection_event_handlers: DashMap, connection_id: RwLock>, message_tx: RwLock>>, + connected_notify: tokio::sync::Notify, } impl IncomingHarness { @@ -756,6 +796,7 @@ mod tests { connection_event_handlers: DashMap::new(), connection_id: RwLock::new(None), message_tx: RwLock::new(None), + connected_notify: tokio::sync::Notify::new(), } } @@ -768,6 +809,7 @@ mod tests { connection_event_handlers: &self.connection_event_handlers, connection_id: &self.connection_id, message_tx: &self.message_tx, + connected_notify: &self.connected_notify, } } } @@ -1337,4 +1379,87 @@ mod tests { assert!(rx.try_recv().is_err()); } + + #[tokio::test] + async fn connection_established_wakes_handshake_waiter() { + let harness = std::sync::Arc::new(IncomingHarness::new()); + + // Mirrors the wait loop in connect(): park on the notify until the + // connection id is set. Without notify_one in the message handler + // this would hang and the timeout below would fail the test. + let waiter = { + let harness = std::sync::Arc::clone(&harness); + tokio::spawn(async move { + loop { + if harness.connection_id.read().await.is_some() { + break; + } + harness.connected_notify.notified().await; + } + }) + }; + + tokio::task::yield_now().await; + + Client::handle_incoming_message( + BidirectionalMessage::ConnectionEstablished { + connection_id: ConnectionId::new(), + }, + harness.context(), + ) + .await; + + tokio::time::timeout(Duration::from_secs(5), waiter) + .await + .expect("handshake waiter woke up") + .expect("waiter task completed"); + } + + #[tokio::test(start_paused = true)] + async fn call_timeout_removes_pending_entry_and_allows_retry() { + let client = ClientBuilder::new("ws://localhost:8080") + .with_request_timeout(Duration::from_millis(20)) + .build() + .await + .expect("build"); + *client.state.write().await = ClientState::Connected; + + let (tx, mut rx) = mpsc::channel(8); + *client.message_tx.write().await = Some(tx); + + let err = client.call("svc.slow", None).await.unwrap_err(); + assert!(matches!(err, ClientError::Timeout { .. })); + assert!( + client.pending_requests.is_empty(), + "timed-out call must remove its pending entry" + ); + let _ = rx.recv().await; + + // The map must not fill up with dead waiters: a retry times out + // again rather than failing with "Too many pending requests". + let err = client.call("svc.slow", None).await.unwrap_err(); + assert!(matches!(err, ClientError::Timeout { .. })); + assert!(client.pending_requests.is_empty()); + } + + #[tokio::test] + async fn call_send_failure_removes_pending_entry() { + let client = ClientBuilder::new("ws://localhost:8080") + .build() + .await + .expect("build"); + *client.state.write().await = ClientState::Connected; + + // Install a sender whose receiver is already gone so send fails. + let (tx, rx) = mpsc::channel(1); + drop(rx); + *client.message_tx.write().await = Some(tx); + + let err = client.call("svc.echo", None).await.unwrap_err(); + assert!(!matches!(err, ClientError::Timeout { .. })); + assert!( + client.pending_requests.is_empty(), + "failed send must remove its pending entry" + ); + } } From 3fdb921db503e6dfd472061ad2a87f663af01ed0 Mon Sep 17 00:00:00 2001 From: Mathias Myrland Date: Tue, 9 Jun 2026 22:57:00 +0200 Subject: [PATCH 03/12] fix(manager): no DashMap guards across awaits, concurrent fan-out, subscription races - send_to_connection and all broadcast paths clone senders out of the map before awaiting. Previously a slow consumer with a full bounded channel parked the send while holding the shard guard, blocking (or deadlocking) every other access to that shard, including the consumer's own remove_connection. - Broadcasts now fan out concurrently via join_all; wall-clock is bounded by the slowest recipient instead of the sum of all sends. - add_subscription only indexes live connections, dedups, and undoes its insert if the connection vanished mid-call, so a subscribe racing remove_connection cannot leave zombie ids in the topic index. - remove_subscription / remove_pending_request use remove_if for the empty-entry cleanup so a concurrent insert between retain and remove is not thrown away. Co-Authored-By: Claude Fable 5 --- .../src/manager.rs | 166 +++++++++++------- .../tests/manager_unit.rs | 62 +++++++ 2 files changed, 161 insertions(+), 67 deletions(-) diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/manager.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/manager.rs index 2e74c55..78ebeda 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/manager.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/manager.rs @@ -79,6 +79,36 @@ impl DefaultConnectionManager { pub fn get_sender(&self, id: ConnectionId) -> Option { self.connections.get(&id).map(|entry| entry.1.clone()) } + + /// Send `message` to every recipient concurrently. + /// + /// Senders are cloned out of the map before any send so no DashMap shard + /// lock is held across an await (a slow consumer with a full channel + /// would otherwise block every other map access on that shard), and the + /// slowest recipient bounds wall-clock time instead of the sum of all. + async fn fan_out( + &self, + recipients: Vec<(ConnectionId, ChannelMessageSender)>, + message: BidirectionalMessage, + ) -> (usize, Vec) { + let sends = recipients.into_iter().map(|(id, sender)| { + let message = message.clone(); + async move { (id, sender.send(message).await) } + }); + + let mut sent_count = 0; + let mut failed = Vec::new(); + for (id, result) in futures::future::join_all(sends).await { + match result { + Ok(()) => sent_count += 1, + Err(e) => { + warn!("Failed to broadcast to connection {}: {}", id, e); + failed.push(id); + } + } + } + (sent_count, failed) + } } #[async_trait] @@ -112,15 +142,13 @@ impl ConnectionManager for DefaultConnectionManager { async fn remove_connection(&self, id: ConnectionId) -> Result<()> { if let Some((_, (info, _))) = self.connections.remove(&id) { - // Remove from all topic subscriptions + // Remove from all topic subscriptions. remove_if keeps the + // empty-entry cleanup atomic against concurrent subscribes. for topic in info.subscriptions.iter() { if let Some(mut entry) = self.subscriptions.get_mut(topic) { entry.retain(|&connection_id| connection_id != id); - if entry.is_empty() { - drop(entry); - self.subscriptions.remove(topic); - } } + self.subscriptions.remove_if(topic, |_, ids| ids.is_empty()); } // Clean up pending requests for this connection @@ -183,17 +211,37 @@ impl ConnectionManager for DefaultConnectionManager { } async fn add_subscription(&self, id: ConnectionId, topic: String) -> Result<()> { - // Update topic subscriptions - self.subscriptions - .entry(topic.clone()) - .or_default() - .push(id); - - // Update connection subscriptions - if let Some(mut entry) = self.connections.get_mut(&id) { + // Only live connections may enter the topic index, otherwise a + // subscribe racing remove_connection leaves dangling ids behind. + { + let Some(mut entry) = self.connections.get_mut(&id) else { + warn!( + "Attempted to subscribe non-existent connection {} to topic {}", + id, topic + ); + return Ok(()); + }; entry.0.subscribe(topic.clone()); } + { + let mut entry = self.subscriptions.entry(topic.clone()).or_default(); + if !entry.contains(&id) { + entry.push(id); + } + } + + // The connection may have been removed between the liveness check and + // the index insert; undo so no zombie entry survives the race. + if !self.connections.contains_key(&id) { + if let Some(mut entry) = self.subscriptions.get_mut(&topic) { + entry.retain(|&connection_id| connection_id != id); + } + self.subscriptions + .remove_if(&topic, |_, ids| ids.is_empty()); + return Ok(()); + } + debug!("Connection {} subscribed to topic {}", id, topic); Ok(()) } @@ -202,11 +250,11 @@ impl ConnectionManager for DefaultConnectionManager { // Update topic subscriptions if let Some(mut entry) = self.subscriptions.get_mut(topic) { entry.retain(|&connection_id| connection_id != id); - if entry.is_empty() { - drop(entry); - self.subscriptions.remove(topic); - } } + // Drop the topic entry only if it is still empty at removal time, so + // a concurrent subscribe between the retain above and this call is + // not thrown away. + self.subscriptions.remove_if(topic, |_, ids| ids.is_empty()); // Update connection subscriptions if let Some(mut entry) = self.connections.get_mut(&id) { @@ -230,9 +278,11 @@ impl ConnectionManager for DefaultConnectionManager { id: ConnectionId, message: BidirectionalMessage, ) -> Result<()> { - if let Some(entry) = self.connections.get(&id) { - entry - .1 + // Clone the sender out of the map: awaiting a send on a full channel + // while holding the shard guard would block other map accesses. + let sender = self.connections.get(&id).map(|entry| entry.1.clone()); + if let Some(sender) = sender { + sender .send(message) .await .map_err(ras_jsonrpc_bidirectional_types::BidirectionalError::SendError)?; @@ -255,26 +305,21 @@ impl ConnectionManager for DefaultConnectionManager { } let mut failed_connections = Vec::new(); - let mut sent_count = 0; - - for connection_id in &topic_connections { - if let Some(entry) = self.connections.get(connection_id) { - if let Err(e) = entry.1.send(message.clone()).await { - warn!("Failed to broadcast to connection {}: {}", connection_id, e); - failed_connections.push(*connection_id); - } else { - sent_count += 1; - } + let mut recipients = Vec::with_capacity(topic_connections.len()); + for connection_id in topic_connections { + if let Some(entry) = self.connections.get(&connection_id) { + recipients.push((connection_id, entry.1.clone())); } else { - failed_connections.push(*connection_id); + failed_connections.push(connection_id); } } + let (sent_count, send_failures) = self.fan_out(recipients, message).await; + failed_connections.extend(send_failures); + // Clean up failed connections from topic subscriptions - if !failed_connections.is_empty() { - for connection_id in failed_connections { - let _ = self.remove_subscription(connection_id, topic).await; - } + for connection_id in failed_connections { + let _ = self.remove_subscription(connection_id, topic).await; } debug!( @@ -285,21 +330,14 @@ impl ConnectionManager for DefaultConnectionManager { } async fn broadcast_to_authenticated(&self, message: BidirectionalMessage) -> Result { - let mut sent_count = 0; + let recipients: Vec<_> = self + .connections + .iter() + .filter(|entry| entry.value().0.is_authenticated()) + .map(|entry| (*entry.key(), entry.value().1.clone())) + .collect(); - for entry in self.connections.iter() { - let (info, sender) = entry.value(); - if info.is_authenticated() { - if let Err(e) = sender.send(message.clone()).await { - warn!( - "Failed to broadcast to authenticated connection {}: {}", - info.id, e - ); - } else { - sent_count += 1; - } - } - } + let (sent_count, _) = self.fan_out(recipients, message).await; debug!("Broadcasted to {} authenticated connections", sent_count); Ok(sent_count) @@ -310,21 +348,14 @@ impl ConnectionManager for DefaultConnectionManager { permission: &str, message: BidirectionalMessage, ) -> Result { - let mut sent_count = 0; + let recipients: Vec<_> = self + .connections + .iter() + .filter(|entry| entry.value().0.has_permission(permission)) + .map(|entry| (*entry.key(), entry.value().1.clone())) + .collect(); - for entry in self.connections.iter() { - let (info, sender) = entry.value(); - if info.has_permission(permission) { - if let Err(e) = sender.send(message.clone()).await { - warn!( - "Failed to broadcast to connection {} with permission {}: {}", - info.id, permission, e - ); - } else { - sent_count += 1; - } - } - } + let (sent_count, _) = self.fan_out(recipients, message).await; debug!( "Broadcasted to {} connections with permission: {}", @@ -358,10 +389,11 @@ impl ConnectionManager for DefaultConnectionManager { ) -> Result>> { if let Some(mut entry) = self.pending_requests.get_mut(&connection_id) { let sender = entry.remove(request_id); - if entry.is_empty() { - drop(entry); - self.pending_requests.remove(&connection_id); - } + drop(entry); + // Conditional removal so a concurrent register between the drop + // above and this call is not thrown away. + self.pending_requests + .remove_if(&connection_id, |_, requests| requests.is_empty()); debug!("Removed pending request for connection: {}", connection_id); Ok(sender) } else { diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/tests/manager_unit.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/tests/manager_unit.rs index 9debe66..5f411a2 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/tests/manager_unit.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/tests/manager_unit.rs @@ -259,3 +259,65 @@ async fn send_to_missing_connection_is_silent_ok() { async fn default_impl_is_equivalent_to_new() { let _ = Arc::new(DefaultConnectionManager::default()); } + +#[tokio::test] +async fn broadcast_to_full_channel_does_not_block_map_access() { + // A slow consumer with a full bounded channel must not wedge the manager: + // if shard guards were held across the send await (the old behavior), + // the concurrent remove_connection below would deadlock and trip the + // timeout. + tokio::time::timeout(std::time::Duration::from_secs(5), async { + let mgr = Arc::new(DefaultConnectionManager::new()); + + // Connection with a capacity-1 channel, pre-filled so sends park. + let id = ConnectionId::new(); + let (tx, mut rx) = mpsc::channel(1); + tx.send(BidirectionalMessage::Ping).await.unwrap(); + let sender = ChannelMessageSender::new(id, tx); + mgr.add_connection_with_sender_direct(ConnectionInfo::new(id), sender) + .await + .unwrap(); + mgr.add_subscription(id, "topic".into()).await.unwrap(); + + let broadcast = { + let mgr = Arc::clone(&mgr); + tokio::spawn(async move { + mgr.broadcast_to_topic("topic", BidirectionalMessage::Pong) + .await + }) + }; + + // Give the broadcast a chance to park on the full channel. + tokio::task::yield_now().await; + + // Must complete while the broadcast is still parked. + mgr.remove_connection(id).await.unwrap(); + assert_eq!(mgr.connection_count(), 0); + + // Drain one message so the parked send resolves and the broadcast + // can finish. + let _ = rx.recv().await; + let sent = broadcast.await.unwrap().unwrap(); + assert_eq!(sent, 1); + }) + .await + .expect("manager deadlocked under a slow consumer"); +} + +#[tokio::test] +async fn subscribe_missing_connection_leaves_no_zombie_topic_entry() { + let mgr = DefaultConnectionManager::new(); + let ghost = ConnectionId::new(); + mgr.add_subscription(ghost, "t".into()).await.unwrap(); + assert!(mgr.get_topic_connections("t").is_empty()); + assert!(mgr.get_active_topics().is_empty()); +} + +#[tokio::test] +async fn duplicate_subscription_is_tracked_once() { + let mgr = DefaultConnectionManager::new(); + let (a, _ra) = join(&mgr).await; + mgr.add_subscription(a, "t".into()).await.unwrap(); + mgr.add_subscription(a, "t".into()).await.unwrap(); + assert_eq!(mgr.get_topic_connections("t"), vec![a]); +} From 5b37b392c2780410ed1c5fc0d90fe629ad3c39a9 Mon Sep 17 00:00:00 2001 From: Mathias Myrland Date: Tue, 9 Jun 2026 22:59:56 +0200 Subject: [PATCH 04/12] fix(server): answer unparseable WS frames with -32700 instead of dropping the connection A text frame that parses as neither a BidirectionalMessage nor a JSON-RPC request previously bubbled an InvalidRequest error into the handler loop, which closed the whole connection without any protocol-level response. Per JSON-RPC 2.0 the server now replies with a Parse Error (-32700, id null) and keeps serving the connection; only transport failures terminate the loop. Co-Authored-By: Claude Fable 5 --- .../src/handler.rs | 51 +++++++++++++++---- 1 file changed, 41 insertions(+), 10 deletions(-) diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/handler.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/handler.rs index c8a331e..1733c9e 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/handler.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/handler.rs @@ -442,10 +442,16 @@ impl WebSocketHandler { return self.handle_jsonrpc_request(request, socket).await; } - // If neither worked, return error - Err(ServerError::InvalidRequest( - "Could not parse message as JSON-RPC or bidirectional message".to_string(), - )) + // Neither shape parsed. Per JSON-RPC 2.0, answer with a Parse Error + // (-32700, id null) and keep the connection open; only transport + // failures terminate the handler loop. + warn!( + "Could not parse message as JSON-RPC or bidirectional message on connection {}", + self.context.id + ); + let response = JsonRpcResponse::error(JsonRpcError::parse_error(), None); + self.send_message(socket, BidirectionalMessage::Response(response)) + .await } /// Handle bidirectional messages @@ -938,24 +944,49 @@ mod tests { } #[tokio::test] - async fn handler_loop_closes_malformed_text_without_response() { - let mut socket = - InMemorySocket::closing([WebSocketIoMessage::Text("not json-rpc".to_string())]); + async fn handler_loop_answers_malformed_text_with_parse_error_and_continues() { + let request = JsonRpcRequest::new( + "echo".into(), + Some(serde_json::json!({})), + Some(serde_json::json!(9)), + ); + let mut socket = InMemorySocket::closing([ + WebSocketIoMessage::Text("not json-rpc".to_string()), + WebSocketIoMessage::Text( + serde_json::to_string(&BidirectionalMessage::Request(request)).unwrap(), + ), + ]); let (_tx, rx) = mpsc::channel(4); - WebSocketHandler::new(Arc::new(PassThrough), ctx(), rx, 1024) + WebSocketHandler::new(Arc::new(RespondingHandler), ctx(), rx, 1024) .run_with_io(&mut socket) .await .unwrap(); let messages = bidirectional_outgoing(&socket); - assert_eq!(messages.len(), 2); assert!(matches!( messages[0], BidirectionalMessage::ConnectionEstablished { .. } )); + + // The garbage frame is answered with -32700 (id null)... + let parse_error = match &messages[1] { + BidirectionalMessage::Response(response) => response, + other => panic!("expected parse error response, got {other:?}"), + }; + assert_eq!(parse_error.id, None); + let error = parse_error.error.as_ref().expect("parse error"); + assert_eq!(error.code, ras_jsonrpc_types::error_codes::PARSE_ERROR); + + // ...and the connection keeps serving subsequent requests. + let response = match &messages[2] { + BidirectionalMessage::Response(response) => response, + other => panic!("expected response, got {other:?}"), + }; + assert_eq!(response.id, Some(serde_json::json!(9))); + assert!(matches!( - messages[1], + messages[3], BidirectionalMessage::ConnectionClosed { .. } )); } From 52d6df68a12b004223912c40aa21edf0632c152c Mon Sep 17 00:00:00 2001 From: Mathias Myrland Date: Tue, 9 Jun 2026 23:06:09 +0200 Subject: [PATCH 05/12] fix(rest-macro): parse request bodies only after auth, add body_limit option MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Generated REST handlers previously used an axum Json extractor, so the full body was read and deserialized before any auth/CSRF/permission check ran — free pre-auth CPU and allocation for unauthenticated clients. Handlers now take the raw request and read + deserialize the body inside the handler, after authorization succeeds (matching what the file-service macro already did). Unauthorized endpoints keep parsing up front. rest_service! also gains an optional body_limit field (bytes, default 2 MiB to match axum); oversized bodies are answered with 413. Regression tests: malformed JSON without credentials now yields 401 (not 400), and the body_limit option is enforced end-to-end. Co-Authored-By: Claude Fable 5 --- crates/rest/ras-rest-macro/src/lib.rs | 120 ++++++++++-------- .../ras-rest-macro/tests/http_integration.rs | 67 ++++++++++ 2 files changed, 136 insertions(+), 51 deletions(-) diff --git a/crates/rest/ras-rest-macro/src/lib.rs b/crates/rest/ras-rest-macro/src/lib.rs index 6d33a01..b7d547a 100644 --- a/crates/rest/ras-rest-macro/src/lib.rs +++ b/crates/rest/ras-rest-macro/src/lib.rs @@ -78,9 +78,13 @@ struct ServiceDefinition { base_path: String, openapi: Option, static_hosting: static_hosting::StaticHostingConfig, + body_limit: Option, endpoints: Vec, } +/// Default maximum JSON body size in bytes (matches axum's default). +const DEFAULT_BODY_LIMIT: usize = 2 * 1024 * 1024; + #[derive(Debug)] enum OpenApiConfig { Enabled, @@ -244,9 +248,10 @@ impl Parse for ServiceDefinition { let base_path = base_path_lit.value(); let _ = content.parse::()?; - // Parse optional fields (openapi, serve_docs, docs_path, ui_theme) + // Parse optional fields (openapi, serve_docs, docs_path, ui_theme, body_limit) let mut openapi = None; let mut static_hosting = static_hosting::StaticHostingConfig::default(); + let mut body_limit = None; // Parse optional fields while content.peek(Ident) { @@ -292,6 +297,12 @@ impl Parse for ServiceDefinition { let theme = content.parse::()?; static_hosting.ui_theme = theme.value(); let _ = content.parse::()?; + } else if field_name == "body_limit" { + let _ = content.parse::()?; // "body_limit" + let _ = content.parse::()?; + let limit = content.parse::()?; + body_limit = Some(limit.base10_parse::()?); + let _ = content.parse::()?; } else if field_name == "endpoints" { break; // Start parsing endpoints } else { @@ -325,6 +336,7 @@ impl Parse for ServiceDefinition { base_path, openapi, static_hosting, + body_limit, endpoints, }) } @@ -766,11 +778,17 @@ fn generate_service_code(service_def: ServiceDefinition) -> syn::Result json.0, - Err(_) => { - use axum::response::IntoResponse; - return ( - axum::http::StatusCode::BAD_REQUEST, - axum::Json(serde_json::json!({ - "error": "Invalid JSON" - })) - ).into_response(); - }, - }; - } + generate_body_extraction() } else { quote! {} }; @@ -1365,8 +1370,6 @@ fn generate_legacy_handler_body( canonical_args.insert(0, quote! { &user }); quote! { - #json_handling - let auth_credential = match ras_auth_core::extract_auth_credential(&headers, &auth_transport) { Ok(credential) => credential, Err(_) => { @@ -1442,6 +1445,9 @@ fn generate_legacy_handler_body( } } + // Read and parse the body only after auth has succeeded + #json_handling + if let Some(tracker) = &with_usage_tracker { let tracker_headers = ras_auth_core::redact_sensitive_headers_for_auth_transport(&headers, &auth_transport); @@ -1546,9 +1552,11 @@ fn generate_axum_handler( }); } - // Add request body extractor if present - use Result to handle JSON parsing errors + // Take the raw request when a body is declared. The body is read and + // deserialized inside the handler AFTER auth/CSRF/permission checks, so + // unauthenticated clients cannot make the server buffer or parse payloads. if request_type.is_some() { - extractors.push(quote! { body_result: Result, axum::extract::rejection::JsonRejection> }); + extractors.push(quote! { request: axum::extract::Request }); } quote! { @@ -1556,6 +1564,43 @@ fn generate_axum_handler( } } +/// Generated code that reads and JSON-deserializes the request body from the +/// raw `request` extractor, bounded by `__RAS_BODY_LIMIT`. +/// +/// For authenticated endpoints this must be emitted AFTER the +/// auth/CSRF/permission block so unauthenticated clients cannot make the +/// server buffer or parse payloads. +fn generate_body_extraction() -> proc_macro2::TokenStream { + quote! { + let body = { + let body_bytes = match ::axum::body::to_bytes(request.into_body(), __RAS_BODY_LIMIT).await { + Ok(bytes) => bytes, + Err(_) => { + use axum::response::IntoResponse; + return ( + axum::http::StatusCode::PAYLOAD_TOO_LARGE, + axum::Json(serde_json::json!({ + "error": "Request body too large or unreadable" + })) + ).into_response(); + }, + }; + match serde_json::from_slice(&body_bytes) { + Ok(body) => body, + Err(_) => { + use axum::response::IntoResponse; + return ( + axum::http::StatusCode::BAD_REQUEST, + axum::Json(serde_json::json!({ + "error": "Invalid JSON" + })) + ).into_response(); + }, + } + }; + } +} + fn generate_handler_body( endpoint: &EndpointDefinition, handler_name: &Ident, @@ -1587,21 +1632,7 @@ fn generate_handler_body( // Handle JSON body extraction with error handling let json_handling = if endpoint.request_type.is_some() { args.push(quote! { body }); - quote! { - // Handle JSON parsing errors - let body = match body_result { - Ok(json) => json.0, - Err(_) => { - use axum::response::IntoResponse; - return ( - axum::http::StatusCode::BAD_REQUEST, - axum::Json(serde_json::json!({ - "error": "Invalid JSON" - })) - ).into_response(); - }, - }; - } + generate_body_extraction() } else { quote! {} }; @@ -1681,28 +1712,12 @@ fn generate_handler_body( // Handle JSON body extraction with error handling let json_handling = if endpoint.request_type.is_some() { args.push(quote! { body }); - quote! { - // Handle JSON parsing errors - let body = match body_result { - Ok(json) => json.0, - Err(_) => { - use axum::response::IntoResponse; - return ( - axum::http::StatusCode::BAD_REQUEST, - axum::Json(serde_json::json!({ - "error": "Invalid JSON" - })) - ).into_response(); - }, - }; - } + generate_body_extraction() } else { quote! {} }; quote! { - #json_handling - // Extract and validate auth token let auth_credential = match ras_auth_core::extract_auth_credential(&headers, &auth_transport) { Ok(credential) => credential, @@ -1786,6 +1801,9 @@ fn generate_handler_body( } } + // Read and parse the body only after auth has succeeded + #json_handling + // Call usage tracker if configured if let Some(tracker) = &with_usage_tracker { let tracker_headers = diff --git a/crates/rest/ras-rest-macro/tests/http_integration.rs b/crates/rest/ras-rest-macro/tests/http_integration.rs index aa4348b..5fd7604 100644 --- a/crates/rest/ras-rest-macro/tests/http_integration.rs +++ b/crates/rest/ras-rest-macro/tests/http_integration.rs @@ -1308,3 +1308,70 @@ async fn test_query_parameters_with_path_params() { let posts_response: PostsResponse = response.json(); assert_eq!(posts_response.posts.len(), 20); // Default per_page } + +// Minimal service exercising the body_limit option. +rest_service!({ + service_name: TinyBodyService, + base_path: "/tiny", + body_limit: 64, + endpoints: [ + POST UNAUTHORIZED echo(Value) -> Value, + ] +}); + +struct TinyBodyServiceImpl; + +#[async_trait::async_trait] +impl TinyBodyServiceTrait for TinyBodyServiceImpl { + async fn post_echo(&self, request: Value) -> ras_rest_core::RestResult { + Ok(RestResponse::ok(request)) + } +} + +#[tokio::test] +async fn test_body_is_not_parsed_before_auth() { + let server = create_rest_test_server(); + + // Invalid JSON without credentials must be rejected by auth (401, not + // 400), proving the body is neither read nor parsed before the + // auth/CSRF/permission checks succeed. + let response = server + .post("/api/v1/users") + .text("{invalid json") + .content_type("application/json") + .await; + assert_eq!(response.status_code().as_u16(), 401); + + // Same body with an invalid token: still rejected by auth. + let response = server + .post("/api/v1/users") + .authorization_bearer("wrong-token") + .text("{invalid json") + .content_type("application/json") + .await; + assert_eq!(response.status_code().as_u16(), 401); + + // With valid credentials the malformed body is now parsed and rejected. + let response = server + .post("/api/v1/users") + .authorization_bearer("admin-token") + .text("{invalid json") + .content_type("application/json") + .await; + assert_eq!(response.status_code().as_u16(), 400); +} + +#[tokio::test] +async fn test_body_limit_option_enforced() { + let app = TinyBodyServiceBuilder::new(TinyBodyServiceImpl).build(); + let server = TestServer::builder().mock_transport().build(app).unwrap(); + + let response = server.post("/tiny/echo").json(&json!({"ok": true})).await; + assert_eq!(response.status_code().as_u16(), 200); + + let response = server + .post("/tiny/echo") + .json(&json!({"data": "x".repeat(200)})) + .await; + assert_eq!(response.status_code().as_u16(), 413); +} From e91c874d93a8f3d52a7abe7ad647382afc073e10 Mon Sep 17 00:00:00 2001 From: Mathias Myrland Date: Tue, 9 Jun 2026 23:12:54 +0200 Subject: [PATCH 06/12] refactor(auth): single shared authorization pipeline in ras-auth-core MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The credential → CSRF → authenticate → OR-of-AND permission sequence was inlined by codegen in six places (REST canonical + versioned handlers, JSON-RPC dispatch, file-service handlers, bidirectional method dispatch) with already-diverging copies. It now lives in ras-auth-core: - authorize_request(): the full pipeline for header-authenticated HTTP services; generated REST handlers shrink to one call plus a per-service error-shape mapper, with byte-identical responses. - check_permission_groups(): provider-aware OR-of-AND group check used by the REST, JSON-RPC and file macros. - user_satisfies_permission_groups(): set-membership variant for the bidirectional handler, which authorizes against the cached connection user. This also fixes a real divergence: file_service! treated WITH_PERMISSIONS([]) as deny-all while the REST and JSON-RPC macros treated it as any-authenticated-user. All macros now share the authenticated-only semantics. Co-Authored-By: Claude Fable 5 --- Cargo.lock | 1 + crates/core/ras-auth-core/Cargo.toml | 3 + crates/core/ras-auth-core/src/authorize.rs | 253 ++++++++++++++++++ crates/core/ras-auth-core/src/lib.rs | 2 + crates/rest/ras-file-macro/src/server.rs | 8 +- crates/rest/ras-rest-macro/src/lib.rs | 205 ++++---------- .../src/server.rs | 44 +-- crates/rpc/ras-jsonrpc-macro/src/lib.rs | 44 +-- 8 files changed, 333 insertions(+), 227 deletions(-) create mode 100644 crates/core/ras-auth-core/src/authorize.rs diff --git a/Cargo.lock b/Cargo.lock index e5f66ed..f97130b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2646,6 +2646,7 @@ dependencies = [ "serde", "serde_json", "thiserror 2.0.18", + "tokio", ] [[package]] diff --git a/crates/core/ras-auth-core/Cargo.toml b/crates/core/ras-auth-core/Cargo.toml index 3e382f5..5d85cfc 100644 --- a/crates/core/ras-auth-core/Cargo.toml +++ b/crates/core/ras-auth-core/Cargo.toml @@ -15,3 +15,6 @@ http = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } thiserror = { workspace = true } + +[dev-dependencies] +tokio = { workspace = true } diff --git a/crates/core/ras-auth-core/src/authorize.rs b/crates/core/ras-auth-core/src/authorize.rs new file mode 100644 index 0000000..e245f52 --- /dev/null +++ b/crates/core/ras-auth-core/src/authorize.rs @@ -0,0 +1,253 @@ +//! Shared request-authorization pipeline for generated services. +//! +//! Every service macro (REST, file, JSON-RPC, bidirectional WebSocket) used +//! to inline its own copy of the credential → CSRF → authenticate → +//! permission-group sequence. These helpers are the single implementation; +//! generated code maps the returned [`AuthorizeError`] to its own protocol's +//! response shape. + +use crate::{ + AuthError, AuthProvider, AuthTransportConfig, AuthenticatedUser, extract_auth_credential, + validate_csrf_for_credential, +}; +use http::HeaderMap; + +/// Why [`authorize_request`] rejected a request. +#[derive(Debug)] +pub enum AuthorizeError { + /// No usable credential was found in the request + MissingCredential, + /// Double-submit CSRF validation failed for a cookie credential + CsrfValidationFailed, + /// The credential did not authenticate + AuthenticationFailed(AuthError), + /// The service was built without an auth provider + NoAuthProvider, + /// Authenticated, but no required permission group was satisfied + InsufficientPermissions(AuthError), +} + +/// OR-of-AND permission check shared by all generated services. +/// +/// `groups` is a disjunction of conjunctions: access is granted when the user +/// holds every permission of at least one group (verified through the +/// provider's `check_permissions`, which custom providers may override). A +/// group list with no non-empty groups — `WITH_PERMISSIONS([])` or any empty +/// inner group — grants access to any authenticated user. +pub fn check_permission_groups

( + provider: &P, + user: &AuthenticatedUser, + groups: &[Vec], +) -> Result<(), AuthError> +where + P: AuthProvider + ?Sized, +{ + if !groups.iter().any(|group| !group.is_empty()) { + return Ok(()); + } + + for group in groups { + if group.is_empty() || provider.check_permissions(user, group).is_ok() { + return Ok(()); + } + } + + Err(AuthError::InsufficientPermissions { + required: groups + .iter() + .find(|group| !group.is_empty()) + .cloned() + .unwrap_or_default(), + has: user.permissions.iter().cloned().collect(), + }) +} + +/// Set-membership variant of [`check_permission_groups`] for contexts without +/// an auth provider (e.g. the bidirectional WebSocket handler, which +/// authorizes against the cached connection user). +pub fn user_satisfies_permission_groups(user: &AuthenticatedUser, groups: &[Vec]) -> bool { + if !groups.iter().any(|group| !group.is_empty()) { + return true; + } + + groups + .iter() + .any(|group| !group.is_empty() && group.iter().all(|perm| user.permissions.contains(perm))) + || groups.iter().any(|group| group.is_empty()) +} + +/// The credential → CSRF → authenticate → permission pipeline shared by the +/// generated REST and file-service servers. +/// +/// `method` is the HTTP method, used to scope CSRF validation to unsafe +/// requests. Errors are ordered so no work happens for unauthenticated +/// callers: the request body has not been touched when this returns `Err`. +pub async fn authorize_request

( + method: &str, + headers: &HeaderMap, + auth_transport: &AuthTransportConfig, + auth_provider: Option<&P>, + required_permission_groups: &[Vec], +) -> Result +where + P: AuthProvider + ?Sized, +{ + let credential = extract_auth_credential(headers, auth_transport) + .map_err(|_| AuthorizeError::MissingCredential)?; + + validate_csrf_for_credential(method, headers, &credential, auth_transport) + .map_err(|_| AuthorizeError::CsrfValidationFailed)?; + + let provider = auth_provider.ok_or(AuthorizeError::NoAuthProvider)?; + + let user = provider + .authenticate(credential.token().to_string()) + .await + .map_err(AuthorizeError::AuthenticationFailed)?; + + check_permission_groups(provider, &user, required_permission_groups) + .map_err(AuthorizeError::InsufficientPermissions)?; + + Ok(user) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::AuthFuture; + use std::collections::HashSet; + + struct StaticProvider; + + impl AuthProvider for StaticProvider { + fn authenticate(&self, token: String) -> AuthFuture<'_> { + Box::pin(async move { + if token == "good" { + Ok(user(&["read", "write"])) + } else { + Err(AuthError::InvalidToken) + } + }) + } + } + + fn user(perms: &[&str]) -> AuthenticatedUser { + AuthenticatedUser { + user_id: "u".into(), + permissions: perms.iter().map(|p| p.to_string()).collect::>(), + metadata: None, + } + } + + fn groups(groups: &[&[&str]]) -> Vec> { + groups + .iter() + .map(|g| g.iter().map(|p| p.to_string()).collect()) + .collect() + } + + #[test] + fn empty_group_list_is_authenticated_only() { + assert!(check_permission_groups(&StaticProvider, &user(&[]), &[]).is_ok()); + assert!(user_satisfies_permission_groups(&user(&[]), &[])); + } + + #[test] + fn empty_inner_group_grants_any_authenticated_user() { + let g = groups(&[&["admin"], &[]]); + assert!(check_permission_groups(&StaticProvider, &user(&[]), &g).is_ok()); + assert!(user_satisfies_permission_groups(&user(&[]), &g)); + } + + #[test] + fn and_within_group_or_between_groups() { + let g = groups(&[&["read", "write"], &["admin"]]); + + // Satisfies the first group (all permissions present). + assert!(check_permission_groups(&StaticProvider, &user(&["read", "write"]), &g).is_ok()); + assert!(user_satisfies_permission_groups( + &user(&["read", "write"]), + &g + )); + + // Satisfies the second group. + assert!(check_permission_groups(&StaticProvider, &user(&["admin"]), &g).is_ok()); + assert!(user_satisfies_permission_groups(&user(&["admin"]), &g)); + + // Partial match on the first group, none on the second: denied. + let denied = check_permission_groups(&StaticProvider, &user(&["read"]), &g).unwrap_err(); + assert!(matches!( + denied, + AuthError::InsufficientPermissions { required, .. } if required == vec!["read", "write"] + )); + assert!(!user_satisfies_permission_groups(&user(&["read"]), &g)); + } + + #[tokio::test] + async fn authorize_request_full_pipeline() { + let transport = AuthTransportConfig::default(); + let mut headers = HeaderMap::new(); + + // No credential + let err = authorize_request( + "POST", + &headers, + &transport, + Some(&StaticProvider), + &groups(&[&["read"]]), + ) + .await + .unwrap_err(); + assert!(matches!(err, AuthorizeError::MissingCredential)); + + headers.insert("authorization", "Bearer bad".parse().unwrap()); + let err = authorize_request( + "POST", + &headers, + &transport, + Some(&StaticProvider), + &groups(&[&["read"]]), + ) + .await + .unwrap_err(); + assert!(matches!(err, AuthorizeError::AuthenticationFailed(_))); + + headers.insert("authorization", "Bearer good".parse().unwrap()); + + // Missing provider + let err = authorize_request( + "POST", + &headers, + &transport, + None::<&StaticProvider>, + &groups(&[&["read"]]), + ) + .await + .unwrap_err(); + assert!(matches!(err, AuthorizeError::NoAuthProvider)); + + // Insufficient permissions + let err = authorize_request( + "POST", + &headers, + &transport, + Some(&StaticProvider), + &groups(&[&["admin"]]), + ) + .await + .unwrap_err(); + assert!(matches!(err, AuthorizeError::InsufficientPermissions(_))); + + // Success + let user = authorize_request( + "POST", + &headers, + &transport, + Some(&StaticProvider), + &groups(&[&["read", "write"]]), + ) + .await + .unwrap(); + assert_eq!(user.user_id, "u"); + } +} diff --git a/crates/core/ras-auth-core/src/lib.rs b/crates/core/ras-auth-core/src/lib.rs index 54c3056..516bfa4 100644 --- a/crates/core/ras-auth-core/src/lib.rs +++ b/crates/core/ras-auth-core/src/lib.rs @@ -1,5 +1,6 @@ //! Authentication and authorization traits for JSON-RPC services. +mod authorize; mod transport; use std::collections::HashSet; @@ -9,6 +10,7 @@ use std::pin::Pin; use serde::{Deserialize, Serialize}; use thiserror::Error; +pub use authorize::*; pub use transport::*; /// Errors that can occur during authentication or authorization. diff --git a/crates/rest/ras-file-macro/src/server.rs b/crates/rest/ras-file-macro/src/server.rs index 21c1ac9..b16c977 100644 --- a/crates/rest/ras-file-macro/src/server.rs +++ b/crates/rest/ras-file-macro/src/server.rs @@ -929,12 +929,12 @@ fn generate_permission_check(auth: &AuthRequirement) -> TokenStream { }); quote! { + // OR-of-AND permission check (shared ras-auth-core implementation). + // A group list with no non-empty groups means "any authenticated + // user", consistent with the REST and JSON-RPC macros. let required_permission_groups: Vec> = vec![#(#groups),*]; let authenticated_user = user.as_ref().expect("authenticated user exists after auth check"); - let has_permission = required_permission_groups.iter().any(|group| { - group.is_empty() || auth_provider.check_permissions(authenticated_user, group).is_ok() - }); - if !has_permission { + if ::ras_auth_core::check_permission_groups(auth_provider.as_ref(), authenticated_user, &required_permission_groups).is_err() { return __ras_file_error_response(::ras_file_core::FileError::Forbidden); } } diff --git a/crates/rest/ras-rest-macro/src/lib.rs b/crates/rest/ras-rest-macro/src/lib.rs index b7d547a..185d631 100644 --- a/crates/rest/ras-rest-macro/src/lib.rs +++ b/crates/rest/ras-rest-macro/src/lib.rs @@ -789,6 +789,35 @@ fn generate_service_code(service_def: ServiceDefinition) -> syn::Result axum::response::Response { + use axum::response::IntoResponse; + let (status, message) = match error { + ras_auth_core::AuthorizeError::MissingCredential => ( + axum::http::StatusCode::UNAUTHORIZED, + "Missing or invalid Authorization header", + ), + ras_auth_core::AuthorizeError::CsrfValidationFailed => ( + axum::http::StatusCode::FORBIDDEN, + "CSRF validation failed", + ), + ras_auth_core::AuthorizeError::AuthenticationFailed(_) => ( + axum::http::StatusCode::UNAUTHORIZED, + "Authentication failed", + ), + ras_auth_core::AuthorizeError::NoAuthProvider => ( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + "No auth provider configured", + ), + ras_auth_core::AuthorizeError::InsufficientPermissions(_) => ( + axum::http::StatusCode::FORBIDDEN, + "Insufficient permissions", + ), + }; + (status, axum::Json(serde_json::json!({ "error": message }))).into_response() + } + /// Generated service trait #[async_trait::async_trait] #[allow(private_interfaces, private_bounds)] @@ -1370,81 +1399,19 @@ fn generate_legacy_handler_body( canonical_args.insert(0, quote! { &user }); quote! { - let auth_credential = match ras_auth_core::extract_auth_credential(&headers, &auth_transport) { - Ok(credential) => credential, - Err(_) => { - use axum::response::IntoResponse; - return ( - axum::http::StatusCode::UNAUTHORIZED, - axum::Json(serde_json::json!({ - "error": "Missing or invalid Authorization header" - })) - ).into_response(); - }, - }; - - if let Err(_) = ras_auth_core::validate_csrf_for_credential(#method, &headers, &auth_credential, &auth_transport) { - use axum::response::IntoResponse; - return ( - axum::http::StatusCode::FORBIDDEN, - axum::Json(serde_json::json!({ - "error": "CSRF validation failed" - })) - ).into_response(); - } - - let user = match &auth_provider { - Some(provider) => match provider.authenticate(auth_credential.token().to_string()).await { - Ok(user) => user, - Err(_) => { - use axum::response::IntoResponse; - return ( - axum::http::StatusCode::UNAUTHORIZED, - axum::Json(serde_json::json!({ - "error": "Authentication failed" - })) - ).into_response(); - }, - }, - None => { - use axum::response::IntoResponse; - return ( - axum::http::StatusCode::INTERNAL_SERVER_ERROR, - axum::Json(serde_json::json!({ - "error": "No auth provider configured" - })) - ).into_response(); - }, + // Authenticate and authorize: credential → CSRF → authenticate + // → OR-of-AND permission groups (shared ras-auth-core pipeline) + let user = match ras_auth_core::authorize_request( + #method, + &headers, + &auth_transport, + auth_provider.as_deref(), + &required_permission_groups, + ).await { + Ok(user) => user, + Err(error) => return __ras_authorize_error_response(error), }; - let has_non_empty_groups = required_permission_groups.iter().any(|g| !g.is_empty()); - if has_non_empty_groups { - let mut has_permission = false; - - for permission_group in &required_permission_groups { - if permission_group.is_empty() { - has_permission = true; - break; - } else { - let group_result = auth_provider.as_ref().unwrap().check_permissions(&user, permission_group); - if group_result.is_ok() { - has_permission = true; - break; - } - } - } - - if !has_permission { - use axum::response::IntoResponse; - return ( - axum::http::StatusCode::FORBIDDEN, - axum::Json(serde_json::json!({ - "error": "Insufficient permissions" - })) - ).into_response(); - } - } - // Read and parse the body only after auth has succeeded #json_handling @@ -1718,89 +1685,19 @@ fn generate_handler_body( }; quote! { - // Extract and validate auth token - let auth_credential = match ras_auth_core::extract_auth_credential(&headers, &auth_transport) { - Ok(credential) => credential, - Err(_) => { - use axum::response::IntoResponse; - return ( - axum::http::StatusCode::UNAUTHORIZED, - axum::Json(serde_json::json!({ - "error": "Missing or invalid Authorization header" - })) - ).into_response(); - }, - }; - - if let Err(_) = ras_auth_core::validate_csrf_for_credential(#method, &headers, &auth_credential, &auth_transport) { - use axum::response::IntoResponse; - return ( - axum::http::StatusCode::FORBIDDEN, - axum::Json(serde_json::json!({ - "error": "CSRF validation failed" - })) - ).into_response(); - } - - // Authenticate user - let user = match &auth_provider { - Some(provider) => match provider.authenticate(auth_credential.token().to_string()).await { - Ok(user) => user, - Err(_) => { - use axum::response::IntoResponse; - return ( - axum::http::StatusCode::UNAUTHORIZED, - axum::Json(serde_json::json!({ - "error": "Authentication failed" - })) - ).into_response(); - }, - }, - None => { - use axum::response::IntoResponse; - return ( - axum::http::StatusCode::INTERNAL_SERVER_ERROR, - axum::Json(serde_json::json!({ - "error": "No auth provider configured" - })) - ).into_response(); - }, + // Authenticate and authorize: credential → CSRF → authenticate + // → OR-of-AND permission groups (shared ras-auth-core pipeline) + let user = match ras_auth_core::authorize_request( + #method, + &headers, + &auth_transport, + auth_provider.as_deref(), + &required_permission_groups, + ).await { + Ok(user) => user, + Err(error) => return __ras_authorize_error_response(error), }; - // Check permissions - AND within groups, OR between groups - // Only check permissions if we have non-empty groups - let has_non_empty_groups = required_permission_groups.iter().any(|g| !g.is_empty()); - if has_non_empty_groups { - let mut has_permission = false; - - // Check each permission group (OR logic between groups) - for permission_group in &required_permission_groups { - // Check if user has ALL permissions in this group (AND logic within group) - if permission_group.is_empty() { - // Empty group means any authenticated user can access - has_permission = true; - break; - } else { - // Check if user has all permissions in this group - let group_result = auth_provider.as_ref().unwrap().check_permissions(&user, permission_group); - if group_result.is_ok() { - has_permission = true; - break; - } - } - } - - if !has_permission { - use axum::response::IntoResponse; - return ( - axum::http::StatusCode::FORBIDDEN, - axum::Json(serde_json::json!({ - "error": "Insufficient permissions" - })) - ).into_response(); - } - } - // Read and parse the body only after auth has succeeded #json_handling diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/src/server.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/src/server.rs index b6ea1ea..df0cb71 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/src/server.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/src/server.rs @@ -99,43 +99,15 @@ pub fn generate_server_code( let user = context.get_user().await .ok_or_else(|| ras_jsonrpc_bidirectional_server::ServerError::AuthenticationFailed(ras_auth_core::AuthError::InvalidToken))?; - // Check permissions - AND within groups, OR between groups + // OR-of-AND permission check against the cached + // connection user (shared ras-auth-core implementation) let required_permission_groups: Vec> = #permission_groups_code; - // Only check permissions if we have non-empty groups - let has_non_empty_groups = required_permission_groups.iter().any(|g| !g.is_empty()); - if has_non_empty_groups { - let mut has_permission = false; - - // Check each permission group (OR logic between groups) - for permission_group in &required_permission_groups { - // Check if user has ALL permissions in this group (AND logic within group) - if permission_group.is_empty() { - // Empty group means any authenticated user can access - has_permission = true; - break; - } else { - // Check if user has all permissions in this group - let group_satisfied = permission_group.iter() - .all(|perm| user.permissions.contains(perm)); - if group_satisfied { - has_permission = true; - break; - } - } - } - - if !has_permission { - // Find the first non-empty group for error reporting - let first_group = required_permission_groups.iter() - .find(|g| !g.is_empty()) - .cloned() - .unwrap_or_default(); - let error_response = ras_jsonrpc_types::JsonRpcResponse::error( - ras_jsonrpc_types::JsonRpcError::new(-32002, "Insufficient permissions".to_string(), None), - request.id.clone() - ); - return Ok(Some(error_response)); - } + if !ras_auth_core::user_satisfies_permission_groups(user.as_ref(), &required_permission_groups) { + let error_response = ras_jsonrpc_types::JsonRpcResponse::error( + ras_jsonrpc_types::JsonRpcError::new(-32002, "Insufficient permissions".to_string(), None), + request.id.clone() + ); + return Ok(Some(error_response)); } // Parse parameters diff --git a/crates/rpc/ras-jsonrpc-macro/src/lib.rs b/crates/rpc/ras-jsonrpc-macro/src/lib.rs index 67aad46..6d5bc95 100644 --- a/crates/rpc/ras-jsonrpc-macro/src/lib.rs +++ b/crates/rpc/ras-jsonrpc-macro/src/lib.rs @@ -799,40 +799,18 @@ fn jsonrpc_auth_check_code( ), }; + // OR-of-AND permission check (shared ras-auth-core implementation) let required_permission_groups: Vec> = #permission_groups_code; - let has_non_empty_groups = required_permission_groups.iter().any(|g| !g.is_empty()); - if has_non_empty_groups { - let mut has_permission = false; - - for permission_group in &required_permission_groups { - if permission_group.is_empty() { - has_permission = true; - break; - } else { - let group_result = self.auth_provider - .as_ref() - .unwrap() - .check_permissions(user, permission_group); - if group_result.is_ok() { - has_permission = true; - break; - } - } - } - - if !has_permission { - let first_group = required_permission_groups.iter() - .find(|g| !g.is_empty()) - .cloned() - .unwrap_or_default(); - return ras_jsonrpc_types::JsonRpcResponse::error( - ras_jsonrpc_types::JsonRpcError::insufficient_permissions( - first_group, - user.permissions.iter().cloned().collect() - ), - request.id.clone() - ); - } + let provider = self.auth_provider.as_ref().expect("auth provider required for WITH_PERMISSIONS methods"); + if let Err(error) = ras_jsonrpc_core::check_permission_groups(provider.as_ref(), user, &required_permission_groups) { + let (required, has) = match error { + ras_jsonrpc_core::AuthError::InsufficientPermissions { required, has } => (required, has), + _ => (Vec::new(), Vec::new()), + }; + return ras_jsonrpc_types::JsonRpcResponse::error( + ras_jsonrpc_types::JsonRpcError::insufficient_permissions(required, has), + request.id.clone() + ); } }, quote! { Some(user) }, From fe184faaa7df4d945b90ea0549a181fcb881461a Mon Sep 17 00:00:00 2001 From: Mathias Myrland Date: Tue, 9 Jun 2026 23:17:44 +0200 Subject: [PATCH 07/12] refactor(oauth2): typed start_flow API instead of success-inside-Err MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit OAuth2Provider::verify() previously answered StartFlow payloads by JSON-encoding the authorization URL into Err(IdentityError::ProviderError(...)) — the success value travelled inside an error, every caller had to pattern-match an error string and parse JSON out of it, and generic error logging reported successful flow starts as provider failures. start_flow(provider_id, additional_params) is now a public method returning OAuth2Result. verify() handles only Callback payloads and answers StartFlow with UnsupportedMethod. The demo server, the google_oauth2 example, the README, and the tests now use the typed API (the provider is Clone, so callers keep a handle for flow initiation alongside the one registered with SessionService). Co-Authored-By: Claude Fable 5 --- crates/identity/ras-identity-oauth2/README.md | 36 +++---- .../examples/google_oauth2.rs | 49 +++------ .../ras-identity-oauth2/src/provider.rs | 101 ++++++++---------- .../identity/ras-identity-oauth2/src/tests.rs | 70 +++++------- examples/oauth2-demo/server/src/main.rs | 34 ++---- 5 files changed, 105 insertions(+), 185 deletions(-) diff --git a/crates/identity/ras-identity-oauth2/README.md b/crates/identity/ras-identity-oauth2/README.md index a6f1c9d..113fe0a 100644 --- a/crates/identity/ras-identity-oauth2/README.md +++ b/crates/identity/ras-identity-oauth2/README.md @@ -58,38 +58,26 @@ let oauth2_provider = OAuth2Provider::new(config, state_store); ### Integration with Session Service ```rust -use ras_identity_core::{IdentityError, IdentityProvider}; +use ras_identity_core::IdentityProvider; use ras_identity_oauth2::OAuth2Response; -use ras_identity_session::{SessionConfig, SessionError, SessionService}; +use ras_identity_session::{SessionConfig, SessionService}; -// Register with session service +// Register with session service. The provider is cheap to clone; keep one +// handle for flow initiation and register the other for verification. let session_config = SessionConfig::new("use-at-least-32-bytes-of-random-secret")?; let session_service = SessionService::new(session_config)?; -session_service.register_provider(Box::new(oauth2_provider)).await; +session_service.register_provider(Box::new(oauth2_provider.clone())).await; // Start OAuth2 flow -let start_payload = serde_json::json!({ - "type": "StartFlow", - "provider_id": "google" -}); - -// This will return an error containing the authorization URL -match session_service.begin_session("oauth2", start_payload).await { - Err(SessionError::IdentityError(IdentityError::ProviderError(json))) => { - let response: OAuth2Response = serde_json::from_str(&json)?; - match response { - OAuth2Response::AuthorizationUrl { url, state } => { - // Redirect user to `url` - println!("Redirect to: {}", url); - } - OAuth2Response::Error { message } => { - eprintln!("OAuth2 start-flow failed: {message}"); - } - } +match oauth2_provider.start_flow("google", None).await? { + OAuth2Response::AuthorizationUrl { url, state } => { + // Redirect user to `url` + println!("Redirect to: {}", url); + } + OAuth2Response::Error { message } => { + eprintln!("OAuth2 start-flow failed: {message}"); } - Ok(_) => eprintln!("OAuth2 start flow completed without a redirect"), - Err(err) => eprintln!("OAuth2 start flow failed: {err}"), } // Handle callback diff --git a/crates/identity/ras-identity-oauth2/examples/google_oauth2.rs b/crates/identity/ras-identity-oauth2/examples/google_oauth2.rs index 2951cc2..00c4629 100644 --- a/crates/identity/ras-identity-oauth2/examples/google_oauth2.rs +++ b/crates/identity/ras-identity-oauth2/examples/google_oauth2.rs @@ -6,7 +6,6 @@ //! 3. Handling the OAuth2 flow //! 4. Issuing JWTs after successful authentication -use ras_identity_core::IdentityError; use ras_identity_oauth2::{ InMemoryStateStore, OAuth2Config, OAuth2Provider, OAuth2ProviderConfig, OAuth2Response, }; @@ -46,7 +45,9 @@ async fn main() -> Result<(), Box> { .with_state_ttl(600) // 10 minutes .with_http_timeout(30); // 30 seconds - // Create state store and OAuth2 provider + // Create state store and OAuth2 provider. The provider is cheap to clone; + // keep one handle for flow initiation and register the other for + // verification through the session service. let state_store = Arc::new(InMemoryStateStore::new()); let oauth2_provider = OAuth2Provider::new(oauth2_config, state_store); @@ -57,7 +58,7 @@ async fn main() -> Result<(), Box> { // Register OAuth2 provider with session service session_service - .register_provider(Box::new(oauth2_provider)) + .register_provider(Box::new(oauth2_provider.clone())) .await; println!("OAuth2 Example - Google Authentication"); @@ -66,36 +67,20 @@ async fn main() -> Result<(), Box> { // Step 1: Start OAuth2 flow println!("\n1. Starting OAuth2 flow..."); - let start_payload = serde_json::json!({ - "type": "StartFlow", - "provider_id": "google" - }); - - match session_service.begin_session("oauth2", start_payload).await { - Err(ras_identity_session::SessionError::IdentityError(IdentityError::ProviderError( - json, - ))) => { - let response: OAuth2Response = serde_json::from_str(&json)?; - - match response { - OAuth2Response::AuthorizationUrl { url, state } => { - println!("Authorization URL: {}", url); - println!("State: {}", state); - println!("\nIn a real application, you would:"); - println!("1. Redirect the user to the authorization URL"); - println!("2. Handle the callback with the authorization code"); - println!("3. Exchange the code for a JWT token"); - - // Simulate callback (in real app, this comes from OAuth2 provider) - simulate_callback(&session_service, state).await?; - } - _ => { - println!("Unexpected response from OAuth2 provider"); - } - } + match oauth2_provider.start_flow("google", None).await { + Ok(OAuth2Response::AuthorizationUrl { url, state }) => { + println!("Authorization URL: {}", url); + println!("State: {}", state); + println!("\nIn a real application, you would:"); + println!("1. Redirect the user to the authorization URL"); + println!("2. Handle the callback with the authorization code"); + println!("3. Exchange the code for a JWT token"); + + // Simulate callback (in real app, this comes from OAuth2 provider) + simulate_callback(&session_service, state).await?; } - Ok(_) => { - println!("Unexpected success from start flow"); + Ok(OAuth2Response::Error { message }) => { + println!("OAuth2 error: {}", message); } Err(e) => { println!("Error starting OAuth2 flow: {}", e); diff --git a/crates/identity/ras-identity-oauth2/src/provider.rs b/crates/identity/ras-identity-oauth2/src/provider.rs index 596b983..f53e236 100644 --- a/crates/identity/ras-identity-oauth2/src/provider.rs +++ b/crates/identity/ras-identity-oauth2/src/provider.rs @@ -104,8 +104,13 @@ impl OAuth2Provider { }) } - /// Handle the start flow request - async fn handle_start_flow( + /// Start an OAuth2 authorization flow. + /// + /// Returns the authorization URL to redirect the user to, plus the + /// `state` parameter bound to this flow. This is the supported way to + /// initiate a flow; `verify()` only completes one (the `Callback` + /// payload). + pub async fn start_flow( &self, provider_id: &str, additional_params: Option>, @@ -263,21 +268,11 @@ impl IdentityProvider for OAuth2Provider { serde_json::from_value(auth_payload).map_err(|_| IdentityError::InvalidPayload)?; match payload { - OAuth2AuthPayload::StartFlow { - provider_id, - additional_params, - } => { - // For start flow, we return an error with the authorization URL - let response = self - .handle_start_flow(&provider_id, additional_params) - .await - .map_err(|e| IdentityError::ProviderError(e.to_string()))?; - - // Return the response as a provider error (client should handle this specially) - let response_json = - serde_json::to_string(&response).map_err(IdentityError::SerializationError)?; - - Err(IdentityError::ProviderError(response_json)) + OAuth2AuthPayload::StartFlow { .. } => { + // Flow initiation is not identity verification and has no + // identity to return. Call `OAuth2Provider::start_flow` + // directly to obtain the authorization URL. + Err(IdentityError::UnsupportedMethod) } OAuth2AuthPayload::Callback { provider_id, @@ -334,31 +329,27 @@ mod tests { async fn test_start_flow() { let provider = create_test_provider(); + let result = provider.start_flow("google", None).await.unwrap(); + match result { + OAuth2Response::AuthorizationUrl { url, state } => { + assert!(url.contains("https://accounts.google.com/o/oauth2/v2/auth")); + assert!(url.contains("response_type=code")); + assert!(url.contains("client_id=test_client_id")); + assert!(!state.is_empty()); + } + _ => panic!("Expected AuthorizationUrl response"), + } + + // StartFlow payloads are no longer routed through verify() let payload = serde_json::json!({ "type": "StartFlow", "provider_id": "google", "additional_params": null }); - - let result = provider.verify(payload).await; - - // Start flow returns an error with the authorization URL - assert!(result.is_err()); - - if let Err(IdentityError::ProviderError(response_json)) = result { - let response: OAuth2Response = serde_json::from_str(&response_json).unwrap(); - match response { - OAuth2Response::AuthorizationUrl { url, state } => { - assert!(url.contains("https://accounts.google.com/o/oauth2/v2/auth")); - assert!(url.contains("response_type=code")); - assert!(url.contains("client_id=test_client_id")); - assert!(!state.is_empty()); - } - _ => panic!("Expected AuthorizationUrl response"), - } - } else { - panic!("Expected ProviderError"); - } + assert!(matches!( + provider.verify(payload).await, + Err(IdentityError::UnsupportedMethod) + )); } #[tokio::test] @@ -379,17 +370,16 @@ mod tests { async fn verify_reports_unknown_provider() { let provider = create_test_provider(); - let result = provider - .verify(serde_json::json!({ - "type": "StartFlow", - "provider_id": "missing" - })) - .await; + let result = provider.start_flow("missing", None).await; - let Err(IdentityError::ProviderError(message)) = result else { - panic!("expected provider error for missing provider"); + let Err(error) = result else { + panic!("expected error for missing provider"); }; - assert!(message.contains("Provider 'missing' not configured")); + assert!( + error + .to_string() + .contains("Provider 'missing' not configured") + ); } #[tokio::test] @@ -398,20 +388,13 @@ mod tests { let mut provider = OAuth2Provider::new(OAuth2Config::default(), state_store); provider.add_provider(google_config()); - let result = provider - .verify(serde_json::json!({ - "type": "StartFlow", - "provider_id": "google", - "additional_params": { - "prompt": "consent" - } - })) - .await; + let mut params = HashMap::new(); + params.insert("prompt".to_string(), "consent".to_string()); + let response = provider + .start_flow("google", Some(params)) + .await + .expect("start_flow succeeds"); - let Err(IdentityError::ProviderError(response_json)) = result else { - panic!("expected authorization URL response encoded as provider error"); - }; - let response: OAuth2Response = serde_json::from_str(&response_json).unwrap(); let OAuth2Response::AuthorizationUrl { url, state } = response else { panic!("expected authorization URL response"); }; diff --git a/crates/identity/ras-identity-oauth2/src/tests.rs b/crates/identity/ras-identity-oauth2/src/tests.rs index 487934f..5097e2c 100644 --- a/crates/identity/ras-identity-oauth2/src/tests.rs +++ b/crates/identity/ras-identity-oauth2/src/tests.rs @@ -220,30 +220,27 @@ mod integration_tests { let state_store = Arc::new(InMemoryStateStore::new()); let provider = provider_with_server(provider_config, state_store, server); - // Start OAuth2 flow + // Start OAuth2 flow via the typed API + let start_result = provider.start_flow("mock_provider", None).await.unwrap(); + let auth_url = match start_result { + OAuth2Response::AuthorizationUrl { url, state } => { + assert!(url.contains("/authorize")); + assert!(url.contains("response_type=code")); + assert!(url.contains("code_challenge")); + state + } + _ => panic!("Expected authorization URL"), + }; + + // StartFlow payloads are no longer routed through verify() let start_payload = serde_json::json!({ "type": "StartFlow", "provider_id": "mock_provider" }); - - let start_result = provider.verify(start_payload).await; - assert!(start_result.is_err()); - - let auth_url = - if let Err(ras_identity_core::IdentityError::ProviderError(json)) = start_result { - let response: OAuth2Response = serde_json::from_str(&json).unwrap(); - match response { - OAuth2Response::AuthorizationUrl { url, state } => { - assert!(url.contains("/authorize")); - assert!(url.contains("response_type=code")); - assert!(url.contains("code_challenge")); - state - } - _ => panic!("Expected authorization URL"), - } - } else { - panic!("Expected provider error with auth URL"); - }; + assert!(matches!( + provider.verify(start_payload).await, + Err(ras_identity_core::IdentityError::UnsupportedMethod) + )); // Simulate callback let callback_payload = serde_json::json!({ @@ -270,12 +267,7 @@ mod integration_tests { let provider = OAuth2Provider::new(config, state_store); // Test invalid provider - let payload = serde_json::json!({ - "type": "StartFlow", - "provider_id": "nonexistent" - }); - - let result = provider.verify(payload).await; + let result = provider.start_flow("nonexistent", None).await; assert!(result.is_err()); // Test callback with invalid state @@ -334,15 +326,8 @@ mod integration_tests { let provider = OAuth2Provider::new(config, state_store.clone()); // Generate two authorization URLs - let payload1 = serde_json::json!({"type": "StartFlow", "provider_id": "test"}); - let payload2 = serde_json::json!({"type": "StartFlow", "provider_id": "test"}); - - let result1 = provider.verify(payload1).await; - let result2 = provider.verify(payload2).await; - - // Extract states - let state1 = extract_state_from_error(result1); - let state2 = extract_state_from_error(result2); + let state1 = extract_state(provider.start_flow("test", None).await); + let state2 = extract_state(provider.start_flow("test", None).await); // States should be unique assert_ne!(state1, state2); @@ -352,17 +337,10 @@ mod integration_tests { assert_eq!(state2.len(), 36); } - fn extract_state_from_error( - result: Result, - ) -> String { - if let Err(ras_identity_core::IdentityError::ProviderError(json)) = result { - let response: OAuth2Response = serde_json::from_str(&json).unwrap(); - match response { - OAuth2Response::AuthorizationUrl { state, .. } => state, - _ => panic!("Expected authorization URL"), - } - } else { - panic!("Expected provider error"); + fn extract_state(result: crate::OAuth2Result) -> String { + match result.expect("start_flow succeeds") { + OAuth2Response::AuthorizationUrl { state, .. } => state, + _ => panic!("Expected authorization URL"), } } diff --git a/examples/oauth2-demo/server/src/main.rs b/examples/oauth2-demo/server/src/main.rs index 1e2bcd1..0376112 100644 --- a/examples/oauth2-demo/server/src/main.rs +++ b/examples/oauth2-demo/server/src/main.rs @@ -6,7 +6,6 @@ use axum::{ response::{Html, Redirect}, routing::{get, post}, }; -use ras_identity_core::{IdentityError, IdentityProvider}; use ras_identity_oauth2::{ InMemoryStateStore, OAuth2AuthPayload, OAuth2Config, OAuth2Provider, OAuth2ProviderConfig, OAuth2Response, @@ -183,30 +182,17 @@ async fn start_oauth2_handler( ) -> Result, String> { info!("Starting OAuth2 flow for provider: {}", request.provider_id); - let auth_payload = OAuth2AuthPayload::StartFlow { - provider_id: request.provider_id.clone(), - additional_params: request.additional_params, - }; - - let payload_json = serde_json::to_value(auth_payload) - .map_err(|e| format!("Failed to serialize OAuth2 payload: {}", e))?; - - // The OAuth2 provider returns an error with the authorization URL for start flow - match state.oauth2_provider.verify(payload_json).await { - Err(IdentityError::ProviderError(response_json)) => { - let oauth2_response: OAuth2Response = serde_json::from_str(&response_json) - .map_err(|e| format!("Failed to parse OAuth2 response: {}", e))?; - - match oauth2_response { - OAuth2Response::AuthorizationUrl { url, state } => Ok(Json(StartOAuth2Response { - authorization_url: url, - state, - })), - OAuth2Response::Error { message } => Err(format!("OAuth2 error: {}", message)), - } - } + match state + .oauth2_provider + .start_flow(&request.provider_id, request.additional_params) + .await + { + Ok(OAuth2Response::AuthorizationUrl { url, state }) => Ok(Json(StartOAuth2Response { + authorization_url: url, + state, + })), + Ok(OAuth2Response::Error { message }) => Err(format!("OAuth2 error: {}", message)), Err(e) => Err(format!("OAuth2 provider error: {}", e)), - Ok(_) => Err("Unexpected success response from start flow".to_string()), } } From dfdabc09e287d2a38a470a23bd9bb2f69cc5fc63 Mon Sep 17 00:00:00 2001 From: Mathias Myrland Date: Tue, 9 Jun 2026 23:27:14 +0200 Subject: [PATCH 08/12] feat(macros): opt-in consumer-driven feature gating (feature_gated: true) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The four service macros decided whether to emit server/client/reqwest code via cfg!(feature = ...), which evaluates the PROC-MACRO crate's feature set. Cargo unifies those features across the whole workspace, so any crate enabling e.g. ras-rest-macro/reqwest forced client codegen (and its transport dependencies) into every other consumer's expansion — the root cause of the earlier bidirectional-chat-server build break. Proc macros cannot observe the consumer's features (Cargo passes them as --cfg flags, not env), so the fix is to emit the generated blocks wrapped in #[cfg(feature = "server")] / #[cfg(feature = "client")] / #[cfg(feature = "reqwest")] attributes, which rustc resolves against the CONSUMER crate. This is opt-in via a new service-level field: rest_service!({ ..., feature_gated: true, ... }) supported by rest_service!, jsonrpc_service!, file_service! and jsonrpc_bidirectional_service!. Consumers that opt in declare their own server/client (and optionally reqwest) features, as bidirectional-chat- api already did; feature-less consumers keep the old macro-crate-feature behavior. chat-api now uses feature_gated: true for both its services, pinning the regression. Permission-manifest gating still follows the macro crate's permissions feature (additive; unchanged scope). Co-Authored-By: Claude Fable 5 --- crates/rest/ras-file-macro/src/client.rs | 11 ++++- crates/rest/ras-file-macro/src/lib.rs | 27 ++++++++++++- crates/rest/ras-file-macro/src/parser.rs | 7 ++++ crates/rest/ras-rest-macro/src/client.rs | 12 +++++- crates/rest/ras-rest-macro/src/lib.rs | 35 +++++++++++++++- .../src/lib.rs | 40 ++++++++++++++++++- crates/rpc/ras-jsonrpc-macro/src/client.rs | 11 ++++- crates/rpc/ras-jsonrpc-macro/src/lib.rs | 30 +++++++++++++- examples/bidirectional-chat/api/Cargo.toml | 4 +- examples/bidirectional-chat/api/src/auth.rs | 1 + examples/bidirectional-chat/api/src/lib.rs | 1 + 11 files changed, 167 insertions(+), 12 deletions(-) diff --git a/crates/rest/ras-file-macro/src/client.rs b/crates/rest/ras-file-macro/src/client.rs index d4d6ff5..873045c 100644 --- a/crates/rest/ras-file-macro/src/client.rs +++ b/crates/rest/ras-file-macro/src/client.rs @@ -17,8 +17,17 @@ pub fn generate_client(definition: &FileServiceDefinition) -> TokenStream { .endpoints .iter() .map(|endpoint| generate_client_method(definition, endpoint, &base_path)); - let build_method = if cfg!(feature = "reqwest") { + // With `feature_gated: true` the convenience constructor is gated on the + // CONSUMER crate's `reqwest` feature instead of the macro crate's + // (workspace-unified) one. + let cfg_reqwest = if definition.feature_gated { + quote! { #[cfg(feature = "reqwest")] } + } else { + quote! {} + }; + let build_method = if definition.feature_gated || cfg!(feature = "reqwest") { quote! { + #cfg_reqwest pub fn build( self, ) -> Result<#client_name, Box> { diff --git a/crates/rest/ras-file-macro/src/lib.rs b/crates/rest/ras-file-macro/src/lib.rs index 19f2968..ef5d9a4 100644 --- a/crates/rest/ras-file-macro/src/lib.rs +++ b/crates/rest/ras-file-macro/src/lib.rs @@ -37,37 +37,60 @@ pub fn file_service(input: TokenStream) -> TokenStream { quote! {} }; - let server_output = if cfg!(feature = "server") { + // With `feature_gated: true` the generated code is wrapped in + // `#[cfg(feature = ...)]` attributes resolved against the CONSUMER + // crate's features, immune to workspace feature unification of the + // macro crate's own features (which `cfg!` evaluates). + let feature_gated = definition.feature_gated; + let cfg_server = if feature_gated { + quote! { #[cfg(feature = "server")] } + } else { + quote! {} + }; + let cfg_client = if feature_gated { + quote! { #[cfg(feature = "client")] } + } else { + quote! {} + }; + + let server_output = if feature_gated || cfg!(feature = "server") { quote! { + #cfg_server mod #server_mod { use super::*; #server_code } + #cfg_server pub use #server_mod::*; + #cfg_server const _: () = { #schema_checks }; + #cfg_server mod #openapi_mod { use super::*; #openapi_code } + #cfg_server pub use #openapi_mod::*; } } else { quote! {} }; - let client_output = if cfg!(feature = "client") { + let client_output = if feature_gated || cfg!(feature = "client") { quote! { + #cfg_client mod #client_mod { use super::*; #client_code } + #cfg_client pub use #client_mod::*; } } else { diff --git a/crates/rest/ras-file-macro/src/parser.rs b/crates/rest/ras-file-macro/src/parser.rs index 4d1c712..5b06482 100644 --- a/crates/rest/ras-file-macro/src/parser.rs +++ b/crates/rest/ras-file-macro/src/parser.rs @@ -9,6 +9,7 @@ pub struct FileServiceDefinition { pub service_name: Ident, pub base_path: LitStr, pub openapi: Option, + pub feature_gated: bool, pub endpoints: Vec, } @@ -103,6 +104,7 @@ impl Parse for FileServiceDefinition { let mut service_name = None; let mut base_path = None; let mut openapi = None; + let mut feature_gated = false; let mut endpoints = Vec::new(); while !content.is_empty() { @@ -112,6 +114,10 @@ impl Parse for FileServiceDefinition { match field_name.to_string().as_str() { "service_name" => service_name = Some(content.parse()?), "base_path" => base_path = Some(content.parse()?), + "feature_gated" => { + let enabled = content.parse::()?; + feature_gated = enabled.value(); + } "body_limit" => { return Err(Error::new( field_name.span(), @@ -183,6 +189,7 @@ impl Parse for FileServiceDefinition { .ok_or_else(|| Error::new(input.span(), "Missing service_name"))?, base_path: base_path.ok_or_else(|| Error::new(input.span(), "Missing base_path"))?, openapi, + feature_gated, endpoints, }) } diff --git a/crates/rest/ras-rest-macro/src/client.rs b/crates/rest/ras-rest-macro/src/client.rs index 5fa9803..88adee8 100644 --- a/crates/rest/ras-rest-macro/src/client.rs +++ b/crates/rest/ras-rest-macro/src/client.rs @@ -57,13 +57,23 @@ pub fn generate_client_code(service_def: &ServiceDefinition) -> proc_macro2::Tok .iter() .flat_map(generate_client_methods_with_timeout_for_endpoint); - let build_method = if cfg!(feature = "reqwest") { + // With `feature_gated: true` the convenience constructor compiles only + // when the CONSUMER crate enables its own `reqwest` feature (which should + // activate `ras-transport-core/reqwest`); otherwise the macro crate's + // (workspace-unified) feature decides whether to emit it at all. + let cfg_reqwest = if service_def.feature_gated { + quote! { #[cfg(feature = "reqwest")] } + } else { + quote! {} + }; + let build_method = if service_def.feature_gated || cfg!(feature = "reqwest") { quote! { /// Build the client using the default `ReqwestTransport`. /// /// # Errors /// /// Returns an error if the underlying transport fails to construct. + #cfg_reqwest pub fn build(self) -> Result<#client_name, Box> { let transport = std::sync::Arc::new(::ras_transport_core::ReqwestTransport::new()); self.build_with_transport(transport) diff --git a/crates/rest/ras-rest-macro/src/lib.rs b/crates/rest/ras-rest-macro/src/lib.rs index 185d631..e1a64af 100644 --- a/crates/rest/ras-rest-macro/src/lib.rs +++ b/crates/rest/ras-rest-macro/src/lib.rs @@ -79,6 +79,7 @@ struct ServiceDefinition { openapi: Option, static_hosting: static_hosting::StaticHostingConfig, body_limit: Option, + feature_gated: bool, endpoints: Vec, } @@ -252,6 +253,7 @@ impl Parse for ServiceDefinition { let mut openapi = None; let mut static_hosting = static_hosting::StaticHostingConfig::default(); let mut body_limit = None; + let mut feature_gated = false; // Parse optional fields while content.peek(Ident) { @@ -303,6 +305,12 @@ impl Parse for ServiceDefinition { let limit = content.parse::()?; body_limit = Some(limit.base10_parse::()?); let _ = content.parse::()?; + } else if field_name == "feature_gated" { + let _ = content.parse::()?; // "feature_gated" + let _ = content.parse::()?; + let enabled = content.parse::()?; + feature_gated = enabled.value(); + let _ = content.parse::()?; } else if field_name == "endpoints" { break; // Start parsing endpoints } else { @@ -337,6 +345,7 @@ impl Parse for ServiceDefinition { openapi, static_hosting, body_limit, + feature_gated, endpoints, }) } @@ -780,8 +789,27 @@ fn generate_service_code(service_def: ServiceDefinition) -> syn::Result syn::Result TokenStream { #[derive(Debug)] struct BidirectionalServiceDefinition { service_name: Ident, + feature_gated: bool, client_to_server: Vec, server_to_client: Vec, server_to_client_calls: Vec, @@ -68,6 +69,20 @@ impl Parse for BidirectionalServiceDefinition { let service_name = content.parse::()?; let _ = content.parse::()?; + // Optional: feature_gated: , + let mut feature_gated = false; + if content + .fork() + .parse::() + .map(|ident| ident == "feature_gated") + .unwrap_or(false) + { + let _ = content.parse::()?; // "feature_gated" + let _ = content.parse::()?; + feature_gated = content.parse::()?.value(); + let _ = content.parse::()?; + } + // Parse client_to_server: [...] let _ = content.parse::()?; // "client_to_server" let _ = content.parse::()?; @@ -128,6 +143,7 @@ impl Parse for BidirectionalServiceDefinition { Ok(BidirectionalServiceDefinition { service_name, + feature_gated, client_to_server, server_to_client, server_to_client_calls, @@ -279,28 +295,48 @@ fn generate_service_code( quote! {} }; - let server_output = if cfg!(feature = "server") { + // With `feature_gated: true` the generated code is wrapped in + // `#[cfg(feature = ...)]` attributes resolved against the CONSUMER + // crate's features, immune to workspace feature unification of the + // macro crate's own features (which `cfg!` evaluates). + let feature_gated = service_def.feature_gated; + let cfg_server = if feature_gated { + quote! { #[cfg(feature = "server")] } + } else { + quote! {} + }; + let cfg_client = if feature_gated { + quote! { #[cfg(feature = "client")] } + } else { + quote! {} + }; + + let server_output = if feature_gated || cfg!(feature = "server") { quote! { + #cfg_server mod #server_mod { use super::*; #server_code } + #cfg_server pub use #server_mod::*; } } else { quote! {} }; - let client_output = if cfg!(feature = "client") { + let client_output = if feature_gated || cfg!(feature = "client") { quote! { + #cfg_client mod #client_mod { use super::*; #client_code } + #cfg_client pub use #client_mod::*; } } else { diff --git a/crates/rpc/ras-jsonrpc-macro/src/client.rs b/crates/rpc/ras-jsonrpc-macro/src/client.rs index 623be52..1c420d0 100644 --- a/crates/rpc/ras-jsonrpc-macro/src/client.rs +++ b/crates/rpc/ras-jsonrpc-macro/src/client.rs @@ -18,9 +18,18 @@ pub fn generate_client_code(service_def: &ServiceDefinition) -> proc_macro2::Tok .iter() .flat_map(generate_client_methods_with_timeout_for_method); - let build_method = if cfg!(feature = "reqwest") { + // With `feature_gated: true` the convenience constructor is gated on the + // CONSUMER crate's `reqwest` feature instead of the macro crate's + // (workspace-unified) one. + let cfg_reqwest = if service_def.feature_gated { + quote! { #[cfg(feature = "reqwest")] } + } else { + quote! {} + }; + let build_method = if service_def.feature_gated || cfg!(feature = "reqwest") { quote! { /// Build the client using the default `ReqwestTransport`. + #cfg_reqwest pub fn build(self) -> Result<#client_name, Box> { let transport = std::sync::Arc::new(::ras_transport_core::ReqwestTransport::new()); self.build_with_transport(transport) diff --git a/crates/rpc/ras-jsonrpc-macro/src/lib.rs b/crates/rpc/ras-jsonrpc-macro/src/lib.rs index 6d5bc95..2370fd2 100644 --- a/crates/rpc/ras-jsonrpc-macro/src/lib.rs +++ b/crates/rpc/ras-jsonrpc-macro/src/lib.rs @@ -28,6 +28,7 @@ struct ServiceDefinition { service_name: Ident, openrpc: Option, explorer: Option, + feature_gated: bool, methods: Vec, } @@ -148,6 +149,7 @@ impl Parse for ServiceDefinition { // Check if openrpc field is present let mut openrpc = None; let mut explorer = None; + let mut feature_gated = false; // Parse optional fields until we hit "methods" while content.peek(Ident) { @@ -193,6 +195,9 @@ impl Parse for ServiceDefinition { let path = explorer_content.parse::()?; explorer = Some(ExplorerConfig::WithPath(path.value())); } + } else if field_name == "feature_gated" { + let enabled = content.parse::()?; + feature_gated = enabled.value(); } let _ = content.parse::()?; @@ -220,6 +225,7 @@ impl Parse for ServiceDefinition { service_name, openrpc, explorer, + feature_gated, methods, }) } @@ -460,8 +466,25 @@ fn generate_service_code(service_def: ServiceDefinition) -> syn::Result syn::Result syn::Result Server methods (with authentication/permissions) client_to_server: [ From fc5c1874ce89945301f1d44bed7ccb74b48c54e8 Mon Sep 17 00:00:00 2001 From: Mathias Myrland Date: Tue, 9 Jun 2026 23:37:15 +0200 Subject: [PATCH 09/12] fix(examples): close IDOR and content-type sniffing in file-service backend The wasm file-service backend is a template users copy, and it had three real vulnerabilities (audit HIGH-4): - get_file matched stored files by starts_with(file_id): a truncated or single-character id matched arbitrary files (IDOR/enumeration). Ids are now validated as full UUIDs and matched exactly against the file stem. - download_secure had no object-ownership check (the handler even said so in a comment). Files are now stored per-scope (public/ vs users/{owner}/), the authenticated uploader is recorded as owner, and the secure download only resolves files in the caller's own scope; anything else is indistinguishable from missing. Owner ids are charset-validated before becoming a path segment. - The served Content-Type derived from the attacker-controlled upload filename extension. Extensions are now allowlisted at save time (everything else stored as .bin) and only known-safe content types are echoed on download; everything else is application/octet-stream. Regression tests: prefix ids rejected, cross-user and public access to owned files yields 404, unauthenticated secure download yields 401, .html upload served as opaque octet-stream. Co-Authored-By: Claude Fable 5 --- .../file-service-backend/src/file_service.rs | 101 ++++++-- .../file-service-backend/src/storage.rs | 218 +++++++++++++++--- 2 files changed, 268 insertions(+), 51 deletions(-) diff --git a/examples/file-service-wasm/file-service-backend/src/file_service.rs b/examples/file-service-wasm/file-service-backend/src/file_service.rs index 3821c18..e6c72b3 100644 --- a/examples/file-service-wasm/file-service-backend/src/file_service.rs +++ b/examples/file-service-wasm/file-service-backend/src/file_service.rs @@ -24,6 +24,7 @@ impl FileServiceImpl { async fn handle_file_upload( &self, file: &mut IncomingFile<'_>, + owner: Option<&str>, ) -> Result { let file_name = file.file_name().unwrap_or("unknown").to_string(); let content_type = file.content_type().map(ToString::to_string); @@ -37,7 +38,7 @@ impl FileServiceImpl { let metadata = self .storage - .save_file(data, &file_name, content_type) + .save_file(data, &file_name, content_type, owner) .await .map_err(|e| { error!("Failed to save file: {}", e); @@ -60,8 +61,9 @@ impl FileServiceImpl { async fn download_response( &self, file_id: String, + owner: Option<&str>, ) -> Result { - let (data, metadata) = self.storage.get_file(&file_id).await.map_err(|e| { + let (data, metadata) = self.storage.get_file(&file_id, owner).await.map_err(|e| { error!("Failed to get file: {}", e); match e.to_string().contains("not found") { true => DocumentServiceFileError::NotFound, @@ -112,7 +114,8 @@ impl DocumentServiceTrait for FileServiceImpl { ) -> Result<(), DocumentServiceFileError> { match part { DocumentServiceUploadPart::File(file) => { - state.response = Some(self.handle_file_upload(file).await?); + // Anonymous endpoint: files land in the public scope. + state.response = Some(self.handle_file_upload(file, None).await?); } } Ok(()) @@ -144,14 +147,17 @@ impl DocumentServiceTrait for FileServiceImpl { async fn upload_profile_picture_part( &self, - _ctx: &FileRequestContext<'_>, + ctx: &FileRequestContext<'_>, _path: &DocumentServiceUploadProfilePicturePath, state: &mut Self::UploadProfilePictureState, part: &mut DocumentServiceUploadProfilePicturePart<'_>, ) -> Result<(), DocumentServiceFileError> { + // Authenticated endpoint: record the uploader as the file's owner so + // the secure download path can enforce ownership. + let user = ctx.user.ok_or(DocumentServiceFileError::Unauthorized)?; match part { DocumentServiceUploadProfilePicturePart::File(file) => { - state.response = Some(self.handle_file_upload(file).await?); + state.response = Some(self.handle_file_upload(file, Some(&user.user_id)).await?); } } Ok(()) @@ -176,7 +182,8 @@ impl DocumentServiceTrait for FileServiceImpl { path: DocumentServiceDownloadByFileIdPath, ) -> Result { debug!("Handling public file download: {}", path.file_id); - self.download_response(path.file_id).await + // Public endpoint: only the public scope is reachable. + self.download_response(path.file_id, None).await } async fn download_secure_by_file_id( @@ -184,15 +191,15 @@ impl DocumentServiceTrait for FileServiceImpl { ctx: &FileRequestContext<'_>, path: DocumentServiceDownloadSecureByFileIdPath, ) -> Result { - if let Some(user) = ctx.user { - debug!( - "Handling secure file download for user {}: {}", - user.user_id, path.file_id - ); - } - - // In a real app, you might check if the user has access to this file - self.download_response(path.file_id).await + // Object ownership: a user can only download files they uploaded. + // Anything outside their scope is indistinguishable from missing. + let user = ctx.user.ok_or(DocumentServiceFileError::Unauthorized)?; + debug!( + "Handling secure file download for user {}: {}", + user.user_id, path.file_id + ); + self.download_response(path.file_id, Some(&user.user_id)) + .await } } @@ -242,6 +249,7 @@ mod tests { b"download body".to_vec(), "report.txt", Some("text/plain".to_string()), + None, ) .await .expect("save file"); @@ -292,11 +300,16 @@ mod tests { } #[tokio::test] - async fn secure_download_uses_same_download_path() { + async fn secure_download_returns_own_file() { let temp_dir = TempDir::new().expect("temp dir"); let storage = Arc::new(FileStorage::new(temp_dir.path())); let saved = storage - .save_file(b"secure body".to_vec(), "secure.bin", None) + .save_file( + b"secure body".to_vec(), + "secure.bin", + None, + Some("testuser"), + ) .await .expect("save file"); let service = FileServiceImpl::new(storage); @@ -314,4 +327,58 @@ mod tests { assert_eq!(body_bytes(response), b"secure body"); } + + #[tokio::test] + async fn secure_download_denies_other_users_files() { + let temp_dir = TempDir::new().expect("temp dir"); + let storage = Arc::new(FileStorage::new(temp_dir.path())); + let saved = storage + .save_file(b"alice data".to_vec(), "doc.txt", None, Some("alice")) + .await + .expect("save file"); + let service = FileServiceImpl::new(storage); + let headers = HeaderMap::new(); + let user = test_user(); // user_id: testuser + let ctx = test_context(&headers, Some(&user)); + + let result = service + .download_secure_by_file_id( + &ctx, + DocumentServiceDownloadSecureByFileIdPath { + file_id: saved.id.clone(), + }, + ) + .await; + assert!(matches!(result, Err(DocumentServiceFileError::NotFound))); + + // It is also invisible through the public endpoint. + let result = service + .download_by_file_id( + &ctx, + DocumentServiceDownloadByFileIdPath { file_id: saved.id }, + ) + .await; + assert!(matches!(result, Err(DocumentServiceFileError::NotFound))); + } + + #[tokio::test] + async fn secure_download_requires_authenticated_user() { + let temp_dir = TempDir::new().expect("temp dir"); + let service = test_service(&temp_dir); + let headers = HeaderMap::new(); + let ctx = test_context(&headers, None); + + let result = service + .download_secure_by_file_id( + &ctx, + DocumentServiceDownloadSecureByFileIdPath { + file_id: "ignored".to_string(), + }, + ) + .await; + assert!(matches!( + result, + Err(DocumentServiceFileError::Unauthorized) + )); + } } diff --git a/examples/file-service-wasm/file-service-backend/src/storage.rs b/examples/file-service-wasm/file-service-backend/src/storage.rs index b8c490c..31fe14c 100644 --- a/examples/file-service-wasm/file-service-backend/src/storage.rs +++ b/examples/file-service-wasm/file-service-backend/src/storage.rs @@ -2,6 +2,26 @@ use anyhow::Result; use std::path::PathBuf; use uuid::Uuid; +/// Extensions preserved on disk; anything else is stored as `.bin` so an +/// attacker-supplied filename cannot smuggle a renderable extension (e.g. +/// `.html`) into the download path. +const SAFE_EXTENSIONS: &[&str] = &[ + "txt", "png", "jpg", "jpeg", "webp", "gif", "pdf", "json", "bin", +]; + +/// Content types the example will echo back on download. Everything else is +/// served as `application/octet-stream` so uploaded content can never be +/// rendered inline by a browser. +const SAFE_CONTENT_TYPES: &[&str] = &[ + "text/plain", + "image/png", + "image/jpeg", + "image/webp", + "image/gif", + "application/pdf", + "application/json", +]; + #[derive(Debug, Clone)] pub struct FileMetadata { pub id: String, @@ -22,20 +42,47 @@ impl FileStorage { } } + /// Directory for a file's scope: anonymous uploads land in `public/`, + /// owned uploads under `users/{owner}/` so downloads can enforce + /// object ownership by construction. + fn scope_dir(&self, owner: Option<&str>) -> Result { + match owner { + None => Ok(self.base_path.join("public")), + Some(owner) => { + // The owner id becomes a path segment: restrict it to a safe + // charset instead of trusting arbitrary subject strings. + if owner.is_empty() + || !owner + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_') + { + anyhow::bail!("invalid owner id"); + } + Ok(self.base_path.join("users").join(owner)) + } + } + } + pub async fn save_file( &self, data: Vec, original_name: &str, content_type: Option, + owner: Option<&str>, ) -> Result { let file_id = Uuid::new_v4().to_string(); let extension = std::path::Path::new(original_name) .extension() .and_then(|ext| ext.to_str()) - .unwrap_or("bin"); + .map(|ext| ext.to_ascii_lowercase()) + .filter(|ext| SAFE_EXTENSIONS.contains(&ext.as_str())) + .unwrap_or_else(|| "bin".to_string()); + + let dir = self.scope_dir(owner)?; + tokio::fs::create_dir_all(&dir).await?; let stored_name = format!("{}.{}", file_id, extension); - let stored_path = self.base_path.join(&stored_name); + let stored_path = dir.join(&stored_name); // Save file tokio::fs::write(&stored_path, &data).await?; @@ -49,37 +96,57 @@ impl FileStorage { }) } - pub async fn get_file(&self, file_id: &str) -> Result<(Vec, Option)> { - // Find file with matching ID - let mut entries = tokio::fs::read_dir(&self.base_path).await?; + /// Fetch a file by id within a scope. Only files saved with the same + /// `owner` are visible, so a caller can never read another user's + /// objects, and the id must be a full UUID matched exactly against the + /// stored file stem — no prefix matching. + pub async fn get_file( + &self, + file_id: &str, + owner: Option<&str>, + ) -> Result<(Vec, Option)> { + // Strict id validation: rejects traversal characters and the + // prefix/enumeration tricks a partial id would allow. + if Uuid::parse_str(file_id).is_err() { + anyhow::bail!("File not found: {}", file_id); + } + + let dir = self.scope_dir(owner)?; + let Ok(mut entries) = tokio::fs::read_dir(&dir).await else { + anyhow::bail!("File not found: {}", file_id); + }; while let Some(entry) = entries.next_entry().await? { - let file_name = entry.file_name(); - let file_name_str = file_name.to_string_lossy(); - - if file_name_str.starts_with(file_id) { - let path = entry.path(); - let data = tokio::fs::read(&path).await?; - let metadata = entry.metadata().await?; - - // Try to guess original name from extension - let extension = path - .extension() - .and_then(|ext| ext.to_str()) - .unwrap_or("bin"); - - let file_metadata = FileMetadata { - id: file_id.to_string(), - original_name: format!("file.{}", extension), - size: metadata.len(), - content_type: mime_guess::from_path(&path) - .first() - .map(|mime| mime.to_string()), - stored_path: path, - }; - - return Ok((data, Some(file_metadata))); + let path = entry.path(); + if path.file_stem().and_then(|stem| stem.to_str()) != Some(file_id) { + continue; } + + let data = tokio::fs::read(&path).await?; + let metadata = entry.metadata().await?; + + let extension = path + .extension() + .and_then(|ext| ext.to_str()) + .unwrap_or("bin"); + + // Only echo back known-safe content types; everything else is + // an opaque download. + let content_type = mime_guess::from_path(&path) + .first() + .map(|mime| mime.to_string()) + .filter(|mime| SAFE_CONTENT_TYPES.contains(&mime.as_str())) + .unwrap_or_else(|| "application/octet-stream".to_string()); + + let file_metadata = FileMetadata { + id: file_id.to_string(), + original_name: format!("file.{}", extension), + size: metadata.len(), + content_type: Some(content_type), + stored_path: path, + }; + + return Ok((data, Some(file_metadata))); } anyhow::bail!("File not found: {}", file_id) @@ -101,6 +168,7 @@ mod tests { b"hello world".to_vec(), "greeting.txt", Some("text/plain".to_string()), + None, ) .await .expect("save file"); @@ -123,7 +191,7 @@ mod tests { let storage = FileStorage::new(temp_dir.path()); let metadata = storage - .save_file(b"secret".to_vec(), "../../secret.txt", None) + .save_file(b"secret".to_vec(), "../../secret.txt", None, None) .await .expect("save file"); @@ -142,16 +210,50 @@ mod tests { ); } + #[tokio::test] + async fn save_file_replaces_unsafe_extensions_with_bin() { + let temp_dir = TempDir::new().expect("temp dir"); + let storage = FileStorage::new(temp_dir.path()); + + let metadata = storage + .save_file( + b"".to_vec(), + "evil.html", + None, + None, + ) + .await + .expect("save file"); + + assert_eq!( + metadata + .stored_path + .extension() + .and_then(|extension| extension.to_str()), + Some("bin") + ); + + // And the download side serves it as an opaque blob. + let (_, fetched) = storage + .get_file(&metadata.id, None) + .await + .expect("get file"); + assert_eq!( + fetched.expect("metadata").content_type.as_deref(), + Some("application/octet-stream") + ); + } + #[tokio::test] async fn get_file_reads_saved_file_by_id() { let temp_dir = TempDir::new().expect("temp dir"); let storage = FileStorage::new(temp_dir.path()); let saved = storage - .save_file(b"download me".to_vec(), "download.txt", None) + .save_file(b"download me".to_vec(), "download.txt", None, None) .await .expect("save file"); - let (data, metadata) = storage.get_file(&saved.id).await.expect("get file"); + let (data, metadata) = storage.get_file(&saved.id, None).await.expect("get file"); assert_eq!(data, b"download me"); let metadata = metadata.expect("file metadata"); @@ -167,10 +269,58 @@ mod tests { let storage = FileStorage::new(temp_dir.path()); let error = storage - .get_file("missing") + .get_file("missing", None) .await .expect_err("missing file should be rejected"); assert!(error.to_string().contains("File not found: missing")); } + + #[tokio::test] + async fn get_file_rejects_id_prefixes() { + let temp_dir = TempDir::new().expect("temp dir"); + let storage = FileStorage::new(temp_dir.path()); + let saved = storage + .save_file(b"target".to_vec(), "target.txt", None, None) + .await + .expect("save file"); + + // A truncated id must not match by prefix (the old behavior). + let prefix = &saved.id[..8]; + assert!(storage.get_file(prefix, None).await.is_err()); + + // Nor may a non-UUID id reach the filesystem at all. + assert!(storage.get_file("..", None).await.is_err()); + } + + #[tokio::test] + async fn get_file_enforces_owner_scope() { + let temp_dir = TempDir::new().expect("temp dir"); + let storage = FileStorage::new(temp_dir.path()); + let saved = storage + .save_file(b"alice data".to_vec(), "doc.txt", None, Some("alice")) + .await + .expect("save file"); + + // The owner can read it. + assert!(storage.get_file(&saved.id, Some("alice")).await.is_ok()); + + // Another user cannot, even with a valid id. + assert!(storage.get_file(&saved.id, Some("bob")).await.is_err()); + + // Nor can the public scope. + assert!(storage.get_file(&saved.id, None).await.is_err()); + } + + #[tokio::test] + async fn save_file_rejects_path_like_owner_ids() { + let temp_dir = TempDir::new().expect("temp dir"); + let storage = FileStorage::new(temp_dir.path()); + + let error = storage + .save_file(b"x".to_vec(), "x.txt", None, Some("../escape")) + .await + .expect_err("path-like owner must be rejected"); + assert!(error.to_string().contains("invalid owner id")); + } } From c3ded22ec5deafbeda4bf17f407cc0dfc017cc0b Mon Sep 17 00:00:00 2001 From: Mathias Myrland Date: Tue, 9 Jun 2026 23:45:01 +0200 Subject: [PATCH 10/12] fix(oauth2): OIDC nonce, id_token claim validation, login-CSRF session binding Closes audit HIGH-3. Previously the identity was taken solely from the userinfo endpoint, id_tokens were never inspected, no nonce was sent, and the state parameter was not bound to the browser session that started the flow. - Every authorization request now carries a fresh OIDC nonce, stored in the flow state. - When the token endpoint returns an id_token, its claims are validated on callback: iss must equal the new OAuth2ProviderConfig::issuer field (when configured), aud must contain the client_id (string or array form), exp must be in the future, and nonce must echo the one we sent. The signature is intentionally not verified: the token arrives directly from the token endpoint over TLS, which OIDC Core 3.1.3.7 permits for the authorization-code flow. - Login-CSRF guard: start_flow_bound / generate_authorization_url_bound accept an unguessable per-browser-session binding value (e.g. a random cookie); the callback payload must present the identical binding or the flow is rejected before any token exchange. A failed attempt also burns the one-use state. - Google configs in the examples/README now set the issuer. - oauth2-demo-api adopts feature_gated: true, fixing its pre-existing package-scoped build break (cargo check -p oauth2-demo-server failed on master with ReqwestTransport missing from ras-transport-core). Co-Authored-By: Claude Fable 5 --- crates/identity/ras-identity-oauth2/README.md | 4 + .../examples/google_oauth2.rs | 2 + .../ras-identity-oauth2/src/client.rs | 252 +++++++++++++++++- .../ras-identity-oauth2/src/config.rs | 6 + .../identity/ras-identity-oauth2/src/error.rs | 3 + .../ras-identity-oauth2/src/provider.rs | 27 +- .../identity/ras-identity-oauth2/src/state.rs | 21 ++ .../identity/ras-identity-oauth2/src/tests.rs | 5 + .../identity/ras-identity-oauth2/src/types.rs | 5 + examples/oauth2-demo/api/Cargo.toml | 4 +- examples/oauth2-demo/api/src/lib.rs | 1 + examples/oauth2-demo/server/src/main.rs | 2 + 12 files changed, 328 insertions(+), 4 deletions(-) diff --git a/crates/identity/ras-identity-oauth2/README.md b/crates/identity/ras-identity-oauth2/README.md index 113fe0a..0f5d42c 100644 --- a/crates/identity/ras-identity-oauth2/README.md +++ b/crates/identity/ras-identity-oauth2/README.md @@ -16,6 +16,9 @@ OAuth2 identity provider implementation with PKCE support for Rust Agent Stack. - **PKCE Support**: Mitigates authorization code interception attacks - **State Parameter**: CSRF protection using cryptographically random UUIDs +- **OIDC Nonce**: Sent on every authorization request and verified against the id_token +- **id_token Claim Validation**: `iss` (when `issuer` is configured), `aud`, `exp` and `nonce` are checked on callback. The signature is not verified because the token arrives directly from the token endpoint over TLS, which OIDC Core §3.1.3.7 permits for the code flow +- **Session Binding (login-CSRF guard)**: `start_flow_bound` accepts an unguessable per-browser-session value (e.g. a random cookie); the callback payload must carry the identical `binding` or it is rejected, so an attacker cannot trick a victim into completing the attacker's flow - **Input Validation**: Robust handling of malformed responses - **Single-Use State**: Callback state is removed after successful retrieval @@ -111,6 +114,7 @@ let jwt_token = session_service.begin_session("oauth2", callback_payload).await? - `authorization_endpoint`: Provider's authorization URL - `token_endpoint`: Provider's token exchange URL - `userinfo_endpoint`: Provider's user info URL (optional) +- `issuer`: Expected `iss` claim of id_tokens (e.g. `https://accounts.google.com`); when set, id_tokens with a different issuer are rejected - `redirect_uri`: Your application's callback URL - `scopes`: Requested OAuth2 scopes - `auth_params`: Additional authorization parameters diff --git a/crates/identity/ras-identity-oauth2/examples/google_oauth2.rs b/crates/identity/ras-identity-oauth2/examples/google_oauth2.rs index 00c4629..4f0ea5b 100644 --- a/crates/identity/ras-identity-oauth2/examples/google_oauth2.rs +++ b/crates/identity/ras-identity-oauth2/examples/google_oauth2.rs @@ -28,6 +28,7 @@ async fn main() -> Result<(), Box> { authorization_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(), token_endpoint: "https://oauth2.googleapis.com/token".to_string(), userinfo_endpoint: Some("https://www.googleapis.com/oauth2/v1/userinfo".to_string()), + issuer: Some("https://accounts.google.com".to_string()), redirect_uri: "http://localhost:3000/auth/google/callback".to_string(), scopes: vec![ "openid".to_string(), @@ -153,6 +154,7 @@ mod tests { authorization_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(), token_endpoint: "https://oauth2.googleapis.com/token".to_string(), userinfo_endpoint: Some("https://www.googleapis.com/oauth2/v1/userinfo".to_string()), + issuer: Some("https://accounts.google.com".to_string()), redirect_uri: "http://localhost:3000/callback".to_string(), scopes: vec!["openid".to_string(), "email".to_string()], auth_params: HashMap::new(), diff --git a/crates/identity/ras-identity-oauth2/src/client.rs b/crates/identity/ras-identity-oauth2/src/client.rs index d8ee7a3..831be92 100644 --- a/crates/identity/ras-identity-oauth2/src/client.rs +++ b/crates/identity/ras-identity-oauth2/src/client.rs @@ -207,6 +207,22 @@ impl OAuth2Client { &self, provider_config: &OAuth2ProviderConfig, additional_params: HashMap, + ) -> OAuth2Result<(String, String)> { + self.generate_authorization_url_bound(provider_config, additional_params, None) + .await + } + + /// Generate an authorization URL bound to the initiating browser session. + /// + /// `binding` should be an unguessable value the integrator can recover on + /// callback (e.g. a random cookie value); the callback must then present + /// the identical value, preventing login CSRF where an attacker tricks a + /// victim into completing the attacker's flow. + pub async fn generate_authorization_url_bound( + &self, + provider_config: &OAuth2ProviderConfig, + additional_params: HashMap, + binding: Option, ) -> OAuth2Result<(String, String)> { let mut url = Url::parse(&provider_config.authorization_endpoint)?; @@ -217,13 +233,19 @@ impl OAuth2Client { None }; + // OIDC nonce: echoed back inside the id_token and verified on + // callback, binding the token to this authorization request. + let nonce = uuid::Uuid::new_v4().to_string(); + // Create and store state let state = OAuth2State::new( provider_config.provider_id.clone(), provider_config.redirect_uri.clone(), pkce.as_ref().map(|p| p.code_verifier.clone()), self.state_ttl_seconds, - ); + ) + .with_nonce(nonce.clone()) + .with_binding(binding); let state_param = state.state.clone(); self.state_store.store(state).await?; @@ -234,6 +256,7 @@ impl OAuth2Client { params.append_pair("client_id", &provider_config.client_id); params.append_pair("redirect_uri", &provider_config.redirect_uri); params.append_pair("state", &state_param); + params.append_pair("nonce", &nonce); // Add scopes if !provider_config.scopes.is_empty() { @@ -280,6 +303,12 @@ impl OAuth2Client { return Err(OAuth2Error::InvalidState); } + // When the flow was bound to a browser session, the callback must + // present the identical binding value (login-CSRF guard). + if state.binding.is_some() && state.binding != callback_response.binding { + return Err(OAuth2Error::InvalidState); + } + // Check for errors in callback if let Some(error) = &callback_response.error { let error_desc = callback_response @@ -301,6 +330,14 @@ impl OAuth2Client { ) .await?; + // Validate id_token claims when the provider returned one. The token + // arrived directly from the token endpoint over TLS, which OIDC Core + // §3.1.3.7 permits in place of signature validation for the code + // flow — but iss / aud / exp / nonce are still mandatory checks. + if let Some(id_token) = &token_response.id_token { + validate_id_token_claims(provider_config, id_token, state.nonce.as_deref())?; + } + Ok(token_response) } @@ -350,6 +387,79 @@ impl OAuth2Client { } } +/// Claims checked on an id_token returned by the token endpoint. +#[derive(serde::Deserialize)] +struct IdTokenClaims { + iss: Option, + aud: Option, + exp: Option, + nonce: Option, +} + +fn decode_id_token_claims(id_token: &str) -> OAuth2Result { + let payload = id_token + .split('.') + .nth(1) + .ok_or_else(|| OAuth2Error::InvalidIdToken("malformed JWT".to_string()))?; + let bytes = URL_SAFE_NO_PAD + .decode(payload) + .map_err(|_| OAuth2Error::InvalidIdToken("invalid base64 payload".to_string()))?; + serde_json::from_slice(&bytes) + .map_err(|_| OAuth2Error::InvalidIdToken("invalid JSON payload".to_string())) +} + +/// Validate the mandatory id_token claims: issuer (when configured), +/// audience, expiry, and the nonce echoed from the authorization request. +/// +/// The signature is not verified: the token was received directly from the +/// token endpoint over TLS, which OIDC Core §3.1.3.7 permits as a substitute +/// for signature validation in the authorization-code flow. +pub(crate) fn validate_id_token_claims( + provider_config: &OAuth2ProviderConfig, + id_token: &str, + expected_nonce: Option<&str>, +) -> OAuth2Result<()> { + let claims = decode_id_token_claims(id_token)?; + + if let Some(expected_issuer) = &provider_config.issuer + && claims.iss.as_deref() != Some(expected_issuer.as_str()) + { + return Err(OAuth2Error::InvalidIdToken(format!( + "issuer mismatch: expected {expected_issuer}" + ))); + } + + let audience_matches = match &claims.aud { + Some(serde_json::Value::String(aud)) => aud == &provider_config.client_id, + Some(serde_json::Value::Array(auds)) => auds + .iter() + .any(|aud| aud.as_str() == Some(provider_config.client_id.as_str())), + _ => false, + }; + if !audience_matches { + return Err(OAuth2Error::InvalidIdToken( + "audience does not include this client".to_string(), + )); + } + + match claims.exp { + Some(exp) if exp > chrono::Utc::now().timestamp() => {} + _ => { + return Err(OAuth2Error::InvalidIdToken( + "token expired or missing exp".to_string(), + )); + } + } + + if let Some(expected) = expected_nonce + && claims.nonce.as_deref() != Some(expected) + { + return Err(OAuth2Error::InvalidIdToken("nonce mismatch".to_string())); + } + + Ok(()) +} + #[cfg(test)] mod tests { use super::*; @@ -436,6 +546,7 @@ mod tests { authorization_endpoint: "https://example.com/auth".to_string(), token_endpoint: "https://example.com/token".to_string(), userinfo_endpoint: Some("https://example.com/userinfo".to_string()), + issuer: None, redirect_uri: "http://localhost:3000/callback".to_string(), scopes: vec!["openid".to_string(), "email".to_string()], auth_params: HashMap::new(), @@ -554,6 +665,7 @@ mod tests { state, error: None, error_description: None, + binding: None, }, ) .await @@ -583,6 +695,7 @@ mod tests { state, error: Some("access_denied".to_string()), error_description: Some("user denied consent".to_string()), + binding: None, }, ) .await @@ -618,6 +731,7 @@ mod tests { state, error: None, error_description: None, + binding: None, }, ) .await @@ -678,4 +792,140 @@ mod tests { )] ); } + + fn fake_id_token(payload: serde_json::Value) -> String { + let header = URL_SAFE_NO_PAD.encode(br#"{"alg":"RS256","typ":"JWT"}"#); + let payload = URL_SAFE_NO_PAD.encode(serde_json::to_vec(&payload).unwrap()); + format!("{header}.{payload}.signature") + } + + #[tokio::test] + async fn authorization_url_includes_nonce_and_stores_it() { + let state_store = Arc::new(InMemoryStateStore::new()); + let client = OAuth2Client::new(state_store.clone(), 600, 30); + + let (auth_url, state) = client + .generate_authorization_url(&provider_config(), HashMap::new()) + .await + .unwrap(); + + let url = Url::parse(&auth_url).unwrap(); + let params: HashMap<_, _> = url.query_pairs().collect(); + let url_nonce = params.get("nonce").expect("nonce in URL").to_string(); + assert!(!url_nonce.is_empty()); + + let stored = state_store.retrieve(&state).await.unwrap(); + assert_eq!(stored.nonce.as_deref(), Some(url_nonce.as_str())); + } + + #[test] + fn id_token_claim_validation_covers_iss_aud_exp_and_nonce() { + let mut config = provider_config(); + config.issuer = Some("https://issuer.test".to_string()); + let exp = chrono::Utc::now().timestamp() + 600; + + let good = fake_id_token(serde_json::json!({ + "iss": "https://issuer.test", + "aud": "test_client_id", + "exp": exp, + "nonce": "nonce-1", + })); + assert!(validate_id_token_claims(&config, &good, Some("nonce-1")).is_ok()); + + // aud may be an array containing this client + let aud_array = fake_id_token(serde_json::json!({ + "iss": "https://issuer.test", + "aud": ["other", "test_client_id"], + "exp": exp, + })); + assert!(validate_id_token_claims(&config, &aud_array, None).is_ok()); + + let bad_iss = fake_id_token(serde_json::json!({ + "iss": "https://evil.test", "aud": "test_client_id", "exp": exp, + })); + assert!(matches!( + validate_id_token_claims(&config, &bad_iss, None), + Err(OAuth2Error::InvalidIdToken(_)) + )); + + let bad_aud = fake_id_token(serde_json::json!({ + "iss": "https://issuer.test", "aud": "someone_else", "exp": exp, + })); + assert!(validate_id_token_claims(&config, &bad_aud, None).is_err()); + + let expired = fake_id_token(serde_json::json!({ + "iss": "https://issuer.test", + "aud": "test_client_id", + "exp": chrono::Utc::now().timestamp() - 10, + })); + assert!(validate_id_token_claims(&config, &expired, None).is_err()); + + let wrong_nonce = fake_id_token(serde_json::json!({ + "iss": "https://issuer.test", + "aud": "test_client_id", + "exp": exp, + "nonce": "other", + })); + assert!(validate_id_token_claims(&config, &wrong_nonce, Some("nonce-1")).is_err()); + + assert!(validate_id_token_claims(&config, "garbage", None).is_err()); + } + + #[tokio::test] + async fn handle_callback_enforces_session_binding() { + let state_store = Arc::new(InMemoryStateStore::new()); + let transport = Arc::new(RecordingTransport::new()); + let client = client_with_transport(state_store.clone(), transport.clone()); + let config = provider_config(); + + // A callback missing the binding is rejected before any token + // exchange (and the one-use state is burned). + let (_, state) = client + .generate_authorization_url_bound( + &config, + HashMap::new(), + Some("cookie-123".to_string()), + ) + .await + .unwrap(); + let err = client + .handle_callback( + &config, + AuthorizationResponse { + code: "code".to_string(), + state, + error: None, + error_description: None, + binding: None, + }, + ) + .await + .unwrap_err(); + assert!(matches!(err, OAuth2Error::InvalidState)); + assert!(transport.token_requests().is_empty()); + + // The matching binding completes the flow. + let (_, state) = client + .generate_authorization_url_bound( + &config, + HashMap::new(), + Some("cookie-123".to_string()), + ) + .await + .unwrap(); + client + .handle_callback( + &config, + AuthorizationResponse { + code: "code".to_string(), + state, + error: None, + error_description: None, + binding: Some("cookie-123".to_string()), + }, + ) + .await + .expect("bound callback succeeds"); + assert_eq!(transport.token_requests().len(), 1); + } } diff --git a/crates/identity/ras-identity-oauth2/src/config.rs b/crates/identity/ras-identity-oauth2/src/config.rs index 01f8a35..380a300 100644 --- a/crates/identity/ras-identity-oauth2/src/config.rs +++ b/crates/identity/ras-identity-oauth2/src/config.rs @@ -12,6 +12,11 @@ pub struct OAuth2ProviderConfig { pub authorization_endpoint: String, pub token_endpoint: String, pub userinfo_endpoint: Option, + /// Expected `iss` claim of id_tokens returned by this provider + /// (e.g. "https://accounts.google.com"). When set, callbacks carrying + /// an id_token with a different issuer are rejected. + #[serde(default)] + pub issuer: Option, pub redirect_uri: String, pub scopes: Vec, /// Additional parameters to include in authorization request @@ -93,6 +98,7 @@ mod tests { authorization_endpoint: "https://x/auth".into(), token_endpoint: "https://x/token".into(), userinfo_endpoint: Some("https://x/info".into()), + issuer: None, redirect_uri: "https://app/cb".into(), scopes: vec!["openid".into(), "email".into()], auth_params: HashMap::new(), diff --git a/crates/identity/ras-identity-oauth2/src/error.rs b/crates/identity/ras-identity-oauth2/src/error.rs index b9e293b..ebf90d6 100644 --- a/crates/identity/ras-identity-oauth2/src/error.rs +++ b/crates/identity/ras-identity-oauth2/src/error.rs @@ -24,6 +24,9 @@ pub enum OAuth2Error { #[error("Token exchange failed: {0}")] TokenExchangeFailed(String), + #[error("Invalid id_token: {0}")] + InvalidIdToken(String), + #[error("User info request failed: {0}")] UserInfoFailed(String), diff --git a/crates/identity/ras-identity-oauth2/src/provider.rs b/crates/identity/ras-identity-oauth2/src/provider.rs index f53e236..a056c8d 100644 --- a/crates/identity/ras-identity-oauth2/src/provider.rs +++ b/crates/identity/ras-identity-oauth2/src/provider.rs @@ -28,6 +28,10 @@ pub enum OAuth2AuthPayload { state: String, error: Option, error_description: Option, + /// Session-binding value captured when the flow was started (e.g. + /// from a cookie); required when the flow was started with one. + #[serde(default)] + binding: Option, }, } @@ -114,13 +118,28 @@ impl OAuth2Provider { &self, provider_id: &str, additional_params: Option>, + ) -> OAuth2Result { + self.start_flow_bound(provider_id, additional_params, None) + .await + } + + /// Start a flow bound to the initiating browser session. + /// + /// `binding` should be an unguessable value the integrator can recover on + /// callback (e.g. a random cookie value); the callback payload must then + /// carry the identical value or it is rejected, preventing login CSRF. + pub async fn start_flow_bound( + &self, + provider_id: &str, + additional_params: Option>, + binding: Option, ) -> OAuth2Result { let provider_config = self.get_provider_config(provider_id)?; let params = additional_params.unwrap_or_default(); let (auth_url, state) = self .client - .generate_authorization_url(provider_config, params) + .generate_authorization_url_bound(provider_config, params, binding) .await?; info!("Started OAuth2 flow for provider: {}", provider_id); @@ -139,6 +158,7 @@ impl OAuth2Provider { state: String, error: Option, error_description: Option, + binding: Option, ) -> OAuth2Result { let provider_config = self.get_provider_config(provider_id)?; @@ -147,6 +167,7 @@ impl OAuth2Provider { state, error, error_description, + binding, }; // Exchange code for tokens @@ -280,9 +301,10 @@ impl IdentityProvider for OAuth2Provider { state, error, error_description, + binding, } => { // For callback, we complete the flow and return the verified identity - self.handle_callback(&provider_id, code, state, error, error_description) + self.handle_callback(&provider_id, code, state, error, error_description, binding) .await .map_err(|e| IdentityError::ProviderError(e.to_string())) } @@ -304,6 +326,7 @@ mod tests { authorization_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(), token_endpoint: "https://oauth2.googleapis.com/token".to_string(), userinfo_endpoint: Some("https://www.googleapis.com/oauth2/v1/userinfo".to_string()), + issuer: None, redirect_uri: "http://localhost:3000/callback".to_string(), scopes: vec![ "openid".to_string(), diff --git a/crates/identity/ras-identity-oauth2/src/state.rs b/crates/identity/ras-identity-oauth2/src/state.rs index a0f6362..3521a94 100644 --- a/crates/identity/ras-identity-oauth2/src/state.rs +++ b/crates/identity/ras-identity-oauth2/src/state.rs @@ -16,6 +16,13 @@ pub struct OAuth2State { pub provider_id: String, pub redirect_uri: String, pub code_verifier: Option, + /// OIDC nonce sent in the authorization request; the id_token returned + /// on callback must echo it. + pub nonce: Option, + /// Optional caller-supplied value binding this flow to the browser + /// session that started it (e.g. a random cookie value). When set, the + /// callback must present the identical value, preventing login CSRF. + pub binding: Option, pub created_at: DateTime, pub expires_at: DateTime, pub metadata: Option, @@ -37,12 +44,26 @@ impl OAuth2State { provider_id, redirect_uri, code_verifier, + nonce: None, + binding: None, created_at, expires_at, metadata: None, } } + /// Attach an OIDC nonce to the flow. + pub fn with_nonce(mut self, nonce: String) -> Self { + self.nonce = Some(nonce); + self + } + + /// Bind the flow to the initiating browser session (login-CSRF guard). + pub fn with_binding(mut self, binding: Option) -> Self { + self.binding = binding; + self + } + pub fn is_expired(&self) -> bool { Utc::now() > self.expires_at } diff --git a/crates/identity/ras-identity-oauth2/src/tests.rs b/crates/identity/ras-identity-oauth2/src/tests.rs index 5097e2c..adffb5f 100644 --- a/crates/identity/ras-identity-oauth2/src/tests.rs +++ b/crates/identity/ras-identity-oauth2/src/tests.rs @@ -85,6 +85,7 @@ mod integration_tests { authorization_endpoint: "http://oauth.test/authorize".to_string(), token_endpoint: "http://oauth.test/token".to_string(), userinfo_endpoint: Some("http://oauth.test/userinfo".to_string()), + issuer: None, redirect_uri: "http://localhost:3000/callback".to_string(), scopes: vec!["openid".to_string(), "email".to_string()], auth_params: HashMap::new(), @@ -315,6 +316,7 @@ mod integration_tests { authorization_endpoint: "https://example.com/auth".to_string(), token_endpoint: "https://example.com/token".to_string(), userinfo_endpoint: None, + issuer: None, redirect_uri: "http://localhost:3000/callback".to_string(), scopes: vec![], auth_params: HashMap::new(), @@ -358,6 +360,7 @@ mod integration_tests { authorization_endpoint: "https://example.com/auth".to_string(), token_endpoint: "https://example.com/token".to_string(), userinfo_endpoint: None, + issuer: None, redirect_uri: "http://localhost:3000/callback".to_string(), scopes: vec![], auth_params: HashMap::new(), @@ -411,6 +414,7 @@ mod integration_tests { state: state.state, error: None, error_description: None, + binding: None, }; let result = client.handle_callback(&provider_config, callback).await; @@ -438,6 +442,7 @@ mod integration_tests { state: state2.state, error: None, error_description: None, + binding: None, }; let result = client.handle_callback(&provider_config, callback2).await; diff --git a/crates/identity/ras-identity-oauth2/src/types.rs b/crates/identity/ras-identity-oauth2/src/types.rs index e20e240..375dec0 100644 --- a/crates/identity/ras-identity-oauth2/src/types.rs +++ b/crates/identity/ras-identity-oauth2/src/types.rs @@ -21,6 +21,11 @@ pub struct AuthorizationResponse { pub state: String, pub error: Option, pub error_description: Option, + /// Session-binding value captured by the integrator when the flow was + /// started (e.g. from a cookie). Must match the value given to + /// `start_flow` for the same state, when one was supplied. + #[serde(default)] + pub binding: Option, } /// OAuth2 token response diff --git a/examples/oauth2-demo/api/Cargo.toml b/examples/oauth2-demo/api/Cargo.toml index 8b485cf..5528791 100644 --- a/examples/oauth2-demo/api/Cargo.toml +++ b/examples/oauth2-demo/api/Cargo.toml @@ -13,7 +13,9 @@ readme = "README.md" [features] default = [] server = ["ras-jsonrpc-macro/server", "dep:axum", "dep:ras-jsonrpc-core"] -client = ["ras-jsonrpc-macro/reqwest", "ras-transport-core/reqwest"] +client = ["ras-jsonrpc-macro/reqwest", "reqwest"] +# Enables the generated client's default ReqwestTransport constructor +reqwest = ["ras-transport-core/reqwest"] [dependencies] # JSON-RPC infrastructure diff --git a/examples/oauth2-demo/api/src/lib.rs b/examples/oauth2-demo/api/src/lib.rs index 14da006..e002afc 100644 --- a/examples/oauth2-demo/api/src/lib.rs +++ b/examples/oauth2-demo/api/src/lib.rs @@ -129,6 +129,7 @@ pub struct GetBetaFeaturesResponse { jsonrpc_service!({ service_name: GoogleOAuth2Service, openrpc: true, + feature_gated: true, methods: [ // Public endpoints (require authentication but no specific permissions) WITH_PERMISSIONS([]) get_user_info(GetUserInfoRequest) -> GetUserInfoResponse, diff --git a/examples/oauth2-demo/server/src/main.rs b/examples/oauth2-demo/server/src/main.rs index 0376112..da96270 100644 --- a/examples/oauth2-demo/server/src/main.rs +++ b/examples/oauth2-demo/server/src/main.rs @@ -129,6 +129,7 @@ fn create_oauth2_provider(config: &AppConfig) -> Result { authorization_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(), token_endpoint: "https://oauth2.googleapis.com/token".to_string(), userinfo_endpoint: Some("https://www.googleapis.com/oauth2/v3/userinfo".to_string()), + issuer: Some("https://accounts.google.com".to_string()), redirect_uri: config.redirect_uri.clone(), scopes: vec![ "openid".to_string(), @@ -228,6 +229,7 @@ async fn oauth2_callback_handler( state: state_param, error: callback_query.error, error_description: callback_query.error_description, + binding: None, }; let payload_json = serde_json::to_value(auth_payload) From 0e72f6c83e9d01b17230d9d1173c64b635a9a872 Mon Sep 17 00:00:00 2001 From: Mathias Myrland Date: Tue, 9 Jun 2026 23:49:50 +0200 Subject: [PATCH 11/12] fix: bound in-memory stores and add WebSocket connection cap Closes the audit MEDIUM findings on unbounded resources: - OAuth2 InMemoryStateStore now prunes expired flows opportunistically on every store and caps pending flows (default 10k, configurable via with_capacity); the cap returns TooManyPendingFlows instead of growing without bound. cleanup_expired no longer needs external scheduling. - SessionService gains active_session_count() and start_cleanup_task(), a background sweeper holding only a Weak reference (stops when the service is dropped), so expired sessions are pruned even through traffic lulls instead of only on begin/verify calls. - WebSocket services accept an optional max_connections cap (builder field + WebSocketService::max_connections). Upgrades beyond the cap are refused with 503 before the socket upgrade; non-axum transports are refused with a close frame before the connection is registered. ConnectionManager gains active_connection_count() (cheap override in DefaultConnectionManager). Co-Authored-By: Claude Fable 5 --- .../identity/ras-identity-oauth2/src/error.rs | 3 + .../identity/ras-identity-oauth2/src/state.rs | 75 ++++++++++++++ .../identity/ras-identity-session/src/lib.rs | 67 +++++++++++++ .../src/manager.rs | 4 + .../src/service.rs | 99 ++++++++++++++++++- .../src/manager.rs | 6 ++ 6 files changed, 253 insertions(+), 1 deletion(-) diff --git a/crates/identity/ras-identity-oauth2/src/error.rs b/crates/identity/ras-identity-oauth2/src/error.rs index ebf90d6..1c4fa4c 100644 --- a/crates/identity/ras-identity-oauth2/src/error.rs +++ b/crates/identity/ras-identity-oauth2/src/error.rs @@ -27,6 +27,9 @@ pub enum OAuth2Error { #[error("Invalid id_token: {0}")] InvalidIdToken(String), + #[error("Too many pending OAuth2 flows")] + TooManyPendingFlows, + #[error("User info request failed: {0}")] UserInfoFailed(String), diff --git a/crates/identity/ras-identity-oauth2/src/state.rs b/crates/identity/ras-identity-oauth2/src/state.rs index 3521a94..7cc9c88 100644 --- a/crates/identity/ras-identity-oauth2/src/state.rs +++ b/crates/identity/ras-identity-oauth2/src/state.rs @@ -83,14 +83,26 @@ pub trait OAuth2StateStore: Send + Sync { } /// In-memory implementation of OAuth2StateStore +/// Default cap on concurrently pending flows held in memory. +const DEFAULT_MAX_PENDING_STATES: usize = 10_000; + pub struct InMemoryStateStore { states: Arc>>, + max_states: usize, } impl InMemoryStateStore { pub fn new() -> Self { + Self::with_capacity(DEFAULT_MAX_PENDING_STATES) + } + + /// Create a store holding at most `max_states` pending flows. Expired + /// states are pruned opportunistically on every `store`, so no external + /// cleanup scheduling is required; `store` fails once the cap is hit. + pub fn with_capacity(max_states: usize) -> Self { Self { states: Arc::new(RwLock::new(HashMap::new())), + max_states, } } } @@ -105,6 +117,16 @@ impl Default for InMemoryStateStore { impl OAuth2StateStore for InMemoryStateStore { async fn store(&self, state: OAuth2State) -> OAuth2Result<()> { let mut states = self.states.write().await; + + // Opportunistic pruning: abandoned flows must not accumulate just + // because nobody schedules cleanup_expired. + let now = Utc::now(); + states.retain(|_, stored| now <= stored.expires_at); + + if states.len() >= self.max_states { + return Err(OAuth2Error::TooManyPendingFlows); + } + states.insert(state.state.clone(), state); Ok(()) } @@ -199,4 +221,57 @@ mod tests { let result = store.retrieve(&state.state).await; assert!(matches!(result, Err(OAuth2Error::StateNotFound))); } + + #[tokio::test] + async fn store_rejects_when_capacity_reached() { + let store = InMemoryStateStore::with_capacity(2); + for _ in 0..2 { + store + .store(OAuth2State::new( + "google".to_string(), + "http://localhost/cb".to_string(), + None, + 300, + )) + .await + .unwrap(); + } + + let err = store + .store(OAuth2State::new( + "google".to_string(), + "http://localhost/cb".to_string(), + None, + 300, + )) + .await + .unwrap_err(); + assert!(matches!(err, OAuth2Error::TooManyPendingFlows)); + } + + #[tokio::test] + async fn store_prunes_expired_states_opportunistically() { + let store = InMemoryStateStore::with_capacity(1); + + // An already-expired flow occupies the only slot... + let mut expired = OAuth2State::new( + "google".to_string(), + "http://localhost/cb".to_string(), + None, + 300, + ); + expired.expires_at = Utc::now() - Duration::seconds(10); + store.store(expired).await.unwrap(); + + // ...but is pruned when the next flow is stored. + store + .store(OAuth2State::new( + "google".to_string(), + "http://localhost/cb".to_string(), + None, + 300, + )) + .await + .expect("expired state must be pruned to make room"); + } } diff --git a/crates/identity/ras-identity-session/src/lib.rs b/crates/identity/ras-identity-session/src/lib.rs index df42077..e883f65 100644 --- a/crates/identity/ras-identity-session/src/lib.rs +++ b/crates/identity/ras-identity-session/src/lib.rs @@ -382,6 +382,36 @@ impl SessionService { Ok(claims) } + /// Number of sessions currently held in the in-memory store + /// (only populated when `enforce_active_sessions` is on). + pub async fn active_session_count(&self) -> usize { + self.active_sessions.read().await.len() + } + + /// Spawn a background task pruning expired sessions every `interval`. + /// + /// Expired sessions are otherwise only pruned opportunistically when + /// begin_session/verify_session run, so a traffic lull leaves them in + /// memory indefinitely. The task holds only a weak reference and stops + /// when the service is dropped (or when the returned handle is aborted). + pub fn start_cleanup_task( + self: &std::sync::Arc, + interval: std::time::Duration, + ) -> tokio::task::JoinHandle<()> { + let service = std::sync::Arc::downgrade(self); + tokio::spawn(async move { + let mut timer = tokio::time::interval(interval); + timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + loop { + timer.tick().await; + let Some(service) = service.upgrade() else { + break; + }; + service.cleanup_expired_sessions().await; + } + }) + } + pub async fn end_session(&self, jti: &str) -> Option { let mut sessions = self.active_sessions.write().await; sessions.remove(jti) @@ -682,4 +712,41 @@ mod tests { assert!(user.permissions.contains("chat:read")); assert!(user.metadata.is_none()); } + + #[tokio::test] + async fn cleanup_task_prunes_expired_sessions_in_background() { + let config = SessionConfig::new(TEST_SECRET).unwrap(); + let service = std::sync::Arc::new(SessionService::new(config).unwrap()); + + // Plant an already-expired session directly in the store. + let now = chrono::Utc::now().timestamp(); + service.active_sessions.write().await.insert( + "expired-jti".to_string(), + JwtClaims { + sub: "alice".to_string(), + exp: now - 10, + iat: now - 20, + jti: "expired-jti".to_string(), + provider_id: "local".to_string(), + email: None, + display_name: None, + permissions: HashSet::new(), + metadata: None, + }, + ); + assert_eq!(service.active_session_count().await, 1); + + let handle = service.start_cleanup_task(std::time::Duration::from_millis(20)); + + // The sweeper prunes the expired session without any begin/verify call. + tokio::time::timeout(std::time::Duration::from_secs(5), async { + while service.active_session_count().await != 0 { + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + } + }) + .await + .expect("cleanup task prunes expired sessions"); + + handle.abort(); + } } diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/manager.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/manager.rs index 78ebeda..5213133 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/manager.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/manager.rs @@ -174,6 +174,10 @@ impl ConnectionManager for DefaultConnectionManager { .collect()) } + async fn active_connection_count(&self) -> Result { + Ok(self.connections.len()) + } + async fn get_subscribed_connections(&self, topic: &str) -> Result> { let connection_ids = self.get_topic_connections(topic); let mut connections = Vec::new(); diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/service.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/service.rs index dec3e17..bde528e 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/service.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/service.rs @@ -4,7 +4,10 @@ use crate::{ ConnectionContext, DefaultConnectionManager, MessageHandler, MessageRouter, ServerError, ServerResult, WebSocketHandler, WebSocketUpgrade, connection::ChannelMessageSender, - handler::{AuthRevalidation, AxumWebSocketIo, DEFAULT_AUTH_REVALIDATION_INTERVAL, WebSocketIo}, + handler::{ + AuthRevalidation, AxumWebSocketIo, DEFAULT_AUTH_REVALIDATION_INTERVAL, WebSocketIo, + WebSocketIoMessage, + }, }; use axum::{ extract::{State, ws::WebSocketUpgrade as AxumWebSocketUpgrade}, @@ -62,12 +65,35 @@ pub trait WebSocketService: Clone + Send + Sync + 'static { DEFAULT_AUTH_REVALIDATION_INTERVAL } + /// Maximum simultaneous connections; further upgrades are refused. + /// + /// The check is advisory (checked-then-added, not atomic), so a burst + /// can briefly overshoot by a few connections. + fn max_connections(&self) -> Option { + None + } + /// Handle WebSocket upgrade async fn handle_upgrade( &self, upgrade: AxumWebSocketUpgrade, headers: HeaderMap, ) -> Result { + // Refuse before upgrading when the connection cap is reached. + if let Some(limit) = self.max_connections() { + let current = self + .connection_manager() + .active_connection_count() + .await + .unwrap_or(0); + if current >= limit { + return Err(( + axum::http::StatusCode::SERVICE_UNAVAILABLE, + "Connection limit reached".to_string(), + )); + } + } + let ws_upgrade = WebSocketUpgrade::new(upgrade, headers); // Captured pre-upgrade so the connection can periodically re-validate it. let auth_token = ws_upgrade.extract_auth_token(); @@ -132,6 +158,25 @@ where let connection_id = ConnectionId::new(); info!("New WebSocket connection: {}", connection_id); + // Enforce the connection cap for transports that bypass handle_upgrade. + if let Some(limit) = service.max_connections() { + let current = service + .connection_manager() + .active_connection_count() + .await + .unwrap_or(0); + if current >= limit { + let _ = socket + .send(WebSocketIoMessage::Close(Some( + "connection limit reached".to_string(), + ))) + .await; + return Err(ServerError::Internal( + "connection limit reached".to_string(), + )); + } + } + let channel_capacity = service.message_channel_capacity().max(1); let (message_tx, message_rx) = mpsc::channel(channel_capacity); let sender = ChannelMessageSender::new(connection_id, message_tx); @@ -203,6 +248,8 @@ pub struct WebSocketServiceBuilder { /// Interval between credential re-validations for live connections #[builder(default = DEFAULT_AUTH_REVALIDATION_INTERVAL)] auth_revalidation_interval: Duration, + /// Maximum simultaneous connections (None = unbounded) + max_connections: Option, } impl WebSocketServiceBuilder @@ -222,6 +269,7 @@ where message_channel_capacity: self.message_channel_capacity, max_message_size: self.max_message_size, auth_revalidation_interval: self.auth_revalidation_interval, + max_connections: self.max_connections, } } } @@ -242,6 +290,7 @@ where message_channel_capacity: self.message_channel_capacity, max_message_size: self.max_message_size, auth_revalidation_interval: self.auth_revalidation_interval, + max_connections: self.max_connections, } } } @@ -255,6 +304,7 @@ pub struct BuiltWebSocketService { message_channel_capacity: usize, max_message_size: usize, auth_revalidation_interval: Duration, + max_connections: Option, } impl Clone for BuiltWebSocketService { @@ -267,6 +317,7 @@ impl Clone for BuiltWebSocketService { message_channel_capacity: self.message_channel_capacity, max_message_size: self.max_message_size, auth_revalidation_interval: self.auth_revalidation_interval, + max_connections: self.max_connections, } } } @@ -308,6 +359,10 @@ where fn auth_revalidation_interval(&self) -> Duration { self.auth_revalidation_interval } + + fn max_connections(&self) -> Option { + self.max_connections + } } /// Convenience function to create a simple router-based service @@ -448,6 +503,48 @@ mod tests { assert_eq!(service.auth_revalidation_interval(), Duration::from_secs(5)); } + #[tokio::test] + async fn connection_cap_refuses_excess_connections() { + let manager = Arc::new(DefaultConnectionManager::new()); + let builder = WebSocketServiceBuilder::builder() + .handler(Arc::new(MessageRouter::new())) + .auth_provider(Arc::new(MockAuthProvider)) + .max_connections(0) + .build(); + let service = builder.build_with_manager(manager.clone()); + + let mut socket = InMemorySocket::closing([]); + let result = service + .handle_connection_with_io(&mut socket, Some(test_user())) + .await; + + assert!(result.is_err(), "connection over the cap must be refused"); + assert!( + socket + .outgoing + .iter() + .any(|message| matches!(message, WebSocketIoMessage::Close(_))) + ); + assert_eq!(manager.connection_count(), 0); + } + + #[tokio::test] + async fn connection_cap_admits_connections_under_the_limit() { + let manager = Arc::new(DefaultConnectionManager::new()); + let builder = WebSocketServiceBuilder::builder() + .handler(Arc::new(MessageRouter::new())) + .auth_provider(Arc::new(MockAuthProvider)) + .max_connections(1) + .build(); + let service = builder.build_with_manager(manager.clone()); + + let mut socket = InMemorySocket::closing([]); + service + .handle_connection_with_io(&mut socket, Some(test_user())) + .await + .expect("connection under the cap is admitted"); + } + #[tokio::test] async fn test_service_with_auth_required() { let router = MessageRouter::new(); diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types/src/manager.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types/src/manager.rs index c62da3c..ed8bd0d 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types/src/manager.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types/src/manager.rs @@ -30,6 +30,12 @@ pub trait ConnectionManager: Send + Sync { /// Get all active connections async fn get_all_connections(&self) -> Result>; + /// Number of active connections. Implementations should override this + /// with a cheap counter; the default clones every ConnectionInfo. + async fn active_connection_count(&self) -> Result { + Ok(self.get_all_connections().await?.len()) + } + /// Get connections subscribed to a topic async fn get_subscribed_connections(&self, topic: &str) -> Result>; From c07af39fd1b1ffad71116fb47924efbce02c0d30 Mon Sep 17 00:00:00 2001 From: Mathias Myrland Date: Wed, 10 Jun 2026 00:32:39 +0200 Subject: [PATCH 12/12] docs(oauth2): fix bare URL in issuer field doc comment rustdoc runs with -D warnings in CI; the bare https URL tripped rustdoc::bare-urls. Wrap it in a code span. Co-Authored-By: Claude Fable 5 --- crates/identity/ras-identity-oauth2/src/config.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/identity/ras-identity-oauth2/src/config.rs b/crates/identity/ras-identity-oauth2/src/config.rs index 380a300..982245f 100644 --- a/crates/identity/ras-identity-oauth2/src/config.rs +++ b/crates/identity/ras-identity-oauth2/src/config.rs @@ -13,7 +13,7 @@ pub struct OAuth2ProviderConfig { pub token_endpoint: String, pub userinfo_endpoint: Option, /// Expected `iss` claim of id_tokens returned by this provider - /// (e.g. "https://accounts.google.com"). When set, callbacks carrying + /// (e.g. `https://accounts.google.com`). When set, callbacks carrying /// an id_token with a different issuer are rejected. #[serde(default)] pub issuer: Option,