From 29496bcf84c9d201f28ebe8302cac472d55badb5 Mon Sep 17 00:00:00 2001 From: giulio-leone Date: Fri, 13 Mar 2026 03:48:54 +0100 Subject: [PATCH] fix: track all tool invocations in ToolMetrics instead of overwriting Previously, ToolMetrics.add_call() overwrote self.tool with the latest invocation, losing all previous tool inputs and IDs. Users inspecting AgentResult.metrics.tool_metrics could only see the last call's data. Changes: - Add 'invocations' list field to ToolMetrics that accumulates every call - Update add_call() to append each tool invocation to the list - Update to_dict() / get_summary() to expose invocations array - Keep self.tool update for backwards compatibility Closes #301 --- src/strands/telemetry/metrics.py | 14 ++++++- tests/strands/telemetry/test_metrics.py | 52 +++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/src/strands/telemetry/metrics.py b/src/strands/telemetry/metrics.py index 163df803a..f3232e4df 100644 --- a/src/strands/telemetry/metrics.py +++ b/src/strands/telemetry/metrics.py @@ -109,7 +109,8 @@ class ToolMetrics: """Metrics for a specific tool's usage. Attributes: - tool: The tool being tracked. + tool: The most recent tool invocation being tracked (for backwards compatibility). + invocations: List of all tool invocations with their inputs. call_count: Number of times the tool has been called. success_count: Number of successful tool calls. error_count: Number of failed tool calls. @@ -117,6 +118,7 @@ class ToolMetrics: """ tool: ToolUse + invocations: list[ToolUse] = field(default_factory=list) call_count: int = 0 success_count: int = 0 error_count: int = 0 @@ -139,7 +141,8 @@ def add_call( metrics_client: The metrics client for recording the metrics. attributes: attributes of the metrics. """ - self.tool = tool # Update with latest tool state + self.tool = tool # Update with latest tool state for backwards compatibility + self.invocations.append(tool) self.call_count += 1 self.total_time += duration metrics_client.tool_call_count.add(1, attributes=attributes) @@ -377,6 +380,13 @@ def get_summary(self) -> dict[str, Any]: "name": metrics.tool.get("name", "unknown"), "input_params": metrics.tool.get("input", {}), }, + "invocations": [ + { + "tool_use_id": inv.get("toolUseId", "N/A"), + "input_params": inv.get("input", {}), + } + for inv in metrics.invocations + ], "execution_stats": { "call_count": metrics.call_count, "success_count": metrics.success_count, diff --git a/tests/strands/telemetry/test_metrics.py b/tests/strands/telemetry/test_metrics.py index 800bcebc4..ee2b0b34c 100644 --- a/tests/strands/telemetry/test_metrics.py +++ b/tests/strands/telemetry/test_metrics.py @@ -222,6 +222,7 @@ def test_tool_metrics_add_call(success, tool, tool_metrics, mock_get_meter_provi tru_attrs = dataclasses.asdict(tool_metrics) exp_attrs = { "tool": tool, + "invocations": [tool], "call_count": 1, "success_count": success, "error_count": not success, @@ -309,6 +310,7 @@ def test_event_loop_metrics_add_tool_usage(mock_time, trace, tool, event_loop_me "tool_metrics": { "tool1": strands.telemetry.metrics.ToolMetrics( tool=tool, + invocations=[tool], call_count=1, success_count=1, error_count=0, @@ -411,6 +413,12 @@ def test_event_loop_metrics_get_summary(trace, tool, event_loop_metrics, mock_ge "name": "tool1", "tool_use_id": "123", }, + "invocations": [ + { + "tool_use_id": "123", + "input_params": {}, + }, + ], }, }, "total_cycles": 0, @@ -566,3 +574,47 @@ def test_reset_usage_metrics(usage, event_loop_metrics, mock_get_meter_provider) # Verify accumulated_usage is NOT cleared assert event_loop_metrics.accumulated_usage["inputTokens"] == 11 + + +def test_tool_metrics_tracks_all_invocations(mock_get_meter_provider): + """ToolMetrics.invocations should contain all tool calls, not just the latest.""" + tool1 = {"name": "weather", "toolUseId": "id1", "input": {"city": "Berlin"}} + tool2 = {"name": "weather", "toolUseId": "id2", "input": {"city": "Paris"}} + tool3 = {"name": "weather", "toolUseId": "id3", "input": {"city": "Rome"}} + + metrics_client = MetricsClient() + tm = strands.telemetry.metrics.ToolMetrics(tool=tool1) + tm.add_call(tool1, 0.1, True, metrics_client) + tm.add_call(tool2, 0.2, True, metrics_client) + tm.add_call(tool3, 0.15, False, metrics_client) + + assert tm.call_count == 3 + assert tm.success_count == 2 + assert tm.error_count == 1 + assert len(tm.invocations) == 3 + assert tm.invocations[0]["input"] == {"city": "Berlin"} + assert tm.invocations[1]["input"] == {"city": "Paris"} + assert tm.invocations[2]["input"] == {"city": "Rome"} + # Backwards compatibility: .tool still points to the latest + assert tm.tool["input"] == {"city": "Rome"} + assert tm.tool["toolUseId"] == "id3" + + +def test_tool_metrics_invocations_in_summary(tool, event_loop_metrics, mock_get_meter_provider): + """get_summary() should include all invocations for repeatedly called tools.""" + tool1 = {"name": "search", "toolUseId": "s1", "input": {"query": "python"}} + tool2 = {"name": "search", "toolUseId": "s2", "input": {"query": "rust"}} + + trace = strands.telemetry.metrics.Trace("test") + msg = {"role": "user", "content": [{"toolResult": {"toolUseId": "s1", "tool_name": "search"}}]} + + event_loop_metrics.add_tool_usage(tool1, 0.1, trace, True, msg) + event_loop_metrics.add_tool_usage(tool2, 0.2, trace, True, msg) + + summary = event_loop_metrics.get_summary() + tool_usage = summary["tool_usage"]["search"] + + assert tool_usage["execution_stats"]["call_count"] == 2 + assert len(tool_usage["invocations"]) == 2 + assert tool_usage["invocations"][0]["input_params"] == {"query": "python"} + assert tool_usage["invocations"][1]["input_params"] == {"query": "rust"}