Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions temporalio/contrib/openai_agents/_invoke_model_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from agents.items import TResponseStreamEvent
from agents.tool import (
ApplyPatchTool,
CustomTool,
LocalShellTool,
ShellTool,
ShellToolEnvironment,
Expand All @@ -39,6 +40,7 @@
APIStatusError,
AsyncOpenAI,
)
from openai.types.responses import CustomToolParam
from openai.types.responses.tool_param import Mcp
from typing_extensions import Required, TypedDict

Expand Down Expand Up @@ -112,6 +114,15 @@ class ApplyPatchToolInput:
name: str = "apply_patch"


@dataclass
class CustomToolInput:
"""Data conversion friendly representation of a CustomTool. Contains only the fields which are needed by the model
execution to determine what tool to call, not the actual tool invocation, which remains in the workflow context.
"""

tool_config: CustomToolParam


ToolInput = (
FunctionToolInput
| FileSearchTool
Expand All @@ -122,6 +133,7 @@ class ApplyPatchToolInput:
| ShellToolInput
| LocalShellTool
| ApplyPatchToolInput
| CustomToolInput
| ToolSearchTool
)

Expand Down Expand Up @@ -235,6 +247,14 @@ def _build_tool(tool: ToolInput) -> Tool:
return ApplyPatchTool(name=tool.name, editor=_NoopApplyPatchEditor())
elif isinstance(tool, HostedMCPToolInput):
return HostedMCPTool(tool_config=tool.tool_config)
elif isinstance(tool, CustomToolInput):
return CustomTool(
name=tool.tool_config["name"],
description=tool.tool_config.get("description", ""),
on_invoke_tool=_empty_on_invoke_tool,
format=tool.tool_config.get("format"),
defer_loading=tool.tool_config.get("defer_loading", False),
)
elif isinstance(tool, FunctionToolInput):
return FunctionTool(
name=tool.name,
Expand Down
11 changes: 10 additions & 1 deletion temporalio/contrib/openai_agents/_temporal_model_stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,21 @@
WebSearchTool,
)
from agents.items import TResponseStreamEvent
from agents.tool import ApplyPatchTool, LocalShellTool, ShellTool, ToolSearchTool
from agents.tool import (
ApplyPatchTool,
CustomTool,
LocalShellTool,
ShellTool,
ToolSearchTool,
)
from openai.types.responses.response_prompt_param import ResponsePromptParam

from temporalio import workflow
from temporalio.contrib.openai_agents._invoke_model_activity import (
ActivityModelInput,
AgentOutputSchemaInput,
ApplyPatchToolInput,
CustomToolInput,
FunctionToolInput,
HandoffInput,
HostedMCPToolInput,
Expand Down Expand Up @@ -92,6 +99,8 @@ def make_tool_info(tool: Tool) -> ToolInput:
return ApplyPatchToolInput(name=tool.name)
elif isinstance(tool, HostedMCPTool):
return HostedMCPToolInput(tool_config=tool.tool_config)
elif isinstance(tool, CustomTool):
return CustomToolInput(tool_config=tool.tool_config)
elif isinstance(tool, FunctionTool):
return FunctionToolInput(
name=tool.name,
Expand Down
138 changes: 138 additions & 0 deletions tests/contrib/openai_agents/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,13 @@
TResponseStreamEvent,
)
from agents.mcp import MCPServer, MCPServerStdio
from agents.sandbox.capabilities.tools import SandboxApplyPatchTool
from agents.tool import CustomTool
from agents.tool_context import ToolContext
from openai import APIStatusError, AsyncOpenAI, BaseModel
from openai.types.responses import (
ResponseCodeInterpreterToolCall,
ResponseCustomToolCall,
ResponseFileSearchToolCall,
ResponseFunctionWebSearch,
)
Expand All @@ -83,6 +87,7 @@
StatefulMCPServerProvider,
StatelessMCPServerProvider,
)
from temporalio.contrib.openai_agents._invoke_model_activity import _build_tool
from temporalio.contrib.openai_agents._model_parameters import ModelSummaryProvider
from temporalio.contrib.openai_agents._openai_runner import _convert_agent
from temporalio.contrib.openai_agents._temporal_model_stub import (
Expand Down Expand Up @@ -1996,6 +2001,66 @@ async def test_hosted_mcp_tool(client: Client):
assert result == "Some language"


def custom_tool_mock_model():
return TestModel.returning_responses(
[
ModelResponse(
output=[
ResponseCustomToolCall(
call_id="c1",
input="ping",
name="echo",
type="custom_tool_call",
)
],
usage=Usage(),
response_id=None,
),
ResponseBuilders.output_message("done"),
]
)


@workflow.defn
class CustomToolWorkflow:
@workflow.run
async def run(self) -> str:
captured: list[str] = []

async def echo(ctx: ToolContext[Any], input: str) -> str: # type: ignore[reportUnusedParameter]
captured.append(input)
return input

agent = Agent[str](
name="custom-tool-agent",
instructions="Use the echo tool.",
tools=[
CustomTool(
name="echo",
description="Echo the input string back.",
on_invoke_tool=echo,
)
],
)
result = await Runner.run(starting_agent=agent, input="say something")
return f"{result.final_output}:{captured[0]}"


async def test_custom_tool_workflow(client: Client):
async with AgentEnvironment(model=custom_tool_mock_model()) as env:
client = env.applied_on_client(client)

async with new_worker(client, CustomToolWorkflow) as worker:
workflow_handle = await client.start_workflow(
CustomToolWorkflow.run,
id=f"custom-tool-workflow-{uuid.uuid4()}",
task_queue=worker.task_queue,
execution_timeout=timedelta(seconds=30),
)
result = await workflow_handle.result()
assert result == "done:ping"


class AssertDifferentModelProvider(ModelProvider):
model_names: set[str | None]

Expand Down Expand Up @@ -2538,6 +2603,79 @@ async def test_model_conversion_loops():
assert isinstance(triage_agent.model, _TemporalModelStub)


def test_sandbox_apply_patch_tool_round_trips_through_activity_input():
class FakeSandboxSession:
pass

tool = SandboxApplyPatchTool(session=FakeSandboxSession()) # type: ignore[arg-type]

stub = _TemporalModelStub(
model_name="gpt-5",
model_params=ModelActivityParameters(),
agent=None,
)

activity_input, _summary = stub._build_activity_input(
system_instructions=None,
input="hi",
model_settings=ModelSettings(),
tools=[tool],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
conversation_id=None,
prompt=None,
)

tool_inputs = activity_input.get("tools") or []
assert len(tool_inputs) == 1
rebuilt = _build_tool(tool_inputs[0])
assert isinstance(rebuilt, CustomTool)
assert rebuilt.name == tool.name
assert rebuilt.description == tool.description
assert rebuilt.format == tool.format
assert rebuilt.tool_config == tool.tool_config


def test_custom_tool_with_defer_loading_round_trips_through_activity_input():
async def stub(_ctx: Any, _payload: str) -> str:
return ""

tool = CustomTool(
name="deferred_tool",
description="A custom tool with defer_loading enabled",
on_invoke_tool=stub,
defer_loading=True,
)

stub_model = _TemporalModelStub(
model_name="gpt-5",
model_params=ModelActivityParameters(),
agent=None,
)

activity_input, _summary = stub_model._build_activity_input(
system_instructions=None,
input="hi",
model_settings=ModelSettings(),
tools=[tool],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
conversation_id=None,
prompt=None,
)

tool_inputs = activity_input.get("tools") or []
assert len(tool_inputs) == 1
rebuilt = _build_tool(tool_inputs[0])
assert isinstance(rebuilt, CustomTool)
assert rebuilt.tool_config == tool.tool_config
assert rebuilt.defer_loading is True


Comment thread
xumaple marked this conversation as resolved.
async def test_local_hello_world_agent(client: Client):
async with AgentEnvironment(
model=hello_mock_model(),
Expand Down
Loading