Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ async def _call_model() -> Message:
result_message: Message | None = None
async for event in process_stream(chunks):
if "stop" in event:
_, result_message, _, _ = event["stop"]
_, result_message, *_ = event["stop"]

if result_message is None:
raise RuntimeError("Failed to generate summary: no response from model")
Expand Down
4 changes: 3 additions & 1 deletion src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ async def _handle_model_execution(
):
yield event

stop_reason, message, usage, metrics = event["stop"]
stop_reason, message, usage, metrics, cost = event["stop"]
invocation_state.setdefault("request_state", {})

after_model_call_event = AfterModelCallEvent(
Expand Down Expand Up @@ -412,6 +412,8 @@ async def _handle_model_execution(
# Update metrics
agent.event_loop_metrics.update_usage(usage)
agent.event_loop_metrics.update_metrics(metrics)
if cost is not None:
agent.event_loop_metrics.update_cost(cost)

except Exception as e:
yield ForceStopEvent(reason=e)
Expand Down
5 changes: 4 additions & 1 deletion src/strands/event_loop/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ async def process_stream(

usage: Usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0)
metrics: Metrics = Metrics(latencyMs=0, timeToFirstByteMs=0)
cost: float | None = None

async for chunk in chunks:
# Check for cancellation during stream processing
Expand Down Expand Up @@ -433,10 +434,12 @@ async def process_stream(
int(1000 * (first_byte_time - start_time)) if (start_time and first_byte_time) else None
)
usage, metrics = extract_usage_metrics(chunk["metadata"], time_to_first_byte_ms)
if "cost" in chunk["metadata"]:
cost = chunk["metadata"]["cost"]
elif "redactContent" in chunk:
handle_redact_content(chunk["redactContent"], state)

yield ModelStopReason(stop_reason=stop_reason, message=state["message"], usage=usage, metrics=metrics)
yield ModelStopReason(stop_reason=stop_reason, message=state["message"], usage=usage, metrics=metrics, cost=cost)


async def stream_messages(
Expand Down
2 changes: 1 addition & 1 deletion src/strands/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ async def structured_output(
async for event in process_stream(response):
yield event

stop_reason, messages, _, _ = event["stop"]
stop_reason, messages, *_ = event["stop"]

if stop_reason != "tool_use":
raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".')
Expand Down
2 changes: 1 addition & 1 deletion src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,7 @@ async def structured_output(
async for event in streaming.process_stream(response):
yield event

stop_reason, messages, _, _ = event["stop"]
stop_reason, messages, *_ = event["stop"]

if stop_reason != "tool_use":
raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".')
Expand Down
61 changes: 53 additions & 8 deletions src/strands/models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent:
"""Format a LiteLLM response event into a standardized message chunk.

This method overrides OpenAI's format_chunk to handle the metadata case
with prompt caching support. All other chunk types use the parent implementation.
with prompt caching support and cost tracking. All other chunk types use the parent implementation.

Args:
event: A response event from the LiteLLM model.
Expand All @@ -223,23 +223,68 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent:

# Only LiteLLM over Anthropic supports cache write tokens
# Waiting until a more general approach is available to set cacheWriteInputTokens
cache_read_tokens = 0
cache_write_tokens = 0
if tokens_details := getattr(event["data"], "prompt_tokens_details", None):
if cached := getattr(tokens_details, "cached_tokens", None):
usage_data["cacheReadInputTokens"] = cached
cache_read_tokens = cached
if creation := getattr(event["data"], "cache_creation_input_tokens", None):
usage_data["cacheWriteInputTokens"] = creation
cache_write_tokens = creation

return StreamEvent(
metadata=MetadataEvent(
metrics={
"latencyMs": 0, # TODO
},
usage=usage_data,
)
metadata_event = MetadataEvent(
metrics={
"latencyMs": 0, # TODO
},
usage=usage_data,
)

cost = self._calculate_cost(
prompt_tokens=event["data"].prompt_tokens,
completion_tokens=event["data"].completion_tokens,
cache_read_input_tokens=cache_read_tokens,
cache_creation_input_tokens=cache_write_tokens,
)
if cost is not None:
metadata_event["cost"] = cost

return StreamEvent(metadata=metadata_event)
# For all other cases, use the parent implementation
return super().format_chunk(event)

def _calculate_cost(
self,
prompt_tokens: int,
completion_tokens: int,
cache_read_input_tokens: int = 0,
cache_creation_input_tokens: int = 0,
) -> float | None:
"""Calculate the cost for a model invocation using LiteLLM's cost tracking.

Args:
prompt_tokens: Number of input tokens.
completion_tokens: Number of output tokens.
cache_read_input_tokens: Number of tokens read from cache.
cache_creation_input_tokens: Number of tokens written to cache.

Returns:
Cost in USD, or None if cost calculation is not available for the model.
"""
try:
model_id = self.get_config()["model_id"]
prompt_cost, completion_cost = litellm.cost_per_token(
model=model_id,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cache_read_input_tokens=cache_read_input_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
)
return prompt_cost + completion_cost
except Exception:
logger.debug("model_id=<%s> | could not calculate completion cost", self.get_config().get("model_id"))
return None

@override
async def stream(
self,
Expand Down
12 changes: 12 additions & 0 deletions src/strands/telemetry/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ class EventLoopMetrics:
traces: list[Trace] = field(default_factory=list)
accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0))
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
accumulated_cost: float = 0.0

@property
def _metrics_client(self) -> "MetricsClient":
Expand Down Expand Up @@ -348,6 +349,14 @@ def reset_usage_metrics(self) -> None:
"""
self.agent_invocations.append(AgentInvocation())

def update_cost(self, cost: float) -> None:
"""Update the accumulated cost with new cost data.

Args:
cost: The cost in USD to add to the accumulated total.
"""
self.accumulated_cost += cost

def update_metrics(self, metrics: Metrics) -> None:
"""Update the accumulated performance metrics with new metrics data.

Expand Down Expand Up @@ -391,6 +400,7 @@ def get_summary(self) -> dict[str, Any]:
"traces": [trace.to_dict() for trace in self.traces],
"accumulated_usage": self.accumulated_usage,
"accumulated_metrics": self.accumulated_metrics,
"accumulated_cost": self.accumulated_cost,
"agent_invocations": [
{
"usage": invocation.usage,
Expand Down Expand Up @@ -436,6 +446,8 @@ def _metrics_summary_to_lines(event_loop_metrics: EventLoopMetrics, allowed_name
token_parts.append(f"cache_write_input_tokens={summary['accumulated_usage']['cacheWriteInputTokens']}")

yield f"├─ Tokens: {', '.join(token_parts)}"
if summary["accumulated_cost"] > 0:
yield f"├─ Cost: ${summary['accumulated_cost']:.6f}"
yield f"├─ Bedrock Latency: {summary['accumulated_metrics']['latencyMs']}ms"

yield "├─ Tool Usage:"
Expand Down
4 changes: 3 additions & 1 deletion src/strands/types/_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def __init__(
message: Message,
usage: Usage,
metrics: Metrics,
cost: float | None = None,
) -> None:
"""Initialize with the final execution results.

Expand All @@ -206,8 +207,9 @@ def __init__(
message: Final message from the model
usage: Usage information from the model
metrics: Execution metrics and performance data
cost: Cost in USD for the model invocation, if available from the provider.
"""
super().__init__({"stop": (stop_reason, message, usage, metrics)})
super().__init__({"stop": (stop_reason, message, usage, metrics, cost)})

@property
@override
Expand Down
2 changes: 2 additions & 0 deletions src/strands/types/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,13 @@ class MetadataEvent(TypedDict, total=False):
metrics: Performance metrics related to the model invocation.
trace: Trace information for debugging and monitoring.
usage: Resource usage information for the model invocation.
cost: Cost in USD for the model invocation, as calculated by the model provider (e.g. LiteLLM).
"""

metrics: Metrics
trace: Trace | None
usage: Usage
cost: float


class ExceptionEvent(TypedDict):
Expand Down
6 changes: 6 additions & 0 deletions tests/strands/event_loop/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,7 @@ def test_extract_usage_metrics_empty_metadata():
},
{"inputTokens": 1, "outputTokens": 1, "totalTokens": 1},
{"latencyMs": 1},
None,
)
},
],
Expand Down Expand Up @@ -833,6 +834,7 @@ def test_extract_usage_metrics_empty_metadata():
},
{"inputTokens": 5, "outputTokens": 10, "totalTokens": 15},
{"latencyMs": 100},
None,
)
},
],
Expand All @@ -853,6 +855,7 @@ def test_extract_usage_metrics_empty_metadata():
},
{"inputTokens": 0, "outputTokens": 0, "totalTokens": 0},
{"latencyMs": 0, "timeToFirstByteMs": 0},
None,
),
},
],
Expand Down Expand Up @@ -938,6 +941,7 @@ async def test_process_stream(response, exp_events, agenerator, alist):
{"role": "assistant", "content": [{"text": "REDACTED."}]},
{"inputTokens": 1, "outputTokens": 1, "totalTokens": 1},
{"latencyMs": 1},
None,
)
},
],
Expand Down Expand Up @@ -998,6 +1002,7 @@ async def test_process_stream(response, exp_events, agenerator, alist):
},
{"inputTokens": 1, "outputTokens": 1, "totalTokens": 1},
{"latencyMs": 1},
None,
)
},
],
Expand Down Expand Up @@ -1144,6 +1149,7 @@ async def test_stream_messages(agenerator, alist):
{"role": "assistant", "content": [{"text": "test"}]},
{"inputTokens": 0, "outputTokens": 0, "totalTokens": 0},
{"latencyMs": 0, "timeToFirstByteMs": 0},
None,
)
},
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ async def test_stream_messages_with_forced_structured_output(agenerator, alist):
break

assert stop_event is not None
stop_reason, message, usage, metrics = stop_event["stop"]
stop_reason, message, usage, metrics, _cost = stop_event["stop"]

assert stop_reason == "tool_use"
assert message["role"] == "assistant"
Expand Down
93 changes: 93 additions & 0 deletions tests/strands/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,3 +848,96 @@ def test_format_request_messages_with_tool_calls_no_content():
},
]
assert tru_result == exp_result


def test_format_chunk_metadata_includes_cost():
"""Test that format_chunk includes cost when cost_per_token succeeds."""
model = LiteLLMModel(model_id="openai/gpt-4o")

mock_usage = unittest.mock.Mock()
mock_usage.prompt_tokens = 100
mock_usage.completion_tokens = 50
mock_usage.total_tokens = 150
mock_usage.prompt_tokens_details = None
mock_usage.cache_creation_input_tokens = None

event = {"chunk_type": "metadata", "data": mock_usage}

with unittest.mock.patch.object(strands.models.litellm.litellm, "cost_per_token", return_value=(0.0025, 0.005)):
result = model.format_chunk(event)

assert result["metadata"]["cost"] == 0.0075


def test_format_chunk_metadata_omits_cost_on_failure():
"""Test that format_chunk gracefully omits cost when cost_per_token raises."""
model = LiteLLMModel(model_id="unknown/model")

mock_usage = unittest.mock.Mock()
mock_usage.prompt_tokens = 100
mock_usage.completion_tokens = 50
mock_usage.total_tokens = 150
mock_usage.prompt_tokens_details = None
mock_usage.cache_creation_input_tokens = None

event = {"chunk_type": "metadata", "data": mock_usage}

with unittest.mock.patch.object(
strands.models.litellm.litellm, "cost_per_token", side_effect=Exception("model not mapped")
):
result = model.format_chunk(event)

assert "cost" not in result["metadata"]
assert result["metadata"]["usage"]["inputTokens"] == 100


def test_format_chunk_metadata_cost_with_cache_tokens():
"""Test that cache tokens are passed to cost_per_token."""
model = LiteLLMModel(model_id="anthropic/claude-3-sonnet")

mock_usage = unittest.mock.Mock()
mock_usage.prompt_tokens = 100
mock_usage.completion_tokens = 50
mock_usage.total_tokens = 150
mock_tokens_details = unittest.mock.Mock()
mock_tokens_details.cached_tokens = 25
mock_usage.prompt_tokens_details = mock_tokens_details
mock_usage.cache_creation_input_tokens = 10

event = {"chunk_type": "metadata", "data": mock_usage}

with unittest.mock.patch.object(
strands.models.litellm.litellm, "cost_per_token", return_value=(0.001, 0.002)
) as mock_cost:
result = model.format_chunk(event)

mock_cost.assert_called_once_with(
model="anthropic/claude-3-sonnet",
prompt_tokens=100,
completion_tokens=50,
cache_read_input_tokens=25,
cache_creation_input_tokens=10,
)
assert result["metadata"]["cost"] == 0.003


def test_calculate_cost():
"""Test _calculate_cost returns correct total cost."""
model = LiteLLMModel(model_id="openai/gpt-4o")

with unittest.mock.patch.object(strands.models.litellm.litellm, "cost_per_token", return_value=(0.01, 0.02)):
cost = model._calculate_cost(prompt_tokens=1000, completion_tokens=500)

assert cost == 0.03


def test_calculate_cost_returns_none_on_failure():
"""Test _calculate_cost returns None when cost_per_token raises."""
model = LiteLLMModel(model_id="unknown/model")

with unittest.mock.patch.object(
strands.models.litellm.litellm, "cost_per_token", side_effect=Exception("not mapped")
):
cost = model._calculate_cost(prompt_tokens=100, completion_tokens=50)

assert cost is None
Loading