diff --git a/temporalio/contrib/openai_agents/_invoke_model_activity.py b/temporalio/contrib/openai_agents/_invoke_model_activity.py index a43f9aeaf..ebc984460 100644 --- a/temporalio/contrib/openai_agents/_invoke_model_activity.py +++ b/temporalio/contrib/openai_agents/_invoke_model_activity.py @@ -30,6 +30,7 @@ from agents.items import TResponseStreamEvent from agents.tool import ( ApplyPatchTool, + CustomTool, LocalShellTool, ShellTool, ShellToolEnvironment, @@ -39,6 +40,7 @@ APIStatusError, AsyncOpenAI, ) +from openai.types.responses import CustomToolParam from openai.types.responses.tool_param import Mcp from typing_extensions import Required, TypedDict @@ -112,6 +114,15 @@ class ApplyPatchToolInput: name: str = "apply_patch" +@dataclass +class CustomToolInput: + """Data conversion friendly representation of a CustomTool. Contains only the fields which are needed by the model + execution to determine what tool to call, not the actual tool invocation, which remains in the workflow context. + """ + + tool_config: CustomToolParam + + ToolInput = ( FunctionToolInput | FileSearchTool @@ -122,6 +133,7 @@ class ApplyPatchToolInput: | ShellToolInput | LocalShellTool | ApplyPatchToolInput + | CustomToolInput | ToolSearchTool ) @@ -235,6 +247,14 @@ def _build_tool(tool: ToolInput) -> Tool: return ApplyPatchTool(name=tool.name, editor=_NoopApplyPatchEditor()) elif isinstance(tool, HostedMCPToolInput): return HostedMCPTool(tool_config=tool.tool_config) + elif isinstance(tool, CustomToolInput): + return CustomTool( + name=tool.tool_config["name"], + description=tool.tool_config.get("description", ""), + on_invoke_tool=_empty_on_invoke_tool, + format=tool.tool_config.get("format"), + defer_loading=tool.tool_config.get("defer_loading", False), + ) elif isinstance(tool, FunctionToolInput): return FunctionTool( name=tool.name, diff --git a/temporalio/contrib/openai_agents/_temporal_model_stub.py b/temporalio/contrib/openai_agents/_temporal_model_stub.py index 7f9ab11d9..d184daa4a 100644 --- a/temporalio/contrib/openai_agents/_temporal_model_stub.py +++ b/temporalio/contrib/openai_agents/_temporal_model_stub.py @@ -22,7 +22,13 @@ WebSearchTool, ) from agents.items import TResponseStreamEvent -from agents.tool import ApplyPatchTool, LocalShellTool, ShellTool, ToolSearchTool +from agents.tool import ( + ApplyPatchTool, + CustomTool, + LocalShellTool, + ShellTool, + ToolSearchTool, +) from openai.types.responses.response_prompt_param import ResponsePromptParam from temporalio import workflow @@ -30,6 +36,7 @@ ActivityModelInput, AgentOutputSchemaInput, ApplyPatchToolInput, + CustomToolInput, FunctionToolInput, HandoffInput, HostedMCPToolInput, @@ -92,6 +99,8 @@ def make_tool_info(tool: Tool) -> ToolInput: return ApplyPatchToolInput(name=tool.name) elif isinstance(tool, HostedMCPTool): return HostedMCPToolInput(tool_config=tool.tool_config) + elif isinstance(tool, CustomTool): + return CustomToolInput(tool_config=tool.tool_config) elif isinstance(tool, FunctionTool): return FunctionToolInput( name=tool.name, diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index de0af3923..96cc25133 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -58,9 +58,13 @@ TResponseStreamEvent, ) from agents.mcp import MCPServer, MCPServerStdio +from agents.sandbox.capabilities.tools import SandboxApplyPatchTool +from agents.tool import CustomTool +from agents.tool_context import ToolContext from openai import APIStatusError, AsyncOpenAI, BaseModel from openai.types.responses import ( ResponseCodeInterpreterToolCall, + ResponseCustomToolCall, ResponseFileSearchToolCall, ResponseFunctionWebSearch, ) @@ -83,6 +87,7 @@ StatefulMCPServerProvider, StatelessMCPServerProvider, ) +from temporalio.contrib.openai_agents._invoke_model_activity import _build_tool from temporalio.contrib.openai_agents._model_parameters import ModelSummaryProvider from temporalio.contrib.openai_agents._openai_runner import _convert_agent from temporalio.contrib.openai_agents._temporal_model_stub import ( @@ -1996,6 +2001,66 @@ async def test_hosted_mcp_tool(client: Client): assert result == "Some language" +def custom_tool_mock_model(): + return TestModel.returning_responses( + [ + ModelResponse( + output=[ + ResponseCustomToolCall( + call_id="c1", + input="ping", + name="echo", + type="custom_tool_call", + ) + ], + usage=Usage(), + response_id=None, + ), + ResponseBuilders.output_message("done"), + ] + ) + + +@workflow.defn +class CustomToolWorkflow: + @workflow.run + async def run(self) -> str: + captured: list[str] = [] + + async def echo(ctx: ToolContext[Any], input: str) -> str: # type: ignore[reportUnusedParameter] + captured.append(input) + return input + + agent = Agent[str]( + name="custom-tool-agent", + instructions="Use the echo tool.", + tools=[ + CustomTool( + name="echo", + description="Echo the input string back.", + on_invoke_tool=echo, + ) + ], + ) + result = await Runner.run(starting_agent=agent, input="say something") + return f"{result.final_output}:{captured[0]}" + + +async def test_custom_tool_workflow(client: Client): + async with AgentEnvironment(model=custom_tool_mock_model()) as env: + client = env.applied_on_client(client) + + async with new_worker(client, CustomToolWorkflow) as worker: + workflow_handle = await client.start_workflow( + CustomToolWorkflow.run, + id=f"custom-tool-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=30), + ) + result = await workflow_handle.result() + assert result == "done:ping" + + class AssertDifferentModelProvider(ModelProvider): model_names: set[str | None] @@ -2538,6 +2603,79 @@ async def test_model_conversion_loops(): assert isinstance(triage_agent.model, _TemporalModelStub) +def test_sandbox_apply_patch_tool_round_trips_through_activity_input(): + class FakeSandboxSession: + pass + + tool = SandboxApplyPatchTool(session=FakeSandboxSession()) # type: ignore[arg-type] + + stub = _TemporalModelStub( + model_name="gpt-5", + model_params=ModelActivityParameters(), + agent=None, + ) + + activity_input, _summary = stub._build_activity_input( + system_instructions=None, + input="hi", + model_settings=ModelSettings(), + tools=[tool], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + conversation_id=None, + prompt=None, + ) + + tool_inputs = activity_input.get("tools") or [] + assert len(tool_inputs) == 1 + rebuilt = _build_tool(tool_inputs[0]) + assert isinstance(rebuilt, CustomTool) + assert rebuilt.name == tool.name + assert rebuilt.description == tool.description + assert rebuilt.format == tool.format + assert rebuilt.tool_config == tool.tool_config + + +def test_custom_tool_with_defer_loading_round_trips_through_activity_input(): + async def stub(_ctx: Any, _payload: str) -> str: + return "" + + tool = CustomTool( + name="deferred_tool", + description="A custom tool with defer_loading enabled", + on_invoke_tool=stub, + defer_loading=True, + ) + + stub_model = _TemporalModelStub( + model_name="gpt-5", + model_params=ModelActivityParameters(), + agent=None, + ) + + activity_input, _summary = stub_model._build_activity_input( + system_instructions=None, + input="hi", + model_settings=ModelSettings(), + tools=[tool], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + conversation_id=None, + prompt=None, + ) + + tool_inputs = activity_input.get("tools") or [] + assert len(tool_inputs) == 1 + rebuilt = _build_tool(tool_inputs[0]) + assert isinstance(rebuilt, CustomTool) + assert rebuilt.tool_config == tool.tool_config + assert rebuilt.defer_loading is True + + async def test_local_hello_world_agent(client: Client): async with AgentEnvironment( model=hello_mock_model(),