Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,12 @@ def get_author_for_event(llm_response):
invocation_id=invocation_context.invocation_id,
author=get_author_for_event(llm_response),
)
trace_call_llm(
invocation_context,
model_response_event.id,
llm_request,
llm_response,
)
async for event in self._postprocess_live(
invocation_context,
llm_request,
Expand Down
112 changes: 106 additions & 6 deletions src/google/adk/models/gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,11 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
content = message.server_content.model_turn
if content and content.parts:
llm_response = LlmResponse(
content=content, interrupted=message.server_content.interrupted
content=content,
interrupted=message.server_content.interrupted,
usage_metadata=self._fix_usage_metadata(
getattr(message, 'usage_metadata', None)
),
)
if content.parts[0].text:
text += content.parts[0].text
Expand All @@ -169,7 +173,10 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
)
]
llm_response = LlmResponse(
content=types.Content(role='user', parts=parts)
content=types.Content(role='user', parts=parts),
usage_metadata=self._fix_usage_metadata(
getattr(message, 'usage_metadata', None)
),
)
yield llm_response
if (
Expand All @@ -190,7 +197,11 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
)
]
llm_response = LlmResponse(
content=types.Content(role='model', parts=parts), partial=True
content=types.Content(role='model', parts=parts),
partial=True,
usage_metadata=self._fix_usage_metadata(
getattr(message, 'usage_metadata', None)
),
)
yield llm_response

Expand All @@ -199,7 +210,11 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
yield self.__build_full_text_response(text)
text = ''
yield LlmResponse(
turn_complete=True, interrupted=message.server_content.interrupted
turn_complete=True,
interrupted=message.server_content.interrupted,
usage_metadata=self._fix_usage_metadata(
getattr(message, 'usage_metadata', None)
),
)
break
# in case of empty content or parts, we sill surface it
Expand All @@ -209,7 +224,12 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
if message.server_content.interrupted and text:
yield self.__build_full_text_response(text)
text = ''
yield LlmResponse(interrupted=message.server_content.interrupted)
yield LlmResponse(
interrupted=message.server_content.interrupted,
usage_metadata=self._fix_usage_metadata(
getattr(message, 'usage_metadata', None)
),
)
if message.tool_call:
if text:
yield self.__build_full_text_response(text)
Expand All @@ -218,15 +238,95 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
types.Part(function_call=function_call)
for function_call in message.tool_call.function_calls
]
yield LlmResponse(content=types.Content(role='model', parts=parts))
yield LlmResponse(
content=types.Content(role='model', parts=parts)
usage_metadata=self._fix_usage_metadata(
getattr(message, 'usage_metadata', None)
),
)
if message.session_resumption_update:
logger.info('Redeived session reassumption message: %s', message)
yield (
LlmResponse(
live_session_resumption_update=message.session_resumption_update
usage_metadata=self._fix_usage_metadata(
getattr(message, 'usage_metadata', None)
),
)
)

def _fix_usage_metadata(self, usage_metadata):
"""
Fix missing candidates_token_count in Gemini Live API responses.

The Gemini Live API inconsistently returns usage metadata. While it typically
provides total_token_count and prompt_token_count, it often leaves
candidates_token_count as None. This creates incomplete telemetry data which
affects billing reporting and token usage monitoring.

This method calculates the missing candidates_token_count using the formula:
candidates_token_count = total_token_count - prompt_token_count

Args:
usage_metadata: The usage metadata from the Live API response, which may
have missing candidates_token_count.

Returns:
Fixed usage metadata with calculated candidates_token_count, or the
original metadata if no fix is needed/possible.
"""
if not usage_metadata:
return usage_metadata

# Safely get token counts using getattr with defaults
total_tokens = getattr(usage_metadata, 'total_token_count', None)
prompt_tokens = getattr(usage_metadata, 'prompt_token_count', None)
candidates_tokens = getattr(usage_metadata, 'candidates_token_count', None)

# Only fix if we have total and prompt but missing candidates
if (
total_tokens is not None
and prompt_tokens is not None
and candidates_tokens is None
):
# Calculate candidates tokens as: total - prompt
calculated_candidates = total_tokens - prompt_tokens

if calculated_candidates > 0:
# Create a new usage metadata object with the calculated value
from google.genai import types

return types.GenerateContentResponseUsageMetadata(
total_token_count=total_tokens,
prompt_token_count=prompt_tokens,
candidates_token_count=calculated_candidates,
# Copy other fields if they exist
cache_tokens_details=getattr(
usage_metadata, 'cache_tokens_details', None
),
cached_content_token_count=getattr(
usage_metadata, 'cached_content_token_count', None
),
candidates_tokens_details=getattr(
usage_metadata, 'candidates_tokens_details', None
),
prompt_tokens_details=getattr(
usage_metadata, 'prompt_tokens_details', None
),
thoughts_token_count=getattr(
usage_metadata, 'thoughts_token_count', None
),
tool_use_prompt_token_count=getattr(
usage_metadata, 'tool_use_prompt_token_count', None
),
tool_use_prompt_tokens_details=getattr(
usage_metadata, 'tool_use_prompt_tokens_details', None
),
traffic_type=getattr(usage_metadata, 'traffic_type', None),
)

return usage_metadata

async def close(self):
"""Closes the llm server connection."""

Expand Down
46 changes: 35 additions & 11 deletions src/google/adk/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,31 @@ def trace_call_llm(
llm_request: The LLM request object.
llm_response: The LLM response object.
"""
span = trace.get_current_span()
# Special standard Open Telemetry GenaI attributes that indicate
# that this is a span related to a Generative AI system.
# For live events with usage metadata, create a new span for each event
# For regular events or live events without usage data, use the current span
if (
hasattr(invocation_context, 'live_request_queue')
and invocation_context.live_request_queue
and llm_response.usage_metadata is not None
):
# Live mode with usage data: create new span for each event
span_name = f'llm_call_live_event [{event_id[:8]}]'
with tracer.start_as_current_span(span_name) as span:
_set_llm_span_attributes(
span, invocation_context, event_id, llm_request, llm_response
)
else:
# Regular mode or live mode without usage data: use current span
span = trace.get_current_span()
_set_llm_span_attributes(
span, invocation_context, event_id, llm_request, llm_response
)


def _set_llm_span_attributes(
span, invocation_context, event_id, llm_request, llm_response
):
"""Set LLM span attributes."""
span.set_attribute('gen_ai.system', 'gcp.vertex.agent')
span.set_attribute('gen_ai.request.model', llm_request.model)
span.set_attribute(
Expand Down Expand Up @@ -196,14 +218,16 @@ def trace_call_llm(
)

if llm_response.usage_metadata is not None:
span.set_attribute(
'gen_ai.usage.input_tokens',
llm_response.usage_metadata.prompt_token_count,
)
span.set_attribute(
'gen_ai.usage.output_tokens',
llm_response.usage_metadata.candidates_token_count,
)
if llm_response.usage_metadata.prompt_token_count is not None:
span.set_attribute(
'gen_ai.usage.input_tokens',
llm_response.usage_metadata.prompt_token_count,
)
if llm_response.usage_metadata.candidates_token_count is not None:
span.set_attribute(
'gen_ai.usage.output_tokens',
llm_response.usage_metadata.candidates_token_count,
)


def trace_send_data(
Expand Down
143 changes: 143 additions & 0 deletions tests/unittests/flows/llm_flows/test_base_llm_flow_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,146 @@ async def test_send_to_model_with_text_content(mock_llm_connection):
# Verify send_content was called instead of send_realtime
mock_llm_connection.send_content.assert_called_once_with(content)
mock_llm_connection.send_realtime.assert_not_called()


@pytest.mark.asyncio
async def test_receive_from_model_calls_telemetry_trace(monkeypatch):
"""Test that _receive_from_model calls trace_call_llm for telemetry."""
# Mock the trace_call_llm function
mock_trace_call_llm = mock.AsyncMock()
monkeypatch.setattr(
'google.adk.flows.llm_flows.base_llm_flow.trace_call_llm',
mock_trace_call_llm,
)

# Create mock LLM connection that yields responses
mock_llm_connection = mock.AsyncMock()

# Create test LLM response with usage metadata
from google.adk.models.llm_response import LlmResponse

test_llm_response = LlmResponse(
content=types.Content(
role='model', parts=[types.Part.from_text(text='Test response')]
),
usage_metadata=types.GenerateContentResponseUsageMetadata(
total_token_count=100,
prompt_token_count=50,
candidates_token_count=50,
),
)

# Mock the receive method to yield our test response
async def mock_receive():
yield test_llm_response

mock_llm_connection.receive = mock_receive

# Create agent and invocation context
agent = Agent(name='test_agent', model='mock')
invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content='test message'
)
invocation_context.live_request_queue = LiveRequestQueue()

# Create flow and test data
flow = TestBaseLlmFlow()
event_id = 'test_event_123'
llm_request = LlmRequest()

# Call _receive_from_model and consume the generator
events = []
async for event in flow._receive_from_model(
mock_llm_connection, event_id, invocation_context, llm_request
):
events.append(event)
break # Exit after first event to avoid infinite loop

# Verify trace_call_llm was called
mock_trace_call_llm.assert_called()

# Verify the call arguments
call_args = mock_trace_call_llm.call_args
assert call_args[0][0] == invocation_context # First arg: invocation_context
assert call_args[0][2] == llm_request # Third arg: llm_request
assert call_args[0][3] == test_llm_response # Fourth arg: llm_response

# Second arg should be the event ID from the generated event
assert len(call_args[0][1]) > 0 # Event ID should be non-empty string


@pytest.mark.asyncio
async def test_receive_from_model_telemetry_integration_with_live_queue(
monkeypatch,
):
"""Test telemetry integration in live mode with actual live request queue."""
# Mock the telemetry tracer to capture span creation
mock_tracer = mock.MagicMock()
mock_span = mock.MagicMock()
mock_tracer.start_as_current_span.return_value.__enter__.return_value = (
mock_span
)

monkeypatch.setattr('google.adk.telemetry.tracer', mock_tracer)

# Create mock LLM connection
mock_llm_connection = mock.AsyncMock()

# Create test responses - one with usage metadata, one without
from google.adk.models.llm_response import LlmResponse

response_with_usage = LlmResponse(
content=types.Content(
role='model', parts=[types.Part.from_text(text='Response 1')]
),
usage_metadata=types.GenerateContentResponseUsageMetadata(
total_token_count=100,
prompt_token_count=50,
candidates_token_count=50,
),
)

response_without_usage = LlmResponse(
content=types.Content(
role='model', parts=[types.Part.from_text(text='Response 2')]
),
usage_metadata=None,
)

# Mock receive to yield both responses
async def mock_receive():
yield response_with_usage
yield response_without_usage

mock_llm_connection.receive = mock_receive

# Create agent and invocation context with live request queue
agent = Agent(name='test_agent', model='mock')
invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content='test message'
)
invocation_context.live_request_queue = LiveRequestQueue()

# Create flow
flow = TestBaseLlmFlow()
event_id = 'test_event_integration'
llm_request = LlmRequest()

# Process events from _receive_from_model
events = []
async for event in flow._receive_from_model(
mock_llm_connection, event_id, invocation_context, llm_request
):
events.append(event)
if len(events) >= 2: # Stop after processing both responses
break

# Verify new spans were created for live events with usage metadata
assert mock_tracer.start_as_current_span.call_count >= 1

# Check that at least one span was created with live event naming
span_calls = mock_tracer.start_as_current_span.call_args_list
live_event_spans = [
call for call in span_calls if 'llm_call_live_event' in call[0][0]
]
assert len(live_event_spans) >= 1, 'Should create live event spans'
Loading