Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 49 additions & 3 deletions crates/rustapi-context/src/cost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,55 @@ impl CostTracker {

/// Record a cost delta and check budget.
///
/// Returns `Err(ContextError::BudgetExceeded)` if any limit is breached
/// **after** applying the delta (fail-open, then check).
/// Returns `Err(ContextError::BudgetExceeded)` if any limit *would be*
/// breached by this delta. The delta is only applied when all checks
/// pass, so a rejected call never modifies the running totals.
///
/// # Concurrency note
/// Budget enforcement uses relaxed atomic reads followed by atomic
/// increments, so a TOCTOU race is theoretically possible under very high
/// concurrent load. The design is intentional: the lock-free accounting
/// keeps overhead minimal and a slight overshoot under extreme concurrency
/// is acceptable. For hard limits, callers should pair this with an
/// external quota gate.
pub fn record(&self, delta: &CostDelta) -> Result<(), ContextError> {
// Pre-check: verify that adding this delta will not breach the budget
// *before* touching any counters.
if let Some(ref budget) = self.budget {
let new_tokens =
self.total_tokens() + delta.input_tokens + delta.output_tokens;
if let Some(max) = budget.max_tokens {
if new_tokens > max {
return Err(ContextError::budget_exceeded(format!(
"Token limit {max} would be exceeded \
(current {}, delta {})",
self.total_tokens(),
delta.input_tokens + delta.output_tokens
)));
}
}
let new_cost = self.total_cost_micros() + delta.cost_micros;
if let Some(max) = budget.max_cost_micros {
if new_cost > max {
return Err(ContextError::budget_exceeded(format!(
"Cost limit ${:.4} would be exceeded \
(current ${:.4}, delta ${:.4})",
max as f64 / 1_000_000.0,
self.total_cost_micros() as f64 / 1_000_000.0,
delta.cost_micros as f64 / 1_000_000.0
)));
}
}
let new_calls = u64::from(self.api_calls()) + 1;
if let Some(max) = budget.max_api_calls {
if new_calls > u64::from(max) {
return Err(ContextError::budget_exceeded(format!(
"API call limit {max} would be exceeded ({new_calls} calls)"
)));
}
}
}

self.input_tokens
.fetch_add(delta.input_tokens, Ordering::Relaxed);
self.output_tokens
Expand All @@ -141,7 +187,7 @@ impl CostTracker {
.fetch_add(delta.cost_micros, Ordering::Relaxed);
self.api_calls.fetch_add(1, Ordering::Relaxed);

self.check_budget()
Ok(())
}

/// Check whether the current totals exceed the budget.
Expand Down
8 changes: 6 additions & 2 deletions crates/rustapi-context/src/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,13 @@ impl SpanGuard {

impl Drop for SpanGuard {
fn drop(&mut self) {
// If not explicitly completed/failed, mark as error and record.
// Panic-safety: if the span was not explicitly completed or failed
// (e.g. the thread panicked during step execution), record it with an
// error status so the trace tree is never left with a dangling
// in-progress node. Explicit complete()/fail() callers use
// std::mem::forget to skip this path entirely.
if self.node.status == TraceStatus::InProgress {
self.node.fail("span dropped without completion");
self.node.fail("span dropped without completion (likely due to a panic)");
}
self.tree.add_root_child(self.node.clone());
}
Expand Down
22 changes: 12 additions & 10 deletions crates/rustapi-tools/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,20 +227,19 @@ impl ToolGraph {
}

ToolNode::Parallel { id: _, nodes } => {
// For parallel execution, we need to collect results.
// Since outputs is &mut, we execute sequentially here for safety.
// A production implementation would use JoinSet with per-node output maps.
let mut handles = Vec::new();
// Spawn all child nodes concurrently via JoinSet so results
// are collected as each task finishes rather than in
// submission order. Any failure aborts the remaining tasks.
let mut set = tokio::task::JoinSet::new();
let registry = registry.clone();
let ctx = ctx.clone();

for child_node in nodes {
let reg = registry.clone();
let c = ctx.clone();
let node = child_node.clone();
handles.push(tokio::spawn(async move {
set.spawn(async move {
let mut local_outputs = HashMap::new();
// We create a temporary graph to execute the child node.
let graph = ToolGraph::new("parallel_child", node.clone());
match graph
.execute_node(&node, &reg, &c, &mut local_outputs)
Expand All @@ -249,15 +248,18 @@ impl ToolGraph {
Ok(()) => Ok(local_outputs),
Err(e) => Err(e),
}
}));
});
}

for handle in handles {
match handle.await {
while let Some(result) = set.join_next().await {
match result {
Ok(Ok(local_outputs)) => {
outputs.extend(local_outputs);
}
Ok(Err(e)) => return Err(e),
Ok(Err(e)) => {
set.abort_all();
return Err(e);
}
Err(e) => {
return Err(ToolError::internal(format!("Join error: {e}")));
}
Expand Down
32 changes: 19 additions & 13 deletions crates/rustapi-tools/src/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,24 @@ use rustapi_context::{CostDelta, RequestContext};
use serde::{Deserialize, Serialize};

// ---------------------------------------------------------------------------
// ToolOutputresult of a tool execution
// ToolFuturetype alias for async closure return types
// ---------------------------------------------------------------------------

/// Boxed, `Send`-safe future returned by closure-based tools and steps.
///
/// Use this alias to avoid writing out `Pin<Box<dyn Future<...> + Send + '_>>`
/// in every closure signature:
///
/// ```rust,ignore
/// ClosureTool::new("my_tool", "...", schema, |ctx, input| {
/// Box::pin(async move { Ok(ToolOutput::value(input)) })
/// });
/// ```
pub type ToolFuture<'a> = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<ToolOutput, ToolError>> + Send + 'a>,
>;


/// Result of a single tool execution.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolOutput {
Expand Down Expand Up @@ -153,10 +168,7 @@ pub trait Tool: Send + Sync + 'static {
/// A tool created from a closure, for quick prototyping.
pub struct ClosureTool<F>
where
F: Fn(&RequestContext, serde_json::Value) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<ToolOutput, ToolError>> + Send + '_>>
+ Send
+ Sync
+ 'static,
F: Fn(&RequestContext, serde_json::Value) -> ToolFuture<'_> + Send + Sync + 'static,
{
name: String,
description: String,
Expand All @@ -166,10 +178,7 @@ where

impl<F> ClosureTool<F>
where
F: Fn(&RequestContext, serde_json::Value) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<ToolOutput, ToolError>> + Send + '_>>
+ Send
+ Sync
+ 'static,
F: Fn(&RequestContext, serde_json::Value) -> ToolFuture<'_> + Send + Sync + 'static,
{
/// Create a new closure-based tool.
pub fn new(
Expand All @@ -190,10 +199,7 @@ where
#[async_trait]
impl<F> Tool for ClosureTool<F>
where
F: Fn(&RequestContext, serde_json::Value) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<ToolOutput, ToolError>> + Send + '_>>
+ Send
+ Sync
+ 'static,
F: Fn(&RequestContext, serde_json::Value) -> ToolFuture<'_> + Send + Sync + 'static,
{
fn name(&self) -> &str {
&self.name
Expand Down
Loading