Skip to content

Commit afa224b

Browse files
authored
Gemini 3 Pro support and cross-model conversation compatibility (openai#2158)
1 parent c4ed605 commit afa224b

13 files changed

Lines changed: 1103 additions & 180 deletions

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ Repository = "https://github.com/openai/openai-agents-python"
3737
[project.optional-dependencies]
3838
voice = ["numpy>=2.2.0, <3; python_version>='3.10'", "websockets>=15.0, <16"]
3939
viz = ["graphviz>=0.17"]
40-
litellm = ["litellm>=1.67.4.post1, <2"]
40+
litellm = ["litellm>=1.80.8, <2"]
4141
realtime = ["websockets>=15.0, <16"]
4242
sqlalchemy = ["SQLAlchemy>=2.0", "asyncpg>=0.29.0"]
4343
encrypt = ["cryptography>=45.0, <46"]

src/agents/extensions/models/litellm_model.py

Lines changed: 131 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,15 @@ class InternalChatCompletionMessage(ChatCompletionMessage):
6262
thinking_blocks: list[dict[str, Any]] | None = None
6363

6464

65+
class InternalToolCall(ChatCompletionMessageFunctionToolCall):
66+
"""
67+
An internal subclass to carry provider-specific metadata (e.g., Gemini thought signatures)
68+
without modifying the original model.
69+
"""
70+
71+
extra_content: dict[str, Any] | None = None
72+
73+
6574
class LitellmModel(Model):
6675
"""This class enables using any model via LiteLLM. LiteLLM allows you to acess OpenAPI,
6776
Anthropic, Gemini, Mistral, and many other models.
@@ -168,9 +177,15 @@ async def get_response(
168177
"output_tokens": usage.output_tokens,
169178
}
170179

180+
# Build provider_data for provider specific fields
181+
provider_data: dict[str, Any] = {"model": self.model}
182+
if message is not None and hasattr(response, "id"):
183+
provider_data["response_id"] = response.id
184+
171185
items = (
172186
Converter.message_to_output_items(
173-
LitellmConverter.convert_message_to_openai(message)
187+
LitellmConverter.convert_message_to_openai(message, model=self.model),
188+
provider_data=provider_data,
174189
)
175190
if message is not None
176191
else []
@@ -215,7 +230,9 @@ async def stream_response(
215230
)
216231

217232
final_response: Response | None = None
218-
async for chunk in ChatCmplStreamHandler.handle_stream(response, stream):
233+
async for chunk in ChatCmplStreamHandler.handle_stream(
234+
response, stream, model=self.model
235+
):
219236
yield chunk
220237

221238
if chunk.type == "response.completed":
@@ -283,12 +300,19 @@ async def _fetch_response(
283300
input,
284301
preserve_thinking_blocks=preserve_thinking_blocks,
285302
preserve_tool_output_all_content=True,
303+
model=self.model,
286304
)
287305

288306
# Fix for interleaved thinking bug: reorder messages to ensure tool_use comes before tool_result # noqa: E501
289307
if "anthropic" in self.model.lower() or "claude" in self.model.lower():
290308
converted_messages = self._fix_tool_message_ordering(converted_messages)
291309

310+
# Convert Google's extra_content to litellm's provider_specific_fields format
311+
if "gemini" in self.model.lower():
312+
converted_messages = self._convert_gemini_extra_content_to_provider_specific_fields(
313+
converted_messages
314+
)
315+
292316
if system_instructions:
293317
converted_messages.insert(
294318
0,
@@ -438,6 +462,65 @@ async def _fetch_response(
438462
)
439463
return response, ret
440464

465+
def _convert_gemini_extra_content_to_provider_specific_fields(
466+
self, messages: list[ChatCompletionMessageParam]
467+
) -> list[ChatCompletionMessageParam]:
468+
"""
469+
Convert Gemini model's extra_content format to provider_specific_fields format for litellm.
470+
471+
Transforms tool calls from internal format:
472+
extra_content={"google": {"thought_signature": "..."}}
473+
To litellm format:
474+
provider_specific_fields={"thought_signature": "..."}
475+
476+
Only processes tool_calls that appear after the last user message.
477+
See: https://ai.google.dev/gemini-api/docs/thought-signatures
478+
"""
479+
480+
# Find the index of the last user message
481+
last_user_index = -1
482+
for i in range(len(messages) - 1, -1, -1):
483+
if isinstance(messages[i], dict) and messages[i].get("role") == "user":
484+
last_user_index = i
485+
break
486+
487+
for i, message in enumerate(messages):
488+
if not isinstance(message, dict):
489+
continue
490+
491+
# Only process assistant messages that come after the last user message
492+
# If no user message found (last_user_index == -1), process all messages
493+
if last_user_index != -1 and i <= last_user_index:
494+
continue
495+
496+
# Check if this is an assistant message with tool calls
497+
if message.get("role") == "assistant" and message.get("tool_calls"):
498+
tool_calls = message.get("tool_calls", [])
499+
500+
for tool_call in tool_calls: # type: ignore[attr-defined]
501+
if not isinstance(tool_call, dict):
502+
continue
503+
504+
# Default to skip validator, overridden if valid thought signature exists
505+
tool_call["provider_specific_fields"] = {
506+
"thought_signature": "skip_thought_signature_validator"
507+
}
508+
509+
# Override with actual thought signature if extra_content exists
510+
if "extra_content" in tool_call:
511+
extra_content = tool_call.pop("extra_content")
512+
if isinstance(extra_content, dict):
513+
# Extract google-specific fields
514+
google_fields = extra_content.get("google")
515+
if google_fields and isinstance(google_fields, dict):
516+
thought_sig = google_fields.get("thought_signature")
517+
if thought_sig:
518+
tool_call["provider_specific_fields"] = {
519+
"thought_signature": thought_sig
520+
}
521+
522+
return messages
523+
441524
def _fix_tool_message_ordering(
442525
self, messages: list[ChatCompletionMessageParam]
443526
) -> list[ChatCompletionMessageParam]:
@@ -565,15 +648,26 @@ def _merge_headers(self, model_settings: ModelSettings):
565648
class LitellmConverter:
566649
@classmethod
567650
def convert_message_to_openai(
568-
cls, message: litellm.types.utils.Message
651+
cls, message: litellm.types.utils.Message, model: str | None = None
569652
) -> ChatCompletionMessage:
653+
"""
654+
Convert a LiteLLM message to OpenAI ChatCompletionMessage format.
655+
656+
Args:
657+
message: The LiteLLM message to convert
658+
model: The target model to convert to. Used to handle provider-specific
659+
transformations.
660+
"""
570661
if message.role != "assistant":
571662
raise ModelBehaviorError(f"Unsupported role: {message.role}")
572663

573664
tool_calls: (
574665
list[ChatCompletionMessageFunctionToolCall | ChatCompletionMessageCustomToolCall] | None
575666
) = (
576-
[LitellmConverter.convert_tool_call_to_openai(tool) for tool in message.tool_calls]
667+
[
668+
LitellmConverter.convert_tool_call_to_openai(tool, model=model)
669+
for tool in message.tool_calls
670+
]
577671
if message.tool_calls
578672
else None
579673
)
@@ -643,13 +737,43 @@ def convert_annotations_to_openai(
643737

644738
@classmethod
645739
def convert_tool_call_to_openai(
646-
cls, tool_call: litellm.types.utils.ChatCompletionMessageToolCall
740+
cls, tool_call: litellm.types.utils.ChatCompletionMessageToolCall, model: str | None = None
647741
) -> ChatCompletionMessageFunctionToolCall:
648-
return ChatCompletionMessageFunctionToolCall(
649-
id=tool_call.id,
742+
# Clean up litellm's addition of __thought__ suffix to tool_call.id for
743+
# Gemini models. See: https://github.com/BerriAI/litellm/pull/16895
744+
# This suffix is redundant since we can get thought_signature from
745+
# provider_specific_fields, and this hack causes validation errors when
746+
# cross-model passing to other models.
747+
tool_call_id = tool_call.id
748+
if model and "gemini" in model.lower() and "__thought__" in tool_call_id:
749+
tool_call_id = tool_call_id.split("__thought__")[0]
750+
751+
# Convert litellm's tool call format to chat completion message format
752+
base_tool_call = ChatCompletionMessageFunctionToolCall(
753+
id=tool_call_id,
650754
type="function",
651755
function=Function(
652756
name=tool_call.function.name or "",
653757
arguments=tool_call.function.arguments,
654758
),
655759
)
760+
761+
# Preserve provider-specific fields if present (e.g., Gemini thought signatures)
762+
if hasattr(tool_call, "provider_specific_fields") and tool_call.provider_specific_fields:
763+
# Convert to nested extra_content structure
764+
extra_content: dict[str, Any] = {}
765+
provider_fields = tool_call.provider_specific_fields
766+
767+
# Check for thought_signature (Gemini specific)
768+
if model and "gemini" in model.lower():
769+
if "thought_signature" in provider_fields:
770+
extra_content["google"] = {
771+
"thought_signature": provider_fields["thought_signature"]
772+
}
773+
774+
return InternalToolCall(
775+
**base_tool_call.model_dump(),
776+
extra_content=extra_content if extra_content else None,
777+
)
778+
779+
return base_tool_call

src/agents/handoffs/history.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def _format_transcript_item(item: TResponseInputItem) -> str:
144144
return f"{prefix}: {content_str}" if content_str else prefix
145145

146146
item_type = item.get("type", "item")
147-
rest = {k: v for k, v in item.items() if k != "type"}
147+
rest = {k: v for k, v in item.items() if k not in ("type", "provider_data")}
148148
try:
149149
serialized = json.dumps(rest, ensure_ascii=False, default=str)
150150
except TypeError:

0 commit comments

Comments
 (0)