diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index be5337f0d..321871cf0 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -19,12 +19,16 @@ from ..types.event_loop import Usage from ..types.exceptions import ContextWindowOverflowException from ..types.streaming import MetadataEvent, StreamEvent -from ..types.tools import ToolChoice, ToolSpec +from ..types.tools import ToolChoice, ToolSpec, ToolUse from ._validation import validate_config_keys from .openai import OpenAIModel logger = logging.getLogger(__name__) +# Separator used by LiteLLM to embed thought signatures inside tool call IDs. +# See: https://ai.google.dev/gemini-api/docs/thought-signatures +_THOUGHT_SIGNATURE_SEPARATOR = "__thought__" + T = TypeVar("T", bound=BaseModel) @@ -114,6 +118,31 @@ def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> return super().format_request_message_content(content) + @classmethod + @override + def format_request_message_tool_call(cls, tool_use: ToolUse, **kwargs: Any) -> dict[str, Any]: + """Format a LiteLLM compatible tool call, encoding thought signatures into the tool call ID. + + Gemini thinking models attach a thought_signature to each function call. LiteLLM's OpenAI-compatible + interface embeds this signature inside the tool call ID using the ``__thought__`` separator. When + ``reasoningSignature`` is present and the tool call ID does not already contain the separator, this + method encodes it so LiteLLM can reconstruct the Gemini-native format on the next request. + + Args: + tool_use: Tool use requested by the model. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + LiteLLM compatible tool call dict with thought signature encoded in the ID when present. + """ + tool_call = super().format_request_message_tool_call(tool_use, **kwargs) + + reasoning_signature = tool_use.get("reasoningSignature") + if reasoning_signature and _THOUGHT_SIGNATURE_SEPARATOR not in tool_call["id"]: + tool_call["id"] = f"{tool_call['id']}{_THOUGHT_SIGNATURE_SEPARATOR}{reasoning_signature}" + + return tool_call + def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> tuple[list[StreamEvent], str]: """Handle switching to a new content stream. @@ -200,8 +229,9 @@ def format_request_messages( def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: """Format a LiteLLM response event into a standardized message chunk. - This method overrides OpenAI's format_chunk to handle the metadata case - with prompt caching support. All other chunk types use the parent implementation. + Extends OpenAI's format_chunk to: + 1. Handle metadata with prompt caching support. + 2. Extract thought signatures that LiteLLM embeds in tool call IDs for Gemini thinking models. Args: event: A response event from the LiteLLM model. @@ -237,6 +267,43 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: usage=usage_data, ) ) + + # Extract thought signature from tool call content_start events. + # LiteLLM embeds Gemini thought signatures in the tool call ID using the __thought__ separator. + # We extract it into reasoningSignature so the streaming layer can preserve it through to + # the internal ToolUse representation. The full encoded ID is kept in toolUseId so that + # tool result messages (which reference toolUseId) continue to match the assistant message. + if event["chunk_type"] == "content_start" and event.get("data_type") == "tool": + data = event.get("data") + tool_call_id = getattr(data, "id", None) or "" + if not isinstance(tool_call_id, str): + tool_call_id = "" + # Also check provider_specific_fields for the signature (non-streaming responses) + psf = getattr(data, "provider_specific_fields", None) or {} + if isinstance(psf, dict): + psf_signature = psf.get("thought_signature") + else: + psf_signature = None + # Extract from encoded ID as fallback + id_signature = None + if _THOUGHT_SIGNATURE_SEPARATOR in tool_call_id: + _, id_signature = tool_call_id.split(_THOUGHT_SIGNATURE_SEPARATOR, 1) + # Also check function-level provider_specific_fields + func = getattr(data, "function", None) + func_psf = getattr(func, "provider_specific_fields", None) or {} + if isinstance(func_psf, dict): + func_signature = func_psf.get("thought_signature") + else: + func_signature = None + + signature = psf_signature or func_signature or id_signature + + chunk = super().format_chunk(event, **kwargs) + if signature: + tool_use = chunk.get("contentBlockStart", {}).get("start", {}).get("toolUse", {}) + tool_use["reasoningSignature"] = signature + return chunk + # For all other cases, use the parent implementation return super().format_chunk(event) diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 9bb0e09ca..d35a1806e 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -848,3 +848,141 @@ def test_format_request_messages_with_tool_calls_no_content(): }, ] assert tru_result == exp_result + + +# --- Thought Signature Tests --- + + +def test_format_chunk_tool_start_extracts_thought_signature_from_id(): + """Test that format_chunk extracts thought_signature from LiteLLM-encoded tool call ID.""" + model = LiteLLMModel(model_id="test") + + mock_data = unittest.mock.Mock() + mock_data.id = "call_abc123__thought__dGhpcy1pcy1hLXNpZw==" + mock_data.function = unittest.mock.Mock() + mock_data.function.name = "get_weather" + mock_data.provider_specific_fields = None + + event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_data} + result = model.format_chunk(event) + + tool_use = result["contentBlockStart"]["start"]["toolUse"] + assert tool_use["reasoningSignature"] == "dGhpcy1pcy1hLXNpZw==" + # toolUseId keeps the full encoded string so tool result IDs match + assert tool_use["toolUseId"] == "call_abc123__thought__dGhpcy1pcy1hLXNpZw==" + + +def test_format_chunk_tool_start_extracts_thought_signature_from_provider_specific_fields(): + """Test that format_chunk extracts thought_signature from provider_specific_fields.""" + model = LiteLLMModel(model_id="test") + + mock_data = unittest.mock.Mock() + mock_data.id = "call_abc123" # No __thought__ in ID + mock_data.function = unittest.mock.Mock() + mock_data.function.name = "get_weather" + mock_data.function.provider_specific_fields = None + mock_data.provider_specific_fields = {"thought_signature": "cHNmLXNpZw=="} + + event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_data} + result = model.format_chunk(event) + + tool_use = result["contentBlockStart"]["start"]["toolUse"] + assert tool_use["reasoningSignature"] == "cHNmLXNpZw==" + assert tool_use["toolUseId"] == "call_abc123" + + +def test_format_chunk_tool_start_no_thought_signature(): + """Test that format_chunk works normally when no thought_signature is present.""" + model = LiteLLMModel(model_id="test") + + mock_data = unittest.mock.Mock() + mock_data.id = "call_plain123" + mock_data.function = unittest.mock.Mock() + mock_data.function.name = "get_weather" + mock_data.provider_specific_fields = None + mock_data.function.provider_specific_fields = None + + event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_data} + result = model.format_chunk(event) + + tool_use = result["contentBlockStart"]["start"]["toolUse"] + assert tool_use["toolUseId"] == "call_plain123" + assert "reasoningSignature" not in tool_use + + +def test_format_request_message_tool_call_encodes_thought_signature(): + """Test that format_request_message_tool_call encodes reasoningSignature into the tool call ID.""" + tool_use = { + "toolUseId": "call_abc123", + "name": "get_weather", + "input": {"city": "Seattle"}, + "reasoningSignature": "dGhpcy1pcy1hLXNpZw==", + } + + result = LiteLLMModel.format_request_message_tool_call(tool_use) + + assert result["id"] == "call_abc123__thought__dGhpcy1pcy1hLXNpZw==" + assert result["function"]["name"] == "get_weather" + assert result["function"]["arguments"] == '{"city": "Seattle"}' + + +def test_format_request_message_tool_call_skips_encoding_when_already_present(): + """Test that format_request_message_tool_call does not double-encode the signature.""" + tool_use = { + "toolUseId": "call_abc123__thought__dGhpcy1pcy1hLXNpZw==", + "name": "get_weather", + "input": {"city": "Seattle"}, + "reasoningSignature": "dGhpcy1pcy1hLXNpZw==", + } + + result = LiteLLMModel.format_request_message_tool_call(tool_use) + + # Should NOT double-encode + assert result["id"] == "call_abc123__thought__dGhpcy1pcy1hLXNpZw==" + + +def test_format_request_message_tool_call_no_reasoning_signature(): + """Test that format_request_message_tool_call works normally without reasoningSignature.""" + tool_use = { + "toolUseId": "call_plain123", + "name": "get_weather", + "input": {"city": "Seattle"}, + } + + result = LiteLLMModel.format_request_message_tool_call(tool_use) + + assert result["id"] == "call_plain123" + assert "__thought__" not in result["id"] + + +def test_thought_signature_round_trip(): + """Test that thought signature is preserved through a full response -> internal -> request cycle.""" + model = LiteLLMModel(model_id="test") + signature = "R2VtaW5pVGhvdWdodFNpZw==" + tool_call_id = f"call_xyz789__thought__{signature}" + + # 1. Response path: format_chunk extracts the signature + mock_data = unittest.mock.Mock() + mock_data.id = tool_call_id + mock_data.function = unittest.mock.Mock() + mock_data.function.name = "current_time" + mock_data.provider_specific_fields = None + mock_data.function.provider_specific_fields = None + + event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_data} + chunk = model.format_chunk(event) + tool_use_data = chunk["contentBlockStart"]["start"]["toolUse"] + assert tool_use_data["reasoningSignature"] == signature + + # 2. Simulate internal storage (streaming layer stores reasoningSignature) + internal_tool_use = { + "toolUseId": tool_use_data["toolUseId"], + "name": tool_use_data["name"], + "input": {"timezone": "UTC"}, + "reasoningSignature": tool_use_data["reasoningSignature"], + } + + # 3. Request path: format_request_message_tool_call re-encodes the signature + tool_call = LiteLLMModel.format_request_message_tool_call(internal_tool_use) + assert "__thought__" in tool_call["id"] + assert signature in tool_call["id"]