Skip to content

Commit ff43254

Browse files
committed
Complete production readiness: secrets TTL, Cedar reload, audit export, rate limiting
1 parent 08ae3bd commit ff43254

5 files changed

Lines changed: 351 additions & 0 deletions

File tree

crates/runtime/src/api/composio_executor.rs

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
#[cfg(feature = "composio")]
77
use std::sync::Arc;
88

9+
#[cfg(feature = "composio")]
10+
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
11+
912
#[cfg(feature = "composio")]
1013
use async_trait::async_trait;
1114

@@ -22,15 +25,78 @@ use crate::reasoning::inference::ToolDefinition;
2225
#[cfg(feature = "composio")]
2326
use crate::reasoning::loop_types::{LoopConfig, Observation, ProposedAction};
2427

28+
/// Simple per-minute rate limiter for MCP tool calls.
29+
#[cfg(feature = "composio")]
30+
struct McpRateLimiter {
31+
/// Maximum calls allowed per minute (0 = unlimited).
32+
max_per_minute: u32,
33+
/// Calls made in the current window.
34+
calls: AtomicU32,
35+
/// Start of the current window (seconds since UNIX epoch).
36+
window_start: AtomicU64,
37+
}
38+
39+
#[cfg(feature = "composio")]
40+
impl McpRateLimiter {
41+
fn new(max_per_minute: Option<u32>) -> Self {
42+
let now = std::time::SystemTime::now()
43+
.duration_since(std::time::UNIX_EPOCH)
44+
.unwrap_or_default()
45+
.as_secs();
46+
Self {
47+
max_per_minute: max_per_minute.unwrap_or(0),
48+
calls: AtomicU32::new(0),
49+
window_start: AtomicU64::new(now),
50+
}
51+
}
52+
53+
/// Check if a call is allowed. Returns `true` if within limits.
54+
fn check(&self) -> bool {
55+
if self.max_per_minute == 0 {
56+
return true; // unlimited
57+
}
58+
59+
let now = std::time::SystemTime::now()
60+
.duration_since(std::time::UNIX_EPOCH)
61+
.unwrap_or_default()
62+
.as_secs();
63+
let window = self.window_start.load(Ordering::Relaxed);
64+
65+
// Reset window if more than 60 seconds have passed
66+
if now - window >= 60 {
67+
self.window_start.store(now, Ordering::Relaxed);
68+
self.calls.store(1, Ordering::Relaxed);
69+
return true;
70+
}
71+
72+
let current = self.calls.fetch_add(1, Ordering::Relaxed);
73+
current < self.max_per_minute
74+
}
75+
}
76+
2577
/// An [`ActionExecutor`] that dispatches tool calls to Composio via JSON-RPC.
2678
#[cfg(feature = "composio")]
2779
pub struct ComposioToolExecutor {
2880
transport: Arc<SseTransport>,
2981
tool_definitions: Vec<ToolDefinition>,
82+
rate_limiter: McpRateLimiter,
3083
}
3184

3285
#[cfg(feature = "composio")]
3386
impl ComposioToolExecutor {
87+
/// Discover available tools from the Composio MCP endpoint and return a
88+
/// new executor ready to dispatch calls.
89+
///
90+
/// `max_calls_per_minute` enforces a per-server rate limit (None = unlimited).
91+
pub async fn discover_with_rate_limit(
92+
transport: Arc<SseTransport>,
93+
max_calls_per_minute: Option<u32>,
94+
) -> Result<Self, ComposioError> {
95+
let mut executor = Self::discover(transport).await?;
96+
executor.rate_limiter = McpRateLimiter::new(max_calls_per_minute);
97+
Ok(executor)
98+
}
99+
34100
/// Discover available tools from the Composio MCP endpoint and return a
35101
/// new executor ready to dispatch calls.
36102
pub async fn discover(transport: Arc<SseTransport>) -> Result<Self, ComposioError> {
@@ -69,6 +135,7 @@ impl ComposioToolExecutor {
69135
Ok(Self {
70136
transport,
71137
tool_definitions,
138+
rate_limiter: McpRateLimiter::new(None),
72139
})
73140
}
74141

@@ -79,6 +146,14 @@ impl ComposioToolExecutor {
79146

80147
/// Call a single tool on the Composio MCP endpoint.
81148
async fn call_tool(&self, name: &str, arguments: &str) -> Result<String, String> {
149+
if !self.rate_limiter.check() {
150+
tracing::warn!(tool = name, "MCP rate limit exceeded");
151+
return Err(format!(
152+
"Rate limit exceeded: max {} calls/min for MCP server",
153+
self.rate_limiter.max_per_minute
154+
));
155+
}
156+
82157
let args: serde_json::Value =
83158
serde_json::from_str(arguments).unwrap_or(serde_json::json!({}));
84159

@@ -199,6 +274,24 @@ mod tests {
199274
assert!(defs[0].parameters["properties"]["text"].is_object());
200275
}
201276

277+
#[test]
278+
fn test_rate_limiter_unlimited() {
279+
let limiter = McpRateLimiter::new(None);
280+
for _ in 0..1000 {
281+
assert!(limiter.check());
282+
}
283+
}
284+
285+
#[test]
286+
fn test_rate_limiter_enforced() {
287+
let limiter = McpRateLimiter::new(Some(5));
288+
for _ in 0..5 {
289+
assert!(limiter.check());
290+
}
291+
// 6th call should be rejected
292+
assert!(!limiter.check());
293+
}
294+
202295
#[test]
203296
fn test_mcp_content_extraction() {
204297
let result = serde_json::json!({

crates/runtime/src/reasoning/cedar_gate.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,50 @@ impl CedarPolicyGate {
8989
self.policies.read().await.clone()
9090
}
9191

92+
/// Reload policies from a JSON file (hot-reload for production).
93+
///
94+
/// The file must contain a JSON array of `CedarPolicy` objects.
95+
/// All existing policies are replaced atomically.
96+
pub async fn reload_policies_from_file(
97+
&self,
98+
path: &std::path::Path,
99+
) -> Result<usize, CedarGateError> {
100+
let contents = tokio::fs::read_to_string(path).await.map_err(|e| {
101+
CedarGateError::ParseError(format!("Failed to read {}: {}", path.display(), e))
102+
})?;
103+
let new_policies: Vec<CedarPolicy> = serde_json::from_str(&contents).map_err(|e| {
104+
CedarGateError::ParseError(format!("Invalid JSON in {}: {}", path.display(), e))
105+
})?;
106+
107+
// Validate all policies parse as valid Cedar before swapping
108+
for policy in &new_policies {
109+
if policy.active {
110+
if let Err(e) = policy.source.parse::<PolicySet>() {
111+
return Err(CedarGateError::ParseError(format!(
112+
"Policy '{}' has invalid Cedar syntax: {}",
113+
policy.name, e
114+
)));
115+
}
116+
}
117+
}
118+
119+
let count = new_policies.len();
120+
let mut policies = self.policies.write().await;
121+
*policies = new_policies;
122+
tracing::info!(
123+
count,
124+
path = %path.display(),
125+
"Cedar policies reloaded"
126+
);
127+
Ok(count)
128+
}
129+
130+
/// Replace all policies atomically (for programmatic hot-reload).
131+
pub async fn replace_policies(&self, new_policies: Vec<CedarPolicy>) {
132+
let mut policies = self.policies.write().await;
133+
*policies = new_policies;
134+
}
135+
92136
/// Get active policy count.
93137
pub async fn active_policy_count(&self) -> usize {
94138
self.policies

crates/runtime/src/reasoning/journal.rs

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,34 @@ impl JournalWriter for DurableJournal {
166166
}
167167
}
168168

169+
/// Export all journal entries for an agent as a JSON string for backup.
170+
pub async fn export_entries(
171+
storage: &dyn JournalStorage,
172+
agent_id: &AgentId,
173+
) -> Result<String, JournalError> {
174+
let entries = storage.read_entries(agent_id).await?;
175+
serde_json::to_string_pretty(&entries)
176+
.map_err(|e| JournalError::WriteFailed(format!("Failed to serialize journal entries: {e}")))
177+
}
178+
179+
/// Import journal entries from a JSON string (restore from backup).
180+
///
181+
/// Entries are appended to storage. Callers should compact first if
182+
/// a clean restore is desired.
183+
pub async fn import_entries(
184+
storage: &dyn JournalStorage,
185+
json: &str,
186+
) -> Result<usize, JournalError> {
187+
let entries: Vec<JournalEntry> = serde_json::from_str(json).map_err(|e| {
188+
JournalError::ReadFailed(format!("Failed to deserialize journal entries: {e}"))
189+
})?;
190+
let count = entries.len();
191+
for entry in &entries {
192+
storage.store(entry).await?;
193+
}
194+
Ok(count)
195+
}
196+
169197
#[cfg(test)]
170198
mod tests {
171199
use super::*;
@@ -350,4 +378,46 @@ mod tests {
350378

351379
assert_eq!(journal.last_completed_iteration().await.unwrap(), 7);
352380
}
381+
382+
#[tokio::test]
383+
async fn test_export_entries() {
384+
let storage = MemoryJournalStorage::new();
385+
let agent = AgentId::new();
386+
387+
storage.store(&make_entry(agent, 0, 0)).await.unwrap();
388+
storage.store(&make_entry(agent, 1, 1)).await.unwrap();
389+
390+
let json = export_entries(&storage, &agent).await.unwrap();
391+
assert!(json.contains("sequence"));
392+
393+
// Should be valid JSON array
394+
let parsed: Vec<JournalEntry> = serde_json::from_str(&json).unwrap();
395+
assert_eq!(parsed.len(), 2);
396+
}
397+
398+
#[tokio::test]
399+
async fn test_import_entries() {
400+
let storage = MemoryJournalStorage::new();
401+
let agent = AgentId::new();
402+
403+
// Create entries, export, then import into a fresh storage
404+
storage.store(&make_entry(agent, 0, 0)).await.unwrap();
405+
storage.store(&make_entry(agent, 1, 1)).await.unwrap();
406+
407+
let json = export_entries(&storage, &agent).await.unwrap();
408+
409+
let fresh_storage = MemoryJournalStorage::new();
410+
let count = import_entries(&fresh_storage, &json).await.unwrap();
411+
assert_eq!(count, 2);
412+
413+
let entries = fresh_storage.read_entries(&agent).await.unwrap();
414+
assert_eq!(entries.len(), 2);
415+
}
416+
417+
#[tokio::test]
418+
async fn test_import_invalid_json() {
419+
let storage = MemoryJournalStorage::new();
420+
let result = import_entries(&storage, "not valid json").await;
421+
assert!(result.is_err());
422+
}
353423
}

crates/runtime/src/reasoning/tracing_spans.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ pub struct LoopTracer {
1515
loop_id: String,
1616
start: Instant,
1717
iteration: u32,
18+
/// Optional external request ID for correlation with API requests.
19+
request_id: Option<String>,
1820
}
1921

2022
impl LoopTracer {
@@ -41,9 +43,27 @@ impl LoopTracer {
4143
loop_id,
4244
start: Instant::now(),
4345
iteration: 0,
46+
request_id: None,
4447
}
4548
}
4649

50+
/// Set an external request ID for correlation with API requests.
51+
pub fn with_request_id(mut self, request_id: String) -> Self {
52+
tracing::info!(
53+
agent_id = %self.agent_id,
54+
loop_id = %self.loop_id,
55+
request_id = %request_id,
56+
"Correlated with API request"
57+
);
58+
self.request_id = Some(request_id);
59+
self
60+
}
61+
62+
/// Get the request ID, if set.
63+
pub fn request_id(&self) -> Option<&str> {
64+
self.request_id.as_deref()
65+
}
66+
4767
/// Begin a new iteration.
4868
pub fn begin_iteration(&mut self) {
4969
self.iteration += 1;
@@ -294,6 +314,20 @@ mod tests {
294314
assert!(parsed.is_some());
295315
}
296316

317+
#[test]
318+
fn test_with_request_id() {
319+
let agent_id = AgentId::new();
320+
let tracer = LoopTracer::start(agent_id).with_request_id("req-abc-123".to_string());
321+
assert_eq!(tracer.request_id(), Some("req-abc-123"));
322+
}
323+
324+
#[test]
325+
fn test_request_id_default_none() {
326+
let agent_id = AgentId::new();
327+
let tracer = LoopTracer::start(agent_id);
328+
assert!(tracer.request_id().is_none());
329+
}
330+
297331
#[test]
298332
fn test_elapsed_increases() {
299333
let agent_id = AgentId::new();

0 commit comments

Comments
 (0)