Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 70 additions & 3 deletions src/strands/models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@
from ..types.event_loop import Usage
from ..types.exceptions import ContextWindowOverflowException
from ..types.streaming import MetadataEvent, StreamEvent
from ..types.tools import ToolChoice, ToolSpec
from ..types.tools import ToolChoice, ToolSpec, ToolUse
from ._validation import validate_config_keys
from .openai import OpenAIModel

logger = logging.getLogger(__name__)

# Separator used by LiteLLM to embed thought signatures inside tool call IDs.
# See: https://ai.google.dev/gemini-api/docs/thought-signatures
_THOUGHT_SIGNATURE_SEPARATOR = "__thought__"

T = TypeVar("T", bound=BaseModel)


Expand Down Expand Up @@ -114,6 +118,31 @@ def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) ->

return super().format_request_message_content(content)

@classmethod
@override
def format_request_message_tool_call(cls, tool_use: ToolUse, **kwargs: Any) -> dict[str, Any]:
"""Format a LiteLLM compatible tool call, encoding thought signatures into the tool call ID.

Gemini thinking models attach a thought_signature to each function call. LiteLLM's OpenAI-compatible
interface embeds this signature inside the tool call ID using the ``__thought__`` separator. When
``reasoningSignature`` is present and the tool call ID does not already contain the separator, this
method encodes it so LiteLLM can reconstruct the Gemini-native format on the next request.

Args:
tool_use: Tool use requested by the model.
**kwargs: Additional keyword arguments for future extensibility.

Returns:
LiteLLM compatible tool call dict with thought signature encoded in the ID when present.
"""
tool_call = super().format_request_message_tool_call(tool_use, **kwargs)

reasoning_signature = tool_use.get("reasoningSignature")
if reasoning_signature and _THOUGHT_SIGNATURE_SEPARATOR not in tool_call["id"]:
tool_call["id"] = f"{tool_call['id']}{_THOUGHT_SIGNATURE_SEPARATOR}{reasoning_signature}"

return tool_call

def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> tuple[list[StreamEvent], str]:
"""Handle switching to a new content stream.

Expand Down Expand Up @@ -200,8 +229,9 @@ def format_request_messages(
def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent:
"""Format a LiteLLM response event into a standardized message chunk.

This method overrides OpenAI's format_chunk to handle the metadata case
with prompt caching support. All other chunk types use the parent implementation.
Extends OpenAI's format_chunk to:
1. Handle metadata with prompt caching support.
2. Extract thought signatures that LiteLLM embeds in tool call IDs for Gemini thinking models.

Args:
event: A response event from the LiteLLM model.
Expand Down Expand Up @@ -237,6 +267,43 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent:
usage=usage_data,
)
)

# Extract thought signature from tool call content_start events.
# LiteLLM embeds Gemini thought signatures in the tool call ID using the __thought__ separator.
# We extract it into reasoningSignature so the streaming layer can preserve it through to
# the internal ToolUse representation. The full encoded ID is kept in toolUseId so that
# tool result messages (which reference toolUseId) continue to match the assistant message.
if event["chunk_type"] == "content_start" and event.get("data_type") == "tool":
data = event.get("data")
tool_call_id = getattr(data, "id", None) or ""
if not isinstance(tool_call_id, str):
tool_call_id = ""
# Also check provider_specific_fields for the signature (non-streaming responses)
psf = getattr(data, "provider_specific_fields", None) or {}
if isinstance(psf, dict):
psf_signature = psf.get("thought_signature")
else:
psf_signature = None
# Extract from encoded ID as fallback
id_signature = None
if _THOUGHT_SIGNATURE_SEPARATOR in tool_call_id:
_, id_signature = tool_call_id.split(_THOUGHT_SIGNATURE_SEPARATOR, 1)
# Also check function-level provider_specific_fields
func = getattr(data, "function", None)
func_psf = getattr(func, "provider_specific_fields", None) or {}
if isinstance(func_psf, dict):
func_signature = func_psf.get("thought_signature")
else:
func_signature = None

signature = psf_signature or func_signature or id_signature

chunk = super().format_chunk(event, **kwargs)
if signature:
tool_use = chunk.get("contentBlockStart", {}).get("start", {}).get("toolUse", {})
tool_use["reasoningSignature"] = signature
return chunk

# For all other cases, use the parent implementation
return super().format_chunk(event)

Expand Down
138 changes: 138 additions & 0 deletions tests/strands/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,3 +848,141 @@ def test_format_request_messages_with_tool_calls_no_content():
},
]
assert tru_result == exp_result


# --- Thought Signature Tests ---


def test_format_chunk_tool_start_extracts_thought_signature_from_id():
"""Test that format_chunk extracts thought_signature from LiteLLM-encoded tool call ID."""
model = LiteLLMModel(model_id="test")

mock_data = unittest.mock.Mock()
mock_data.id = "call_abc123__thought__dGhpcy1pcy1hLXNpZw=="
mock_data.function = unittest.mock.Mock()
mock_data.function.name = "get_weather"
mock_data.provider_specific_fields = None

event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_data}
result = model.format_chunk(event)

tool_use = result["contentBlockStart"]["start"]["toolUse"]
assert tool_use["reasoningSignature"] == "dGhpcy1pcy1hLXNpZw=="
# toolUseId keeps the full encoded string so tool result IDs match
assert tool_use["toolUseId"] == "call_abc123__thought__dGhpcy1pcy1hLXNpZw=="


def test_format_chunk_tool_start_extracts_thought_signature_from_provider_specific_fields():
"""Test that format_chunk extracts thought_signature from provider_specific_fields."""
model = LiteLLMModel(model_id="test")

mock_data = unittest.mock.Mock()
mock_data.id = "call_abc123" # No __thought__ in ID
mock_data.function = unittest.mock.Mock()
mock_data.function.name = "get_weather"
mock_data.function.provider_specific_fields = None
mock_data.provider_specific_fields = {"thought_signature": "cHNmLXNpZw=="}

event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_data}
result = model.format_chunk(event)

tool_use = result["contentBlockStart"]["start"]["toolUse"]
assert tool_use["reasoningSignature"] == "cHNmLXNpZw=="
assert tool_use["toolUseId"] == "call_abc123"


def test_format_chunk_tool_start_no_thought_signature():
"""Test that format_chunk works normally when no thought_signature is present."""
model = LiteLLMModel(model_id="test")

mock_data = unittest.mock.Mock()
mock_data.id = "call_plain123"
mock_data.function = unittest.mock.Mock()
mock_data.function.name = "get_weather"
mock_data.provider_specific_fields = None
mock_data.function.provider_specific_fields = None

event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_data}
result = model.format_chunk(event)

tool_use = result["contentBlockStart"]["start"]["toolUse"]
assert tool_use["toolUseId"] == "call_plain123"
assert "reasoningSignature" not in tool_use


def test_format_request_message_tool_call_encodes_thought_signature():
"""Test that format_request_message_tool_call encodes reasoningSignature into the tool call ID."""
tool_use = {
"toolUseId": "call_abc123",
"name": "get_weather",
"input": {"city": "Seattle"},
"reasoningSignature": "dGhpcy1pcy1hLXNpZw==",
}

result = LiteLLMModel.format_request_message_tool_call(tool_use)

assert result["id"] == "call_abc123__thought__dGhpcy1pcy1hLXNpZw=="
assert result["function"]["name"] == "get_weather"
assert result["function"]["arguments"] == '{"city": "Seattle"}'


def test_format_request_message_tool_call_skips_encoding_when_already_present():
"""Test that format_request_message_tool_call does not double-encode the signature."""
tool_use = {
"toolUseId": "call_abc123__thought__dGhpcy1pcy1hLXNpZw==",
"name": "get_weather",
"input": {"city": "Seattle"},
"reasoningSignature": "dGhpcy1pcy1hLXNpZw==",
}

result = LiteLLMModel.format_request_message_tool_call(tool_use)

# Should NOT double-encode
assert result["id"] == "call_abc123__thought__dGhpcy1pcy1hLXNpZw=="


def test_format_request_message_tool_call_no_reasoning_signature():
"""Test that format_request_message_tool_call works normally without reasoningSignature."""
tool_use = {
"toolUseId": "call_plain123",
"name": "get_weather",
"input": {"city": "Seattle"},
}

result = LiteLLMModel.format_request_message_tool_call(tool_use)

assert result["id"] == "call_plain123"
assert "__thought__" not in result["id"]


def test_thought_signature_round_trip():
"""Test that thought signature is preserved through a full response -> internal -> request cycle."""
model = LiteLLMModel(model_id="test")
signature = "R2VtaW5pVGhvdWdodFNpZw=="
tool_call_id = f"call_xyz789__thought__{signature}"

# 1. Response path: format_chunk extracts the signature
mock_data = unittest.mock.Mock()
mock_data.id = tool_call_id
mock_data.function = unittest.mock.Mock()
mock_data.function.name = "current_time"
mock_data.provider_specific_fields = None
mock_data.function.provider_specific_fields = None

event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_data}
chunk = model.format_chunk(event)
tool_use_data = chunk["contentBlockStart"]["start"]["toolUse"]
assert tool_use_data["reasoningSignature"] == signature

# 2. Simulate internal storage (streaming layer stores reasoningSignature)
internal_tool_use = {
"toolUseId": tool_use_data["toolUseId"],
"name": tool_use_data["name"],
"input": {"timezone": "UTC"},
"reasoningSignature": tool_use_data["reasoningSignature"],
}

# 3. Request path: format_request_message_tool_call re-encodes the signature
tool_call = LiteLLMModel.format_request_message_tool_call(internal_tool_use)
assert "__thought__" in tool_call["id"]
assert signature in tool_call["id"]