From edbec4e30b2a9bd3eb0b466be0fab69a0854bfaa Mon Sep 17 00:00:00 2001 From: habema Date: Thu, 5 Feb 2026 15:23:10 +0300 Subject: [PATCH 1/7] feat: Add tool origin tracking to ToolCallItem and ToolCallOutputItem - Add ToolOriginType enum and ToolOrigin dataclass - Add _tool_origin field to FunctionTool - Set tool_origin for MCP tools and agent-as-tool - Extract and set tool_origin in ToolCallItem and ToolCallOutputItem creation - Add comprehensive tests for tool origin tracking --- src/agents/agent.py | 7 + src/agents/items.py | 7 + src/agents/mcp/util.py | 9 +- src/agents/run_internal/run_loop.py | 9 +- src/agents/run_internal/tool_execution.py | 3 + src/agents/run_internal/turn_resolution.py | 9 +- src/agents/tool.py | 58 ++++ tests/test_tool_origin.py | 333 +++++++++++++++++++++ 8 files changed, 431 insertions(+), 4 deletions(-) create mode 100644 tests/test_tool_origin.py diff --git a/src/agents/agent.py b/src/agents/agent.py index b0368e8698..1afc33757b 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -45,6 +45,8 @@ FunctionToolResult, Tool, ToolErrorFunction, + ToolOrigin, + ToolOriginType, _extract_tool_argument_json_error, default_tool_error_function, ) @@ -802,6 +804,11 @@ async def _run_agent_tool(context: ToolContext, input_json: str) -> Any: ) run_agent_tool._is_agent_tool = True run_agent_tool._agent_instance = self + # Set origin tracking on run_agent (the FunctionTool returned by @function_tool) + run_agent_tool._tool_origin = ToolOrigin( + type=ToolOriginType.AGENT_AS_TOOL, + agent_as_tool=self, + ) return run_agent_tool diff --git a/src/agents/items.py b/src/agents/items.py index 94ab5daa35..64565b6037 100644 --- a/src/agents/items.py +++ b/src/agents/items.py @@ -49,6 +49,7 @@ from .exceptions import AgentsException, ModelBehaviorError from .logger import logger from .tool import ( + ToolOrigin, ToolOutputFileContent, ToolOutputImage, ToolOutputText, @@ -248,6 +249,9 @@ class ToolCallItem(RunItemBase[Any]): description: str | None = None """Optional tool description if known at item creation time.""" + tool_origin: ToolOrigin | None = field(default=None, repr=False) + """Information about the origin/source of the tool call. Only set for FunctionTool calls.""" + ToolCallOutputTypes: TypeAlias = Union[ FunctionCallOutput, @@ -271,6 +275,9 @@ class ToolCallOutputItem(RunItemBase[Any]): type: Literal["tool_call_output_item"] = "tool_call_output_item" + tool_origin: ToolOrigin | None = field(default=None, repr=False) + """Information about the origin/source of the tool call. Only set for FunctionTool calls.""" + def to_input_item(self) -> TResponseInputItem: """Converts the tool output into an input item for the next model turn. diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index 9c9a59f683..a72c55ccaf 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -20,6 +20,8 @@ FunctionTool, Tool, ToolErrorFunction, + ToolOrigin, + ToolOriginType, ToolOutputImageDict, ToolOutputTextDict, default_tool_error_function, @@ -301,7 +303,7 @@ async def invoke_func(ctx: ToolContext[Any], input_json: str) -> ToolOutput: bool | Callable[[RunContextWrapper[Any], dict[str, Any], str], Awaitable[bool]] ) = server._get_needs_approval_for_tool(tool, agent) - return FunctionTool( + function_tool = FunctionTool( name=tool.name, description=tool.description or "", params_json_schema=schema, @@ -309,6 +311,11 @@ async def invoke_func(ctx: ToolContext[Any], input_json: str) -> ToolOutput: strict_json_schema=is_strict, needs_approval=needs_approval, ) + function_tool._tool_origin = ToolOrigin( + type=ToolOriginType.MCP, + mcp_server=server, + ) + return function_tool @staticmethod def _merge_mcp_meta( diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index e807c0cb11..4404868ed8 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -49,7 +49,7 @@ RawResponsesStreamEvent, RunItemStreamEvent, ) -from ..tool import Tool, dispose_resolved_computers +from ..tool import FunctionTool, Tool, _get_tool_origin_info, dispose_resolved_computers from ..tracing import Span, SpanError, agent_span, get_current_trace from ..tracing.model_tracing import get_model_tracing_impl from ..tracing.span_data import AgentSpanData @@ -1216,13 +1216,18 @@ async def run_single_turn_streamed( # execution behavior in process_model_response). tool_name = getattr(output_item, "name", None) tool_description: str | None = None + tool_origin = None if isinstance(tool_name, str) and tool_name in tool_map: - tool_description = getattr(tool_map[tool_name], "description", None) + tool = tool_map[tool_name] + tool_description = getattr(tool, "description", None) + if isinstance(tool, FunctionTool): + tool_origin = _get_tool_origin_info(tool) tool_item = ToolCallItem( raw_item=cast(ToolCallItemTypes, output_item), agent=agent, description=tool_description, + tool_origin=tool_origin, ) streamed_result._event_queue.put_nowait( RunItemStreamEvent(item=tool_item, name="tool_called") diff --git a/src/agents/run_internal/tool_execution.py b/src/agents/run_internal/tool_execution.py index bc370ea611..a22f9a5cdc 100644 --- a/src/agents/run_internal/tool_execution.py +++ b/src/agents/run_internal/tool_execution.py @@ -52,6 +52,7 @@ ShellCallOutcome, ShellCommandOutput, Tool, + _get_tool_origin_info, resolve_computer, ) from ..tool_context import ToolContext @@ -973,10 +974,12 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo run_item: RunItem | None = None if not nested_interruptions: + tool_origin = _get_tool_origin_info(tool_run.function_tool) run_item = ToolCallOutputItem( output=result, raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, result), agent=agent, + tool_origin=tool_origin, ) else: # Skip tool output until nested interruptions are resolved. diff --git a/src/agents/run_internal/turn_resolution.py b/src/agents/run_internal/turn_resolution.py index fed661ea9a..86872f4d27 100644 --- a/src/agents/run_internal/turn_resolution.py +++ b/src/agents/run_internal/turn_resolution.py @@ -62,6 +62,7 @@ LocalShellTool, ShellTool, Tool, + _get_tool_origin_info, ) from ..tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult from ..tracing import SpanError, handoff_span @@ -1473,8 +1474,14 @@ def process_model_response( raise ModelBehaviorError(error) func_tool = function_map[output.name] + tool_origin = _get_tool_origin_info(func_tool) items.append( - ToolCallItem(raw_item=output, agent=agent, description=func_tool.description) + ToolCallItem( + raw_item=output, + agent=agent, + description=func_tool.description, + tool_origin=tool_origin, + ) ) functions.append( ToolRunFunction( diff --git a/src/agents/tool.py b/src/agents/tool.py index 4f70adc0f8..06cc25a734 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import enum import inspect import json import weakref @@ -48,6 +49,7 @@ if TYPE_CHECKING: from .agent import Agent, AgentBase from .items import RunItem, ToolApprovalItem + from .mcp.server import MCPServer ToolParams = ParamSpec("ToolParams") @@ -182,6 +184,59 @@ class ComputerProvider(Generic[ComputerT]): ] +class ToolOriginType(str, enum.Enum): + """The type of tool origin.""" + + FUNCTION = "function" + """Regular Python function tool created via @function_tool decorator.""" + + MCP = "mcp" + """MCP server tool converted via MCPUtil.to_function_tool().""" + + AGENT_AS_TOOL = "agent_as_tool" + """Agent converted to tool via agent.as_tool().""" + + +@dataclass +class ToolOrigin: + """Information about the origin/source of a function tool.""" + + type: ToolOriginType + """The type of tool origin.""" + + mcp_server: MCPServer | None = None + """The MCP server object. Only set when type is MCP.""" + + agent_as_tool: Agent[Any] | None = None + """The agent object. Only set when type is AGENT_AS_TOOL.""" + + def __repr__(self) -> str: + """Custom repr that only includes relevant fields.""" + parts = [f"type={self.type.value!r}"] + if self.mcp_server is not None: + parts.append(f"mcp_server_name={self.mcp_server.name!r}") + if self.agent_as_tool is not None: + parts.append(f"agent_as_tool_name={self.agent_as_tool.name!r}") + return f"ToolOrigin({', '.join(parts)})" + + +def _get_tool_origin_info(function_tool: FunctionTool) -> ToolOrigin | None: + """Extract origin information from a FunctionTool. + + Args: + function_tool: The function tool to extract origin info from. + + Returns: + ToolOrigin object if origin is set, otherwise None (defaults to FUNCTION type). + """ + origin = function_tool._tool_origin + if origin is None: + # Default to FUNCTION if not explicitly set + return ToolOrigin(type=ToolOriginType.FUNCTION) + + return origin + + @dataclass class FunctionToolResult: tool: FunctionTool @@ -264,6 +319,9 @@ class FunctionTool: _agent_instance: Any = field(default=None, init=False, repr=False) """Internal reference to the agent instance if this is an agent-as-tool.""" + _tool_origin: ToolOrigin | None = field(default=None, init=False, repr=False) + """Internal field tracking the origin of this tool (FUNCTION, MCP, or AGENT_AS_TOOL).""" + def __post_init__(self): if self.strict_json_schema: self.params_json_schema = ensure_strict_json_schema(self.params_json_schema) diff --git a/tests/test_tool_origin.py b/tests/test_tool_origin.py new file mode 100644 index 0000000000..5800b87099 --- /dev/null +++ b/tests/test_tool_origin.py @@ -0,0 +1,333 @@ +"""Tests for tool origin tracking feature.""" + +from __future__ import annotations + +import sys +from typing import cast + +import pytest + +from agents import Agent, FunctionTool, RunContextWrapper, Runner, function_tool +from agents.items import ToolCallItem, ToolCallItemTypes, ToolCallOutputItem +from agents.tool import ToolOrigin, ToolOriginType + +from .fake_model import FakeModel +from .test_responses import get_function_tool_call, get_text_message + +if sys.version_info >= (3, 10): + from .mcp.helpers import FakeMCPServer + + +@pytest.mark.asyncio +async def test_function_tool_origin(): + """Test that regular function tools have FUNCTION origin.""" + model = FakeModel() + + @function_tool + def test_tool(x: int) -> str: + """Test tool.""" + return f"result: {x}" + + agent = Agent(name="test", model=model, tools=[test_tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("test_tool", '{"x": 42}')], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 1 + assert tool_call_items[0].tool_origin is not None + assert tool_call_items[0].tool_origin.type == ToolOriginType.FUNCTION + assert tool_call_items[0].tool_origin.mcp_server is None + assert tool_call_items[0].tool_origin.agent_as_tool is None + + assert len(tool_output_items) == 1 + assert tool_output_items[0].tool_origin is not None + assert tool_output_items[0].tool_origin.type == ToolOriginType.FUNCTION + assert tool_output_items[0].tool_origin.mcp_server is None + assert tool_output_items[0].tool_origin.agent_as_tool is None + + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 10), reason="MCP tests require Python 3.10+") +async def test_mcp_tool_origin(): + """Test that MCP tools have MCP origin with server name.""" + model = FakeModel() + server = FakeMCPServer(server_name="test_mcp_server") + server.add_tool("mcp_tool", {}) + + agent = Agent(name="test", model=model, mcp_servers=[server]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("mcp_tool", "")], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 1 + assert tool_call_items[0].tool_origin is not None + assert tool_call_items[0].tool_origin.type == ToolOriginType.MCP + assert tool_call_items[0].tool_origin.mcp_server is not None + assert tool_call_items[0].tool_origin.mcp_server.name == "test_mcp_server" + assert tool_call_items[0].tool_origin.agent_as_tool is None + + assert len(tool_output_items) == 1 + assert tool_output_items[0].tool_origin is not None + assert tool_output_items[0].tool_origin.type == ToolOriginType.MCP + assert tool_output_items[0].tool_origin.mcp_server is not None + assert tool_output_items[0].tool_origin.mcp_server.name == "test_mcp_server" + assert tool_output_items[0].tool_origin.agent_as_tool is None + + +@pytest.mark.asyncio +async def test_agent_as_tool_origin(): + """Test that agent-as-tool has AGENT_AS_TOOL origin with agent name.""" + model = FakeModel() + nested_model = FakeModel() + + nested_agent = Agent( + name="nested_agent", + model=nested_model, + instructions="You are a nested agent.", + ) + + nested_model.add_multiple_turn_outputs([[get_text_message("nested response")]]) + + tool = nested_agent.as_tool( + tool_name="nested_tool", + tool_description="A nested agent tool", + ) + + orchestrator = Agent(name="orchestrator", model=model, tools=[tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("nested_tool", '{"input": "test"}')], + [get_text_message("done")], + ] + ) + + result = await Runner.run(orchestrator, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 1 + assert tool_call_items[0].tool_origin is not None + assert tool_call_items[0].tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert tool_call_items[0].tool_origin.mcp_server is None + assert tool_call_items[0].tool_origin.agent_as_tool is not None + assert tool_call_items[0].tool_origin.agent_as_tool.name == "nested_agent" + + assert len(tool_output_items) == 1 + assert tool_output_items[0].tool_origin is not None + assert tool_output_items[0].tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert tool_output_items[0].tool_origin.mcp_server is None + assert tool_output_items[0].tool_origin.agent_as_tool is not None + assert tool_output_items[0].tool_origin.agent_as_tool.name == "nested_agent" + + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 10), reason="MCP tests require Python 3.10+") +async def test_multiple_tool_origins(): + """Test that multiple tools from different origins work together.""" + model = FakeModel() + nested_model = FakeModel() + + @function_tool + def func_tool(x: int) -> str: + """Function tool.""" + return f"function: {x}" + + mcp_server = FakeMCPServer(server_name="mcp_server") + mcp_server.add_tool("mcp_tool", {}) + + nested_agent = Agent(name="nested", model=nested_model, instructions="Nested agent") + nested_model.add_multiple_turn_outputs([[get_text_message("nested response")]]) + agent_tool = nested_agent.as_tool(tool_name="agent_tool", tool_description="Agent tool") + + agent = Agent( + name="test", + model=model, + tools=[func_tool, agent_tool], + mcp_servers=[mcp_server], + ) + + model.add_multiple_turn_outputs( + [ + [ + get_function_tool_call("func_tool", '{"x": 1}'), + get_function_tool_call("mcp_tool", ""), + get_function_tool_call("agent_tool", '{"input": "test"}'), + ], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 3 + assert len(tool_output_items) == 3 + + # Find items by tool name + function_item = next( + item for item in tool_call_items if getattr(item.raw_item, "name", None) == "func_tool" + ) + mcp_item = next( + item for item in tool_call_items if getattr(item.raw_item, "name", None) == "mcp_tool" + ) + agent_item = next( + item for item in tool_call_items if getattr(item.raw_item, "name", None) == "agent_tool" + ) + + assert function_item.tool_origin is not None + assert function_item.tool_origin.type == ToolOriginType.FUNCTION + assert mcp_item.tool_origin is not None + assert mcp_item.tool_origin.type == ToolOriginType.MCP + assert mcp_item.tool_origin.mcp_server is not None + assert mcp_item.tool_origin.mcp_server.name == "mcp_server" + assert agent_item.tool_origin is not None + assert agent_item.tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert agent_item.tool_origin.agent_as_tool is not None + assert agent_item.tool_origin.agent_as_tool.name == "nested" + + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 10), reason="MCP tests require Python 3.10+") +async def test_tool_origin_streaming(): + """Test that tool origin is populated correctly in streaming scenarios.""" + model = FakeModel() + server = FakeMCPServer(server_name="streaming_server") + server.add_tool("streaming_tool", {}) + + agent = Agent(name="test", model=model, mcp_servers=[server]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("streaming_tool", "")], + [get_text_message("done")], + ] + ) + + result = Runner.run_streamed(agent, input="test") + tool_call_items = [] + tool_output_items = [] + + async for event in result.stream_events(): + if event.type == "run_item_stream_event": + if isinstance(event.item, ToolCallItem): + tool_call_items.append(event.item) + elif isinstance(event.item, ToolCallOutputItem): + tool_output_items.append(event.item) + + assert len(tool_call_items) == 1 + assert tool_call_items[0].tool_origin is not None + assert tool_call_items[0].tool_origin.type == ToolOriginType.MCP + assert tool_call_items[0].tool_origin.mcp_server is not None + assert tool_call_items[0].tool_origin.mcp_server.name == "streaming_server" + + assert len(tool_output_items) == 1 + assert tool_output_items[0].tool_origin is not None + assert tool_output_items[0].tool_origin.type == ToolOriginType.MCP + assert tool_output_items[0].tool_origin.mcp_server is not None + assert tool_output_items[0].tool_origin.mcp_server.name == "streaming_server" + + +@pytest.mark.asyncio +async def test_tool_origin_repr(): + """Test that ToolOrigin repr only shows relevant fields.""" + # FUNCTION origin + function_origin = ToolOrigin(type=ToolOriginType.FUNCTION) + assert "mcp_server_name" not in repr(function_origin) + assert "agent_as_tool_name" not in repr(function_origin) + + # MCP origin + if sys.version_info >= (3, 10): + from .mcp.helpers import FakeMCPServer + + test_server = FakeMCPServer(server_name="test_server") + mcp_origin = ToolOrigin(type=ToolOriginType.MCP, mcp_server=test_server) + assert "mcp_server_name='test_server'" in repr(mcp_origin) + assert "agent_as_tool_name" not in repr(mcp_origin) + + # AGENT_AS_TOOL origin + model = FakeModel() + test_agent = Agent(name="test_agent", model=model, instructions="Test agent") + agent_origin = ToolOrigin(type=ToolOriginType.AGENT_AS_TOOL, agent_as_tool=test_agent) + assert "agent_as_tool_name='test_agent'" in repr(agent_origin) + assert "mcp_server_name" not in repr(agent_origin) + + +@pytest.mark.asyncio +async def test_tool_origin_defaults_to_function(): + """Test that tools without explicit origin default to FUNCTION.""" + model = FakeModel() + + # Create a FunctionTool directly without using @function_tool decorator + async def test_func(ctx: RunContextWrapper, args: str) -> str: + return "result" + + tool = FunctionTool( + name="direct_tool", + description="Direct tool", + params_json_schema={"type": "object", "properties": {}}, + on_invoke_tool=test_func, + ) + + agent = Agent(name="test", model=model, tools=[tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("direct_tool", "")], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + + assert len(tool_call_items) == 1 + # Even though _tool_origin is None, _get_tool_origin_info defaults to FUNCTION + assert tool_call_items[0].tool_origin is not None + assert tool_call_items[0].tool_origin.type == ToolOriginType.FUNCTION + + +@pytest.mark.asyncio +async def test_non_function_tool_items_have_no_origin(): + """Test that non-FunctionTool items (computer, shell, etc.) don't have tool_origin.""" + model = FakeModel() + + @function_tool + def func_tool() -> str: + """Function tool.""" + return "result" + + agent = Agent(name="test", model=model, tools=[func_tool]) + + # Create a ToolCallItem for a non-function tool (simulating computer/shell tool) + computer_call = { + "type": "computer_use_preview", + "call_id": "call_123", + "actions": [], + } + + # This simulates what happens for non-FunctionTool items + # They should not have tool_origin set + item = ToolCallItem( + raw_item=cast(ToolCallItemTypes, computer_call), + agent=agent, + ) + + assert item.tool_origin is None From e1b635702c2f3f2f6b3e8ec98fe187c00beab71f Mon Sep 17 00:00:00 2001 From: habema Date: Thu, 5 Feb 2026 15:30:28 +0300 Subject: [PATCH 2/7] fix memory leak in code review and add test --- src/agents/items.py | 12 +++++++++ src/agents/tool.py | 41 ++++++++++++++++++++++++++++-- tests/test_tool_origin.py | 53 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 104 insertions(+), 2 deletions(-) diff --git a/src/agents/items.py b/src/agents/items.py index 64565b6037..7139e07f99 100644 --- a/src/agents/items.py +++ b/src/agents/items.py @@ -252,6 +252,12 @@ class ToolCallItem(RunItemBase[Any]): tool_origin: ToolOrigin | None = field(default=None, repr=False) """Information about the origin/source of the tool call. Only set for FunctionTool calls.""" + def release_agent(self) -> None: + """Release agent references including tool_origin.agent_as_tool.""" + super().release_agent() + if self.tool_origin is not None: + self.tool_origin.release_agent() + ToolCallOutputTypes: TypeAlias = Union[ FunctionCallOutput, @@ -278,6 +284,12 @@ class ToolCallOutputItem(RunItemBase[Any]): tool_origin: ToolOrigin | None = field(default=None, repr=False) """Information about the origin/source of the tool call. Only set for FunctionTool calls.""" + def release_agent(self) -> None: + """Release agent references including tool_origin.agent_as_tool.""" + super().release_agent() + if self.tool_origin is not None: + self.tool_origin.release_agent() + def to_input_item(self) -> TResponseInputItem: """Converts the tool output into an input item for the next model turn. diff --git a/src/agents/tool.py b/src/agents/tool.py index 06cc25a734..2e0e043581 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -210,13 +210,50 @@ class ToolOrigin: agent_as_tool: Agent[Any] | None = None """The agent object. Only set when type is AGENT_AS_TOOL.""" + _agent_as_tool_ref: weakref.ReferenceType[Agent[Any]] | None = field( + default=None, init=False, repr=False + ) + """Weak reference to agent_as_tool for memory management.""" + + def __post_init__(self) -> None: + """Initialize weak reference for agent_as_tool.""" + if self.agent_as_tool is not None: + self._agent_as_tool_ref = weakref.ref(self.agent_as_tool) + + def __getattribute__(self, name: str) -> Any: + """Lazily resolve agent_as_tool via weakref when strong ref is cleared.""" + if name == "agent_as_tool": + # Check if strong reference still exists + value = object.__getattribute__(self, "__dict__").get("agent_as_tool") + if value is not None: + return value + # Try to resolve via weakref + ref = object.__getattribute__(self, "_agent_as_tool_ref") + if ref is not None: + agent = ref() + if agent is not None: + return agent + return None + return super().__getattribute__(name) + + def release_agent(self) -> None: + """Release the strong reference to agent_as_tool while keeping a weak reference.""" + if "agent_as_tool" not in self.__dict__: + return + agent = self.__dict__.get("agent_as_tool") + if agent is not None: + self._agent_as_tool_ref = weakref.ref(agent) + # Set to None instead of deleting so dataclass repr/asdict keep working. + self.__dict__["agent_as_tool"] = None + def __repr__(self) -> str: """Custom repr that only includes relevant fields.""" parts = [f"type={self.type.value!r}"] if self.mcp_server is not None: parts.append(f"mcp_server_name={self.mcp_server.name!r}") - if self.agent_as_tool is not None: - parts.append(f"agent_as_tool_name={self.agent_as_tool.name!r}") + agent = self.agent_as_tool + if agent is not None: + parts.append(f"agent_as_tool_name={agent.name!r}") return f"ToolOrigin({', '.join(parts)})" diff --git a/tests/test_tool_origin.py b/tests/test_tool_origin.py index 5800b87099..245f982491 100644 --- a/tests/test_tool_origin.py +++ b/tests/test_tool_origin.py @@ -2,7 +2,9 @@ from __future__ import annotations +import gc import sys +import weakref from typing import cast import pytest @@ -331,3 +333,54 @@ def func_tool() -> str: ) assert item.tool_origin is None + + +def test_tool_origin_release_agent_clears_strong_reference(): + """Test that release_agent() clears strong reference to agent_as_tool.""" + # Create a ToolOrigin with an agent_as_tool + nested_agent = Agent( + name="nested_agent", + model=FakeModel(), + instructions="You are a nested agent.", + ) + + tool_origin = ToolOrigin( + type=ToolOriginType.AGENT_AS_TOOL, + agent_as_tool=nested_agent, + ) + + # Create a ToolCallItem with this tool_origin + tool_call_item = ToolCallItem( + raw_item=cast( + ToolCallItemTypes, + { + "type": "function_call", + "name": "test_tool", + "call_id": "call_123", + "arguments": "{}", + }, + ), + agent=nested_agent, + tool_origin=tool_origin, + ) + + # Verify agent_as_tool is set + assert tool_call_item.tool_origin is not None + assert tool_call_item.tool_origin.agent_as_tool is nested_agent + + # Create weak reference to verify GC behavior + nested_agent_ref = weakref.ref(nested_agent) + + # Release agent - this should clear strong reference in tool_origin + tool_call_item.release_agent() + + # After release, agent_as_tool should still be accessible via weakref + assert tool_call_item.tool_origin.agent_as_tool is nested_agent + + # Delete the agent and force GC + del nested_agent + gc.collect() + + # After GC, agent_as_tool should be None since strong refs were cleared + assert nested_agent_ref() is None + assert tool_call_item.tool_origin.agent_as_tool is None From 5b2835b82a4a1732ab17fd363f25a119bd4724bc Mon Sep 17 00:00:00 2001 From: habema Date: Sat, 7 Feb 2026 21:46:39 +0300 Subject: [PATCH 3/7] address code review and add test --- src/agents/run_state.py | 86 ++++++++- tests/test_tool_origin_serialization.py | 228 ++++++++++++++++++++++++ 2 files changed, 313 insertions(+), 1 deletion(-) create mode 100644 tests/test_tool_origin_serialization.py diff --git a/src/agents/run_state.py b/src/agents/run_state.py index d02d298140..08b23e2505 100644 --- a/src/agents/run_state.py +++ b/src/agents/run_state.py @@ -63,6 +63,8 @@ HostedMCPTool, LocalShellTool, ShellTool, + ToolOrigin, + ToolOriginType, ) from .tool_guardrails import ( AllowBehavior, @@ -635,6 +637,13 @@ def _serialize_item(self, item: RunItem) -> dict[str, Any]: result["tool_name"] = item.tool_name if hasattr(item, "description") and item.description is not None: result["description"] = item.description + if hasattr(item, "tool_origin") and item.tool_origin is not None: + tool_origin_data: dict[str, Any] = {"type": item.tool_origin.type.value} + if item.tool_origin.agent_as_tool is not None: + tool_origin_data["agent_as_tool"] = {"name": item.tool_origin.agent_as_tool.name} + if item.tool_origin.mcp_server is not None: + tool_origin_data["mcp_server"] = {"name": item.tool_origin.mcp_server.name} + result["tool_origin"] = tool_origin_data return result @@ -1918,6 +1927,67 @@ def _build_agent_map(initial_agent: Agent[Any]) -> dict[str, Agent[Any]]: return agent_map +def _deserialize_tool_origin( + tool_origin_data: dict[str, Any] | None, agent_map: dict[str, Agent[Any]], agent: Agent[Any] +) -> ToolOrigin | None: + """Deserialize ToolOrigin from JSON data. + + Args: + tool_origin_data: Serialized tool origin dictionary. + agent_map: Map of agent names to agent instances. + agent: The agent associated with this item (used for MCP server lookup). + + Returns: + ToolOrigin instance or None if data is missing/invalid. + """ + if not tool_origin_data: + return None + + origin_type_str = tool_origin_data.get("type") + if not origin_type_str: + return None + + try: + origin_type = ToolOriginType(origin_type_str) + except ValueError: + logger.warning(f"Unknown tool origin type: {origin_type_str}") + return None + + agent_as_tool: Agent[Any] | None = None + mcp_server: Any | None = None + + if origin_type == ToolOriginType.AGENT_AS_TOOL: + agent_data = tool_origin_data.get("agent_as_tool") + if agent_data and isinstance(agent_data, Mapping): + agent_name = agent_data.get("name") + if agent_name: + agent_as_tool = agent_map.get(agent_name) + if not agent_as_tool: + logger.warning(f"Agent {agent_name} not found in agent map for tool_origin") + + elif origin_type == ToolOriginType.MCP: + mcp_data = tool_origin_data.get("mcp_server") + if mcp_data and isinstance(mcp_data, Mapping): + server_name = mcp_data.get("name") + if server_name: + # Try to find the MCP server from the agent's mcp_servers list + mcp_servers = getattr(agent, "mcp_servers", []) + for server in mcp_servers: + if hasattr(server, "name") and server.name == server_name: + mcp_server = server + break + if not mcp_server: + logger.debug( + f"MCP server {server_name} not found in agent's mcp_servers for tool_origin" + ) + + return ToolOrigin( + type=origin_type, + agent_as_tool=agent_as_tool, + mcp_server=mcp_server, + ) + + def _deserialize_model_responses(responses_data: list[dict[str, Any]]) -> list[ModelResponse]: """Deserialize model responses from JSON data. @@ -2019,8 +2089,17 @@ def _resolve_agent_info( raw_item_tool = _deserialize_tool_call_raw_item(normalized_raw_item) # Preserve description if it was stored with the item description = item_data.get("description") + # Preserve tool_origin if it was stored with the item + tool_origin = _deserialize_tool_origin( + item_data.get("tool_origin"), agent_map, agent + ) result.append( - ToolCallItem(agent=agent, raw_item=raw_item_tool, description=description) + ToolCallItem( + agent=agent, + raw_item=raw_item_tool, + description=description, + tool_origin=tool_origin, + ) ) elif item_type == "tool_call_output_item": @@ -2029,11 +2108,16 @@ def _resolve_agent_info( raw_item_output = _deserialize_tool_call_output_raw_item(normalized_raw_item) if raw_item_output is None: continue + # Preserve tool_origin if it was stored with the item + tool_origin = _deserialize_tool_origin( + item_data.get("tool_origin"), agent_map, agent + ) result.append( ToolCallOutputItem( agent=agent, raw_item=raw_item_output, output=item_data.get("output", ""), + tool_origin=tool_origin, ) ) diff --git a/tests/test_tool_origin_serialization.py b/tests/test_tool_origin_serialization.py new file mode 100644 index 0000000000..87bca9fcc4 --- /dev/null +++ b/tests/test_tool_origin_serialization.py @@ -0,0 +1,228 @@ +"""Tests for tool_origin serialization in RunState.""" + +from __future__ import annotations + +import sys + +import pytest + +from agents import Agent, Runner, function_tool +from agents.items import ToolCallItem, ToolCallOutputItem +from agents.run_state import RunState +from agents.tool import ToolOriginType + +from .fake_model import FakeModel +from .test_responses import get_function_tool_call, get_text_message + +if sys.version_info >= (3, 10): + from .mcp.helpers import FakeMCPServer + + +@pytest.mark.asyncio +async def test_serialize_tool_origin_function(): + """Test that FUNCTION tool_origin is serialized and deserialized.""" + model = FakeModel() + + @function_tool + def test_tool(x: int) -> str: + """Test tool.""" + return f"result: {x}" + + agent = Agent(name="test", model=model, tools=[test_tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("test_tool", '{"x": 42}')], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 1 + assert len(tool_output_items) == 1 + + tool_call_item = tool_call_items[0] + tool_output_item = tool_output_items[0] + + # Verify tool_origin is set + assert tool_call_item.tool_origin is not None + assert tool_call_item.tool_origin.type == ToolOriginType.FUNCTION + assert tool_output_item.tool_origin is not None + assert tool_output_item.tool_origin.type == ToolOriginType.FUNCTION + + # Serialize and deserialize + context = result.context_wrapper + state = RunState( + context=context, + original_input="test", + starting_agent=agent, + max_turns=5, + ) + state._generated_items = [tool_call_item, tool_output_item] + + json_data = state.to_json() + deserialized_state = await RunState.from_json(agent, json_data) + + # Verify tool_origin was preserved + deserialized_tool_call = next( + item for item in deserialized_state._generated_items if isinstance(item, ToolCallItem) + ) + deserialized_tool_output = next( + item for item in deserialized_state._generated_items if isinstance(item, ToolCallOutputItem) + ) + + assert deserialized_tool_call.tool_origin is not None + assert deserialized_tool_call.tool_origin.type == ToolOriginType.FUNCTION + assert deserialized_tool_output.tool_origin is not None + assert deserialized_tool_output.tool_origin.type == ToolOriginType.FUNCTION + + +@pytest.mark.asyncio +async def test_serialize_tool_origin_agent_as_tool(): + """Test that AGENT_AS_TOOL tool_origin is serialized and deserialized.""" + model = FakeModel() + nested_model = FakeModel() + + nested_agent = Agent( + name="nested_agent", + model=nested_model, + instructions="You are a nested agent.", + ) + + nested_model.add_multiple_turn_outputs([[get_text_message("nested response")]]) + + tool = nested_agent.as_tool( + tool_name="nested_tool", + tool_description="A nested agent tool", + ) + + orchestrator = Agent(name="orchestrator", model=model, tools=[tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("nested_tool", '{"input": "test"}')], + [get_text_message("done")], + ] + ) + + result = await Runner.run(orchestrator, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 1 + assert len(tool_output_items) == 1 + + tool_call_item = tool_call_items[0] + tool_output_item = tool_output_items[0] + + # Verify tool_origin is set + assert tool_call_item.tool_origin is not None + assert tool_call_item.tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert tool_call_item.tool_origin.agent_as_tool is not None + assert tool_call_item.tool_origin.agent_as_tool.name == "nested_agent" + assert tool_output_item.tool_origin is not None + assert tool_output_item.tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert tool_output_item.tool_origin.agent_as_tool is not None + assert tool_output_item.tool_origin.agent_as_tool.name == "nested_agent" + + # Serialize and deserialize + context = result.context_wrapper + state = RunState( + context=context, + original_input="test", + starting_agent=orchestrator, + max_turns=5, + ) + state._generated_items = [tool_call_item, tool_output_item] + + json_data = state.to_json() + deserialized_state = await RunState.from_json(orchestrator, json_data) + + # Verify tool_origin was preserved + deserialized_tool_call = next( + item for item in deserialized_state._generated_items if isinstance(item, ToolCallItem) + ) + deserialized_tool_output = next( + item for item in deserialized_state._generated_items if isinstance(item, ToolCallOutputItem) + ) + + assert deserialized_tool_call.tool_origin is not None + assert deserialized_tool_call.tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert deserialized_tool_call.tool_origin.agent_as_tool is not None + assert deserialized_tool_call.tool_origin.agent_as_tool.name == "nested_agent" + assert deserialized_tool_output.tool_origin is not None + assert deserialized_tool_output.tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert deserialized_tool_output.tool_origin.agent_as_tool is not None + assert deserialized_tool_output.tool_origin.agent_as_tool.name == "nested_agent" + + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 10), reason="MCP tests require Python 3.10+") +async def test_serialize_tool_origin_mcp(): + """Test that MCP tool_origin is serialized and deserialized.""" + model = FakeModel() + server = FakeMCPServer(server_name="test_mcp_server") + server.add_tool("mcp_tool", {}) + + agent = Agent(name="test", model=model, mcp_servers=[server]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("mcp_tool", "")], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 1 + assert len(tool_output_items) == 1 + + tool_call_item = tool_call_items[0] + tool_output_item = tool_output_items[0] + + # Verify tool_origin is set + assert tool_call_item.tool_origin is not None + assert tool_call_item.tool_origin.type == ToolOriginType.MCP + assert tool_call_item.tool_origin.mcp_server is not None + assert tool_call_item.tool_origin.mcp_server.name == "test_mcp_server" + assert tool_output_item.tool_origin is not None + assert tool_output_item.tool_origin.type == ToolOriginType.MCP + assert tool_output_item.tool_origin.mcp_server is not None + assert tool_output_item.tool_origin.mcp_server.name == "test_mcp_server" + + # Serialize and deserialize + context = result.context_wrapper + state = RunState( + context=context, + original_input="test", + starting_agent=agent, + max_turns=5, + ) + state._generated_items = [tool_call_item, tool_output_item] + + json_data = state.to_json() + deserialized_state = await RunState.from_json(agent, json_data) + + # Verify tool_origin was preserved + deserialized_tool_call = next( + item for item in deserialized_state._generated_items if isinstance(item, ToolCallItem) + ) + deserialized_tool_output = next( + item for item in deserialized_state._generated_items if isinstance(item, ToolCallOutputItem) + ) + + assert deserialized_tool_call.tool_origin is not None + assert deserialized_tool_call.tool_origin.type == ToolOriginType.MCP + # MCP server should be reconstructed from agent's mcp_servers + assert deserialized_tool_call.tool_origin.mcp_server is not None + assert deserialized_tool_call.tool_origin.mcp_server.name == "test_mcp_server" + assert deserialized_tool_output.tool_origin is not None + assert deserialized_tool_output.tool_origin.type == ToolOriginType.MCP + assert deserialized_tool_output.tool_origin.mcp_server is not None + assert deserialized_tool_output.tool_origin.mcp_server.name == "test_mcp_server" From 68c9dde2f5f55f687a5ad510004c152fae8d5169 Mon Sep 17 00:00:00 2001 From: habema Date: Sun, 8 Feb 2026 15:00:32 +0300 Subject: [PATCH 4/7] address code review --- src/agents/run_internal/items.py | 4 +- src/agents/run_internal/tool_execution.py | 2 + src/agents/run_internal/turn_resolution.py | 5 +- tests/test_tool_origin_rejection.py | 205 +++++++++++++++++++++ 4 files changed, 214 insertions(+), 2 deletions(-) create mode 100644 tests/test_tool_origin_rejection.py diff --git a/src/agents/run_internal/items.py b/src/agents/run_internal/items.py index 04e00f598f..015d73afdc 100644 --- a/src/agents/run_internal/items.py +++ b/src/agents/run_internal/items.py @@ -15,7 +15,7 @@ from ..agent_tool_state import drop_agent_tool_run_result from ..items import ItemHelpers, ToolCallOutputItem, TResponseInputItem from ..models.fake_id import FAKE_RESPONSES_ID -from ..tool import DEFAULT_APPROVAL_REJECTION_MESSAGE +from ..tool import DEFAULT_APPROVAL_REJECTION_MESSAGE, ToolOrigin REJECTION_MESSAGE = DEFAULT_APPROVAL_REJECTION_MESSAGE _TOOL_CALL_TO_OUTPUT_TYPE: dict[str, str] = { @@ -191,6 +191,7 @@ def function_rejection_item( tool_call: Any, *, rejection_message: str = REJECTION_MESSAGE, + tool_origin: ToolOrigin | None = None, ) -> ToolCallOutputItem: """Build a ToolCallOutputItem representing a rejected function tool call.""" if isinstance(tool_call, ResponseFunctionToolCall): @@ -199,6 +200,7 @@ def function_rejection_item( output=rejection_message, raw_item=ItemHelpers.tool_call_output_item(tool_call, rejection_message), agent=agent, + tool_origin=tool_origin, ) diff --git a/src/agents/run_internal/tool_execution.py b/src/agents/run_internal/tool_execution.py index a22f9a5cdc..e8a26ff121 100644 --- a/src/agents/run_internal/tool_execution.py +++ b/src/agents/run_internal/tool_execution.py @@ -868,6 +868,7 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo ) result = rejection_message span_fn.span_data.output = result + tool_origin = _get_tool_origin_info(func_tool) return FunctionToolResult( tool=func_tool, output=result, @@ -875,6 +876,7 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo agent, tool_call, rejection_message=rejection_message, + tool_origin=tool_origin, ), ) diff --git a/src/agents/run_internal/turn_resolution.py b/src/agents/run_internal/turn_resolution.py index 86872f4d27..ac28ff3a1a 100644 --- a/src/agents/run_internal/turn_resolution.py +++ b/src/agents/run_internal/turn_resolution.py @@ -692,8 +692,11 @@ async def _record_function_rejection( tool_name=function_tool.name, call_id=call_id, ) + tool_origin = _get_tool_origin_info(function_tool) rejected_function_outputs.append( - function_rejection_item(agent, tool_call, rejection_message=rejection_message) + function_rejection_item( + agent, tool_call, rejection_message=rejection_message, tool_origin=tool_origin + ) ) if isinstance(call_id, str): rejected_function_call_ids.add(call_id) diff --git a/tests/test_tool_origin_rejection.py b/tests/test_tool_origin_rejection.py new file mode 100644 index 0000000000..8582e03fdb --- /dev/null +++ b/tests/test_tool_origin_rejection.py @@ -0,0 +1,205 @@ +"""Tests for tool_origin preservation on rejected function tool calls.""" + +from __future__ import annotations + +import sys + +import pytest + +from agents import Agent, function_tool +from agents.items import ToolCallOutputItem +from agents.tool import ToolOriginType + +from .fake_model import FakeModel +from .test_responses import get_function_tool_call, get_text_message +from .utils.hitl import reject_tool_call + +if sys.version_info >= (3, 10): + from .mcp.helpers import FakeMCPServer + + +@pytest.mark.asyncio +async def test_rejected_function_tool_preserves_tool_origin(): + """Test that rejected function tools preserve tool_origin.""" + model = FakeModel() + + @function_tool + def test_tool(x: int) -> str: + """Test tool.""" + return f"result: {x}" + + agent = Agent(name="test", model=model, tools=[test_tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("test_tool", '{"x": 42}')], + [get_text_message("done")], + ] + ) + + # Pre-reject the tool call + tool_call = get_function_tool_call("test_tool", '{"x": 42}') + from openai.types.responses import ResponseFunctionToolCall + + from agents.lifecycle import RunHooks + from agents.run_config import RunConfig + from agents.run_context import RunContextWrapper + from agents.run_internal.run_steps import ToolRunFunction + from agents.run_internal.tool_execution import execute_function_tool_calls + + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + assert isinstance(tool_call, ResponseFunctionToolCall) + reject_tool_call(context, agent, tool_call, "test_tool") + + # Execute the tool call which should be rejected + tool_run = ToolRunFunction(tool_call=tool_call, function_tool=test_tool) + results, _, _ = await execute_function_tool_calls( + agent=agent, + tool_runs=[tool_run], + hooks=RunHooks(), + context_wrapper=context, + config=RunConfig(), + ) + + # Should have a rejection result + assert len(results) == 1 + result = results[0] + assert result.run_item is not None + assert isinstance(result.run_item, ToolCallOutputItem) + + # Verify tool_origin is preserved on rejection + assert result.run_item.tool_origin is not None + assert result.run_item.tool_origin.type == ToolOriginType.FUNCTION + + +@pytest.mark.asyncio +async def test_rejected_agent_as_tool_preserves_tool_origin(): + """Test that rejected agent-as-tool preserves tool_origin.""" + model = FakeModel() + nested_model = FakeModel() + + nested_agent = Agent( + name="nested_agent", + model=nested_model, + instructions="You are a nested agent.", + ) + + nested_model.add_multiple_turn_outputs([[get_text_message("nested response")]]) + + tool = nested_agent.as_tool( + tool_name="nested_tool", + tool_description="A nested agent tool", + ) + + orchestrator = Agent(name="orchestrator", model=model, tools=[tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("nested_tool", '{"input": "test"}')], + [get_text_message("done")], + ] + ) + + # Pre-reject the tool call + tool_call = get_function_tool_call("nested_tool", '{"input": "test"}') + from openai.types.responses import ResponseFunctionToolCall + + from agents.lifecycle import RunHooks + from agents.run_config import RunConfig + from agents.run_context import RunContextWrapper + from agents.run_internal.run_steps import ToolRunFunction + from agents.run_internal.tool_execution import execute_function_tool_calls + from agents.tool import FunctionTool + + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + assert isinstance(tool_call, ResponseFunctionToolCall) + assert isinstance(tool, FunctionTool) + reject_tool_call(context, orchestrator, tool_call, "nested_tool") + + # Execute the tool call which should be rejected + tool_run = ToolRunFunction(tool_call=tool_call, function_tool=tool) + results, _, _ = await execute_function_tool_calls( + agent=orchestrator, + tool_runs=[tool_run], + hooks=RunHooks(), + context_wrapper=context, + config=RunConfig(), + ) + + # Should have a rejection result + assert len(results) == 1 + result = results[0] + assert result.run_item is not None + assert isinstance(result.run_item, ToolCallOutputItem) + + # Verify tool_origin is preserved on rejection + assert result.run_item.tool_origin is not None + assert result.run_item.tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert result.run_item.tool_origin.agent_as_tool is not None + assert result.run_item.tool_origin.agent_as_tool.name == "nested_agent" + + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 10), reason="MCP tests require Python 3.10+") +async def test_rejected_mcp_tool_preserves_tool_origin(): + """Test that rejected MCP tools preserve tool_origin.""" + model = FakeModel() + server = FakeMCPServer(server_name="test_mcp_server") + server.add_tool("mcp_tool", {}) + + agent = Agent(name="test", model=model, mcp_servers=[server]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("mcp_tool", "")], + [get_text_message("done")], + ] + ) + + # Pre-reject the tool call + tool_call = get_function_tool_call("mcp_tool", "") + from openai.types.responses import ResponseFunctionToolCall + + from agents.lifecycle import RunHooks + from agents.mcp import MCPUtil + from agents.run_config import RunConfig + from agents.run_context import RunContextWrapper + from agents.run_internal.run_steps import ToolRunFunction + from agents.run_internal.tool_execution import execute_function_tool_calls + from agents.tool import FunctionTool + + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + assert isinstance(tool_call, ResponseFunctionToolCall) + reject_tool_call(context, agent, tool_call, "mcp_tool") + + # Get the MCP tool as FunctionTool + mcp_tools = await MCPUtil.get_all_function_tools( + agent.mcp_servers, + convert_schemas_to_strict=False, + run_context=context, + agent=agent, + ) + mcp_tool = next(tool for tool in mcp_tools if tool.name == "mcp_tool") + assert isinstance(mcp_tool, FunctionTool) + + # Execute the tool call which should be rejected + tool_run = ToolRunFunction(tool_call=tool_call, function_tool=mcp_tool) + results, _, _ = await execute_function_tool_calls( + agent=agent, + tool_runs=[tool_run], + hooks=RunHooks(), + context_wrapper=context, + config=RunConfig(), + ) + + # Should have a rejection result + assert len(results) == 1 + result = results[0] + assert result.run_item is not None + assert isinstance(result.run_item, ToolCallOutputItem) + + # Verify tool_origin is preserved on rejection + assert result.run_item.tool_origin is not None + assert result.run_item.tool_origin.type == ToolOriginType.MCP + assert result.run_item.tool_origin.mcp_server is not None + assert result.run_item.tool_origin.mcp_server.name == "test_mcp_server" From c2ea42d1a27e6a92720994b9f501a4e93160ac29 Mon Sep 17 00:00:00 2001 From: habema Date: Sun, 8 Feb 2026 15:13:16 +0300 Subject: [PATCH 5/7] address code review --- src/agents/run_internal/turn_resolution.py | 12 +++- tests/test_tool_origin_output_schema.py | 69 ++++++++++++++++++++++ 2 files changed, 79 insertions(+), 2 deletions(-) create mode 100644 tests/test_tool_origin_output_schema.py diff --git a/src/agents/run_internal/turn_resolution.py b/src/agents/run_internal/turn_resolution.py index ac28ff3a1a..806a112361 100644 --- a/src/agents/run_internal/turn_resolution.py +++ b/src/agents/run_internal/turn_resolution.py @@ -1459,11 +1459,19 @@ def process_model_response( else: if output.name not in function_map: if output_schema is not None and output.name == "json_tool_call": - items.append(ToolCallItem(raw_item=output, agent=agent)) + json_tool = build_litellm_json_tool_call(output) + tool_origin = _get_tool_origin_info(json_tool) + items.append( + ToolCallItem( + raw_item=output, + agent=agent, + tool_origin=tool_origin, + ) + ) functions.append( ToolRunFunction( tool_call=output, - function_tool=build_litellm_json_tool_call(output), + function_tool=json_tool, ) ) continue diff --git a/tests/test_tool_origin_output_schema.py b/tests/test_tool_origin_output_schema.py new file mode 100644 index 0000000000..3b7144a650 --- /dev/null +++ b/tests/test_tool_origin_output_schema.py @@ -0,0 +1,69 @@ +"""Tests for tool_origin with output_schema json_tool_call.""" + +from __future__ import annotations + +from pydantic import BaseModel + +from agents import Agent +from agents.agent_output import AgentOutputSchema +from agents.items import ModelResponse, ToolCallItem +from agents.run_internal.turn_resolution import process_model_response +from agents.tool import ToolOriginType +from agents.usage import Usage + +from .test_responses import get_function_tool_call + + +class OutputSchema(BaseModel): + """Test output schema.""" + + result: str + + +def test_output_schema_json_tool_call_has_tool_origin(): + """Test that json_tool_call ToolCallItem has tool_origin when output_schema is enabled.""" + agent = Agent(name="test", output_type=OutputSchema) + + # Get the output_schema + from agents.run_internal.run_loop import get_output_schema + + output_schema = get_output_schema(agent) + assert output_schema is not None + assert isinstance(output_schema, AgentOutputSchema) + + # Simulate a json_tool_call response + json_output = OutputSchema(result="test").model_dump_json() + json_tool_call = get_function_tool_call("json_tool_call", json_output) + + response = ModelResponse( + output=[json_tool_call], + usage=Usage(), + response_id=None, + ) + + # Process the response + processed = process_model_response( + agent=agent, + all_tools=[], + response=response, + output_schema=output_schema, + handoffs=[], + ) + + # Find the json_tool_call item + json_tool_call_item = next( + item + for item in processed.new_items + if isinstance(item, ToolCallItem) + and hasattr(item.raw_item, "name") + and item.raw_item.name == "json_tool_call" + ) + + # Verify tool_origin is set on ToolCallItem + assert json_tool_call_item.tool_origin is not None + assert json_tool_call_item.tool_origin.type == ToolOriginType.FUNCTION + + # Verify that a ToolRunFunction was created for execution + assert len(processed.functions) == 1 + function_run = processed.functions[0] + assert function_run.function_tool.name == "json_tool_call" From 6825a7453d34f2a9bb1210f94719098eddeded3d Mon Sep 17 00:00:00 2001 From: habema Date: Sun, 8 Feb 2026 15:38:17 +0300 Subject: [PATCH 6/7] address code review --- src/agents/run_internal/run_loop.py | 17 ++++++++++- tests/test_tool_origin_output_schema.py | 40 +++++++++++++++++++++++-- 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index 4404868ed8..0116e04d14 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -10,7 +10,11 @@ from collections.abc import Awaitable, Callable from typing import Any, TypeVar, cast -from openai.types.responses import ResponseCompletedEvent, ResponseOutputItemDoneEvent +from openai.types.responses import ( + ResponseCompletedEvent, + ResponseFunctionToolCall, + ResponseOutputItemDoneEvent, +) from openai.types.responses.response_prompt_param import ResponsePromptParam from openai.types.responses.response_reasoning_item import ResponseReasoningItem @@ -113,6 +117,7 @@ from .streaming import stream_step_items_to_queue, stream_step_result_to_queue from .tool_actions import ApplyPatchAction, ComputerAction, LocalShellAction, ShellAction from .tool_execution import ( + build_litellm_json_tool_call, coerce_shell_call, execute_apply_patch_calls, execute_computer_actions, @@ -1222,6 +1227,16 @@ async def run_single_turn_streamed( tool_description = getattr(tool, "description", None) if isinstance(tool, FunctionTool): tool_origin = _get_tool_origin_info(tool) + elif ( + isinstance(tool_name, str) + and tool_name == "json_tool_call" + and output_schema is not None + and isinstance(output_item, ResponseFunctionToolCall) + ): + # json_tool_call is synthesized dynamically and not in tool_map. + # Synthesize it here to get tool_origin, matching process_model_response. + json_tool = build_litellm_json_tool_call(output_item) + tool_origin = _get_tool_origin_info(json_tool) tool_item = ToolCallItem( raw_item=cast(ToolCallItemTypes, output_item), diff --git a/tests/test_tool_origin_output_schema.py b/tests/test_tool_origin_output_schema.py index 3b7144a650..8556c9e9f3 100644 --- a/tests/test_tool_origin_output_schema.py +++ b/tests/test_tool_origin_output_schema.py @@ -2,16 +2,18 @@ from __future__ import annotations +import pytest from pydantic import BaseModel -from agents import Agent +from agents import Agent, Runner from agents.agent_output import AgentOutputSchema from agents.items import ModelResponse, ToolCallItem from agents.run_internal.turn_resolution import process_model_response from agents.tool import ToolOriginType from agents.usage import Usage -from .test_responses import get_function_tool_call +from .fake_model import FakeModel +from .test_responses import get_final_output_message, get_function_tool_call class OutputSchema(BaseModel): @@ -67,3 +69,37 @@ def test_output_schema_json_tool_call_has_tool_origin(): assert len(processed.functions) == 1 function_run = processed.functions[0] assert function_run.function_tool.name == "json_tool_call" + + +@pytest.mark.asyncio +async def test_output_schema_json_tool_call_streaming_has_tool_origin(): + """ + Test that streamed json_tool_call ToolCallItem has tool_origin when output_schema is enabled. + """ + model = FakeModel() + agent = Agent(name="test", model=model, output_type=OutputSchema) + + # Simulate a json_tool_call response followed by completion + json_output = OutputSchema(result="test").model_dump_json() + json_tool_call = get_function_tool_call("json_tool_call", json_output) + final_output = get_final_output_message(json_output) + model.add_multiple_turn_outputs([[json_tool_call], [final_output]]) + + # Collect streamed events + streamed_tool_call_items: list[ToolCallItem] = [] + + result = Runner.run_streamed(agent, input="test") + async for event in result.stream_events(): + if event.type == "run_item_stream_event" and isinstance(event.item, ToolCallItem): + streamed_tool_call_items.append(event.item) + + # Find the json_tool_call item + json_tool_call_item = next( + item + for item in streamed_tool_call_items + if hasattr(item.raw_item, "name") and item.raw_item.name == "json_tool_call" + ) + + # Verify tool_origin is set on streamed ToolCallItem + assert json_tool_call_item.tool_origin is not None + assert json_tool_call_item.tool_origin.type == ToolOriginType.FUNCTION From f850af4c7b174ea88794acacbb8e1288368e8d81 Mon Sep 17 00:00:00 2001 From: habema Date: Sun, 8 Feb 2026 15:43:59 +0300 Subject: [PATCH 7/7] export ToolOrigin and ToolOriginType --- src/agents/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index c4f1de30f2..6393b45b0e 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -126,6 +126,8 @@ ShellResult, ShellTool, Tool, + ToolOrigin, + ToolOriginType, ToolOutputFileContent, ToolOutputFileContentDict, ToolOutputImage, @@ -359,6 +361,8 @@ def enable_verbose_stdout_logging(): "ApplyPatchResult", "ApplyPatchTool", "Tool", + "ToolOrigin", + "ToolOriginType", "WebSearchTool", "HostedMCPTool", "MCPToolApprovalFunction",