diff --git a/Cargo.lock b/Cargo.lock
index e5f66ed..f97130b 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -2646,6 +2646,7 @@ dependencies = [
"serde",
"serde_json",
"thiserror 2.0.18",
+ "tokio",
]
[[package]]
diff --git a/crates/core/ras-auth-core/Cargo.toml b/crates/core/ras-auth-core/Cargo.toml
index 3e382f5..5d85cfc 100644
--- a/crates/core/ras-auth-core/Cargo.toml
+++ b/crates/core/ras-auth-core/Cargo.toml
@@ -15,3 +15,6 @@ http = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
+
+[dev-dependencies]
+tokio = { workspace = true }
diff --git a/crates/core/ras-auth-core/src/authorize.rs b/crates/core/ras-auth-core/src/authorize.rs
new file mode 100644
index 0000000..e245f52
--- /dev/null
+++ b/crates/core/ras-auth-core/src/authorize.rs
@@ -0,0 +1,253 @@
+//! Shared request-authorization pipeline for generated services.
+//!
+//! Every service macro (REST, file, JSON-RPC, bidirectional WebSocket) used
+//! to inline its own copy of the credential → CSRF → authenticate →
+//! permission-group sequence. These helpers are the single implementation;
+//! generated code maps the returned [`AuthorizeError`] to its own protocol's
+//! response shape.
+
+use crate::{
+ AuthError, AuthProvider, AuthTransportConfig, AuthenticatedUser, extract_auth_credential,
+ validate_csrf_for_credential,
+};
+use http::HeaderMap;
+
+/// Why [`authorize_request`] rejected a request.
+#[derive(Debug)]
+pub enum AuthorizeError {
+ /// No usable credential was found in the request
+ MissingCredential,
+ /// Double-submit CSRF validation failed for a cookie credential
+ CsrfValidationFailed,
+ /// The credential did not authenticate
+ AuthenticationFailed(AuthError),
+ /// The service was built without an auth provider
+ NoAuthProvider,
+ /// Authenticated, but no required permission group was satisfied
+ InsufficientPermissions(AuthError),
+}
+
+/// OR-of-AND permission check shared by all generated services.
+///
+/// `groups` is a disjunction of conjunctions: access is granted when the user
+/// holds every permission of at least one group (verified through the
+/// provider's `check_permissions`, which custom providers may override). A
+/// group list with no non-empty groups — `WITH_PERMISSIONS([])` or any empty
+/// inner group — grants access to any authenticated user.
+pub fn check_permission_groups
(
+ provider: &P,
+ user: &AuthenticatedUser,
+ groups: &[Vec],
+) -> Result<(), AuthError>
+where
+ P: AuthProvider + ?Sized,
+{
+ if !groups.iter().any(|group| !group.is_empty()) {
+ return Ok(());
+ }
+
+ for group in groups {
+ if group.is_empty() || provider.check_permissions(user, group).is_ok() {
+ return Ok(());
+ }
+ }
+
+ Err(AuthError::InsufficientPermissions {
+ required: groups
+ .iter()
+ .find(|group| !group.is_empty())
+ .cloned()
+ .unwrap_or_default(),
+ has: user.permissions.iter().cloned().collect(),
+ })
+}
+
+/// Set-membership variant of [`check_permission_groups`] for contexts without
+/// an auth provider (e.g. the bidirectional WebSocket handler, which
+/// authorizes against the cached connection user).
+pub fn user_satisfies_permission_groups(user: &AuthenticatedUser, groups: &[Vec]) -> bool {
+ if !groups.iter().any(|group| !group.is_empty()) {
+ return true;
+ }
+
+ groups
+ .iter()
+ .any(|group| !group.is_empty() && group.iter().all(|perm| user.permissions.contains(perm)))
+ || groups.iter().any(|group| group.is_empty())
+}
+
+/// The credential → CSRF → authenticate → permission pipeline shared by the
+/// generated REST and file-service servers.
+///
+/// `method` is the HTTP method, used to scope CSRF validation to unsafe
+/// requests. Errors are ordered so no work happens for unauthenticated
+/// callers: the request body has not been touched when this returns `Err`.
+pub async fn authorize_request
(
+ method: &str,
+ headers: &HeaderMap,
+ auth_transport: &AuthTransportConfig,
+ auth_provider: Option<&P>,
+ required_permission_groups: &[Vec],
+) -> Result
+where
+ P: AuthProvider + ?Sized,
+{
+ let credential = extract_auth_credential(headers, auth_transport)
+ .map_err(|_| AuthorizeError::MissingCredential)?;
+
+ validate_csrf_for_credential(method, headers, &credential, auth_transport)
+ .map_err(|_| AuthorizeError::CsrfValidationFailed)?;
+
+ let provider = auth_provider.ok_or(AuthorizeError::NoAuthProvider)?;
+
+ let user = provider
+ .authenticate(credential.token().to_string())
+ .await
+ .map_err(AuthorizeError::AuthenticationFailed)?;
+
+ check_permission_groups(provider, &user, required_permission_groups)
+ .map_err(AuthorizeError::InsufficientPermissions)?;
+
+ Ok(user)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::AuthFuture;
+ use std::collections::HashSet;
+
+ struct StaticProvider;
+
+ impl AuthProvider for StaticProvider {
+ fn authenticate(&self, token: String) -> AuthFuture<'_> {
+ Box::pin(async move {
+ if token == "good" {
+ Ok(user(&["read", "write"]))
+ } else {
+ Err(AuthError::InvalidToken)
+ }
+ })
+ }
+ }
+
+ fn user(perms: &[&str]) -> AuthenticatedUser {
+ AuthenticatedUser {
+ user_id: "u".into(),
+ permissions: perms.iter().map(|p| p.to_string()).collect::>(),
+ metadata: None,
+ }
+ }
+
+ fn groups(groups: &[&[&str]]) -> Vec> {
+ groups
+ .iter()
+ .map(|g| g.iter().map(|p| p.to_string()).collect())
+ .collect()
+ }
+
+ #[test]
+ fn empty_group_list_is_authenticated_only() {
+ assert!(check_permission_groups(&StaticProvider, &user(&[]), &[]).is_ok());
+ assert!(user_satisfies_permission_groups(&user(&[]), &[]));
+ }
+
+ #[test]
+ fn empty_inner_group_grants_any_authenticated_user() {
+ let g = groups(&[&["admin"], &[]]);
+ assert!(check_permission_groups(&StaticProvider, &user(&[]), &g).is_ok());
+ assert!(user_satisfies_permission_groups(&user(&[]), &g));
+ }
+
+ #[test]
+ fn and_within_group_or_between_groups() {
+ let g = groups(&[&["read", "write"], &["admin"]]);
+
+ // Satisfies the first group (all permissions present).
+ assert!(check_permission_groups(&StaticProvider, &user(&["read", "write"]), &g).is_ok());
+ assert!(user_satisfies_permission_groups(
+ &user(&["read", "write"]),
+ &g
+ ));
+
+ // Satisfies the second group.
+ assert!(check_permission_groups(&StaticProvider, &user(&["admin"]), &g).is_ok());
+ assert!(user_satisfies_permission_groups(&user(&["admin"]), &g));
+
+ // Partial match on the first group, none on the second: denied.
+ let denied = check_permission_groups(&StaticProvider, &user(&["read"]), &g).unwrap_err();
+ assert!(matches!(
+ denied,
+ AuthError::InsufficientPermissions { required, .. } if required == vec!["read", "write"]
+ ));
+ assert!(!user_satisfies_permission_groups(&user(&["read"]), &g));
+ }
+
+ #[tokio::test]
+ async fn authorize_request_full_pipeline() {
+ let transport = AuthTransportConfig::default();
+ let mut headers = HeaderMap::new();
+
+ // No credential
+ let err = authorize_request(
+ "POST",
+ &headers,
+ &transport,
+ Some(&StaticProvider),
+ &groups(&[&["read"]]),
+ )
+ .await
+ .unwrap_err();
+ assert!(matches!(err, AuthorizeError::MissingCredential));
+
+ headers.insert("authorization", "Bearer bad".parse().unwrap());
+ let err = authorize_request(
+ "POST",
+ &headers,
+ &transport,
+ Some(&StaticProvider),
+ &groups(&[&["read"]]),
+ )
+ .await
+ .unwrap_err();
+ assert!(matches!(err, AuthorizeError::AuthenticationFailed(_)));
+
+ headers.insert("authorization", "Bearer good".parse().unwrap());
+
+ // Missing provider
+ let err = authorize_request(
+ "POST",
+ &headers,
+ &transport,
+ None::<&StaticProvider>,
+ &groups(&[&["read"]]),
+ )
+ .await
+ .unwrap_err();
+ assert!(matches!(err, AuthorizeError::NoAuthProvider));
+
+ // Insufficient permissions
+ let err = authorize_request(
+ "POST",
+ &headers,
+ &transport,
+ Some(&StaticProvider),
+ &groups(&[&["admin"]]),
+ )
+ .await
+ .unwrap_err();
+ assert!(matches!(err, AuthorizeError::InsufficientPermissions(_)));
+
+ // Success
+ let user = authorize_request(
+ "POST",
+ &headers,
+ &transport,
+ Some(&StaticProvider),
+ &groups(&[&["read", "write"]]),
+ )
+ .await
+ .unwrap();
+ assert_eq!(user.user_id, "u");
+ }
+}
diff --git a/crates/core/ras-auth-core/src/lib.rs b/crates/core/ras-auth-core/src/lib.rs
index 54c3056..516bfa4 100644
--- a/crates/core/ras-auth-core/src/lib.rs
+++ b/crates/core/ras-auth-core/src/lib.rs
@@ -1,5 +1,6 @@
//! Authentication and authorization traits for JSON-RPC services.
+mod authorize;
mod transport;
use std::collections::HashSet;
@@ -9,6 +10,7 @@ use std::pin::Pin;
use serde::{Deserialize, Serialize};
use thiserror::Error;
+pub use authorize::*;
pub use transport::*;
/// Errors that can occur during authentication or authorization.
diff --git a/crates/identity/ras-identity-oauth2/README.md b/crates/identity/ras-identity-oauth2/README.md
index a6f1c9d..0f5d42c 100644
--- a/crates/identity/ras-identity-oauth2/README.md
+++ b/crates/identity/ras-identity-oauth2/README.md
@@ -16,6 +16,9 @@ OAuth2 identity provider implementation with PKCE support for Rust Agent Stack.
- **PKCE Support**: Mitigates authorization code interception attacks
- **State Parameter**: CSRF protection using cryptographically random UUIDs
+- **OIDC Nonce**: Sent on every authorization request and verified against the id_token
+- **id_token Claim Validation**: `iss` (when `issuer` is configured), `aud`, `exp` and `nonce` are checked on callback. The signature is not verified because the token arrives directly from the token endpoint over TLS, which OIDC Core §3.1.3.7 permits for the code flow
+- **Session Binding (login-CSRF guard)**: `start_flow_bound` accepts an unguessable per-browser-session value (e.g. a random cookie); the callback payload must carry the identical `binding` or it is rejected, so an attacker cannot trick a victim into completing the attacker's flow
- **Input Validation**: Robust handling of malformed responses
- **Single-Use State**: Callback state is removed after successful retrieval
@@ -58,38 +61,26 @@ let oauth2_provider = OAuth2Provider::new(config, state_store);
### Integration with Session Service
```rust
-use ras_identity_core::{IdentityError, IdentityProvider};
+use ras_identity_core::IdentityProvider;
use ras_identity_oauth2::OAuth2Response;
-use ras_identity_session::{SessionConfig, SessionError, SessionService};
+use ras_identity_session::{SessionConfig, SessionService};
-// Register with session service
+// Register with session service. The provider is cheap to clone; keep one
+// handle for flow initiation and register the other for verification.
let session_config = SessionConfig::new("use-at-least-32-bytes-of-random-secret")?;
let session_service = SessionService::new(session_config)?;
-session_service.register_provider(Box::new(oauth2_provider)).await;
+session_service.register_provider(Box::new(oauth2_provider.clone())).await;
// Start OAuth2 flow
-let start_payload = serde_json::json!({
- "type": "StartFlow",
- "provider_id": "google"
-});
-
-// This will return an error containing the authorization URL
-match session_service.begin_session("oauth2", start_payload).await {
- Err(SessionError::IdentityError(IdentityError::ProviderError(json))) => {
- let response: OAuth2Response = serde_json::from_str(&json)?;
- match response {
- OAuth2Response::AuthorizationUrl { url, state } => {
- // Redirect user to `url`
- println!("Redirect to: {}", url);
- }
- OAuth2Response::Error { message } => {
- eprintln!("OAuth2 start-flow failed: {message}");
- }
- }
+match oauth2_provider.start_flow("google", None).await? {
+ OAuth2Response::AuthorizationUrl { url, state } => {
+ // Redirect user to `url`
+ println!("Redirect to: {}", url);
+ }
+ OAuth2Response::Error { message } => {
+ eprintln!("OAuth2 start-flow failed: {message}");
}
- Ok(_) => eprintln!("OAuth2 start flow completed without a redirect"),
- Err(err) => eprintln!("OAuth2 start flow failed: {err}"),
}
// Handle callback
@@ -123,6 +114,7 @@ let jwt_token = session_service.begin_session("oauth2", callback_payload).await?
- `authorization_endpoint`: Provider's authorization URL
- `token_endpoint`: Provider's token exchange URL
- `userinfo_endpoint`: Provider's user info URL (optional)
+- `issuer`: Expected `iss` claim of id_tokens (e.g. `https://accounts.google.com`); when set, id_tokens with a different issuer are rejected
- `redirect_uri`: Your application's callback URL
- `scopes`: Requested OAuth2 scopes
- `auth_params`: Additional authorization parameters
diff --git a/crates/identity/ras-identity-oauth2/examples/google_oauth2.rs b/crates/identity/ras-identity-oauth2/examples/google_oauth2.rs
index 2951cc2..4f0ea5b 100644
--- a/crates/identity/ras-identity-oauth2/examples/google_oauth2.rs
+++ b/crates/identity/ras-identity-oauth2/examples/google_oauth2.rs
@@ -6,7 +6,6 @@
//! 3. Handling the OAuth2 flow
//! 4. Issuing JWTs after successful authentication
-use ras_identity_core::IdentityError;
use ras_identity_oauth2::{
InMemoryStateStore, OAuth2Config, OAuth2Provider, OAuth2ProviderConfig, OAuth2Response,
};
@@ -29,6 +28,7 @@ async fn main() -> Result<(), Box> {
authorization_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
token_endpoint: "https://oauth2.googleapis.com/token".to_string(),
userinfo_endpoint: Some("https://www.googleapis.com/oauth2/v1/userinfo".to_string()),
+ issuer: Some("https://accounts.google.com".to_string()),
redirect_uri: "http://localhost:3000/auth/google/callback".to_string(),
scopes: vec![
"openid".to_string(),
@@ -46,7 +46,9 @@ async fn main() -> Result<(), Box> {
.with_state_ttl(600) // 10 minutes
.with_http_timeout(30); // 30 seconds
- // Create state store and OAuth2 provider
+ // Create state store and OAuth2 provider. The provider is cheap to clone;
+ // keep one handle for flow initiation and register the other for
+ // verification through the session service.
let state_store = Arc::new(InMemoryStateStore::new());
let oauth2_provider = OAuth2Provider::new(oauth2_config, state_store);
@@ -57,7 +59,7 @@ async fn main() -> Result<(), Box> {
// Register OAuth2 provider with session service
session_service
- .register_provider(Box::new(oauth2_provider))
+ .register_provider(Box::new(oauth2_provider.clone()))
.await;
println!("OAuth2 Example - Google Authentication");
@@ -66,36 +68,20 @@ async fn main() -> Result<(), Box> {
// Step 1: Start OAuth2 flow
println!("\n1. Starting OAuth2 flow...");
- let start_payload = serde_json::json!({
- "type": "StartFlow",
- "provider_id": "google"
- });
-
- match session_service.begin_session("oauth2", start_payload).await {
- Err(ras_identity_session::SessionError::IdentityError(IdentityError::ProviderError(
- json,
- ))) => {
- let response: OAuth2Response = serde_json::from_str(&json)?;
-
- match response {
- OAuth2Response::AuthorizationUrl { url, state } => {
- println!("Authorization URL: {}", url);
- println!("State: {}", state);
- println!("\nIn a real application, you would:");
- println!("1. Redirect the user to the authorization URL");
- println!("2. Handle the callback with the authorization code");
- println!("3. Exchange the code for a JWT token");
-
- // Simulate callback (in real app, this comes from OAuth2 provider)
- simulate_callback(&session_service, state).await?;
- }
- _ => {
- println!("Unexpected response from OAuth2 provider");
- }
- }
+ match oauth2_provider.start_flow("google", None).await {
+ Ok(OAuth2Response::AuthorizationUrl { url, state }) => {
+ println!("Authorization URL: {}", url);
+ println!("State: {}", state);
+ println!("\nIn a real application, you would:");
+ println!("1. Redirect the user to the authorization URL");
+ println!("2. Handle the callback with the authorization code");
+ println!("3. Exchange the code for a JWT token");
+
+ // Simulate callback (in real app, this comes from OAuth2 provider)
+ simulate_callback(&session_service, state).await?;
}
- Ok(_) => {
- println!("Unexpected success from start flow");
+ Ok(OAuth2Response::Error { message }) => {
+ println!("OAuth2 error: {}", message);
}
Err(e) => {
println!("Error starting OAuth2 flow: {}", e);
@@ -168,6 +154,7 @@ mod tests {
authorization_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
token_endpoint: "https://oauth2.googleapis.com/token".to_string(),
userinfo_endpoint: Some("https://www.googleapis.com/oauth2/v1/userinfo".to_string()),
+ issuer: Some("https://accounts.google.com".to_string()),
redirect_uri: "http://localhost:3000/callback".to_string(),
scopes: vec!["openid".to_string(), "email".to_string()],
auth_params: HashMap::new(),
diff --git a/crates/identity/ras-identity-oauth2/src/client.rs b/crates/identity/ras-identity-oauth2/src/client.rs
index d8ee7a3..831be92 100644
--- a/crates/identity/ras-identity-oauth2/src/client.rs
+++ b/crates/identity/ras-identity-oauth2/src/client.rs
@@ -207,6 +207,22 @@ impl OAuth2Client {
&self,
provider_config: &OAuth2ProviderConfig,
additional_params: HashMap,
+ ) -> OAuth2Result<(String, String)> {
+ self.generate_authorization_url_bound(provider_config, additional_params, None)
+ .await
+ }
+
+ /// Generate an authorization URL bound to the initiating browser session.
+ ///
+ /// `binding` should be an unguessable value the integrator can recover on
+ /// callback (e.g. a random cookie value); the callback must then present
+ /// the identical value, preventing login CSRF where an attacker tricks a
+ /// victim into completing the attacker's flow.
+ pub async fn generate_authorization_url_bound(
+ &self,
+ provider_config: &OAuth2ProviderConfig,
+ additional_params: HashMap,
+ binding: Option,
) -> OAuth2Result<(String, String)> {
let mut url = Url::parse(&provider_config.authorization_endpoint)?;
@@ -217,13 +233,19 @@ impl OAuth2Client {
None
};
+ // OIDC nonce: echoed back inside the id_token and verified on
+ // callback, binding the token to this authorization request.
+ let nonce = uuid::Uuid::new_v4().to_string();
+
// Create and store state
let state = OAuth2State::new(
provider_config.provider_id.clone(),
provider_config.redirect_uri.clone(),
pkce.as_ref().map(|p| p.code_verifier.clone()),
self.state_ttl_seconds,
- );
+ )
+ .with_nonce(nonce.clone())
+ .with_binding(binding);
let state_param = state.state.clone();
self.state_store.store(state).await?;
@@ -234,6 +256,7 @@ impl OAuth2Client {
params.append_pair("client_id", &provider_config.client_id);
params.append_pair("redirect_uri", &provider_config.redirect_uri);
params.append_pair("state", &state_param);
+ params.append_pair("nonce", &nonce);
// Add scopes
if !provider_config.scopes.is_empty() {
@@ -280,6 +303,12 @@ impl OAuth2Client {
return Err(OAuth2Error::InvalidState);
}
+ // When the flow was bound to a browser session, the callback must
+ // present the identical binding value (login-CSRF guard).
+ if state.binding.is_some() && state.binding != callback_response.binding {
+ return Err(OAuth2Error::InvalidState);
+ }
+
// Check for errors in callback
if let Some(error) = &callback_response.error {
let error_desc = callback_response
@@ -301,6 +330,14 @@ impl OAuth2Client {
)
.await?;
+ // Validate id_token claims when the provider returned one. The token
+ // arrived directly from the token endpoint over TLS, which OIDC Core
+ // §3.1.3.7 permits in place of signature validation for the code
+ // flow — but iss / aud / exp / nonce are still mandatory checks.
+ if let Some(id_token) = &token_response.id_token {
+ validate_id_token_claims(provider_config, id_token, state.nonce.as_deref())?;
+ }
+
Ok(token_response)
}
@@ -350,6 +387,79 @@ impl OAuth2Client {
}
}
+/// Claims checked on an id_token returned by the token endpoint.
+#[derive(serde::Deserialize)]
+struct IdTokenClaims {
+ iss: Option,
+ aud: Option,
+ exp: Option,
+ nonce: Option,
+}
+
+fn decode_id_token_claims(id_token: &str) -> OAuth2Result {
+ let payload = id_token
+ .split('.')
+ .nth(1)
+ .ok_or_else(|| OAuth2Error::InvalidIdToken("malformed JWT".to_string()))?;
+ let bytes = URL_SAFE_NO_PAD
+ .decode(payload)
+ .map_err(|_| OAuth2Error::InvalidIdToken("invalid base64 payload".to_string()))?;
+ serde_json::from_slice(&bytes)
+ .map_err(|_| OAuth2Error::InvalidIdToken("invalid JSON payload".to_string()))
+}
+
+/// Validate the mandatory id_token claims: issuer (when configured),
+/// audience, expiry, and the nonce echoed from the authorization request.
+///
+/// The signature is not verified: the token was received directly from the
+/// token endpoint over TLS, which OIDC Core §3.1.3.7 permits as a substitute
+/// for signature validation in the authorization-code flow.
+pub(crate) fn validate_id_token_claims(
+ provider_config: &OAuth2ProviderConfig,
+ id_token: &str,
+ expected_nonce: Option<&str>,
+) -> OAuth2Result<()> {
+ let claims = decode_id_token_claims(id_token)?;
+
+ if let Some(expected_issuer) = &provider_config.issuer
+ && claims.iss.as_deref() != Some(expected_issuer.as_str())
+ {
+ return Err(OAuth2Error::InvalidIdToken(format!(
+ "issuer mismatch: expected {expected_issuer}"
+ )));
+ }
+
+ let audience_matches = match &claims.aud {
+ Some(serde_json::Value::String(aud)) => aud == &provider_config.client_id,
+ Some(serde_json::Value::Array(auds)) => auds
+ .iter()
+ .any(|aud| aud.as_str() == Some(provider_config.client_id.as_str())),
+ _ => false,
+ };
+ if !audience_matches {
+ return Err(OAuth2Error::InvalidIdToken(
+ "audience does not include this client".to_string(),
+ ));
+ }
+
+ match claims.exp {
+ Some(exp) if exp > chrono::Utc::now().timestamp() => {}
+ _ => {
+ return Err(OAuth2Error::InvalidIdToken(
+ "token expired or missing exp".to_string(),
+ ));
+ }
+ }
+
+ if let Some(expected) = expected_nonce
+ && claims.nonce.as_deref() != Some(expected)
+ {
+ return Err(OAuth2Error::InvalidIdToken("nonce mismatch".to_string()));
+ }
+
+ Ok(())
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -436,6 +546,7 @@ mod tests {
authorization_endpoint: "https://example.com/auth".to_string(),
token_endpoint: "https://example.com/token".to_string(),
userinfo_endpoint: Some("https://example.com/userinfo".to_string()),
+ issuer: None,
redirect_uri: "http://localhost:3000/callback".to_string(),
scopes: vec!["openid".to_string(), "email".to_string()],
auth_params: HashMap::new(),
@@ -554,6 +665,7 @@ mod tests {
state,
error: None,
error_description: None,
+ binding: None,
},
)
.await
@@ -583,6 +695,7 @@ mod tests {
state,
error: Some("access_denied".to_string()),
error_description: Some("user denied consent".to_string()),
+ binding: None,
},
)
.await
@@ -618,6 +731,7 @@ mod tests {
state,
error: None,
error_description: None,
+ binding: None,
},
)
.await
@@ -678,4 +792,140 @@ mod tests {
)]
);
}
+
+ fn fake_id_token(payload: serde_json::Value) -> String {
+ let header = URL_SAFE_NO_PAD.encode(br#"{"alg":"RS256","typ":"JWT"}"#);
+ let payload = URL_SAFE_NO_PAD.encode(serde_json::to_vec(&payload).unwrap());
+ format!("{header}.{payload}.signature")
+ }
+
+ #[tokio::test]
+ async fn authorization_url_includes_nonce_and_stores_it() {
+ let state_store = Arc::new(InMemoryStateStore::new());
+ let client = OAuth2Client::new(state_store.clone(), 600, 30);
+
+ let (auth_url, state) = client
+ .generate_authorization_url(&provider_config(), HashMap::new())
+ .await
+ .unwrap();
+
+ let url = Url::parse(&auth_url).unwrap();
+ let params: HashMap<_, _> = url.query_pairs().collect();
+ let url_nonce = params.get("nonce").expect("nonce in URL").to_string();
+ assert!(!url_nonce.is_empty());
+
+ let stored = state_store.retrieve(&state).await.unwrap();
+ assert_eq!(stored.nonce.as_deref(), Some(url_nonce.as_str()));
+ }
+
+ #[test]
+ fn id_token_claim_validation_covers_iss_aud_exp_and_nonce() {
+ let mut config = provider_config();
+ config.issuer = Some("https://issuer.test".to_string());
+ let exp = chrono::Utc::now().timestamp() + 600;
+
+ let good = fake_id_token(serde_json::json!({
+ "iss": "https://issuer.test",
+ "aud": "test_client_id",
+ "exp": exp,
+ "nonce": "nonce-1",
+ }));
+ assert!(validate_id_token_claims(&config, &good, Some("nonce-1")).is_ok());
+
+ // aud may be an array containing this client
+ let aud_array = fake_id_token(serde_json::json!({
+ "iss": "https://issuer.test",
+ "aud": ["other", "test_client_id"],
+ "exp": exp,
+ }));
+ assert!(validate_id_token_claims(&config, &aud_array, None).is_ok());
+
+ let bad_iss = fake_id_token(serde_json::json!({
+ "iss": "https://evil.test", "aud": "test_client_id", "exp": exp,
+ }));
+ assert!(matches!(
+ validate_id_token_claims(&config, &bad_iss, None),
+ Err(OAuth2Error::InvalidIdToken(_))
+ ));
+
+ let bad_aud = fake_id_token(serde_json::json!({
+ "iss": "https://issuer.test", "aud": "someone_else", "exp": exp,
+ }));
+ assert!(validate_id_token_claims(&config, &bad_aud, None).is_err());
+
+ let expired = fake_id_token(serde_json::json!({
+ "iss": "https://issuer.test",
+ "aud": "test_client_id",
+ "exp": chrono::Utc::now().timestamp() - 10,
+ }));
+ assert!(validate_id_token_claims(&config, &expired, None).is_err());
+
+ let wrong_nonce = fake_id_token(serde_json::json!({
+ "iss": "https://issuer.test",
+ "aud": "test_client_id",
+ "exp": exp,
+ "nonce": "other",
+ }));
+ assert!(validate_id_token_claims(&config, &wrong_nonce, Some("nonce-1")).is_err());
+
+ assert!(validate_id_token_claims(&config, "garbage", None).is_err());
+ }
+
+ #[tokio::test]
+ async fn handle_callback_enforces_session_binding() {
+ let state_store = Arc::new(InMemoryStateStore::new());
+ let transport = Arc::new(RecordingTransport::new());
+ let client = client_with_transport(state_store.clone(), transport.clone());
+ let config = provider_config();
+
+ // A callback missing the binding is rejected before any token
+ // exchange (and the one-use state is burned).
+ let (_, state) = client
+ .generate_authorization_url_bound(
+ &config,
+ HashMap::new(),
+ Some("cookie-123".to_string()),
+ )
+ .await
+ .unwrap();
+ let err = client
+ .handle_callback(
+ &config,
+ AuthorizationResponse {
+ code: "code".to_string(),
+ state,
+ error: None,
+ error_description: None,
+ binding: None,
+ },
+ )
+ .await
+ .unwrap_err();
+ assert!(matches!(err, OAuth2Error::InvalidState));
+ assert!(transport.token_requests().is_empty());
+
+ // The matching binding completes the flow.
+ let (_, state) = client
+ .generate_authorization_url_bound(
+ &config,
+ HashMap::new(),
+ Some("cookie-123".to_string()),
+ )
+ .await
+ .unwrap();
+ client
+ .handle_callback(
+ &config,
+ AuthorizationResponse {
+ code: "code".to_string(),
+ state,
+ error: None,
+ error_description: None,
+ binding: Some("cookie-123".to_string()),
+ },
+ )
+ .await
+ .expect("bound callback succeeds");
+ assert_eq!(transport.token_requests().len(), 1);
+ }
}
diff --git a/crates/identity/ras-identity-oauth2/src/config.rs b/crates/identity/ras-identity-oauth2/src/config.rs
index 01f8a35..982245f 100644
--- a/crates/identity/ras-identity-oauth2/src/config.rs
+++ b/crates/identity/ras-identity-oauth2/src/config.rs
@@ -12,6 +12,11 @@ pub struct OAuth2ProviderConfig {
pub authorization_endpoint: String,
pub token_endpoint: String,
pub userinfo_endpoint: Option,
+ /// Expected `iss` claim of id_tokens returned by this provider
+ /// (e.g. `https://accounts.google.com`). When set, callbacks carrying
+ /// an id_token with a different issuer are rejected.
+ #[serde(default)]
+ pub issuer: Option,
pub redirect_uri: String,
pub scopes: Vec,
/// Additional parameters to include in authorization request
@@ -93,6 +98,7 @@ mod tests {
authorization_endpoint: "https://x/auth".into(),
token_endpoint: "https://x/token".into(),
userinfo_endpoint: Some("https://x/info".into()),
+ issuer: None,
redirect_uri: "https://app/cb".into(),
scopes: vec!["openid".into(), "email".into()],
auth_params: HashMap::new(),
diff --git a/crates/identity/ras-identity-oauth2/src/error.rs b/crates/identity/ras-identity-oauth2/src/error.rs
index b9e293b..1c4fa4c 100644
--- a/crates/identity/ras-identity-oauth2/src/error.rs
+++ b/crates/identity/ras-identity-oauth2/src/error.rs
@@ -24,6 +24,12 @@ pub enum OAuth2Error {
#[error("Token exchange failed: {0}")]
TokenExchangeFailed(String),
+ #[error("Invalid id_token: {0}")]
+ InvalidIdToken(String),
+
+ #[error("Too many pending OAuth2 flows")]
+ TooManyPendingFlows,
+
#[error("User info request failed: {0}")]
UserInfoFailed(String),
diff --git a/crates/identity/ras-identity-oauth2/src/provider.rs b/crates/identity/ras-identity-oauth2/src/provider.rs
index 596b983..a056c8d 100644
--- a/crates/identity/ras-identity-oauth2/src/provider.rs
+++ b/crates/identity/ras-identity-oauth2/src/provider.rs
@@ -28,6 +28,10 @@ pub enum OAuth2AuthPayload {
state: String,
error: Option,
error_description: Option,
+ /// Session-binding value captured when the flow was started (e.g.
+ /// from a cookie); required when the flow was started with one.
+ #[serde(default)]
+ binding: Option,
},
}
@@ -104,18 +108,38 @@ impl OAuth2Provider {
})
}
- /// Handle the start flow request
- async fn handle_start_flow(
+ /// Start an OAuth2 authorization flow.
+ ///
+ /// Returns the authorization URL to redirect the user to, plus the
+ /// `state` parameter bound to this flow. This is the supported way to
+ /// initiate a flow; `verify()` only completes one (the `Callback`
+ /// payload).
+ pub async fn start_flow(
&self,
provider_id: &str,
additional_params: Option>,
+ ) -> OAuth2Result {
+ self.start_flow_bound(provider_id, additional_params, None)
+ .await
+ }
+
+ /// Start a flow bound to the initiating browser session.
+ ///
+ /// `binding` should be an unguessable value the integrator can recover on
+ /// callback (e.g. a random cookie value); the callback payload must then
+ /// carry the identical value or it is rejected, preventing login CSRF.
+ pub async fn start_flow_bound(
+ &self,
+ provider_id: &str,
+ additional_params: Option>,
+ binding: Option,
) -> OAuth2Result {
let provider_config = self.get_provider_config(provider_id)?;
let params = additional_params.unwrap_or_default();
let (auth_url, state) = self
.client
- .generate_authorization_url(provider_config, params)
+ .generate_authorization_url_bound(provider_config, params, binding)
.await?;
info!("Started OAuth2 flow for provider: {}", provider_id);
@@ -134,6 +158,7 @@ impl OAuth2Provider {
state: String,
error: Option,
error_description: Option,
+ binding: Option,
) -> OAuth2Result {
let provider_config = self.get_provider_config(provider_id)?;
@@ -142,6 +167,7 @@ impl OAuth2Provider {
state,
error,
error_description,
+ binding,
};
// Exchange code for tokens
@@ -263,21 +289,11 @@ impl IdentityProvider for OAuth2Provider {
serde_json::from_value(auth_payload).map_err(|_| IdentityError::InvalidPayload)?;
match payload {
- OAuth2AuthPayload::StartFlow {
- provider_id,
- additional_params,
- } => {
- // For start flow, we return an error with the authorization URL
- let response = self
- .handle_start_flow(&provider_id, additional_params)
- .await
- .map_err(|e| IdentityError::ProviderError(e.to_string()))?;
-
- // Return the response as a provider error (client should handle this specially)
- let response_json =
- serde_json::to_string(&response).map_err(IdentityError::SerializationError)?;
-
- Err(IdentityError::ProviderError(response_json))
+ OAuth2AuthPayload::StartFlow { .. } => {
+ // Flow initiation is not identity verification and has no
+ // identity to return. Call `OAuth2Provider::start_flow`
+ // directly to obtain the authorization URL.
+ Err(IdentityError::UnsupportedMethod)
}
OAuth2AuthPayload::Callback {
provider_id,
@@ -285,9 +301,10 @@ impl IdentityProvider for OAuth2Provider {
state,
error,
error_description,
+ binding,
} => {
// For callback, we complete the flow and return the verified identity
- self.handle_callback(&provider_id, code, state, error, error_description)
+ self.handle_callback(&provider_id, code, state, error, error_description, binding)
.await
.map_err(|e| IdentityError::ProviderError(e.to_string()))
}
@@ -309,6 +326,7 @@ mod tests {
authorization_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
token_endpoint: "https://oauth2.googleapis.com/token".to_string(),
userinfo_endpoint: Some("https://www.googleapis.com/oauth2/v1/userinfo".to_string()),
+ issuer: None,
redirect_uri: "http://localhost:3000/callback".to_string(),
scopes: vec![
"openid".to_string(),
@@ -334,31 +352,27 @@ mod tests {
async fn test_start_flow() {
let provider = create_test_provider();
+ let result = provider.start_flow("google", None).await.unwrap();
+ match result {
+ OAuth2Response::AuthorizationUrl { url, state } => {
+ assert!(url.contains("https://accounts.google.com/o/oauth2/v2/auth"));
+ assert!(url.contains("response_type=code"));
+ assert!(url.contains("client_id=test_client_id"));
+ assert!(!state.is_empty());
+ }
+ _ => panic!("Expected AuthorizationUrl response"),
+ }
+
+ // StartFlow payloads are no longer routed through verify()
let payload = serde_json::json!({
"type": "StartFlow",
"provider_id": "google",
"additional_params": null
});
-
- let result = provider.verify(payload).await;
-
- // Start flow returns an error with the authorization URL
- assert!(result.is_err());
-
- if let Err(IdentityError::ProviderError(response_json)) = result {
- let response: OAuth2Response = serde_json::from_str(&response_json).unwrap();
- match response {
- OAuth2Response::AuthorizationUrl { url, state } => {
- assert!(url.contains("https://accounts.google.com/o/oauth2/v2/auth"));
- assert!(url.contains("response_type=code"));
- assert!(url.contains("client_id=test_client_id"));
- assert!(!state.is_empty());
- }
- _ => panic!("Expected AuthorizationUrl response"),
- }
- } else {
- panic!("Expected ProviderError");
- }
+ assert!(matches!(
+ provider.verify(payload).await,
+ Err(IdentityError::UnsupportedMethod)
+ ));
}
#[tokio::test]
@@ -379,17 +393,16 @@ mod tests {
async fn verify_reports_unknown_provider() {
let provider = create_test_provider();
- let result = provider
- .verify(serde_json::json!({
- "type": "StartFlow",
- "provider_id": "missing"
- }))
- .await;
+ let result = provider.start_flow("missing", None).await;
- let Err(IdentityError::ProviderError(message)) = result else {
- panic!("expected provider error for missing provider");
+ let Err(error) = result else {
+ panic!("expected error for missing provider");
};
- assert!(message.contains("Provider 'missing' not configured"));
+ assert!(
+ error
+ .to_string()
+ .contains("Provider 'missing' not configured")
+ );
}
#[tokio::test]
@@ -398,20 +411,13 @@ mod tests {
let mut provider = OAuth2Provider::new(OAuth2Config::default(), state_store);
provider.add_provider(google_config());
- let result = provider
- .verify(serde_json::json!({
- "type": "StartFlow",
- "provider_id": "google",
- "additional_params": {
- "prompt": "consent"
- }
- }))
- .await;
+ let mut params = HashMap::new();
+ params.insert("prompt".to_string(), "consent".to_string());
+ let response = provider
+ .start_flow("google", Some(params))
+ .await
+ .expect("start_flow succeeds");
- let Err(IdentityError::ProviderError(response_json)) = result else {
- panic!("expected authorization URL response encoded as provider error");
- };
- let response: OAuth2Response = serde_json::from_str(&response_json).unwrap();
let OAuth2Response::AuthorizationUrl { url, state } = response else {
panic!("expected authorization URL response");
};
diff --git a/crates/identity/ras-identity-oauth2/src/state.rs b/crates/identity/ras-identity-oauth2/src/state.rs
index a0f6362..7cc9c88 100644
--- a/crates/identity/ras-identity-oauth2/src/state.rs
+++ b/crates/identity/ras-identity-oauth2/src/state.rs
@@ -16,6 +16,13 @@ pub struct OAuth2State {
pub provider_id: String,
pub redirect_uri: String,
pub code_verifier: Option,
+ /// OIDC nonce sent in the authorization request; the id_token returned
+ /// on callback must echo it.
+ pub nonce: Option,
+ /// Optional caller-supplied value binding this flow to the browser
+ /// session that started it (e.g. a random cookie value). When set, the
+ /// callback must present the identical value, preventing login CSRF.
+ pub binding: Option,
pub created_at: DateTime,
pub expires_at: DateTime,
pub metadata: Option,
@@ -37,12 +44,26 @@ impl OAuth2State {
provider_id,
redirect_uri,
code_verifier,
+ nonce: None,
+ binding: None,
created_at,
expires_at,
metadata: None,
}
}
+ /// Attach an OIDC nonce to the flow.
+ pub fn with_nonce(mut self, nonce: String) -> Self {
+ self.nonce = Some(nonce);
+ self
+ }
+
+ /// Bind the flow to the initiating browser session (login-CSRF guard).
+ pub fn with_binding(mut self, binding: Option) -> Self {
+ self.binding = binding;
+ self
+ }
+
pub fn is_expired(&self) -> bool {
Utc::now() > self.expires_at
}
@@ -62,14 +83,26 @@ pub trait OAuth2StateStore: Send + Sync {
}
/// In-memory implementation of OAuth2StateStore
+/// Default cap on concurrently pending flows held in memory.
+const DEFAULT_MAX_PENDING_STATES: usize = 10_000;
+
pub struct InMemoryStateStore {
states: Arc>>,
+ max_states: usize,
}
impl InMemoryStateStore {
pub fn new() -> Self {
+ Self::with_capacity(DEFAULT_MAX_PENDING_STATES)
+ }
+
+ /// Create a store holding at most `max_states` pending flows. Expired
+ /// states are pruned opportunistically on every `store`, so no external
+ /// cleanup scheduling is required; `store` fails once the cap is hit.
+ pub fn with_capacity(max_states: usize) -> Self {
Self {
states: Arc::new(RwLock::new(HashMap::new())),
+ max_states,
}
}
}
@@ -84,6 +117,16 @@ impl Default for InMemoryStateStore {
impl OAuth2StateStore for InMemoryStateStore {
async fn store(&self, state: OAuth2State) -> OAuth2Result<()> {
let mut states = self.states.write().await;
+
+ // Opportunistic pruning: abandoned flows must not accumulate just
+ // because nobody schedules cleanup_expired.
+ let now = Utc::now();
+ states.retain(|_, stored| now <= stored.expires_at);
+
+ if states.len() >= self.max_states {
+ return Err(OAuth2Error::TooManyPendingFlows);
+ }
+
states.insert(state.state.clone(), state);
Ok(())
}
@@ -178,4 +221,57 @@ mod tests {
let result = store.retrieve(&state.state).await;
assert!(matches!(result, Err(OAuth2Error::StateNotFound)));
}
+
+ #[tokio::test]
+ async fn store_rejects_when_capacity_reached() {
+ let store = InMemoryStateStore::with_capacity(2);
+ for _ in 0..2 {
+ store
+ .store(OAuth2State::new(
+ "google".to_string(),
+ "http://localhost/cb".to_string(),
+ None,
+ 300,
+ ))
+ .await
+ .unwrap();
+ }
+
+ let err = store
+ .store(OAuth2State::new(
+ "google".to_string(),
+ "http://localhost/cb".to_string(),
+ None,
+ 300,
+ ))
+ .await
+ .unwrap_err();
+ assert!(matches!(err, OAuth2Error::TooManyPendingFlows));
+ }
+
+ #[tokio::test]
+ async fn store_prunes_expired_states_opportunistically() {
+ let store = InMemoryStateStore::with_capacity(1);
+
+ // An already-expired flow occupies the only slot...
+ let mut expired = OAuth2State::new(
+ "google".to_string(),
+ "http://localhost/cb".to_string(),
+ None,
+ 300,
+ );
+ expired.expires_at = Utc::now() - Duration::seconds(10);
+ store.store(expired).await.unwrap();
+
+ // ...but is pruned when the next flow is stored.
+ store
+ .store(OAuth2State::new(
+ "google".to_string(),
+ "http://localhost/cb".to_string(),
+ None,
+ 300,
+ ))
+ .await
+ .expect("expired state must be pruned to make room");
+ }
}
diff --git a/crates/identity/ras-identity-oauth2/src/tests.rs b/crates/identity/ras-identity-oauth2/src/tests.rs
index 487934f..adffb5f 100644
--- a/crates/identity/ras-identity-oauth2/src/tests.rs
+++ b/crates/identity/ras-identity-oauth2/src/tests.rs
@@ -85,6 +85,7 @@ mod integration_tests {
authorization_endpoint: "http://oauth.test/authorize".to_string(),
token_endpoint: "http://oauth.test/token".to_string(),
userinfo_endpoint: Some("http://oauth.test/userinfo".to_string()),
+ issuer: None,
redirect_uri: "http://localhost:3000/callback".to_string(),
scopes: vec!["openid".to_string(), "email".to_string()],
auth_params: HashMap::new(),
@@ -220,30 +221,27 @@ mod integration_tests {
let state_store = Arc::new(InMemoryStateStore::new());
let provider = provider_with_server(provider_config, state_store, server);
- // Start OAuth2 flow
+ // Start OAuth2 flow via the typed API
+ let start_result = provider.start_flow("mock_provider", None).await.unwrap();
+ let auth_url = match start_result {
+ OAuth2Response::AuthorizationUrl { url, state } => {
+ assert!(url.contains("/authorize"));
+ assert!(url.contains("response_type=code"));
+ assert!(url.contains("code_challenge"));
+ state
+ }
+ _ => panic!("Expected authorization URL"),
+ };
+
+ // StartFlow payloads are no longer routed through verify()
let start_payload = serde_json::json!({
"type": "StartFlow",
"provider_id": "mock_provider"
});
-
- let start_result = provider.verify(start_payload).await;
- assert!(start_result.is_err());
-
- let auth_url =
- if let Err(ras_identity_core::IdentityError::ProviderError(json)) = start_result {
- let response: OAuth2Response = serde_json::from_str(&json).unwrap();
- match response {
- OAuth2Response::AuthorizationUrl { url, state } => {
- assert!(url.contains("/authorize"));
- assert!(url.contains("response_type=code"));
- assert!(url.contains("code_challenge"));
- state
- }
- _ => panic!("Expected authorization URL"),
- }
- } else {
- panic!("Expected provider error with auth URL");
- };
+ assert!(matches!(
+ provider.verify(start_payload).await,
+ Err(ras_identity_core::IdentityError::UnsupportedMethod)
+ ));
// Simulate callback
let callback_payload = serde_json::json!({
@@ -270,12 +268,7 @@ mod integration_tests {
let provider = OAuth2Provider::new(config, state_store);
// Test invalid provider
- let payload = serde_json::json!({
- "type": "StartFlow",
- "provider_id": "nonexistent"
- });
-
- let result = provider.verify(payload).await;
+ let result = provider.start_flow("nonexistent", None).await;
assert!(result.is_err());
// Test callback with invalid state
@@ -323,6 +316,7 @@ mod integration_tests {
authorization_endpoint: "https://example.com/auth".to_string(),
token_endpoint: "https://example.com/token".to_string(),
userinfo_endpoint: None,
+ issuer: None,
redirect_uri: "http://localhost:3000/callback".to_string(),
scopes: vec![],
auth_params: HashMap::new(),
@@ -334,15 +328,8 @@ mod integration_tests {
let provider = OAuth2Provider::new(config, state_store.clone());
// Generate two authorization URLs
- let payload1 = serde_json::json!({"type": "StartFlow", "provider_id": "test"});
- let payload2 = serde_json::json!({"type": "StartFlow", "provider_id": "test"});
-
- let result1 = provider.verify(payload1).await;
- let result2 = provider.verify(payload2).await;
-
- // Extract states
- let state1 = extract_state_from_error(result1);
- let state2 = extract_state_from_error(result2);
+ let state1 = extract_state(provider.start_flow("test", None).await);
+ let state2 = extract_state(provider.start_flow("test", None).await);
// States should be unique
assert_ne!(state1, state2);
@@ -352,17 +339,10 @@ mod integration_tests {
assert_eq!(state2.len(), 36);
}
- fn extract_state_from_error(
- result: Result,
- ) -> String {
- if let Err(ras_identity_core::IdentityError::ProviderError(json)) = result {
- let response: OAuth2Response = serde_json::from_str(&json).unwrap();
- match response {
- OAuth2Response::AuthorizationUrl { state, .. } => state,
- _ => panic!("Expected authorization URL"),
- }
- } else {
- panic!("Expected provider error");
+ fn extract_state(result: crate::OAuth2Result) -> String {
+ match result.expect("start_flow succeeds") {
+ OAuth2Response::AuthorizationUrl { state, .. } => state,
+ _ => panic!("Expected authorization URL"),
}
}
@@ -380,6 +360,7 @@ mod integration_tests {
authorization_endpoint: "https://example.com/auth".to_string(),
token_endpoint: "https://example.com/token".to_string(),
userinfo_endpoint: None,
+ issuer: None,
redirect_uri: "http://localhost:3000/callback".to_string(),
scopes: vec![],
auth_params: HashMap::new(),
@@ -433,6 +414,7 @@ mod integration_tests {
state: state.state,
error: None,
error_description: None,
+ binding: None,
};
let result = client.handle_callback(&provider_config, callback).await;
@@ -460,6 +442,7 @@ mod integration_tests {
state: state2.state,
error: None,
error_description: None,
+ binding: None,
};
let result = client.handle_callback(&provider_config, callback2).await;
diff --git a/crates/identity/ras-identity-oauth2/src/types.rs b/crates/identity/ras-identity-oauth2/src/types.rs
index e20e240..375dec0 100644
--- a/crates/identity/ras-identity-oauth2/src/types.rs
+++ b/crates/identity/ras-identity-oauth2/src/types.rs
@@ -21,6 +21,11 @@ pub struct AuthorizationResponse {
pub state: String,
pub error: Option,
pub error_description: Option,
+ /// Session-binding value captured by the integrator when the flow was
+ /// started (e.g. from a cookie). Must match the value given to
+ /// `start_flow` for the same state, when one was supplied.
+ #[serde(default)]
+ pub binding: Option,
}
/// OAuth2 token response
diff --git a/crates/identity/ras-identity-session/src/lib.rs b/crates/identity/ras-identity-session/src/lib.rs
index df42077..e883f65 100644
--- a/crates/identity/ras-identity-session/src/lib.rs
+++ b/crates/identity/ras-identity-session/src/lib.rs
@@ -382,6 +382,36 @@ impl SessionService {
Ok(claims)
}
+ /// Number of sessions currently held in the in-memory store
+ /// (only populated when `enforce_active_sessions` is on).
+ pub async fn active_session_count(&self) -> usize {
+ self.active_sessions.read().await.len()
+ }
+
+ /// Spawn a background task pruning expired sessions every `interval`.
+ ///
+ /// Expired sessions are otherwise only pruned opportunistically when
+ /// begin_session/verify_session run, so a traffic lull leaves them in
+ /// memory indefinitely. The task holds only a weak reference and stops
+ /// when the service is dropped (or when the returned handle is aborted).
+ pub fn start_cleanup_task(
+ self: &std::sync::Arc,
+ interval: std::time::Duration,
+ ) -> tokio::task::JoinHandle<()> {
+ let service = std::sync::Arc::downgrade(self);
+ tokio::spawn(async move {
+ let mut timer = tokio::time::interval(interval);
+ timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
+ loop {
+ timer.tick().await;
+ let Some(service) = service.upgrade() else {
+ break;
+ };
+ service.cleanup_expired_sessions().await;
+ }
+ })
+ }
+
pub async fn end_session(&self, jti: &str) -> Option {
let mut sessions = self.active_sessions.write().await;
sessions.remove(jti)
@@ -682,4 +712,41 @@ mod tests {
assert!(user.permissions.contains("chat:read"));
assert!(user.metadata.is_none());
}
+
+ #[tokio::test]
+ async fn cleanup_task_prunes_expired_sessions_in_background() {
+ let config = SessionConfig::new(TEST_SECRET).unwrap();
+ let service = std::sync::Arc::new(SessionService::new(config).unwrap());
+
+ // Plant an already-expired session directly in the store.
+ let now = chrono::Utc::now().timestamp();
+ service.active_sessions.write().await.insert(
+ "expired-jti".to_string(),
+ JwtClaims {
+ sub: "alice".to_string(),
+ exp: now - 10,
+ iat: now - 20,
+ jti: "expired-jti".to_string(),
+ provider_id: "local".to_string(),
+ email: None,
+ display_name: None,
+ permissions: HashSet::new(),
+ metadata: None,
+ },
+ );
+ assert_eq!(service.active_session_count().await, 1);
+
+ let handle = service.start_cleanup_task(std::time::Duration::from_millis(20));
+
+ // The sweeper prunes the expired session without any begin/verify call.
+ tokio::time::timeout(std::time::Duration::from_secs(5), async {
+ while service.active_session_count().await != 0 {
+ tokio::time::sleep(std::time::Duration::from_millis(10)).await;
+ }
+ })
+ .await
+ .expect("cleanup task prunes expired sessions");
+
+ handle.abort();
+ }
}
diff --git a/crates/rest/ras-file-macro/src/client.rs b/crates/rest/ras-file-macro/src/client.rs
index d4d6ff5..873045c 100644
--- a/crates/rest/ras-file-macro/src/client.rs
+++ b/crates/rest/ras-file-macro/src/client.rs
@@ -17,8 +17,17 @@ pub fn generate_client(definition: &FileServiceDefinition) -> TokenStream {
.endpoints
.iter()
.map(|endpoint| generate_client_method(definition, endpoint, &base_path));
- let build_method = if cfg!(feature = "reqwest") {
+ // With `feature_gated: true` the convenience constructor is gated on the
+ // CONSUMER crate's `reqwest` feature instead of the macro crate's
+ // (workspace-unified) one.
+ let cfg_reqwest = if definition.feature_gated {
+ quote! { #[cfg(feature = "reqwest")] }
+ } else {
+ quote! {}
+ };
+ let build_method = if definition.feature_gated || cfg!(feature = "reqwest") {
quote! {
+ #cfg_reqwest
pub fn build(
self,
) -> Result<#client_name, Box> {
diff --git a/crates/rest/ras-file-macro/src/lib.rs b/crates/rest/ras-file-macro/src/lib.rs
index 19f2968..ef5d9a4 100644
--- a/crates/rest/ras-file-macro/src/lib.rs
+++ b/crates/rest/ras-file-macro/src/lib.rs
@@ -37,37 +37,60 @@ pub fn file_service(input: TokenStream) -> TokenStream {
quote! {}
};
- let server_output = if cfg!(feature = "server") {
+ // With `feature_gated: true` the generated code is wrapped in
+ // `#[cfg(feature = ...)]` attributes resolved against the CONSUMER
+ // crate's features, immune to workspace feature unification of the
+ // macro crate's own features (which `cfg!` evaluates).
+ let feature_gated = definition.feature_gated;
+ let cfg_server = if feature_gated {
+ quote! { #[cfg(feature = "server")] }
+ } else {
+ quote! {}
+ };
+ let cfg_client = if feature_gated {
+ quote! { #[cfg(feature = "client")] }
+ } else {
+ quote! {}
+ };
+
+ let server_output = if feature_gated || cfg!(feature = "server") {
quote! {
+ #cfg_server
mod #server_mod {
use super::*;
#server_code
}
+ #cfg_server
pub use #server_mod::*;
+ #cfg_server
const _: () = {
#schema_checks
};
+ #cfg_server
mod #openapi_mod {
use super::*;
#openapi_code
}
+ #cfg_server
pub use #openapi_mod::*;
}
} else {
quote! {}
};
- let client_output = if cfg!(feature = "client") {
+ let client_output = if feature_gated || cfg!(feature = "client") {
quote! {
+ #cfg_client
mod #client_mod {
use super::*;
#client_code
}
+ #cfg_client
pub use #client_mod::*;
}
} else {
diff --git a/crates/rest/ras-file-macro/src/parser.rs b/crates/rest/ras-file-macro/src/parser.rs
index 4d1c712..5b06482 100644
--- a/crates/rest/ras-file-macro/src/parser.rs
+++ b/crates/rest/ras-file-macro/src/parser.rs
@@ -9,6 +9,7 @@ pub struct FileServiceDefinition {
pub service_name: Ident,
pub base_path: LitStr,
pub openapi: Option,
+ pub feature_gated: bool,
pub endpoints: Vec,
}
@@ -103,6 +104,7 @@ impl Parse for FileServiceDefinition {
let mut service_name = None;
let mut base_path = None;
let mut openapi = None;
+ let mut feature_gated = false;
let mut endpoints = Vec::new();
while !content.is_empty() {
@@ -112,6 +114,10 @@ impl Parse for FileServiceDefinition {
match field_name.to_string().as_str() {
"service_name" => service_name = Some(content.parse()?),
"base_path" => base_path = Some(content.parse()?),
+ "feature_gated" => {
+ let enabled = content.parse::()?;
+ feature_gated = enabled.value();
+ }
"body_limit" => {
return Err(Error::new(
field_name.span(),
@@ -183,6 +189,7 @@ impl Parse for FileServiceDefinition {
.ok_or_else(|| Error::new(input.span(), "Missing service_name"))?,
base_path: base_path.ok_or_else(|| Error::new(input.span(), "Missing base_path"))?,
openapi,
+ feature_gated,
endpoints,
})
}
diff --git a/crates/rest/ras-file-macro/src/server.rs b/crates/rest/ras-file-macro/src/server.rs
index 21c1ac9..b16c977 100644
--- a/crates/rest/ras-file-macro/src/server.rs
+++ b/crates/rest/ras-file-macro/src/server.rs
@@ -929,12 +929,12 @@ fn generate_permission_check(auth: &AuthRequirement) -> TokenStream {
});
quote! {
+ // OR-of-AND permission check (shared ras-auth-core implementation).
+ // A group list with no non-empty groups means "any authenticated
+ // user", consistent with the REST and JSON-RPC macros.
let required_permission_groups: Vec> = vec![#(#groups),*];
let authenticated_user = user.as_ref().expect("authenticated user exists after auth check");
- let has_permission = required_permission_groups.iter().any(|group| {
- group.is_empty() || auth_provider.check_permissions(authenticated_user, group).is_ok()
- });
- if !has_permission {
+ if ::ras_auth_core::check_permission_groups(auth_provider.as_ref(), authenticated_user, &required_permission_groups).is_err() {
return __ras_file_error_response(::ras_file_core::FileError::Forbidden);
}
}
diff --git a/crates/rest/ras-rest-macro/src/client.rs b/crates/rest/ras-rest-macro/src/client.rs
index 5fa9803..88adee8 100644
--- a/crates/rest/ras-rest-macro/src/client.rs
+++ b/crates/rest/ras-rest-macro/src/client.rs
@@ -57,13 +57,23 @@ pub fn generate_client_code(service_def: &ServiceDefinition) -> proc_macro2::Tok
.iter()
.flat_map(generate_client_methods_with_timeout_for_endpoint);
- let build_method = if cfg!(feature = "reqwest") {
+ // With `feature_gated: true` the convenience constructor compiles only
+ // when the CONSUMER crate enables its own `reqwest` feature (which should
+ // activate `ras-transport-core/reqwest`); otherwise the macro crate's
+ // (workspace-unified) feature decides whether to emit it at all.
+ let cfg_reqwest = if service_def.feature_gated {
+ quote! { #[cfg(feature = "reqwest")] }
+ } else {
+ quote! {}
+ };
+ let build_method = if service_def.feature_gated || cfg!(feature = "reqwest") {
quote! {
/// Build the client using the default `ReqwestTransport`.
///
/// # Errors
///
/// Returns an error if the underlying transport fails to construct.
+ #cfg_reqwest
pub fn build(self) -> Result<#client_name, Box> {
let transport = std::sync::Arc::new(::ras_transport_core::ReqwestTransport::new());
self.build_with_transport(transport)
diff --git a/crates/rest/ras-rest-macro/src/lib.rs b/crates/rest/ras-rest-macro/src/lib.rs
index 6d33a01..e1a64af 100644
--- a/crates/rest/ras-rest-macro/src/lib.rs
+++ b/crates/rest/ras-rest-macro/src/lib.rs
@@ -78,9 +78,14 @@ struct ServiceDefinition {
base_path: String,
openapi: Option,
static_hosting: static_hosting::StaticHostingConfig,
+ body_limit: Option,
+ feature_gated: bool,
endpoints: Vec,
}
+/// Default maximum JSON body size in bytes (matches axum's default).
+const DEFAULT_BODY_LIMIT: usize = 2 * 1024 * 1024;
+
#[derive(Debug)]
enum OpenApiConfig {
Enabled,
@@ -244,9 +249,11 @@ impl Parse for ServiceDefinition {
let base_path = base_path_lit.value();
let _ = content.parse::()?;
- // Parse optional fields (openapi, serve_docs, docs_path, ui_theme)
+ // Parse optional fields (openapi, serve_docs, docs_path, ui_theme, body_limit)
let mut openapi = None;
let mut static_hosting = static_hosting::StaticHostingConfig::default();
+ let mut body_limit = None;
+ let mut feature_gated = false;
// Parse optional fields
while content.peek(Ident) {
@@ -292,6 +299,18 @@ impl Parse for ServiceDefinition {
let theme = content.parse::()?;
static_hosting.ui_theme = theme.value();
let _ = content.parse::()?;
+ } else if field_name == "body_limit" {
+ let _ = content.parse::()?; // "body_limit"
+ let _ = content.parse::()?;
+ let limit = content.parse::()?;
+ body_limit = Some(limit.base10_parse::()?);
+ let _ = content.parse::()?;
+ } else if field_name == "feature_gated" {
+ let _ = content.parse::()?; // "feature_gated"
+ let _ = content.parse::()?;
+ let enabled = content.parse::()?;
+ feature_gated = enabled.value();
+ let _ = content.parse::()?;
} else if field_name == "endpoints" {
break; // Start parsing endpoints
} else {
@@ -325,6 +344,8 @@ impl Parse for ServiceDefinition {
base_path,
openapi,
static_hosting,
+ body_limit,
+ feature_gated,
endpoints,
})
}
@@ -766,11 +787,65 @@ fn generate_service_code(service_def: ServiceDefinition) -> syn::Result axum::response::Response {
+ use axum::response::IntoResponse;
+ let (status, message) = match error {
+ ras_auth_core::AuthorizeError::MissingCredential => (
+ axum::http::StatusCode::UNAUTHORIZED,
+ "Missing or invalid Authorization header",
+ ),
+ ras_auth_core::AuthorizeError::CsrfValidationFailed => (
+ axum::http::StatusCode::FORBIDDEN,
+ "CSRF validation failed",
+ ),
+ ras_auth_core::AuthorizeError::AuthenticationFailed(_) => (
+ axum::http::StatusCode::UNAUTHORIZED,
+ "Authentication failed",
+ ),
+ ras_auth_core::AuthorizeError::NoAuthProvider => (
+ axum::http::StatusCode::INTERNAL_SERVER_ERROR,
+ "No auth provider configured",
+ ),
+ ras_auth_core::AuthorizeError::InsufficientPermissions(_) => (
+ axum::http::StatusCode::FORBIDDEN,
+ "Insufficient permissions",
+ ),
+ };
+ (status, axum::Json(serde_json::json!({ "error": message }))).into_response()
+ }
+
/// Generated service trait
#[async_trait::async_trait]
#[allow(private_interfaces, private_bounds)]
@@ -894,20 +969,23 @@ fn generate_service_code(service_def: ServiceDefinition) -> syn::Result json.0,
- Err(_) => {
- use axum::response::IntoResponse;
- return (
- axum::http::StatusCode::BAD_REQUEST,
- axum::Json(serde_json::json!({
- "error": "Invalid JSON"
- }))
- ).into_response();
- },
- };
- }
+ generate_body_extraction()
} else {
quote! {}
};
@@ -1365,82 +1430,21 @@ fn generate_legacy_handler_body(
canonical_args.insert(0, quote! { &user });
quote! {
- #json_handling
-
- let auth_credential = match ras_auth_core::extract_auth_credential(&headers, &auth_transport) {
- Ok(credential) => credential,
- Err(_) => {
- use axum::response::IntoResponse;
- return (
- axum::http::StatusCode::UNAUTHORIZED,
- axum::Json(serde_json::json!({
- "error": "Missing or invalid Authorization header"
- }))
- ).into_response();
- },
+ // Authenticate and authorize: credential → CSRF → authenticate
+ // → OR-of-AND permission groups (shared ras-auth-core pipeline)
+ let user = match ras_auth_core::authorize_request(
+ #method,
+ &headers,
+ &auth_transport,
+ auth_provider.as_deref(),
+ &required_permission_groups,
+ ).await {
+ Ok(user) => user,
+ Err(error) => return __ras_authorize_error_response(error),
};
- if let Err(_) = ras_auth_core::validate_csrf_for_credential(#method, &headers, &auth_credential, &auth_transport) {
- use axum::response::IntoResponse;
- return (
- axum::http::StatusCode::FORBIDDEN,
- axum::Json(serde_json::json!({
- "error": "CSRF validation failed"
- }))
- ).into_response();
- }
-
- let user = match &auth_provider {
- Some(provider) => match provider.authenticate(auth_credential.token().to_string()).await {
- Ok(user) => user,
- Err(_) => {
- use axum::response::IntoResponse;
- return (
- axum::http::StatusCode::UNAUTHORIZED,
- axum::Json(serde_json::json!({
- "error": "Authentication failed"
- }))
- ).into_response();
- },
- },
- None => {
- use axum::response::IntoResponse;
- return (
- axum::http::StatusCode::INTERNAL_SERVER_ERROR,
- axum::Json(serde_json::json!({
- "error": "No auth provider configured"
- }))
- ).into_response();
- },
- };
-
- let has_non_empty_groups = required_permission_groups.iter().any(|g| !g.is_empty());
- if has_non_empty_groups {
- let mut has_permission = false;
-
- for permission_group in &required_permission_groups {
- if permission_group.is_empty() {
- has_permission = true;
- break;
- } else {
- let group_result = auth_provider.as_ref().unwrap().check_permissions(&user, permission_group);
- if group_result.is_ok() {
- has_permission = true;
- break;
- }
- }
- }
-
- if !has_permission {
- use axum::response::IntoResponse;
- return (
- axum::http::StatusCode::FORBIDDEN,
- axum::Json(serde_json::json!({
- "error": "Insufficient permissions"
- }))
- ).into_response();
- }
- }
+ // Read and parse the body only after auth has succeeded
+ #json_handling
if let Some(tracker) = &with_usage_tracker {
let tracker_headers =
@@ -1546,9 +1550,11 @@ fn generate_axum_handler(
});
}
- // Add request body extractor if present - use Result to handle JSON parsing errors
+ // Take the raw request when a body is declared. The body is read and
+ // deserialized inside the handler AFTER auth/CSRF/permission checks, so
+ // unauthenticated clients cannot make the server buffer or parse payloads.
if request_type.is_some() {
- extractors.push(quote! { body_result: Result, axum::extract::rejection::JsonRejection> });
+ extractors.push(quote! { request: axum::extract::Request });
}
quote! {
@@ -1556,6 +1562,43 @@ fn generate_axum_handler(
}
}
+/// Generated code that reads and JSON-deserializes the request body from the
+/// raw `request` extractor, bounded by `__RAS_BODY_LIMIT`.
+///
+/// For authenticated endpoints this must be emitted AFTER the
+/// auth/CSRF/permission block so unauthenticated clients cannot make the
+/// server buffer or parse payloads.
+fn generate_body_extraction() -> proc_macro2::TokenStream {
+ quote! {
+ let body = {
+ let body_bytes = match ::axum::body::to_bytes(request.into_body(), __RAS_BODY_LIMIT).await {
+ Ok(bytes) => bytes,
+ Err(_) => {
+ use axum::response::IntoResponse;
+ return (
+ axum::http::StatusCode::PAYLOAD_TOO_LARGE,
+ axum::Json(serde_json::json!({
+ "error": "Request body too large or unreadable"
+ }))
+ ).into_response();
+ },
+ };
+ match serde_json::from_slice(&body_bytes) {
+ Ok(body) => body,
+ Err(_) => {
+ use axum::response::IntoResponse;
+ return (
+ axum::http::StatusCode::BAD_REQUEST,
+ axum::Json(serde_json::json!({
+ "error": "Invalid JSON"
+ }))
+ ).into_response();
+ },
+ }
+ };
+ }
+}
+
fn generate_handler_body(
endpoint: &EndpointDefinition,
handler_name: &Ident,
@@ -1587,21 +1630,7 @@ fn generate_handler_body(
// Handle JSON body extraction with error handling
let json_handling = if endpoint.request_type.is_some() {
args.push(quote! { body });
- quote! {
- // Handle JSON parsing errors
- let body = match body_result {
- Ok(json) => json.0,
- Err(_) => {
- use axum::response::IntoResponse;
- return (
- axum::http::StatusCode::BAD_REQUEST,
- axum::Json(serde_json::json!({
- "error": "Invalid JSON"
- }))
- ).into_response();
- },
- };
- }
+ generate_body_extraction()
} else {
quote! {}
};
@@ -1681,110 +1710,27 @@ fn generate_handler_body(
// Handle JSON body extraction with error handling
let json_handling = if endpoint.request_type.is_some() {
args.push(quote! { body });
- quote! {
- // Handle JSON parsing errors
- let body = match body_result {
- Ok(json) => json.0,
- Err(_) => {
- use axum::response::IntoResponse;
- return (
- axum::http::StatusCode::BAD_REQUEST,
- axum::Json(serde_json::json!({
- "error": "Invalid JSON"
- }))
- ).into_response();
- },
- };
- }
+ generate_body_extraction()
} else {
quote! {}
};
quote! {
- #json_handling
-
- // Extract and validate auth token
- let auth_credential = match ras_auth_core::extract_auth_credential(&headers, &auth_transport) {
- Ok(credential) => credential,
- Err(_) => {
- use axum::response::IntoResponse;
- return (
- axum::http::StatusCode::UNAUTHORIZED,
- axum::Json(serde_json::json!({
- "error": "Missing or invalid Authorization header"
- }))
- ).into_response();
- },
- };
-
- if let Err(_) = ras_auth_core::validate_csrf_for_credential(#method, &headers, &auth_credential, &auth_transport) {
- use axum::response::IntoResponse;
- return (
- axum::http::StatusCode::FORBIDDEN,
- axum::Json(serde_json::json!({
- "error": "CSRF validation failed"
- }))
- ).into_response();
- }
-
- // Authenticate user
- let user = match &auth_provider {
- Some(provider) => match provider.authenticate(auth_credential.token().to_string()).await {
- Ok(user) => user,
- Err(_) => {
- use axum::response::IntoResponse;
- return (
- axum::http::StatusCode::UNAUTHORIZED,
- axum::Json(serde_json::json!({
- "error": "Authentication failed"
- }))
- ).into_response();
- },
- },
- None => {
- use axum::response::IntoResponse;
- return (
- axum::http::StatusCode::INTERNAL_SERVER_ERROR,
- axum::Json(serde_json::json!({
- "error": "No auth provider configured"
- }))
- ).into_response();
- },
+ // Authenticate and authorize: credential → CSRF → authenticate
+ // → OR-of-AND permission groups (shared ras-auth-core pipeline)
+ let user = match ras_auth_core::authorize_request(
+ #method,
+ &headers,
+ &auth_transport,
+ auth_provider.as_deref(),
+ &required_permission_groups,
+ ).await {
+ Ok(user) => user,
+ Err(error) => return __ras_authorize_error_response(error),
};
- // Check permissions - AND within groups, OR between groups
- // Only check permissions if we have non-empty groups
- let has_non_empty_groups = required_permission_groups.iter().any(|g| !g.is_empty());
- if has_non_empty_groups {
- let mut has_permission = false;
-
- // Check each permission group (OR logic between groups)
- for permission_group in &required_permission_groups {
- // Check if user has ALL permissions in this group (AND logic within group)
- if permission_group.is_empty() {
- // Empty group means any authenticated user can access
- has_permission = true;
- break;
- } else {
- // Check if user has all permissions in this group
- let group_result = auth_provider.as_ref().unwrap().check_permissions(&user, permission_group);
- if group_result.is_ok() {
- has_permission = true;
- break;
- }
- }
- }
-
- if !has_permission {
- use axum::response::IntoResponse;
- return (
- axum::http::StatusCode::FORBIDDEN,
- axum::Json(serde_json::json!({
- "error": "Insufficient permissions"
- }))
- ).into_response();
- }
- }
+ // Read and parse the body only after auth has succeeded
+ #json_handling
// Call usage tracker if configured
if let Some(tracker) = &with_usage_tracker {
diff --git a/crates/rest/ras-rest-macro/tests/http_integration.rs b/crates/rest/ras-rest-macro/tests/http_integration.rs
index aa4348b..5fd7604 100644
--- a/crates/rest/ras-rest-macro/tests/http_integration.rs
+++ b/crates/rest/ras-rest-macro/tests/http_integration.rs
@@ -1308,3 +1308,70 @@ async fn test_query_parameters_with_path_params() {
let posts_response: PostsResponse = response.json();
assert_eq!(posts_response.posts.len(), 20); // Default per_page
}
+
+// Minimal service exercising the body_limit option.
+rest_service!({
+ service_name: TinyBodyService,
+ base_path: "/tiny",
+ body_limit: 64,
+ endpoints: [
+ POST UNAUTHORIZED echo(Value) -> Value,
+ ]
+});
+
+struct TinyBodyServiceImpl;
+
+#[async_trait::async_trait]
+impl TinyBodyServiceTrait for TinyBodyServiceImpl {
+ async fn post_echo(&self, request: Value) -> ras_rest_core::RestResult {
+ Ok(RestResponse::ok(request))
+ }
+}
+
+#[tokio::test]
+async fn test_body_is_not_parsed_before_auth() {
+ let server = create_rest_test_server();
+
+ // Invalid JSON without credentials must be rejected by auth (401, not
+ // 400), proving the body is neither read nor parsed before the
+ // auth/CSRF/permission checks succeed.
+ let response = server
+ .post("/api/v1/users")
+ .text("{invalid json")
+ .content_type("application/json")
+ .await;
+ assert_eq!(response.status_code().as_u16(), 401);
+
+ // Same body with an invalid token: still rejected by auth.
+ let response = server
+ .post("/api/v1/users")
+ .authorization_bearer("wrong-token")
+ .text("{invalid json")
+ .content_type("application/json")
+ .await;
+ assert_eq!(response.status_code().as_u16(), 401);
+
+ // With valid credentials the malformed body is now parsed and rejected.
+ let response = server
+ .post("/api/v1/users")
+ .authorization_bearer("admin-token")
+ .text("{invalid json")
+ .content_type("application/json")
+ .await;
+ assert_eq!(response.status_code().as_u16(), 400);
+}
+
+#[tokio::test]
+async fn test_body_limit_option_enforced() {
+ let app = TinyBodyServiceBuilder::new(TinyBodyServiceImpl).build();
+ let server = TestServer::builder().mock_transport().build(app).unwrap();
+
+ let response = server.post("/tiny/echo").json(&json!({"ok": true})).await;
+ assert_eq!(response.status_code().as_u16(), 200);
+
+ let response = server
+ .post("/tiny/echo")
+ .json(&json!({"data": "x".repeat(200)}))
+ .await;
+ assert_eq!(response.status_code().as_u16(), 413);
+}
diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/Cargo.toml b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/Cargo.toml
index 62e1a68..61b0d39 100644
--- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/Cargo.toml
+++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/Cargo.toml
@@ -64,3 +64,4 @@ wasm = ["tokio", "web-sys", "wasm-bindgen", "wasm-bindgen-futures", "js-sys"]
[dev-dependencies]
tracing-subscriber = { workspace = true }
chrono = { workspace = true }
+tokio = { workspace = true, features = ["test-util"] }
diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/client.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/client.rs
index d1e2e03..ff0864b 100644
--- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/client.rs
+++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/client.rs
@@ -42,6 +42,8 @@ pub struct Client {
request_id_counter: Arc,
shutdown_tx: Arc>>>,
message_tx: Arc>>>,
+ /// Signaled when the server's ConnectionEstablished message arrives
+ connected_notify: Arc,
}
struct IncomingMessageContext<'a> {
@@ -52,6 +54,7 @@ struct IncomingMessageContext<'a> {
connection_event_handlers: &'a DashMap,
connection_id: &'a RwLock