diff --git a/examples/tools/codex_same_thread.py b/examples/tools/codex_same_thread.py new file mode 100644 index 0000000000..e7db01eb08 --- /dev/null +++ b/examples/tools/codex_same_thread.py @@ -0,0 +1,125 @@ +import asyncio +from collections.abc import Mapping +from datetime import datetime + +from pydantic import BaseModel + +from agents import Agent, ModelSettings, Runner, gen_trace_id, trace + +# This tool is still in experimental phase and the details could be changed until being GAed. +from agents.extensions.experimental.codex import ( + CodexToolStreamEvent, + ThreadErrorEvent, + ThreadOptions, + ThreadStartedEvent, + TurnCompletedEvent, + TurnFailedEvent, + TurnStartedEvent, + codex_tool, +) + +# Derived from codex_tool(name="codex_engineer") when run_context_thread_id_key is omitted. +THREAD_ID_KEY = "codex_thread_id_engineer" + + +async def on_codex_stream(payload: CodexToolStreamEvent) -> None: + event = payload.event + + if isinstance(event, ThreadStartedEvent): + log(f"codex thread started: {event.thread_id}") + return + if isinstance(event, TurnStartedEvent): + log("codex turn started") + return + if isinstance(event, TurnCompletedEvent): + log(f"codex turn completed, usage: {event.usage}") + return + if isinstance(event, TurnFailedEvent): + log(f"codex turn failed: {event.error.message}") + return + if isinstance(event, ThreadErrorEvent): + log(f"codex stream error: {event.message}") + + +def _timestamp() -> str: + return datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + +def log(message: str) -> None: + timestamp = _timestamp() + lines = str(message).splitlines() or [""] + for line in lines: + print(f"{timestamp} {line}") + + +def read_context_value(context: Mapping[str, str] | BaseModel, key: str) -> str | None: + # either dict or pydantic model + if isinstance(context, Mapping): + return context.get(key) + return getattr(context, key, None) + + +async def main() -> None: + agent = Agent( + name="Codex Agent (same thread)", + instructions=( + "Always use the Codex tool answer the user's question. " + "Even when you don't have enough context, the Codex tool may know. " + "In that case, you can simply forward the question to the Codex tool." + ), + tools=[ + codex_tool( + # Give each Codex tool a unique `codex_` name when you run multiple tools in one agent. + # Name-based defaults keep their run-context thread IDs separated. + name="codex_engineer", + sandbox_mode="workspace-write", + default_thread_options=ThreadOptions( + model="gpt-5.2-codex", + model_reasoning_effort="low", + network_access_enabled=True, + web_search_enabled=False, + approval_policy="never", + ), + on_stream=on_codex_stream, + # Reuse the same Codex thread across runs that share this context object. + use_run_context_thread_id=True, + ) + ], + model_settings=ModelSettings(tool_choice="required"), + ) + + class MyContext(BaseModel): + something: str | None = None + # the default is "codex_thread_id"; missing this works as well + codex_thread_id_engineer: str | None = None # aligns with run_context_thread_id_key + + context = MyContext() + + # Simple dict object works as well: + # context: dict[str, str] = {} + + trace_id = gen_trace_id() + log(f"View trace: https://platform.openai.com/traces/trace?trace_id={trace_id}") + + with trace("Codex same thread example", trace_id=trace_id): + log("Turn 1: ask writing python code") + first_prompt = "Write working python code example demonstrating how to call OpenAI's Responses API with web search tool." + first_result = await Runner.run(agent, first_prompt, context=context) + first_thread_id = read_context_value(context, THREAD_ID_KEY) + log(first_result.final_output) + log(f"thread id after turn 1: {first_thread_id}") + + log("Turn 2: continue with the same Codex thread.") + second_prompt = "Write the same code in TypeScript." + second_result = await Runner.run(agent, second_prompt, context=context) + second_thread_id = read_context_value(context, THREAD_ID_KEY) + log(second_result.final_output) + log(f"thread id after turn 2: {second_thread_id}") + log( + "same thread reused: " + + str(first_thread_id is not None and first_thread_id == second_thread_id) + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/agents/agent.py b/src/agents/agent.py index b0368e8698..5e3376a05f 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -25,7 +25,7 @@ peek_agent_tool_run_result, record_agent_tool_run_result, ) -from .exceptions import ModelBehaviorError +from .exceptions import ModelBehaviorError, UserError from .guardrail import InputGuardrail, OutputGuardrail from .handoffs import Handoff from .logger import logger @@ -88,6 +88,32 @@ class ToolsToFinalOutputResult: """ +def _validate_codex_tool_name_collisions(tools: list[Tool]) -> None: + codex_tool_names = { + tool.name + for tool in tools + if isinstance(tool, FunctionTool) and bool(getattr(tool, "_is_codex_tool", False)) + } + if not codex_tool_names: + return + + name_counts: dict[str, int] = {} + for tool in tools: + tool_name = getattr(tool, "name", None) + if isinstance(tool_name, str) and tool_name: + name_counts[tool_name] = name_counts.get(tool_name, 0) + 1 + + duplicate_codex_names = sorted( + name for name in codex_tool_names if name_counts.get(name, 0) > 1 + ) + if duplicate_codex_names: + raise UserError( + "Duplicate Codex tool names found: " + + ", ".join(duplicate_codex_names) + + ". Provide a unique codex_tool(name=...) per tool instance." + ) + + class AgentToolStreamEvent(TypedDict): """Streaming event emitted when an agent is invoked as a tool.""" @@ -182,7 +208,9 @@ async def _check_tool_enabled(tool: Tool) -> bool: results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools)) enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok] - return [*mcp_tools, *enabled] + all_tools: list[Tool] = [*mcp_tools, *enabled] + _validate_codex_tool_name_collisions(all_tools) + return all_tools @dataclass diff --git a/src/agents/extensions/experimental/codex/codex_tool.py b/src/agents/extensions/experimental/codex/codex_tool.py index 50d6300377..4ec4653e46 100644 --- a/src/agents/extensions/experimental/codex/codex_tool.py +++ b/src/agents/extensions/experimental/codex/codex_tool.py @@ -5,9 +5,10 @@ import inspect import json import os -from collections.abc import AsyncGenerator, Awaitable, Mapping +import re +from collections.abc import AsyncGenerator, Awaitable, Mapping, MutableMapping from dataclasses import dataclass -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Union from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from pydantic import BaseModel, ConfigDict, Field, ValidationError, model_validator @@ -34,6 +35,7 @@ ItemUpdatedEvent, ThreadErrorEvent, ThreadEvent, + ThreadStartedEvent, TurnCompletedEvent, TurnFailedEvent, Usage, @@ -62,6 +64,9 @@ "changes", "items", ) +DEFAULT_CODEX_TOOL_NAME = "codex" +DEFAULT_RUN_CONTEXT_THREAD_ID_KEY = "codex_thread_id" +CODEX_TOOL_NAME_PREFIX = "codex_" class CodexToolInputItem(BaseModel): @@ -102,6 +107,37 @@ class CodexToolParameters(BaseModel): "Structured inputs appended to the Codex task. Provide at least one input item." ), ) + thread_id: str | None = Field( + default=None, + description=( + "Optional Codex thread ID to resume. If omitted, a new thread is started unless " + "configured elsewhere." + ), + ) + + model_config = ConfigDict(extra="forbid") + + @model_validator(mode="after") + def validate_thread_id(self) -> CodexToolParameters: + if self.thread_id is None: + return self + + normalized = self.thread_id.strip() + if not normalized: + raise ValueError('When provided, "thread_id" must be a non-empty string.') + + self.thread_id = normalized + return self + + +class CodexToolRunContextParameters(BaseModel): + inputs: list[CodexToolInputItem] = Field( + ..., + min_length=1, + description=( + "Structured inputs appended to the Codex task. Provide at least one input item." + ), + ) model_config = ConfigDict(extra="forbid") @@ -177,9 +213,13 @@ class CodexToolOptions: on_stream: Callable[[CodexToolStreamEvent], MaybeAwaitable[None]] | None = None is_enabled: bool | Callable[[RunContextWrapper[Any], Any], MaybeAwaitable[bool]] = True failure_error_function: ToolErrorFunction | None = default_tool_error_function + use_run_context_thread_id: bool = False + run_context_thread_id_key: str | None = None -CodexToolCallArguments: TypeAlias = dict[str, Optional[list[UserInput]]] +class CodexToolCallArguments(TypedDict): + inputs: list[UserInput] | None + thread_id: str | None class _UnsetType: @@ -209,6 +249,8 @@ def codex_tool( on_stream: Callable[[CodexToolStreamEvent], MaybeAwaitable[None]] | None = None, is_enabled: bool | Callable[[RunContextWrapper[Any], Any], MaybeAwaitable[bool]] | None = None, failure_error_function: ToolErrorFunction | None | _UnsetType = _UNSET, + use_run_context_thread_id: bool | None = None, + run_context_thread_id_key: str | None = None, ) -> FunctionTool: resolved_options = _coerce_tool_options(options) if name is not None: @@ -245,7 +287,10 @@ def codex_tool( resolved_options.is_enabled = is_enabled if not isinstance(failure_error_function, _UnsetType): resolved_options.failure_error_function = failure_error_function - + if use_run_context_thread_id is not None: + resolved_options.use_run_context_thread_id = use_run_context_thread_id + if run_context_thread_id_key is not None: + resolved_options.run_context_thread_id_key = run_context_thread_id_key resolved_options.codex_options = coerce_codex_options(resolved_options.codex_options) resolved_options.default_thread_options = coerce_thread_options( resolved_options.default_thread_options @@ -253,11 +298,22 @@ def codex_tool( resolved_options.default_turn_options = coerce_turn_options( resolved_options.default_turn_options ) - name = resolved_options.name or "codex" + name = _resolve_codex_tool_name(resolved_options.name) + resolved_run_context_thread_id_key = _resolve_run_context_thread_id_key( + tool_name=name, + configured_key=resolved_options.run_context_thread_id_key, + strict_default_key=resolved_options.use_run_context_thread_id, + ) description = resolved_options.description or ( "Executes an agentic Codex task against the current workspace." ) - parameters_model = resolved_options.parameters or CodexToolParameters + if resolved_options.parameters is not None: + parameters_model = resolved_options.parameters + elif resolved_options.use_run_context_thread_id: + # In run-context mode, hide thread_id from the default tool schema. + parameters_model = CodexToolRunContextParameters + else: + parameters_model = CodexToolParameters params_schema = ensure_strict_json_schema(parameters_model.model_json_schema()) resolved_codex_options = _resolve_codex_options(resolved_options.codex_options) @@ -275,45 +331,77 @@ def codex_tool( async def _on_invoke_tool(ctx: ToolContext[Any], input_json: str) -> Any: nonlocal persisted_thread + resolved_thread_id: str | None = None try: parsed = _parse_tool_input(parameters_model, input_json) args = _normalize_parameters(parsed) + if resolved_options.use_run_context_thread_id: + _validate_run_context_thread_id_context(ctx, resolved_run_context_thread_id_key) + codex = await resolve_codex() + call_thread_id = _resolve_call_thread_id( + args=args, + ctx=ctx, + configured_thread_id=resolved_options.thread_id, + use_run_context_thread_id=resolved_options.use_run_context_thread_id, + run_context_thread_id_key=resolved_run_context_thread_id_key, + ) if resolved_options.persist_session: # Reuse a single Codex thread across tool calls. thread = _get_or_create_persisted_thread( codex, - resolved_options.thread_id, + call_thread_id, resolved_thread_options, persisted_thread, ) if persisted_thread is None: persisted_thread = thread else: - thread = _get_thread(codex, resolved_options.thread_id, resolved_thread_options) + thread = _get_thread(codex, call_thread_id, resolved_thread_options) turn_options = _build_turn_options( resolved_options.default_turn_options, validated_output_schema ) codex_input = _build_codex_input(args) + resolved_thread_id = thread.id or call_thread_id # Always stream and aggregate locally to enable on_stream callbacks. stream_result = await thread.run_streamed(codex_input, turn_options) - response, usage = await _consume_events( - stream_result.events, - args, - ctx, - thread, - resolved_options.on_stream, - resolved_options.span_data_max_chars, - ) + resolved_thread_id_holder: dict[str, str | None] = {"thread_id": resolved_thread_id} + try: + response, usage, resolved_thread_id = await _consume_events( + stream_result.events, + args, + ctx, + thread, + resolved_options.on_stream, + resolved_options.span_data_max_chars, + resolved_thread_id_holder=resolved_thread_id_holder, + ) + except Exception: + resolved_thread_id = resolved_thread_id_holder["thread_id"] + raise if usage is not None: ctx.usage.add(_to_agent_usage(usage)) - return CodexToolResult(thread_id=thread.id, response=response, usage=usage) + if resolved_options.use_run_context_thread_id: + _store_thread_id_in_run_context( + ctx, + resolved_run_context_thread_id_key, + resolved_thread_id, + ) + + return CodexToolResult(thread_id=resolved_thread_id, response=response, usage=usage) except Exception as exc: # noqa: BLE001 + _try_store_thread_id_in_run_context_after_error( + ctx=ctx, + key=resolved_run_context_thread_id_key, + thread_id=resolved_thread_id, + enabled=resolved_options.use_run_context_thread_id, + ) + if resolved_options.failure_error_function is None: raise @@ -333,7 +421,7 @@ async def _on_invoke_tool(ctx: ToolContext[Any], input_json: str) -> Any: logger.error("Codex tool failed: %s", exc, exc_info=exc) return result - return FunctionTool( + function_tool = FunctionTool( name=name, description=description, params_json_schema=params_schema, @@ -341,14 +429,17 @@ async def _on_invoke_tool(ctx: ToolContext[Any], input_json: str) -> Any: strict_json_schema=True, is_enabled=resolved_options.is_enabled, ) + # Internal marker used for codex-tool specific runtime validation. + function_tool._is_codex_tool = True + return function_tool def _coerce_tool_options( options: CodexToolOptions | Mapping[str, Any] | None, ) -> CodexToolOptions: if options is None: - return CodexToolOptions() - if isinstance(options, CodexToolOptions): + resolved = CodexToolOptions() + elif isinstance(options, CodexToolOptions): resolved = options else: if not isinstance(options, Mapping): @@ -364,9 +455,87 @@ def _coerce_tool_options( resolved.codex_options = coerce_codex_options(resolved.codex_options) resolved.default_thread_options = coerce_thread_options(resolved.default_thread_options) resolved.default_turn_options = coerce_turn_options(resolved.default_turn_options) + key = resolved.run_context_thread_id_key + if key is not None: + resolved.run_context_thread_id_key = _validate_run_context_thread_id_key(key) + return resolved +def _validate_run_context_thread_id_key(value: Any) -> str: + if not isinstance(value, str): + raise UserError("run_context_thread_id_key must be a string.") + + key = value.strip() + if not key: + raise UserError("run_context_thread_id_key must be a non-empty string.") + + return key + + +def _resolve_codex_tool_name(configured_name: str | None) -> str: + if configured_name is None: + return DEFAULT_CODEX_TOOL_NAME + + if not isinstance(configured_name, str): + raise UserError("Codex tool name must be a string.") + + normalized = configured_name.strip() + if not normalized: + raise UserError("Codex tool name must be a non-empty string.") + + if normalized != DEFAULT_CODEX_TOOL_NAME and not normalized.startswith(CODEX_TOOL_NAME_PREFIX): + raise UserError( + f'Codex tool name must be "{DEFAULT_CODEX_TOOL_NAME}" or start with ' + f'"{CODEX_TOOL_NAME_PREFIX}".' + ) + + return normalized + + +def _resolve_run_context_thread_id_key( + tool_name: str, configured_key: str | None, *, strict_default_key: bool = False +) -> str: + if configured_key is not None: + return _validate_run_context_thread_id_key(configured_key) + + if tool_name == DEFAULT_CODEX_TOOL_NAME: + return DEFAULT_RUN_CONTEXT_THREAD_ID_KEY + + suffix = tool_name[len(CODEX_TOOL_NAME_PREFIX) :] + if strict_default_key: + suffix = _validate_default_run_context_thread_id_suffix(suffix) + return f"{DEFAULT_RUN_CONTEXT_THREAD_ID_KEY}_{suffix}" + suffix = _normalize_name_for_context_key(suffix) + return f"{DEFAULT_RUN_CONTEXT_THREAD_ID_KEY}_{suffix}" + + +def _normalize_name_for_context_key(value: str) -> str: + # Keep generated context keys deterministic and broadly attribute-safe. + normalized = re.sub(r"[^0-9a-zA-Z_]+", "_", value.strip().lower()) + normalized = normalized.strip("_") + return normalized or "tool" + + +def _validate_default_run_context_thread_id_suffix(value: str) -> str: + suffix = value.strip() + if not suffix: + raise UserError( + "When use_run_context_thread_id=True and run_context_thread_id_key is omitted, " + 'codex tool names must include a non-empty suffix after "codex_".' + ) + + if not re.fullmatch(r"[A-Za-z0-9_]+", suffix): + raise UserError( + "When use_run_context_thread_id=True and run_context_thread_id_key is omitted, " + 'the codex tool name suffix (after "codex_") must match [A-Za-z0-9_]+. ' + "Use only letters, numbers, and underscores, " + "or set run_context_thread_id_key explicitly." + ) + + return suffix + + def _parse_tool_input(parameters_model: type[BaseModel], input_json: str) -> BaseModel: try: json_data = json.loads(input_json) if input_json else {} @@ -387,6 +556,7 @@ def _normalize_parameters(params: BaseModel) -> CodexToolCallArguments: inputs_value = getattr(params, "inputs", None) if inputs_value is None: raise UserError("Codex tool parameters must include an inputs field.") + thread_id_value = getattr(params, "thread_id", None) inputs = [{"type": item.type, "text": item.text, "path": item.path} for item in inputs_value] @@ -397,7 +567,10 @@ def _normalize_parameters(params: BaseModel) -> CodexToolCallArguments: else: normalized_inputs.append({"type": "local_image", "path": item["path"] or ""}) - return {"inputs": normalized_inputs if normalized_inputs else None} + return { + "inputs": normalized_inputs if normalized_inputs else None, + "thread_id": _normalize_thread_id(thread_id_value), + } def _build_codex_input(args: CodexToolCallArguments) -> Input: @@ -639,6 +812,186 @@ def _get_thread(codex: Codex, thread_id: str | None, defaults: ThreadOptions | N return codex.start_thread(defaults) +def _normalize_thread_id(value: Any) -> str | None: + if value is None: + return None + if not isinstance(value, str): + raise UserError("Codex thread_id must be a string when provided.") + + normalized = value.strip() + if not normalized: + return None + return normalized + + +def _resolve_call_thread_id( + args: CodexToolCallArguments, + ctx: RunContextWrapper[Any], + configured_thread_id: str | None, + use_run_context_thread_id: bool, + run_context_thread_id_key: str, +) -> str | None: + explicit_thread_id = _normalize_thread_id(args.get("thread_id")) + if explicit_thread_id: + return explicit_thread_id + + if use_run_context_thread_id: + context_thread_id = _read_thread_id_from_run_context(ctx, run_context_thread_id_key) + if context_thread_id: + return context_thread_id + + return configured_thread_id + + +def _read_thread_id_from_run_context(ctx: RunContextWrapper[Any], key: str) -> str | None: + context = ctx.context + if context is None: + return None + + if isinstance(context, Mapping): + value = context.get(key) + else: + value = getattr(context, key, None) + + if value is None: + return None + if not isinstance(value, str): + raise UserError(f'Run context "{key}" must be a string when provided.') + + normalized = value.strip() + if not normalized: + return None + + return normalized + + +def _validate_run_context_thread_id_context(ctx: RunContextWrapper[Any], key: str) -> None: + context = ctx.context + if context is None: + raise UserError( + "use_run_context_thread_id=True requires a mutable run context object. " + "Pass context={} (or an object) to Runner.run()." + ) + + if isinstance(context, MutableMapping): + return + + if isinstance(context, Mapping): + raise UserError( + "use_run_context_thread_id=True requires a mutable run context mapping " + "or a writable object context." + ) + + if isinstance(context, BaseModel): + if bool(context.model_config.get("frozen", False)): + raise UserError( + "use_run_context_thread_id=True requires a mutable run context object. " + "Frozen Pydantic models are not supported." + ) + return + + if dataclasses.is_dataclass(context): + params = getattr(type(context), "__dataclass_params__", None) + if params is not None and bool(getattr(params, "frozen", False)): + raise UserError( + "use_run_context_thread_id=True requires a mutable run context object. " + "Frozen dataclass contexts are not supported." + ) + + slots = getattr(type(context), "__slots__", None) + if slots is not None and not hasattr(context, "__dict__"): + slot_names = (slots,) if isinstance(slots, str) else tuple(slots) + if key not in slot_names: + raise UserError( + "use_run_context_thread_id=True requires the run context to support field " + + f'"{key}". ' + "Use a mutable dict context, or add a writable field/slot to the context object." + ) + return + + if not hasattr(context, "__dict__"): + raise UserError( + "use_run_context_thread_id=True requires a mutable run context mapping " + "or a writable object context." + ) + + +def _store_thread_id_in_run_context( + ctx: RunContextWrapper[Any], key: str, thread_id: str | None +) -> None: + if thread_id is None: + return + + _validate_run_context_thread_id_context(ctx, key) + context = ctx.context + assert context is not None + + if isinstance(context, MutableMapping): + context[key] = thread_id + return + + if isinstance(context, BaseModel): + if _set_pydantic_context_value(context, key, thread_id): + return + raise UserError( + f'Unable to store Codex thread_id in run context field "{key}". ' + "Use a mutable dict context or set a writable attribute." + ) + + try: + setattr(context, key, thread_id) + except Exception as exc: # noqa: BLE001 + raise UserError( + f'Unable to store Codex thread_id in run context field "{key}". ' + "Use a mutable dict context or set a writable attribute." + ) from exc + + +def _try_store_thread_id_in_run_context_after_error( + *, + ctx: RunContextWrapper[Any], + key: str, + thread_id: str | None, + enabled: bool, +) -> None: + if not enabled or thread_id is None: + return + + try: + _store_thread_id_in_run_context(ctx, key, thread_id) + except Exception: + logger.exception("Failed to store Codex thread id in run context after error.") + + +def _set_pydantic_context_value(context: BaseModel, key: str, value: str) -> bool: + model_config = context.model_config + if bool(model_config.get("frozen", False)): + return False + + model_fields = type(context).model_fields + if key in model_fields: + try: + setattr(context, key, value) + except Exception: # noqa: BLE001 + return False + return True + + try: + setattr(context, key, value) + return True + except ValueError: + pass + except Exception: # noqa: BLE001 + return False + + state = getattr(context, "__dict__", None) + if isinstance(state, dict): + state[key] = value + return True + + return False + + def _get_or_create_persisted_thread( codex: Codex, thread_id: str | None, @@ -676,11 +1029,17 @@ async def _consume_events( thread: Thread, on_stream: Callable[[CodexToolStreamEvent], MaybeAwaitable[None]] | None, span_data_max_chars: int | None, -) -> tuple[str, Usage | None]: + resolved_thread_id_holder: dict[str, str | None] | None = None, +) -> tuple[str, Usage | None, str | None]: # Track spans keyed by item id for command/mcp/reasoning events. active_spans: dict[str, Any] = {} final_response = "" usage: Usage | None = None + resolved_thread_id = thread.id + if resolved_thread_id is None and resolved_thread_id_holder is not None: + resolved_thread_id = resolved_thread_id_holder.get("thread_id") + if resolved_thread_id_holder is not None: + resolved_thread_id_holder["thread_id"] = resolved_thread_id event_queue: asyncio.Queue[CodexToolStreamEvent | None] | None = None dispatch_task: asyncio.Task[None] | None = None @@ -735,6 +1094,10 @@ async def _dispatch() -> None: final_response = event.item.text elif isinstance(event, TurnCompletedEvent): usage = event.usage + elif isinstance(event, ThreadStartedEvent): + resolved_thread_id = event.thread_id + if resolved_thread_id_holder is not None: + resolved_thread_id_holder["thread_id"] = resolved_thread_id elif isinstance(event, TurnFailedEvent): error = event.error.message raise UserError(f"Codex turn failed{(': ' + error) if error else ''}") @@ -755,7 +1118,7 @@ async def _dispatch() -> None: if not final_response: final_response = _build_default_response(args) - return final_response, usage + return final_response, usage, resolved_thread_id def _handle_item_started( diff --git a/src/agents/tool.py b/src/agents/tool.py index 4f70adc0f8..fd5b8e0119 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -261,6 +261,9 @@ class FunctionTool: _is_agent_tool: bool = field(default=False, init=False, repr=False) """Internal flag indicating if this tool is an agent-as-tool.""" + _is_codex_tool: bool = field(default=False, init=False, repr=False) + """Internal flag indicating if this tool is a Codex tool wrapper.""" + _agent_instance: Any = field(default=None, init=False, repr=False) """Internal reference to the agent instance if this is an agent-as-tool.""" diff --git a/tests/extensions/experiemental/codex/test_codex_tool.py b/tests/extensions/experiemental/codex/test_codex_tool.py index a54e13854f..3f53146b47 100644 --- a/tests/extensions/experiemental/codex/test_codex_tool.py +++ b/tests/extensions/experiemental/codex/test_codex_tool.py @@ -3,13 +3,14 @@ import importlib import inspect import json -from dataclasses import fields -from types import SimpleNamespace +from dataclasses import dataclass, fields +from types import MappingProxyType, SimpleNamespace from typing import Any, cast import pytest from pydantic import BaseModel, ConfigDict +from agents import Agent, function_tool from agents.exceptions import ModelBehaviorError, UserError from agents.extensions.experimental.codex import ( Codex, @@ -31,10 +32,11 @@ class CodexMockState: def __init__(self) -> None: self.events: list[dict[str, Any]] = [] - self.thread_id = "thread-1" + self.thread_id: str | None = "thread-1" self.last_turn_options: Any = None self.start_calls = 0 self.resume_calls = 0 + self.last_resumed_thread_id: str | None = None self.options: Any = None @@ -65,6 +67,7 @@ def start_thread(self, _options: Any = None) -> FakeThread: def resume_thread(self, _thread_id: str, _options: Any = None) -> FakeThread: self._state.resume_calls += 1 + self._state.last_resumed_thread_id = _thread_id return FakeThread(self._state) @@ -469,7 +472,7 @@ def __init__(self, options: Any = None) -> None: monkeypatch.setattr(codex_tool_module, "Codex", CaptureCodex) - tool = codex_tool(name="codex-keyword", codex_options={"api_key": "from-kwargs"}) + tool = codex_tool(name="codex_keyword", codex_options={"api_key": "from-kwargs"}) input_json = '{"inputs": [{"type": "text", "text": "Check keyword options", "path": ""}]}' context = ToolContext( context=None, @@ -480,7 +483,7 @@ def __init__(self, options: Any = None) -> None: await tool.on_invoke_tool(context, input_json) - assert tool.name == "codex-keyword" + assert tool.name == "codex_keyword" assert state.options is not None assert getattr(state.options, "api_key", None) == "from-kwargs" @@ -602,6 +605,707 @@ async def test_codex_tool_persists_session() -> None: assert state.resume_calls == 0 +@pytest.mark.asyncio +async def test_codex_tool_accepts_thread_id_from_tool_input() -> None: + state = CodexMockState() + state.thread_id = "thread-from-input" + state.events = [ + {"type": "thread.started", "thread_id": "thread-from-input"}, + { + "type": "item.completed", + "item": {"id": "agent-1", "type": "agent_message", "text": "Codex done."}, + }, + { + "type": "turn.completed", + "usage": {"input_tokens": 1, "cached_input_tokens": 0, "output_tokens": 1}, + }, + ] + + tool = codex_tool(CodexToolOptions(codex=cast(Codex, FakeCodex(state)))) + input_json = ( + '{"inputs": [{"type": "text", "text": "Continue thread", "path": ""}], ' + '"thread_id": "thread-xyz"}' + ) + context = ToolContext( + context=None, + tool_name=tool.name, + tool_call_id="call-1", + tool_arguments=input_json, + ) + + result = await tool.on_invoke_tool(context, input_json) + + assert isinstance(result, CodexToolResult) + assert state.resume_calls == 1 + assert state.last_resumed_thread_id == "thread-xyz" + assert result.thread_id == "thread-from-input" + + +@pytest.mark.asyncio +async def test_codex_tool_uses_run_context_thread_id_and_persists_latest() -> None: + state = CodexMockState() + state.thread_id = "thread-next" + state.events = [ + {"type": "thread.started", "thread_id": "thread-next"}, + { + "type": "item.completed", + "item": {"id": "agent-1", "type": "agent_message", "text": "Codex done."}, + }, + { + "type": "turn.completed", + "usage": {"input_tokens": 1, "cached_input_tokens": 0, "output_tokens": 1}, + }, + ] + + tool = codex_tool( + CodexToolOptions( + codex=cast(Codex, FakeCodex(state)), + use_run_context_thread_id=True, + run_context_thread_id_key="codex_agent_thread_id", + ) + ) + input_json = '{"inputs": [{"type": "text", "text": "Continue thread", "path": ""}]}' + run_context = {"codex_agent_thread_id": "thread-prev"} + context = ToolContext( + context=run_context, + tool_name=tool.name, + tool_call_id="call-1", + tool_arguments=input_json, + ) + + result = await tool.on_invoke_tool(context, input_json) + + assert isinstance(result, CodexToolResult) + assert state.resume_calls == 1 + assert state.last_resumed_thread_id == "thread-prev" + assert run_context["codex_agent_thread_id"] == "thread-next" + assert result.thread_id == "thread-next" + + +@pytest.mark.asyncio +async def test_codex_tool_persists_thread_started_id_when_thread_object_id_is_none() -> None: + state = CodexMockState() + state.thread_id = None + state.events = [ + {"type": "thread.started", "thread_id": "thread-next"}, + { + "type": "item.completed", + "item": {"id": "agent-1", "type": "agent_message", "text": "Codex done."}, + }, + { + "type": "turn.completed", + "usage": {"input_tokens": 1, "cached_input_tokens": 0, "output_tokens": 1}, + }, + ] + + tool = codex_tool( + CodexToolOptions( + codex=cast(Codex, FakeCodex(state)), + use_run_context_thread_id=True, + run_context_thread_id_key="codex_agent_thread_id", + ) + ) + input_json = '{"inputs": [{"type": "text", "text": "Continue thread", "path": ""}]}' + run_context: dict[str, str] = {} + context = ToolContext( + context=run_context, + tool_name=tool.name, + tool_call_id="call-1", + tool_arguments=input_json, + ) + + first_result = await tool.on_invoke_tool(context, input_json) + second_result = await tool.on_invoke_tool(context, input_json) + + assert isinstance(first_result, CodexToolResult) + assert isinstance(second_result, CodexToolResult) + assert first_result.thread_id == "thread-next" + assert second_result.thread_id == "thread-next" + assert run_context["codex_agent_thread_id"] == "thread-next" + assert state.start_calls == 1 + assert state.resume_calls == 1 + assert state.last_resumed_thread_id == "thread-next" + + +@pytest.mark.asyncio +async def test_codex_tool_persists_thread_id_for_recoverable_turn_failure() -> None: + state = CodexMockState() + state.thread_id = None + state.events = [ + {"type": "thread.started", "thread_id": "thread-next"}, + {"type": "turn.failed", "error": {"message": "boom"}}, + ] + + tool = codex_tool( + CodexToolOptions( + codex=cast(Codex, FakeCodex(state)), + use_run_context_thread_id=True, + run_context_thread_id_key="codex_agent_thread_id", + failure_error_function=lambda _ctx, _exc: "handled", + ) + ) + input_json = '{"inputs": [{"type": "text", "text": "Continue thread", "path": ""}]}' + run_context: dict[str, str] = {} + context = ToolContext( + context=run_context, + tool_name=tool.name, + tool_call_id="call-1", + tool_arguments=input_json, + ) + + first_result = await tool.on_invoke_tool(context, input_json) + second_result = await tool.on_invoke_tool(context, input_json) + + assert first_result == "handled" + assert second_result == "handled" + assert run_context["codex_agent_thread_id"] == "thread-next" + assert state.start_calls == 1 + assert state.resume_calls == 1 + assert state.last_resumed_thread_id == "thread-next" + + +@pytest.mark.asyncio +async def test_codex_tool_persists_thread_id_for_raised_turn_failure() -> None: + state = CodexMockState() + state.thread_id = None + state.events = [ + {"type": "thread.started", "thread_id": "thread-next"}, + {"type": "turn.failed", "error": {"message": "boom"}}, + ] + + tool = codex_tool( + CodexToolOptions( + codex=cast(Codex, FakeCodex(state)), + use_run_context_thread_id=True, + run_context_thread_id_key="codex_agent_thread_id", + failure_error_function=None, + ) + ) + input_json = '{"inputs": [{"type": "text", "text": "Continue thread", "path": ""}]}' + run_context: dict[str, str] = {} + context = ToolContext( + context=run_context, + tool_name=tool.name, + tool_call_id="call-1", + tool_arguments=input_json, + ) + + with pytest.raises(UserError, match="Codex turn failed: boom"): + await tool.on_invoke_tool(context, input_json) + + assert run_context["codex_agent_thread_id"] == "thread-next" + + with pytest.raises(UserError, match="Codex turn failed: boom"): + await tool.on_invoke_tool(context, input_json) + + assert run_context["codex_agent_thread_id"] == "thread-next" + assert state.start_calls == 1 + assert state.resume_calls == 1 + assert state.last_resumed_thread_id == "thread-next" + + +@pytest.mark.asyncio +async def test_codex_tool_falls_back_to_call_thread_id_when_thread_object_id_is_none() -> None: + state = CodexMockState() + state.thread_id = None + state.events = [ + { + "type": "item.completed", + "item": {"id": "agent-1", "type": "agent_message", "text": "Codex done."}, + }, + { + "type": "turn.completed", + "usage": {"input_tokens": 1, "cached_input_tokens": 0, "output_tokens": 1}, + }, + ] + + tool = codex_tool( + CodexToolOptions( + codex=cast(Codex, FakeCodex(state)), + parameters=codex_tool_module.CodexToolParameters, + use_run_context_thread_id=True, + ) + ) + first_input_json = ( + '{"inputs": [{"type": "text", "text": "Continue thread", "path": ""}], ' + '"thread_id": "thread-explicit"}' + ) + second_input_json = '{"inputs": [{"type": "text", "text": "Continue thread", "path": ""}]}' + run_context: dict[str, str] = {} + context = ToolContext( + context=run_context, + tool_name=tool.name, + tool_call_id="call-1", + tool_arguments=first_input_json, + ) + + first_result = await tool.on_invoke_tool(context, first_input_json) + second_result = await tool.on_invoke_tool(context, second_input_json) + + assert isinstance(first_result, CodexToolResult) + assert isinstance(second_result, CodexToolResult) + assert first_result.thread_id == "thread-explicit" + assert second_result.thread_id == "thread-explicit" + assert run_context["codex_thread_id"] == "thread-explicit" + assert state.start_calls == 0 + assert state.resume_calls == 2 + assert state.last_resumed_thread_id == "thread-explicit" + + +@pytest.mark.asyncio +async def test_codex_tool_uses_run_context_thread_id_with_pydantic_context() -> None: + class RunContext(BaseModel): + model_config = ConfigDict(extra="forbid") + user_id: str + + state = CodexMockState() + state.thread_id = "thread-next" + state.events = [ + {"type": "thread.started", "thread_id": "thread-next"}, + { + "type": "item.completed", + "item": {"id": "agent-1", "type": "agent_message", "text": "Codex done."}, + }, + { + "type": "turn.completed", + "usage": {"input_tokens": 1, "cached_input_tokens": 0, "output_tokens": 1}, + }, + ] + + tool = codex_tool( + CodexToolOptions( + codex=cast(Codex, FakeCodex(state)), + use_run_context_thread_id=True, + ) + ) + input_json = '{"inputs": [{"type": "text", "text": "Continue thread", "path": ""}]}' + run_context = RunContext(user_id="abc") + context = ToolContext( + context=run_context, + tool_name=tool.name, + tool_call_id="call-1", + tool_arguments=input_json, + ) + + await tool.on_invoke_tool(context, input_json) + await tool.on_invoke_tool(context, input_json) + + assert state.start_calls == 1 + assert state.resume_calls == 1 + assert state.last_resumed_thread_id == "thread-next" + assert run_context.__dict__["codex_thread_id"] == "thread-next" + + +@pytest.mark.asyncio +async def test_codex_tool_uses_pydantic_context_field_matching_thread_id_key() -> None: + class RunContext(BaseModel): + model_config = ConfigDict(extra="forbid") + user_id: str + codex_thread_id: str | None = None + + state = CodexMockState() + state.thread_id = "thread-next" + state.events = [ + {"type": "thread.started", "thread_id": "thread-next"}, + { + "type": "item.completed", + "item": {"id": "agent-1", "type": "agent_message", "text": "Codex done."}, + }, + { + "type": "turn.completed", + "usage": {"input_tokens": 1, "cached_input_tokens": 0, "output_tokens": 1}, + }, + ] + + tool = codex_tool( + CodexToolOptions( + codex=cast(Codex, FakeCodex(state)), + use_run_context_thread_id=True, + ) + ) + input_json = '{"inputs": [{"type": "text", "text": "Continue thread", "path": ""}]}' + run_context = RunContext(user_id="abc", codex_thread_id="thread-prev") + context = ToolContext( + context=run_context, + tool_name=tool.name, + tool_call_id="call-1", + tool_arguments=input_json, + ) + + await tool.on_invoke_tool(context, input_json) + + assert state.start_calls == 0 + assert state.resume_calls == 1 + assert state.last_resumed_thread_id == "thread-prev" + assert run_context.codex_thread_id == "thread-next" + + +@pytest.mark.asyncio +async def test_codex_tool_default_run_context_key_follows_tool_name() -> None: + state = CodexMockState() + state.thread_id = "thread-next" + state.events = [ + {"type": "thread.started", "thread_id": "thread-next"}, + { + "type": "item.completed", + "item": {"id": "agent-1", "type": "agent_message", "text": "Codex done."}, + }, + { + "type": "turn.completed", + "usage": {"input_tokens": 1, "cached_input_tokens": 0, "output_tokens": 1}, + }, + ] + + tool = codex_tool( + CodexToolOptions( + codex=cast(Codex, FakeCodex(state)), + use_run_context_thread_id=True, + ), + name="codex_engineer", + ) + input_json = '{"inputs": [{"type": "text", "text": "Continue thread", "path": ""}]}' + run_context = {"codex_thread_id_engineer": "thread-prev"} + context = ToolContext( + context=run_context, + tool_name=tool.name, + tool_call_id="call-1", + tool_arguments=input_json, + ) + + await tool.on_invoke_tool(context, input_json) + + assert state.last_resumed_thread_id == "thread-prev" + assert run_context["codex_thread_id_engineer"] == "thread-next" + + +def test_codex_tool_rejects_custom_name_without_codex_prefix() -> None: + with pytest.raises(UserError, match='must be "codex" or start with "codex_"'): + codex_tool(name="engineer") + + +def test_codex_tool_allows_non_alnum_suffix_when_run_context_thread_id_disabled() -> None: + tool = codex_tool(name="codex_a-b") + assert tool.name == "codex_a-b" + + +def test_codex_tool_rejects_lossy_default_run_context_thread_id_key_suffix() -> None: + with pytest.raises(UserError, match="run_context_thread_id_key"): + codex_tool(name="codex_a-b", use_run_context_thread_id=True) + + +@pytest.mark.asyncio +async def test_codex_tool_tool_input_thread_id_overrides_run_context_thread_id() -> None: + state = CodexMockState() + state.thread_id = "thread-from-tool-input" + state.events = [ + {"type": "thread.started", "thread_id": "thread-from-tool-input"}, + { + "type": "item.completed", + "item": {"id": "agent-1", "type": "agent_message", "text": "Codex done."}, + }, + { + "type": "turn.completed", + "usage": {"input_tokens": 1, "cached_input_tokens": 0, "output_tokens": 1}, + }, + ] + + tool = codex_tool( + CodexToolOptions( + codex=cast(Codex, FakeCodex(state)), + parameters=codex_tool_module.CodexToolParameters, + use_run_context_thread_id=True, + failure_error_function=None, + ) + ) + input_json = ( + '{"inputs": [{"type": "text", "text": "Continue thread", "path": ""}], ' + '"thread_id": "thread-from-args"}' + ) + context = ToolContext( + context={"codex_thread_id": "thread-from-context"}, + tool_name=tool.name, + tool_call_id="call-1", + tool_arguments=input_json, + ) + + await tool.on_invoke_tool(context, input_json) + + assert state.last_resumed_thread_id == "thread-from-args" + + +def test_codex_tool_run_context_mode_hides_thread_id_in_default_parameters() -> None: + tool = codex_tool(use_run_context_thread_id=True) + assert "thread_id" not in tool.params_json_schema["properties"] + + +@pytest.mark.asyncio +async def test_codex_tool_duplicate_names_fail_fast() -> None: + agent = Agent( + name="test", + tools=[ + codex_tool(), + codex_tool(), + ], + ) + + with pytest.raises(UserError, match="Duplicate Codex tool names found"): + await agent.get_all_tools(RunContextWrapper(context=None)) + + +@pytest.mark.asyncio +async def test_codex_tool_name_collision_with_other_tool_fails_fast() -> None: + @function_tool(name_override="codex") + def other_tool() -> str: + return "ok" + + agent = Agent( + name="test", + tools=[ + codex_tool(), + other_tool, + ], + ) + + with pytest.raises(UserError, match="Duplicate Codex tool names found"): + await agent.get_all_tools(RunContextWrapper(context=None)) + + +@pytest.mark.asyncio +async def test_codex_tool_run_context_thread_id_requires_mutable_context() -> None: + state = CodexMockState() + state.events = [ + {"type": "thread.started", "thread_id": "thread-1"}, + { + "type": "item.completed", + "item": {"id": "agent-1", "type": "agent_message", "text": "Codex done."}, + }, + { + "type": "turn.completed", + "usage": {"input_tokens": 1, "cached_input_tokens": 0, "output_tokens": 1}, + }, + ] + + tool = codex_tool( + CodexToolOptions( + codex=cast(Codex, FakeCodex(state)), + use_run_context_thread_id=True, + failure_error_function=None, + ) + ) + input_json = '{"inputs": [{"type": "text", "text": "No context", "path": ""}]}' + context = ToolContext( + context=None, + tool_name=tool.name, + tool_call_id="call-1", + tool_arguments=input_json, + ) + + with pytest.raises(UserError, match="use_run_context_thread_id=True"): + await tool.on_invoke_tool(context, input_json) + + assert state.start_calls == 0 + assert state.resume_calls == 0 + + +@pytest.mark.asyncio +async def test_codex_tool_run_context_thread_id_rejects_immutable_mapping_context() -> None: + state = CodexMockState() + state.events = [ + {"type": "thread.started", "thread_id": "thread-1"}, + { + "type": "item.completed", + "item": {"id": "agent-1", "type": "agent_message", "text": "Codex done."}, + }, + { + "type": "turn.completed", + "usage": {"input_tokens": 1, "cached_input_tokens": 0, "output_tokens": 1}, + }, + ] + + tool = codex_tool( + CodexToolOptions( + codex=cast(Codex, FakeCodex(state)), + use_run_context_thread_id=True, + failure_error_function=None, + ) + ) + input_json = '{"inputs": [{"type": "text", "text": "Immutable context", "path": ""}]}' + context = ToolContext( + context=MappingProxyType({"codex_thread_id": "thread-prev"}), + tool_name=tool.name, + tool_call_id="call-1", + tool_arguments=input_json, + ) + + with pytest.raises(UserError, match="use_run_context_thread_id=True"): + await tool.on_invoke_tool(context, input_json) + + assert state.start_calls == 0 + assert state.resume_calls == 0 + + +@pytest.mark.asyncio +async def test_codex_tool_run_context_thread_id_rejects_frozen_pydantic_context() -> None: + class FrozenRunContext(BaseModel): + model_config = ConfigDict(frozen=True) + user_id: str + + state = CodexMockState() + state.events = [ + {"type": "thread.started", "thread_id": "thread-1"}, + { + "type": "item.completed", + "item": {"id": "agent-1", "type": "agent_message", "text": "Codex done."}, + }, + { + "type": "turn.completed", + "usage": {"input_tokens": 1, "cached_input_tokens": 0, "output_tokens": 1}, + }, + ] + + tool = codex_tool( + CodexToolOptions( + codex=cast(Codex, FakeCodex(state)), + use_run_context_thread_id=True, + failure_error_function=None, + ) + ) + input_json = '{"inputs": [{"type": "text", "text": "Frozen context", "path": ""}]}' + context = ToolContext( + context=FrozenRunContext(user_id="abc"), + tool_name=tool.name, + tool_call_id="call-1", + tool_arguments=input_json, + ) + + with pytest.raises(UserError, match="Frozen Pydantic models"): + await tool.on_invoke_tool(context, input_json) + + assert state.start_calls == 0 + assert state.resume_calls == 0 + + +@pytest.mark.asyncio +async def test_codex_tool_run_context_thread_id_rejects_frozen_dataclass_context() -> None: + @dataclass(frozen=True) + class FrozenRunContext: + user_id: str + + state = CodexMockState() + state.events = [ + {"type": "thread.started", "thread_id": "thread-1"}, + { + "type": "item.completed", + "item": {"id": "agent-1", "type": "agent_message", "text": "Codex done."}, + }, + { + "type": "turn.completed", + "usage": {"input_tokens": 1, "cached_input_tokens": 0, "output_tokens": 1}, + }, + ] + + tool = codex_tool( + CodexToolOptions( + codex=cast(Codex, FakeCodex(state)), + use_run_context_thread_id=True, + failure_error_function=None, + ) + ) + input_json = '{"inputs": [{"type": "text", "text": "Frozen dataclass", "path": ""}]}' + context = ToolContext( + context=FrozenRunContext(user_id="abc"), + tool_name=tool.name, + tool_call_id="call-1", + tool_arguments=input_json, + ) + + with pytest.raises(UserError, match="Frozen dataclass contexts"): + await tool.on_invoke_tool(context, input_json) + + assert state.start_calls == 0 + assert state.resume_calls == 0 + + +@pytest.mark.asyncio +async def test_codex_tool_run_context_thread_id_rejects_slots_object_without_thread_field() -> None: + class SlotsRunContext: + __slots__ = ("user_id",) + + def __init__(self, user_id: str) -> None: + self.user_id = user_id + + state = CodexMockState() + state.events = [ + {"type": "thread.started", "thread_id": "thread-1"}, + { + "type": "item.completed", + "item": {"id": "agent-1", "type": "agent_message", "text": "Codex done."}, + }, + { + "type": "turn.completed", + "usage": {"input_tokens": 1, "cached_input_tokens": 0, "output_tokens": 1}, + }, + ] + + tool = codex_tool( + CodexToolOptions( + codex=cast(Codex, FakeCodex(state)), + use_run_context_thread_id=True, + failure_error_function=None, + ) + ) + input_json = '{"inputs": [{"type": "text", "text": "Slots context", "path": ""}]}' + context = ToolContext( + context=SlotsRunContext(user_id="abc"), + tool_name=tool.name, + tool_call_id="call-1", + tool_arguments=input_json, + ) + + with pytest.raises(UserError, match='support field "codex_thread_id"'): + await tool.on_invoke_tool(context, input_json) + + assert state.start_calls == 0 + assert state.resume_calls == 0 + + +@pytest.mark.asyncio +async def test_codex_tool_run_context_thread_id_rejects_non_writable_object_context() -> None: + state = CodexMockState() + state.events = [ + {"type": "thread.started", "thread_id": "thread-1"}, + { + "type": "item.completed", + "item": {"id": "agent-1", "type": "agent_message", "text": "Codex done."}, + }, + { + "type": "turn.completed", + "usage": {"input_tokens": 1, "cached_input_tokens": 0, "output_tokens": 1}, + }, + ] + + tool = codex_tool( + CodexToolOptions( + codex=cast(Codex, FakeCodex(state)), + use_run_context_thread_id=True, + failure_error_function=None, + ) + ) + input_json = '{"inputs": [{"type": "text", "text": "List context", "path": ""}]}' + context: ToolContext[Any] = ToolContext( + context=cast(Any, []), + tool_name=tool.name, + tool_call_id="call-1", + tool_arguments=input_json, + ) + + with pytest.raises(UserError, match="use_run_context_thread_id=True"): + await tool.on_invoke_tool(context, input_json) + + assert state.start_calls == 0 + assert state.resume_calls == 0 + + @pytest.mark.parametrize( ("payload", "message"), [ @@ -639,6 +1343,11 @@ def test_codex_tool_coerce_options_rejects_unknown_fields() -> None: codex_tool_module._coerce_tool_options({"unknown": "value"}) +def test_codex_tool_keyword_rejects_empty_run_context_key() -> None: + with pytest.raises(UserError, match="run_context_thread_id_key"): + codex_tool(run_context_thread_id_key=" ") + + def test_codex_tool_resolve_output_schema_validation_errors() -> None: with pytest.raises(UserError, match="must include properties"): codex_tool_module._resolve_output_schema({"properties": []}) @@ -780,6 +1489,15 @@ def test_codex_tool_normalize_parameters_handles_local_image() -> None: {"type": "text", "text": "hello"}, {"type": "local_image", "path": "/tmp/img.png"}, ] + assert normalized["thread_id"] is None + + +def test_codex_tool_input_thread_id_validation_errors() -> None: + with pytest.raises(ValueError, match="non-empty string"): + codex_tool_module.CodexToolParameters( + inputs=[codex_tool_module.CodexToolInputItem(type="text", text="hello")], + thread_id=" ", + ) def test_codex_tool_build_codex_input_empty() -> None: @@ -960,7 +1678,7 @@ def on_stream(payload: CodexToolStreamEvent) -> None: ) with trace("codex-test"): - response, usage = await codex_tool_module._consume_events( + response, usage, thread_id = await codex_tool_module._consume_events( event_stream(), {"inputs": [{"type": "text", "text": "hello"}]}, context, @@ -971,6 +1689,7 @@ def on_stream(payload: CodexToolStreamEvent) -> None: assert response == "done" assert usage == Usage(input_tokens=1, cached_input_tokens=0, output_tokens=1) + assert thread_id == "thread-1" assert "item.started" in callbacks @@ -994,7 +1713,7 @@ async def event_stream(): tool_arguments="{}", ) - response, usage = await codex_tool_module._consume_events( + response, usage, thread_id = await codex_tool_module._consume_events( event_stream(), {"inputs": [{"type": "text", "text": "hello"}]}, context, @@ -1005,6 +1724,7 @@ async def event_stream(): assert response == "Codex task completed with inputs." assert usage == Usage(input_tokens=1, cached_input_tokens=0, output_tokens=1) + assert thread_id == "thread-1" @pytest.mark.asyncio @@ -1097,7 +1817,7 @@ class CustomParams(BaseModel): tool = codex_tool( CodexToolOptions(codex=cast(Codex, FakeCodex(state))), - name="codex-overrides", + name="codex_overrides", description="desc", parameters=CustomParams, output_schema={"type": "object", "properties": {}, "additionalProperties": False}, @@ -1114,6 +1834,18 @@ class CustomParams(BaseModel): on_stream=lambda _payload: None, is_enabled=False, failure_error_function=lambda _ctx, _exc: "handled", + use_run_context_thread_id=True, + run_context_thread_id_key="thread_key", ) - assert tool.name == "codex-overrides" + assert tool.name == "codex_overrides" + + +def test_codex_tool_coerce_options_rejects_empty_run_context_key() -> None: + with pytest.raises(UserError, match="run_context_thread_id_key"): + codex_tool_module._coerce_tool_options( + { + "use_run_context_thread_id": True, + "run_context_thread_id_key": " ", + } + )