Skip to content

Commit 871571d

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Mark Vertex calls made from non-gemini models
PiperOrigin-RevId: 848159669
1 parent 0f5b677 commit 871571d

9 files changed

Lines changed: 41 additions & 234 deletions

File tree

src/google/adk/models/anthropic_llm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from pydantic import BaseModel
3737
from typing_extensions import override
3838

39-
from ..utils._google_client_headers import get_tracking_headers
4039
from .base_llm import BaseLlm
4140
from .llm_response import LlmResponse
4241

@@ -346,5 +345,4 @@ def _anthropic_client(self) -> AsyncAnthropicVertex:
346345
return AsyncAnthropicVertex(
347346
project_id=os.environ["GOOGLE_CLOUD_PROJECT"],
348347
region=os.environ["GOOGLE_CLOUD_LOCATION"],
349-
default_headers=get_tracking_headers(),
350348
)

src/google/adk/models/apigee_llm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from google.genai import types
2626
from typing_extensions import override
2727

28-
from ..utils._google_client_headers import merge_tracking_headers
2928
from ..utils.env_utils import is_env_enabled
3029
from .google_llm import Gemini
3130

@@ -146,7 +145,7 @@ def api_client(self) -> Client:
146145
kwargs_for_http_options['api_version'] = self._api_version
147146
http_options = types.HttpOptions(
148147
base_url=self._proxy_url,
149-
headers=merge_tracking_headers(self._custom_headers),
148+
headers=self._merge_tracking_headers(self._custom_headers),
150149
retry_options=self.retry_options,
151150
**kwargs_for_http_options,
152151
)

src/google/adk/models/google_llm.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@
3030
from google.genai.errors import ClientError
3131
from typing_extensions import override
3232

33-
from ..utils._google_client_headers import get_tracking_headers
34-
from ..utils._google_client_headers import merge_tracking_headers
33+
from ..utils._client_labels_utils import get_client_labels
3534
from ..utils.context_utils import Aclosing
3635
from ..utils.streaming_utils import StreamingResponseAggregator
3736
from ..utils.variant_utils import GoogleLLMVariant
@@ -192,7 +191,7 @@ async def generate_content_async(
192191
if llm_request.config:
193192
if not llm_request.config.http_options:
194193
llm_request.config.http_options = types.HttpOptions()
195-
llm_request.config.http_options.headers = merge_tracking_headers(
194+
llm_request.config.http_options.headers = self._merge_tracking_headers(
196195
llm_request.config.http_options.headers
197196
)
198197

@@ -303,7 +302,7 @@ def api_client(self) -> Client:
303302

304303
return Client(
305304
http_options=types.HttpOptions(
306-
headers=get_tracking_headers(),
305+
headers=self._tracking_headers(),
307306
retry_options=self.retry_options,
308307
)
309308
)
@@ -316,6 +315,15 @@ def _api_backend(self) -> GoogleLLMVariant:
316315
else GoogleLLMVariant.GEMINI_API
317316
)
318317

318+
def _tracking_headers(self) -> dict[str, str]:
319+
labels = get_client_labels()
320+
header_value = ' '.join(labels)
321+
tracking_headers = {
322+
'x-goog-api-client': header_value,
323+
'user-agent': header_value,
324+
}
325+
return tracking_headers
326+
319327
@cached_property
320328
def _live_api_version(self) -> str:
321329
if self._api_backend == GoogleLLMVariant.VERTEX_AI:
@@ -331,7 +339,7 @@ def _live_api_client(self) -> Client:
331339

332340
return Client(
333341
http_options=types.HttpOptions(
334-
headers=get_tracking_headers(), api_version=self._live_api_version
342+
headers=self._tracking_headers(), api_version=self._live_api_version
335343
)
336344
)
337345

@@ -354,10 +362,8 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
354362
):
355363
if not llm_request.live_connect_config.http_options.headers:
356364
llm_request.live_connect_config.http_options.headers = {}
357-
llm_request.live_connect_config.http_options.headers = (
358-
merge_tracking_headers(
359-
llm_request.live_connect_config.http_options.headers
360-
)
365+
llm_request.live_connect_config.http_options.headers.update(
366+
self._tracking_headers()
361367
)
362368
llm_request.live_connect_config.http_options.api_version = (
363369
self._live_api_version
@@ -441,6 +447,23 @@ async def _preprocess_request(self, llm_request: LlmRequest) -> None:
441447
llm_request.config.system_instruction = None
442448
await self._adapt_computer_use_tool(llm_request)
443449

450+
def _merge_tracking_headers(self, headers: dict[str, str]) -> dict[str, str]:
451+
"""Merge tracking headers to the given headers."""
452+
headers = headers or {}
453+
for key, tracking_header_value in self._tracking_headers().items():
454+
custom_value = headers.get(key, None)
455+
if not custom_value:
456+
headers[key] = tracking_header_value
457+
continue
458+
459+
# Merge tracking headers with existing headers and avoid duplicates.
460+
value_parts = tracking_header_value.split(' ')
461+
for custom_value_part in custom_value.split(' '):
462+
if custom_value_part not in value_parts:
463+
value_parts.append(custom_value_part)
464+
headers[key] = ' '.join(value_parts)
465+
return headers
466+
444467

445468
def _build_function_declaration_log(
446469
func_decl: types.FunctionDeclaration,

src/google/adk/models/lite_llm.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
from pydantic import Field
5858
from typing_extensions import override
5959

60-
from ..utils._google_client_headers import merge_tracking_headers
6160
from .base_llm import BaseLlm
6261
from .llm_request import LlmRequest
6362
from .llm_response import LlmResponse
@@ -1391,18 +1390,6 @@ def _build_request_log(req: LlmRequest) -> str:
13911390
"""
13921391

13931392

1394-
def _is_litellm_vertex_model(model_string: str) -> bool:
1395-
"""Check if the model is a Vertex AI model accessed via LiteLLM.
1396-
1397-
Args:
1398-
model_string: A LiteLLM model string (e.g., "vertex_ai/gemini-2.5-flash")
1399-
1400-
Returns:
1401-
True if it's a Vertex AI model accessed via LiteLLM, False otherwise
1402-
"""
1403-
return model_string.startswith("vertex_ai/")
1404-
1405-
14061393
def _is_litellm_gemini_model(model_string: str) -> bool:
14071394
"""Check if the model is a Gemini model accessed via LiteLLM.
14081395
@@ -1575,14 +1562,6 @@ async def generate_content_async(
15751562
}
15761563
completion_args.update(self._additional_args)
15771564

1578-
# merge headers
1579-
if _is_litellm_vertex_model(effective_model) or _is_litellm_gemini_model(
1580-
effective_model
1581-
):
1582-
completion_args["headers"] = merge_tracking_headers(
1583-
completion_args.get("headers")
1584-
)
1585-
15861565
if generation_params:
15871566
completion_args.update(generation_params)
15881567

src/google/adk/utils/_google_client_headers.py

Lines changed: 0 additions & 56 deletions
This file was deleted.

tests/unittests/models/test_anthropic_llm.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -391,31 +391,6 @@ async def mock_coro():
391391
assert responses[0].content.parts[0].text == "Hello, how can I help you?"
392392

393393

394-
def test_claude_vertex_client_uses_tracking_headers():
395-
"""Tests that Claude vertex client is called with tracking headers."""
396-
with mock.patch.object(
397-
anthropic_llm, "AsyncAnthropicVertex", autospec=True
398-
) as mock_anthropic_vertex:
399-
with mock.patch.dict(
400-
os.environ,
401-
{
402-
"GOOGLE_CLOUD_PROJECT": "test-project",
403-
"GOOGLE_CLOUD_LOCATION": "us-central1",
404-
},
405-
):
406-
instance = Claude(model="claude-3-5-sonnet-v2@20241022")
407-
_ = instance._anthropic_client
408-
mock_anthropic_vertex.assert_called_once()
409-
_, kwargs = mock_anthropic_vertex.call_args
410-
assert "default_headers" in kwargs
411-
assert "x-goog-api-client" in kwargs["default_headers"]
412-
assert "user-agent" in kwargs["default_headers"]
413-
assert (
414-
f"google-adk/{adk_version.__version__}"
415-
in kwargs["default_headers"]["user-agent"]
416-
)
417-
418-
419394
@pytest.mark.asyncio
420395
async def test_generate_content_async_with_max_tokens(
421396
llm_request, generate_content_response, generate_llm_response

tests/unittests/models/test_google_llm.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from google.adk.models.llm_response import LlmResponse
3232
from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME
3333
from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_TAG
34-
from google.adk.utils._google_client_headers import get_tracking_headers
3534
from google.adk.utils.variant_utils import GoogleLLMVariant
3635
from google.genai import types
3736
from google.genai.errors import ClientError
@@ -470,7 +469,7 @@ async def test_generate_content_async_with_custom_headers(
470469
"""Test that tracking headers are updated when custom headers are provided."""
471470
# Add custom headers to the request config
472471
custom_headers = {"custom-header": "custom-value"}
473-
tracking_headers = get_tracking_headers()
472+
tracking_headers = gemini_llm._tracking_headers()
474473
for key in tracking_headers:
475474
custom_headers[key] = "custom " + tracking_headers[key]
476475
llm_request.config.http_options = types.HttpOptions(headers=custom_headers)
@@ -495,7 +494,7 @@ async def mock_coro():
495494
config_arg = call_args.kwargs["config"]
496495

497496
for key, value in config_arg.http_options.headers.items():
498-
tracking_headers = get_tracking_headers()
497+
tracking_headers = gemini_llm._tracking_headers()
499498
if key in tracking_headers:
500499
assert value == tracking_headers[key] + " custom"
501500
else:
@@ -546,7 +545,7 @@ async def mock_coro():
546545
config_arg = call_args.kwargs["config"]
547546

548547
expected_headers = custom_headers.copy()
549-
expected_headers.update(get_tracking_headers())
548+
expected_headers.update(gemini_llm._tracking_headers())
550549
assert config_arg.http_options.headers == expected_headers
551550

552551
assert len(responses) == 2
@@ -600,7 +599,7 @@ async def mock_coro():
600599
assert final_config.http_options is not None
601600
assert (
602601
final_config.http_options.headers["x-goog-api-client"]
603-
== get_tracking_headers()["x-goog-api-client"]
602+
== gemini_llm._tracking_headers()["x-goog-api-client"]
604603
)
605604

606605
assert len(responses) == 2 if stream else 1
@@ -634,7 +633,7 @@ def test_live_api_client_properties(gemini_llm):
634633
assert http_options.api_version == "v1beta1"
635634

636635
# Check that tracking headers are included
637-
tracking_headers = get_tracking_headers()
636+
tracking_headers = gemini_llm._tracking_headers()
638637
for key, value in tracking_headers.items():
639638
assert key in http_options.headers
640639
assert value in http_options.headers[key]
@@ -672,7 +671,7 @@ async def __aexit__(self, *args):
672671

673672
# Verify that tracking headers were merged with custom headers
674673
expected_headers = custom_headers.copy()
675-
expected_headers.update(get_tracking_headers())
674+
expected_headers.update(gemini_llm._tracking_headers())
676675
assert config_arg.http_options.headers == expected_headers
677676

678677
# Verify that API version was set

tests/unittests/models/test_litellm.py

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2447,12 +2447,11 @@ def test_model_response_to_chunk(
24472447
async def test_acompletion_additional_args(mock_acompletion, mock_client):
24482448
lite_llm_instance = LiteLlm(
24492449
# valid args
2450-
model="vertex_ai/test_model",
2450+
model="test_model",
24512451
llm_client=mock_client,
24522452
api_key="test_key",
24532453
api_base="some://url",
24542454
api_version="2024-09-12",
2455-
headers={"custom": "header"}, # Add custom header to test merge
24562455
# invalid args (ignored)
24572456
stream=True,
24582457
messages=[{"role": "invalid", "content": "invalid"}],
@@ -2479,43 +2478,13 @@ async def test_acompletion_additional_args(mock_acompletion, mock_client):
24792478

24802479
_, kwargs = mock_acompletion.call_args
24812480

2482-
assert kwargs["model"] == "vertex_ai/test_model"
2481+
assert kwargs["model"] == "test_model"
24832482
assert kwargs["messages"][0]["role"] == "user"
24842483
assert kwargs["messages"][0]["content"] == "Test prompt"
24852484
assert kwargs["tools"][0]["function"]["name"] == "test_function"
24862485
assert "stream" not in kwargs
24872486
assert "llm_client" not in kwargs
24882487
assert kwargs["api_base"] == "some://url"
2489-
assert "headers" in kwargs
2490-
assert kwargs["headers"]["custom"] == "header"
2491-
assert "x-goog-api-client" in kwargs["headers"]
2492-
assert "user-agent" in kwargs["headers"]
2493-
2494-
2495-
@pytest.mark.asyncio
2496-
async def test_acompletion_additional_args_non_vertex(
2497-
mock_acompletion, mock_client
2498-
):
2499-
"""Test that tracking headers are not added for non-Vertex AI models."""
2500-
lite_llm_instance = LiteLlm(
2501-
model="openai/gpt-4o",
2502-
llm_client=mock_client,
2503-
api_key="test_key",
2504-
headers={"custom": "header"},
2505-
)
2506-
2507-
async for _ in lite_llm_instance.generate_content_async(
2508-
LLM_REQUEST_WITH_FUNCTION_DECLARATION
2509-
):
2510-
pass
2511-
2512-
mock_acompletion.assert_called_once()
2513-
_, kwargs = mock_acompletion.call_args
2514-
assert kwargs["model"] == "openai/gpt-4o"
2515-
assert "headers" in kwargs
2516-
assert kwargs["headers"]["custom"] == "header"
2517-
assert "x-goog-api-client" not in kwargs["headers"]
2518-
assert "user-agent" not in kwargs["headers"]
25192488

25202489

25212490
@pytest.mark.asyncio

0 commit comments

Comments
 (0)