diff --git a/crates/rustapi-context/src/cost.rs b/crates/rustapi-context/src/cost.rs index a572b7b..cb2ec5b 100644 --- a/crates/rustapi-context/src/cost.rs +++ b/crates/rustapi-context/src/cost.rs @@ -130,9 +130,55 @@ impl CostTracker { /// Record a cost delta and check budget. /// - /// Returns `Err(ContextError::BudgetExceeded)` if any limit is breached - /// **after** applying the delta (fail-open, then check). + /// Returns `Err(ContextError::BudgetExceeded)` if any limit *would be* + /// breached by this delta. The delta is only applied when all checks + /// pass, so a rejected call never modifies the running totals. + /// + /// # Concurrency note + /// Budget enforcement uses relaxed atomic reads followed by atomic + /// increments, so a TOCTOU race is theoretically possible under very high + /// concurrent load. The design is intentional: the lock-free accounting + /// keeps overhead minimal and a slight overshoot under extreme concurrency + /// is acceptable. For hard limits, callers should pair this with an + /// external quota gate. pub fn record(&self, delta: &CostDelta) -> Result<(), ContextError> { + // Pre-check: verify that adding this delta will not breach the budget + // *before* touching any counters. + if let Some(ref budget) = self.budget { + let new_tokens = + self.total_tokens() + delta.input_tokens + delta.output_tokens; + if let Some(max) = budget.max_tokens { + if new_tokens > max { + return Err(ContextError::budget_exceeded(format!( + "Token limit {max} would be exceeded \ + (current {}, delta {})", + self.total_tokens(), + delta.input_tokens + delta.output_tokens + ))); + } + } + let new_cost = self.total_cost_micros() + delta.cost_micros; + if let Some(max) = budget.max_cost_micros { + if new_cost > max { + return Err(ContextError::budget_exceeded(format!( + "Cost limit ${:.4} would be exceeded \ + (current ${:.4}, delta ${:.4})", + max as f64 / 1_000_000.0, + self.total_cost_micros() as f64 / 1_000_000.0, + delta.cost_micros as f64 / 1_000_000.0 + ))); + } + } + let new_calls = u64::from(self.api_calls()) + 1; + if let Some(max) = budget.max_api_calls { + if new_calls > u64::from(max) { + return Err(ContextError::budget_exceeded(format!( + "API call limit {max} would be exceeded ({new_calls} calls)" + ))); + } + } + } + self.input_tokens .fetch_add(delta.input_tokens, Ordering::Relaxed); self.output_tokens @@ -141,7 +187,7 @@ impl CostTracker { .fetch_add(delta.cost_micros, Ordering::Relaxed); self.api_calls.fetch_add(1, Ordering::Relaxed); - self.check_budget() + Ok(()) } /// Check whether the current totals exceed the budget. diff --git a/crates/rustapi-context/src/trace.rs b/crates/rustapi-context/src/trace.rs index c826700..3a1a264 100644 --- a/crates/rustapi-context/src/trace.rs +++ b/crates/rustapi-context/src/trace.rs @@ -288,9 +288,13 @@ impl SpanGuard { impl Drop for SpanGuard { fn drop(&mut self) { - // If not explicitly completed/failed, mark as error and record. + // Panic-safety: if the span was not explicitly completed or failed + // (e.g. the thread panicked during step execution), record it with an + // error status so the trace tree is never left with a dangling + // in-progress node. Explicit complete()/fail() callers use + // std::mem::forget to skip this path entirely. if self.node.status == TraceStatus::InProgress { - self.node.fail("span dropped without completion"); + self.node.fail("span dropped without completion (likely due to a panic)"); } self.tree.add_root_child(self.node.clone()); } diff --git a/crates/rustapi-tools/src/graph.rs b/crates/rustapi-tools/src/graph.rs index 1acf32a..9eb9a62 100644 --- a/crates/rustapi-tools/src/graph.rs +++ b/crates/rustapi-tools/src/graph.rs @@ -227,10 +227,10 @@ impl ToolGraph { } ToolNode::Parallel { id: _, nodes } => { - // For parallel execution, we need to collect results. - // Since outputs is &mut, we execute sequentially here for safety. - // A production implementation would use JoinSet with per-node output maps. - let mut handles = Vec::new(); + // Spawn all child nodes concurrently via JoinSet so results + // are collected as each task finishes rather than in + // submission order. Any failure aborts the remaining tasks. + let mut set = tokio::task::JoinSet::new(); let registry = registry.clone(); let ctx = ctx.clone(); @@ -238,9 +238,8 @@ impl ToolGraph { let reg = registry.clone(); let c = ctx.clone(); let node = child_node.clone(); - handles.push(tokio::spawn(async move { + set.spawn(async move { let mut local_outputs = HashMap::new(); - // We create a temporary graph to execute the child node. let graph = ToolGraph::new("parallel_child", node.clone()); match graph .execute_node(&node, ®, &c, &mut local_outputs) @@ -249,15 +248,18 @@ impl ToolGraph { Ok(()) => Ok(local_outputs), Err(e) => Err(e), } - })); + }); } - for handle in handles { - match handle.await { + while let Some(result) = set.join_next().await { + match result { Ok(Ok(local_outputs)) => { outputs.extend(local_outputs); } - Ok(Err(e)) => return Err(e), + Ok(Err(e)) => { + set.abort_all(); + return Err(e); + } Err(e) => { return Err(ToolError::internal(format!("Join error: {e}"))); } diff --git a/crates/rustapi-tools/src/tool.rs b/crates/rustapi-tools/src/tool.rs index e70190b..434f126 100644 --- a/crates/rustapi-tools/src/tool.rs +++ b/crates/rustapi-tools/src/tool.rs @@ -4,9 +4,24 @@ use rustapi_context::{CostDelta, RequestContext}; use serde::{Deserialize, Serialize}; // --------------------------------------------------------------------------- -// ToolOutput — result of a tool execution +// ToolFuture — type alias for async closure return types // --------------------------------------------------------------------------- +/// Boxed, `Send`-safe future returned by closure-based tools and steps. +/// +/// Use this alias to avoid writing out `Pin + Send + '_>>` +/// in every closure signature: +/// +/// ```rust,ignore +/// ClosureTool::new("my_tool", "...", schema, |ctx, input| { +/// Box::pin(async move { Ok(ToolOutput::value(input)) }) +/// }); +/// ``` +pub type ToolFuture<'a> = std::pin::Pin< + Box> + Send + 'a>, +>; + + /// Result of a single tool execution. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolOutput { @@ -153,10 +168,7 @@ pub trait Tool: Send + Sync + 'static { /// A tool created from a closure, for quick prototyping. pub struct ClosureTool where - F: Fn(&RequestContext, serde_json::Value) -> std::pin::Pin> + Send + '_>> - + Send - + Sync - + 'static, + F: Fn(&RequestContext, serde_json::Value) -> ToolFuture<'_> + Send + Sync + 'static, { name: String, description: String, @@ -166,10 +178,7 @@ where impl ClosureTool where - F: Fn(&RequestContext, serde_json::Value) -> std::pin::Pin> + Send + '_>> - + Send - + Sync - + 'static, + F: Fn(&RequestContext, serde_json::Value) -> ToolFuture<'_> + Send + Sync + 'static, { /// Create a new closure-based tool. pub fn new( @@ -190,10 +199,7 @@ where #[async_trait] impl Tool for ClosureTool where - F: Fn(&RequestContext, serde_json::Value) -> std::pin::Pin> + Send + '_>> - + Send - + Sync - + 'static, + F: Fn(&RequestContext, serde_json::Value) -> ToolFuture<'_> + Send + Sync + 'static, { fn name(&self) -> &str { &self.name