Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,13 +494,32 @@ def _model_response_to_generate_content_response(
"""

message = None
if response.get("choices", None):
message = response["choices"][0].get("message", None)
finish_reason = None
if choices := response.get("choices"):
Comment thread
aperepel marked this conversation as resolved.
Outdated
first_choice = choices[0]
message = first_choice.get("message", None)
finish_reason = first_choice.get("finish_reason", None)

if not message:
raise ValueError("No message in response")

llm_response = _message_to_generate_content_response(message)
if finish_reason:
# Map LiteLLM finish_reason strings to FinishReason enum
# This provides type consistency with Gemini native responses and avoids warnings
finish_reason_str = str(finish_reason).lower()
if finish_reason_str == "length":
Comment thread
aperepel marked this conversation as resolved.
Outdated
llm_response.finish_reason = types.FinishReason.MAX_TOKENS
elif finish_reason_str == "stop":
llm_response.finish_reason = types.FinishReason.STOP
elif "tool" in finish_reason_str or "function" in finish_reason_str:
# Handle tool_calls, function_call variants
llm_response.finish_reason = types.FinishReason.STOP
elif finish_reason_str == "content_filter":
llm_response.finish_reason = types.FinishReason.SAFETY
else:
# For unknown reasons, use OTHER
llm_response.finish_reason = types.FinishReason.OTHER
if response.get("usage", None):
llm_response.usage_metadata = types.GenerateContentResponseUsageMetadata(
prompt_token_count=response["usage"].get("prompt_tokens", 0),
Expand Down
8 changes: 6 additions & 2 deletions src/google/adk/models/llm_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from typing import Any
from typing import Optional
from typing import Union
Comment thread
aperepel marked this conversation as resolved.
Outdated

from google.genai import types
from pydantic import alias_generators
Expand Down Expand Up @@ -77,8 +78,11 @@ class LlmResponse(BaseModel):
Only used for streaming mode.
"""

finish_reason: Optional[types.FinishReason] = None
"""The finish reason of the response."""
finish_reason: Optional[Union[types.FinishReason, str]] = None
"""The finish reason of the response.

Can be either a types.FinishReason enum (from Gemini) or a string (from LiteLLM).
Comment thread
aperepel marked this conversation as resolved.
Outdated
"""
Comment thread
aperepel marked this conversation as resolved.
Outdated

error_code: Optional[str] = None
"""Error code if the response is an error. Code varies by model."""
Expand Down
7 changes: 6 additions & 1 deletion src/google/adk/telemetry/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,14 @@ def trace_call_llm(
llm_response.usage_metadata.candidates_token_count,
)
if llm_response.finish_reason:
if isinstance(llm_response.finish_reason, types.FinishReason):
finish_reason_str = llm_response.finish_reason.name.lower()
else:
# Fallback for string values (should not occur with LiteLLM after enum mapping)
finish_reason_str = str(llm_response.finish_reason).lower()
span.set_attribute(
'gen_ai.response.finish_reasons',
[llm_response.finish_reason.value.lower()],
[finish_reason_str],
)


Expand Down
119 changes: 119 additions & 0 deletions tests/unittests/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1903,3 +1903,122 @@ def test_non_gemini_litellm_no_warning():
# Test with non-Gemini model
LiteLlm(model="openai/gpt-4o")
assert len(w) == 0


@pytest.mark.parametrize(
"finish_reason,response_content,expected_content,has_tool_calls",
[
("length", "Test response", "Test response", False),
("stop", "Complete response", "Complete response", False),
(
"tool_calls",
"",
"",
True,
),
("content_filter", "", "", False),
],
ids=["length", "stop", "tool_calls", "content_filter"],
)
@pytest.mark.asyncio
async def test_finish_reason_propagation(
mock_acompletion,
lite_llm_instance,
finish_reason,
response_content,
expected_content,
has_tool_calls,
):
"""Test that finish_reason is properly propagated from LiteLLM response."""
tool_calls = None
if has_tool_calls:
tool_calls = [
ChatCompletionMessageToolCall(
type="function",
id="test_id",
function=Function(
name="test_function",
arguments='{"arg": "value"}',
),
)
]

mock_response = ModelResponse(
choices=[
Choices(
message=ChatCompletionAssistantMessage(
role="assistant",
content=response_content,
tool_calls=tool_calls,
),
finish_reason=finish_reason,
)
]
)
mock_acompletion.return_value = mock_response

llm_request = LlmRequest(
contents=[
types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
)
],
)

async for response in lite_llm_instance.generate_content_async(llm_request):
assert response.content.role == "model"
# Verify finish_reason is mapped to FinishReason enum, not raw string
assert isinstance(response.finish_reason, types.FinishReason)
# Verify correct enum mapping
if finish_reason == "length":
assert response.finish_reason == types.FinishReason.MAX_TOKENS
elif finish_reason == "stop":
assert response.finish_reason == types.FinishReason.STOP
elif finish_reason == "tool_calls":
assert response.finish_reason == types.FinishReason.STOP
elif finish_reason == "content_filter":
assert response.finish_reason == types.FinishReason.SAFETY
Comment thread
aperepel marked this conversation as resolved.
Outdated
Comment thread
aperepel marked this conversation as resolved.
Outdated
Comment thread
aperepel marked this conversation as resolved.
Outdated
if expected_content:
assert response.content.parts[0].text == expected_content
if has_tool_calls:
assert len(response.content.parts) > 0
assert response.content.parts[-1].function_call.name == "test_function"

mock_acompletion.assert_called_once()



@pytest.mark.asyncio
async def test_finish_reason_unknown_maps_to_other(
mock_acompletion, lite_llm_instance
):
"""Test that unknown finish_reason values map to FinishReason.OTHER."""
mock_response = ModelResponse(
choices=[
Choices(
message=ChatCompletionAssistantMessage(
role="assistant",
content="Test response",
),
finish_reason="unknown_reason_type",
)
]
)
mock_acompletion.return_value = mock_response

llm_request = LlmRequest(
contents=[
types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
)
],
)

async for response in lite_llm_instance.generate_content_async(llm_request):
assert response.content.role == "model"
# Unknown finish_reason should map to OTHER
assert isinstance(response.finish_reason, types.FinishReason)
assert response.finish_reason == types.FinishReason.OTHER

mock_acompletion.assert_called_once()