diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 85adc81a1e..bcfed97d77 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -5,7 +5,7 @@ from collections.abc import AsyncIterator from typing import TYPE_CHECKING, Any, Literal, cast, overload -from openai import AsyncOpenAI, AsyncStream, Omit, omit +from openai import AsyncOpenAI, AsyncStream, NotGiven, Omit, omit from openai.types import ChatModel from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage from openai.types.chat.chat_completion import Choice @@ -45,6 +45,20 @@ from ..model_settings import ModelSettings +def _is_openai_omitted_value(value: Any) -> bool: + return isinstance(value, Omit | NotGiven) + + +# Keys whose first-class create_kwargs entry is reserved by the SDK and must +# never be overridden via ModelSettings.extra_args, even when the SDK passes +# the value as an OpenAI omit sentinel. ``stream`` is the canonical example: +# get_response() pins it to ``omit`` and then expects a non-streaming +# ChatCompletion, so allowing extra_args to flip it to True would cause the +# OpenAI client to return an async stream that the non-streaming code path +# cannot consume. +_RESERVED_CHAT_COMPLETIONS_KEYS = frozenset({"stream"}) + + class OpenAIChatCompletionsModel(Model): _OFFICIAL_OPENAI_SUPPORTED_INPUT_CONTENT_TYPES = frozenset( {"input_text", "input_image", "input_audio", "input_file"} @@ -423,8 +437,15 @@ async def _fetch_response( "extra_body": model_settings.extra_body, "metadata": self._non_null_or_omit(model_settings.metadata), } + extra_args = model_settings.extra_args or {} duplicate_extra_arg_keys = sorted( - set(create_kwargs).intersection(model_settings.extra_args or {}) + k + for k in extra_args + if k in create_kwargs + and ( + k in _RESERVED_CHAT_COMPLETIONS_KEYS + or not _is_openai_omitted_value(create_kwargs[k]) + ) ) if duplicate_extra_arg_keys: if len(duplicate_extra_arg_keys) == 1: @@ -436,7 +457,7 @@ async def _fetch_response( raise TypeError( f"chat.completions.create() got multiple values for keyword arguments {keys}" ) - create_kwargs.update(model_settings.extra_args or {}) + create_kwargs.update(extra_args) ret = await self._get_client().chat.completions.create(**create_kwargs) diff --git a/tests/models/test_openai_chatcompletions.py b/tests/models/test_openai_chatcompletions.py index b2f8affd60..e7d80c3920 100644 --- a/tests/models/test_openai_chatcompletions.py +++ b/tests/models/test_openai_chatcompletions.py @@ -770,3 +770,138 @@ def __init__(self): assert ChatCmplHelpers.get_store_param(client, model_settings) is True, ( "Should respect explicitly set store=True" ) + + +def _build_chat_completions_dummy_client() -> tuple[Any, Any]: + class DummyCompletions: + def __init__(self) -> None: + self.kwargs: dict[str, Any] = {} + + async def create(self, **kwargs: Any) -> Any: + self.kwargs = kwargs + msg = ChatCompletionMessage(role="assistant", content="ok") + choice = Choice(index=0, finish_reason="stop", message=msg) + return ChatCompletion( + id="resp-id", + created=0, + model="fake", + object="chat.completion", + choices=[choice], + ) + + class DummyClient: + def __init__(self, completions: DummyCompletions) -> None: + self.chat = type("_Chat", (), {"completions": completions})() + self.base_url = httpx.URL("http://fake") + + completions = DummyCompletions() + return completions, DummyClient(completions) + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_fetch_response_chat_completions_allows_extra_arg_when_explicit_arg_is_omitted() -> ( + None +): + """An extra_args key must not collide with a create_kwargs entry whose + value is the OpenAI omit sentinel — the user simply has not set the + first-class field, so there is no real duplicate. + """ + + completions, dummy_client = _build_chat_completions_dummy_client() + model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=cast(AsyncOpenAI, dummy_client)) + with generation_span(disabled=True) as span: + await model._fetch_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(extra_args={"reasoning_effort": "high"}), + tools=[], + output_schema=None, + handoffs=[], + span=span, + tracing=ModelTracing.DISABLED, + stream=False, + ) + + assert completions.kwargs["reasoning_effort"] == "high" + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_fetch_response_chat_completions_rejects_duplicate_extra_args_keys() -> None: + """When the same key is supplied through both first-class settings and + extra_args, the duplicate must still be reported. + """ + + _completions, dummy_client = _build_chat_completions_dummy_client() + model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=cast(AsyncOpenAI, dummy_client)) + with generation_span(disabled=True) as span: + with pytest.raises(TypeError, match="multiple values.*temperature"): + await model._fetch_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(temperature=0.5, extra_args={"temperature": 0.7}), + tools=[], + output_schema=None, + handoffs=[], + span=span, + tracing=ModelTracing.DISABLED, + stream=False, + ) + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_fetch_response_chat_completions_rejects_stream_via_extra_args_in_non_streaming_call() -> ( # noqa: E501 + None +): + """``stream`` is reserved by the SDK. The non-streaming get_response path + pins create_kwargs["stream"] to the OpenAI omit sentinel and then expects + a ChatCompletion. Allowing ``extra_args={"stream": True}`` to slip + through the duplicate check would make the OpenAI client return an async + stream that the non-streaming code path cannot consume. + """ + + _completions, dummy_client = _build_chat_completions_dummy_client() + model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=cast(AsyncOpenAI, dummy_client)) + with generation_span(disabled=True) as span: + with pytest.raises(TypeError, match="multiple values.*stream"): + await model._fetch_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(extra_args={"stream": True}), + tools=[], + output_schema=None, + handoffs=[], + span=span, + tracing=ModelTracing.DISABLED, + stream=False, + ) + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_fetch_response_chat_completions_rejects_stream_via_extra_args_in_streaming_call() -> ( # noqa: E501 + None +): + """``stream`` is also reserved on the streaming path. ``stream=True`` sets + create_kwargs["stream"] to ``True`` (a real value), so the original + intersection check would already catch this. Cover it explicitly so the + invariant is enforced for both directions. + """ + + _completions, dummy_client = _build_chat_completions_dummy_client() + model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=cast(AsyncOpenAI, dummy_client)) + with generation_span(disabled=True) as span: + with pytest.raises(TypeError, match="multiple values.*stream"): + await model._fetch_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(extra_args={"stream": False}), + tools=[], + output_schema=None, + handoffs=[], + span=span, + tracing=ModelTracing.DISABLED, + stream=True, + )