Skip to content

Commit 5d7f638

Browse files
giulio-leoneCopilot
andcommitted
fix: preserve Gemini thought_signature in LiteLLM multi-turn tool calls
When using Gemini thinking models (e.g., gemini-2.5-flash) through the LiteLLM model provider, multi-turn conversations with tool calls fail because thought_signature is lost during the response-to-request round trip. LiteLLM encodes Gemini's thought_signature into the tool call ID using a __thought__ separator. The OpenAI parent format_chunk passes this through as-is, but the signature is never extracted into Strands' reasoningSignature field, which the streaming layer already supports. Changes: - Override format_chunk in LiteLLMModel to detect __thought__ in tool call IDs and provider_specific_fields, extracting the signature into reasoningSignature for proper streaming layer storage - Override format_request_message_tool_call to re-encode reasoningSignature back into the tool call ID when it is not already present, ensuring LiteLLM can reconstruct the Gemini-native format - Add 7 unit tests covering extraction from ID, extraction from provider_specific_fields, no-signature passthrough, encoding, double-encode prevention, and full round-trip preservation Closes #1764 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent fca208b commit 5d7f638

2 files changed

Lines changed: 208 additions & 3 deletions

File tree

src/strands/models/litellm.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,16 @@
1919
from ..types.event_loop import Usage
2020
from ..types.exceptions import ContextWindowOverflowException
2121
from ..types.streaming import MetadataEvent, StreamEvent
22-
from ..types.tools import ToolChoice, ToolSpec
22+
from ..types.tools import ToolChoice, ToolSpec, ToolUse
2323
from ._validation import validate_config_keys
2424
from .openai import OpenAIModel
2525

2626
logger = logging.getLogger(__name__)
2727

28+
# Separator used by LiteLLM to embed thought signatures inside tool call IDs.
29+
# See: https://ai.google.dev/gemini-api/docs/thought-signatures
30+
_THOUGHT_SIGNATURE_SEPARATOR = "__thought__"
31+
2832
T = TypeVar("T", bound=BaseModel)
2933

3034

@@ -114,6 +118,31 @@ def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) ->
114118

115119
return super().format_request_message_content(content)
116120

121+
@classmethod
122+
@override
123+
def format_request_message_tool_call(cls, tool_use: ToolUse, **kwargs: Any) -> dict[str, Any]:
124+
"""Format a LiteLLM compatible tool call, encoding thought signatures into the tool call ID.
125+
126+
Gemini thinking models attach a thought_signature to each function call. LiteLLM's OpenAI-compatible
127+
interface embeds this signature inside the tool call ID using the ``__thought__`` separator. When
128+
``reasoningSignature`` is present and the tool call ID does not already contain the separator, this
129+
method encodes it so LiteLLM can reconstruct the Gemini-native format on the next request.
130+
131+
Args:
132+
tool_use: Tool use requested by the model.
133+
**kwargs: Additional keyword arguments for future extensibility.
134+
135+
Returns:
136+
LiteLLM compatible tool call dict with thought signature encoded in the ID when present.
137+
"""
138+
tool_call = super().format_request_message_tool_call(tool_use, **kwargs)
139+
140+
reasoning_signature = tool_use.get("reasoningSignature")
141+
if reasoning_signature and _THOUGHT_SIGNATURE_SEPARATOR not in tool_call["id"]:
142+
tool_call["id"] = f"{tool_call['id']}{_THOUGHT_SIGNATURE_SEPARATOR}{reasoning_signature}"
143+
144+
return tool_call
145+
117146
def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> tuple[list[StreamEvent], str]:
118147
"""Handle switching to a new content stream.
119148
@@ -200,8 +229,9 @@ def format_request_messages(
200229
def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent:
201230
"""Format a LiteLLM response event into a standardized message chunk.
202231
203-
This method overrides OpenAI's format_chunk to handle the metadata case
204-
with prompt caching support. All other chunk types use the parent implementation.
232+
Extends OpenAI's format_chunk to:
233+
1. Handle metadata with prompt caching support.
234+
2. Extract thought signatures that LiteLLM embeds in tool call IDs for Gemini thinking models.
205235
206236
Args:
207237
event: A response event from the LiteLLM model.
@@ -237,6 +267,43 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent:
237267
usage=usage_data,
238268
)
239269
)
270+
271+
# Extract thought signature from tool call content_start events.
272+
# LiteLLM embeds Gemini thought signatures in the tool call ID using the __thought__ separator.
273+
# We extract it into reasoningSignature so the streaming layer can preserve it through to
274+
# the internal ToolUse representation. The full encoded ID is kept in toolUseId so that
275+
# tool result messages (which reference toolUseId) continue to match the assistant message.
276+
if event["chunk_type"] == "content_start" and event.get("data_type") == "tool":
277+
data = event.get("data")
278+
tool_call_id = getattr(data, "id", None) or ""
279+
if not isinstance(tool_call_id, str):
280+
tool_call_id = ""
281+
# Also check provider_specific_fields for the signature (non-streaming responses)
282+
psf = getattr(data, "provider_specific_fields", None) or {}
283+
if isinstance(psf, dict):
284+
psf_signature = psf.get("thought_signature")
285+
else:
286+
psf_signature = None
287+
# Extract from encoded ID as fallback
288+
id_signature = None
289+
if _THOUGHT_SIGNATURE_SEPARATOR in tool_call_id:
290+
_, id_signature = tool_call_id.split(_THOUGHT_SIGNATURE_SEPARATOR, 1)
291+
# Also check function-level provider_specific_fields
292+
func = getattr(data, "function", None)
293+
func_psf = getattr(func, "provider_specific_fields", None) or {}
294+
if isinstance(func_psf, dict):
295+
func_signature = func_psf.get("thought_signature")
296+
else:
297+
func_signature = None
298+
299+
signature = psf_signature or func_signature or id_signature
300+
301+
chunk = super().format_chunk(event, **kwargs)
302+
if signature:
303+
tool_use = chunk.get("contentBlockStart", {}).get("start", {}).get("toolUse", {})
304+
tool_use["reasoningSignature"] = signature
305+
return chunk
306+
240307
# For all other cases, use the parent implementation
241308
return super().format_chunk(event)
242309

tests/strands/models/test_litellm.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,3 +848,141 @@ def test_format_request_messages_with_tool_calls_no_content():
848848
},
849849
]
850850
assert tru_result == exp_result
851+
852+
853+
# --- Thought Signature Tests ---
854+
855+
856+
def test_format_chunk_tool_start_extracts_thought_signature_from_id():
857+
"""Test that format_chunk extracts thought_signature from LiteLLM-encoded tool call ID."""
858+
model = LiteLLMModel(model_id="test")
859+
860+
mock_data = unittest.mock.Mock()
861+
mock_data.id = "call_abc123__thought__dGhpcy1pcy1hLXNpZw=="
862+
mock_data.function = unittest.mock.Mock()
863+
mock_data.function.name = "get_weather"
864+
mock_data.provider_specific_fields = None
865+
866+
event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_data}
867+
result = model.format_chunk(event)
868+
869+
tool_use = result["contentBlockStart"]["start"]["toolUse"]
870+
assert tool_use["reasoningSignature"] == "dGhpcy1pcy1hLXNpZw=="
871+
# toolUseId keeps the full encoded string so tool result IDs match
872+
assert tool_use["toolUseId"] == "call_abc123__thought__dGhpcy1pcy1hLXNpZw=="
873+
874+
875+
def test_format_chunk_tool_start_extracts_thought_signature_from_provider_specific_fields():
876+
"""Test that format_chunk extracts thought_signature from provider_specific_fields."""
877+
model = LiteLLMModel(model_id="test")
878+
879+
mock_data = unittest.mock.Mock()
880+
mock_data.id = "call_abc123" # No __thought__ in ID
881+
mock_data.function = unittest.mock.Mock()
882+
mock_data.function.name = "get_weather"
883+
mock_data.function.provider_specific_fields = None
884+
mock_data.provider_specific_fields = {"thought_signature": "cHNmLXNpZw=="}
885+
886+
event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_data}
887+
result = model.format_chunk(event)
888+
889+
tool_use = result["contentBlockStart"]["start"]["toolUse"]
890+
assert tool_use["reasoningSignature"] == "cHNmLXNpZw=="
891+
assert tool_use["toolUseId"] == "call_abc123"
892+
893+
894+
def test_format_chunk_tool_start_no_thought_signature():
895+
"""Test that format_chunk works normally when no thought_signature is present."""
896+
model = LiteLLMModel(model_id="test")
897+
898+
mock_data = unittest.mock.Mock()
899+
mock_data.id = "call_plain123"
900+
mock_data.function = unittest.mock.Mock()
901+
mock_data.function.name = "get_weather"
902+
mock_data.provider_specific_fields = None
903+
mock_data.function.provider_specific_fields = None
904+
905+
event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_data}
906+
result = model.format_chunk(event)
907+
908+
tool_use = result["contentBlockStart"]["start"]["toolUse"]
909+
assert tool_use["toolUseId"] == "call_plain123"
910+
assert "reasoningSignature" not in tool_use
911+
912+
913+
def test_format_request_message_tool_call_encodes_thought_signature():
914+
"""Test that format_request_message_tool_call encodes reasoningSignature into the tool call ID."""
915+
tool_use = {
916+
"toolUseId": "call_abc123",
917+
"name": "get_weather",
918+
"input": {"city": "Seattle"},
919+
"reasoningSignature": "dGhpcy1pcy1hLXNpZw==",
920+
}
921+
922+
result = LiteLLMModel.format_request_message_tool_call(tool_use)
923+
924+
assert result["id"] == "call_abc123__thought__dGhpcy1pcy1hLXNpZw=="
925+
assert result["function"]["name"] == "get_weather"
926+
assert result["function"]["arguments"] == '{"city": "Seattle"}'
927+
928+
929+
def test_format_request_message_tool_call_skips_encoding_when_already_present():
930+
"""Test that format_request_message_tool_call does not double-encode the signature."""
931+
tool_use = {
932+
"toolUseId": "call_abc123__thought__dGhpcy1pcy1hLXNpZw==",
933+
"name": "get_weather",
934+
"input": {"city": "Seattle"},
935+
"reasoningSignature": "dGhpcy1pcy1hLXNpZw==",
936+
}
937+
938+
result = LiteLLMModel.format_request_message_tool_call(tool_use)
939+
940+
# Should NOT double-encode
941+
assert result["id"] == "call_abc123__thought__dGhpcy1pcy1hLXNpZw=="
942+
943+
944+
def test_format_request_message_tool_call_no_reasoning_signature():
945+
"""Test that format_request_message_tool_call works normally without reasoningSignature."""
946+
tool_use = {
947+
"toolUseId": "call_plain123",
948+
"name": "get_weather",
949+
"input": {"city": "Seattle"},
950+
}
951+
952+
result = LiteLLMModel.format_request_message_tool_call(tool_use)
953+
954+
assert result["id"] == "call_plain123"
955+
assert "__thought__" not in result["id"]
956+
957+
958+
def test_thought_signature_round_trip():
959+
"""Test that thought signature is preserved through a full response -> internal -> request cycle."""
960+
model = LiteLLMModel(model_id="test")
961+
signature = "R2VtaW5pVGhvdWdodFNpZw=="
962+
tool_call_id = f"call_xyz789__thought__{signature}"
963+
964+
# 1. Response path: format_chunk extracts the signature
965+
mock_data = unittest.mock.Mock()
966+
mock_data.id = tool_call_id
967+
mock_data.function = unittest.mock.Mock()
968+
mock_data.function.name = "current_time"
969+
mock_data.provider_specific_fields = None
970+
mock_data.function.provider_specific_fields = None
971+
972+
event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_data}
973+
chunk = model.format_chunk(event)
974+
tool_use_data = chunk["contentBlockStart"]["start"]["toolUse"]
975+
assert tool_use_data["reasoningSignature"] == signature
976+
977+
# 2. Simulate internal storage (streaming layer stores reasoningSignature)
978+
internal_tool_use = {
979+
"toolUseId": tool_use_data["toolUseId"],
980+
"name": tool_use_data["name"],
981+
"input": {"timezone": "UTC"},
982+
"reasoningSignature": tool_use_data["reasoningSignature"],
983+
}
984+
985+
# 3. Request path: format_request_message_tool_call re-encodes the signature
986+
tool_call = LiteLLMModel.format_request_message_tool_call(internal_tool_use)
987+
assert "__thought__" in tool_call["id"]
988+
assert signature in tool_call["id"]

0 commit comments

Comments
 (0)