Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,8 @@ def _part_has_payload(part: types.Part) -> bool:
return True
if part.file_data and (part.file_data.file_uri or part.file_data.data):
return True
if part.function_response:
return True
return False


Expand Down
41 changes: 41 additions & 0 deletions tests/unittests/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,47 @@ async def test_generate_content_async_adds_fallback_user_message(
)


@pytest.mark.asyncio
async def test_generate_content_async_no_fallback_for_function_response(
mock_acompletion, lite_llm_instance
):
"""Tests that no fallback message is added for a user message with a function response."""
llm_request = LlmRequest(
contents=[
types.Content(
role="user",
parts=[
types.Part.from_function_response(
name="test_function",
response={"result": "test_result"},
)
],
)
]
)

# Run generate_content_async which calls _append_fallback_user_content_if_missing
async for _ in lite_llm_instance.generate_content_async(llm_request):
pass

# Verify that the fallback message was NOT added to the llm_request
assert len(llm_request.contents) == 1
assert len(llm_request.contents[0].parts) == 1
assert llm_request.contents[0].parts[0].function_response is not None

# Verify that the message sent to litellm does not contain the fallback text
mock_acompletion.assert_called_once()
_, kwargs = mock_acompletion.call_args
user_messages = [
message for message in kwargs["messages"] if message["role"] == "user"
]
assert not any(
message.get("content")
== "Handle the requests as specified in the System Instruction."
for message in user_messages
)


litellm_append_user_content_test_cases = [
pytest.param(
LlmRequest(
Expand Down
Loading