Skip to content

Commit 0f5b677

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Mark Vertex calls made from non-gemini models
PiperOrigin-RevId: 848140103
1 parent 4f3b733 commit 0f5b677

9 files changed

Lines changed: 234 additions & 41 deletions

File tree

src/google/adk/models/anthropic_llm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from pydantic import BaseModel
3737
from typing_extensions import override
3838

39+
from ..utils._google_client_headers import get_tracking_headers
3940
from .base_llm import BaseLlm
4041
from .llm_response import LlmResponse
4142

@@ -345,4 +346,5 @@ def _anthropic_client(self) -> AsyncAnthropicVertex:
345346
return AsyncAnthropicVertex(
346347
project_id=os.environ["GOOGLE_CLOUD_PROJECT"],
347348
region=os.environ["GOOGLE_CLOUD_LOCATION"],
349+
default_headers=get_tracking_headers(),
348350
)

src/google/adk/models/apigee_llm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from google.genai import types
2626
from typing_extensions import override
2727

28+
from ..utils._google_client_headers import merge_tracking_headers
2829
from ..utils.env_utils import is_env_enabled
2930
from .google_llm import Gemini
3031

@@ -145,7 +146,7 @@ def api_client(self) -> Client:
145146
kwargs_for_http_options['api_version'] = self._api_version
146147
http_options = types.HttpOptions(
147148
base_url=self._proxy_url,
148-
headers=self._merge_tracking_headers(self._custom_headers),
149+
headers=merge_tracking_headers(self._custom_headers),
149150
retry_options=self.retry_options,
150151
**kwargs_for_http_options,
151152
)

src/google/adk/models/google_llm.py

Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
from google.genai.errors import ClientError
3131
from typing_extensions import override
3232

33-
from ..utils._client_labels_utils import get_client_labels
33+
from ..utils._google_client_headers import get_tracking_headers
34+
from ..utils._google_client_headers import merge_tracking_headers
3435
from ..utils.context_utils import Aclosing
3536
from ..utils.streaming_utils import StreamingResponseAggregator
3637
from ..utils.variant_utils import GoogleLLMVariant
@@ -191,7 +192,7 @@ async def generate_content_async(
191192
if llm_request.config:
192193
if not llm_request.config.http_options:
193194
llm_request.config.http_options = types.HttpOptions()
194-
llm_request.config.http_options.headers = self._merge_tracking_headers(
195+
llm_request.config.http_options.headers = merge_tracking_headers(
195196
llm_request.config.http_options.headers
196197
)
197198

@@ -302,7 +303,7 @@ def api_client(self) -> Client:
302303

303304
return Client(
304305
http_options=types.HttpOptions(
305-
headers=self._tracking_headers(),
306+
headers=get_tracking_headers(),
306307
retry_options=self.retry_options,
307308
)
308309
)
@@ -315,15 +316,6 @@ def _api_backend(self) -> GoogleLLMVariant:
315316
else GoogleLLMVariant.GEMINI_API
316317
)
317318

318-
def _tracking_headers(self) -> dict[str, str]:
319-
labels = get_client_labels()
320-
header_value = ' '.join(labels)
321-
tracking_headers = {
322-
'x-goog-api-client': header_value,
323-
'user-agent': header_value,
324-
}
325-
return tracking_headers
326-
327319
@cached_property
328320
def _live_api_version(self) -> str:
329321
if self._api_backend == GoogleLLMVariant.VERTEX_AI:
@@ -339,7 +331,7 @@ def _live_api_client(self) -> Client:
339331

340332
return Client(
341333
http_options=types.HttpOptions(
342-
headers=self._tracking_headers(), api_version=self._live_api_version
334+
headers=get_tracking_headers(), api_version=self._live_api_version
343335
)
344336
)
345337

@@ -362,8 +354,10 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
362354
):
363355
if not llm_request.live_connect_config.http_options.headers:
364356
llm_request.live_connect_config.http_options.headers = {}
365-
llm_request.live_connect_config.http_options.headers.update(
366-
self._tracking_headers()
357+
llm_request.live_connect_config.http_options.headers = (
358+
merge_tracking_headers(
359+
llm_request.live_connect_config.http_options.headers
360+
)
367361
)
368362
llm_request.live_connect_config.http_options.api_version = (
369363
self._live_api_version
@@ -447,23 +441,6 @@ async def _preprocess_request(self, llm_request: LlmRequest) -> None:
447441
llm_request.config.system_instruction = None
448442
await self._adapt_computer_use_tool(llm_request)
449443

450-
def _merge_tracking_headers(self, headers: dict[str, str]) -> dict[str, str]:
451-
"""Merge tracking headers to the given headers."""
452-
headers = headers or {}
453-
for key, tracking_header_value in self._tracking_headers().items():
454-
custom_value = headers.get(key, None)
455-
if not custom_value:
456-
headers[key] = tracking_header_value
457-
continue
458-
459-
# Merge tracking headers with existing headers and avoid duplicates.
460-
value_parts = tracking_header_value.split(' ')
461-
for custom_value_part in custom_value.split(' '):
462-
if custom_value_part not in value_parts:
463-
value_parts.append(custom_value_part)
464-
headers[key] = ' '.join(value_parts)
465-
return headers
466-
467444

468445
def _build_function_declaration_log(
469446
func_decl: types.FunctionDeclaration,

src/google/adk/models/lite_llm.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from pydantic import Field
5858
from typing_extensions import override
5959

60+
from ..utils._google_client_headers import merge_tracking_headers
6061
from .base_llm import BaseLlm
6162
from .llm_request import LlmRequest
6263
from .llm_response import LlmResponse
@@ -1390,6 +1391,18 @@ def _build_request_log(req: LlmRequest) -> str:
13901391
"""
13911392

13921393

1394+
def _is_litellm_vertex_model(model_string: str) -> bool:
1395+
"""Check if the model is a Vertex AI model accessed via LiteLLM.
1396+
1397+
Args:
1398+
model_string: A LiteLLM model string (e.g., "vertex_ai/gemini-2.5-flash")
1399+
1400+
Returns:
1401+
True if it's a Vertex AI model accessed via LiteLLM, False otherwise
1402+
"""
1403+
return model_string.startswith("vertex_ai/")
1404+
1405+
13931406
def _is_litellm_gemini_model(model_string: str) -> bool:
13941407
"""Check if the model is a Gemini model accessed via LiteLLM.
13951408
@@ -1562,6 +1575,14 @@ async def generate_content_async(
15621575
}
15631576
completion_args.update(self._additional_args)
15641577

1578+
# merge headers
1579+
if _is_litellm_vertex_model(effective_model) or _is_litellm_gemini_model(
1580+
effective_model
1581+
):
1582+
completion_args["headers"] = merge_tracking_headers(
1583+
completion_args.get("headers")
1584+
)
1585+
15651586
if generation_params:
15661587
completion_args.update(generation_params)
15671588

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from ._client_labels_utils import get_client_labels
18+
19+
20+
def get_tracking_headers() -> dict[str, str]:
21+
"""Returns a dictionary of HTTP headers for tracking API requests.
22+
23+
These headers are used to identify HTTP calls made by ADK towards
24+
Vertex AI LLM APIs.
25+
"""
26+
labels = get_client_labels()
27+
header_value = " ".join(labels)
28+
return {
29+
"x-goog-api-client": header_value,
30+
"user-agent": header_value,
31+
}
32+
33+
34+
def merge_tracking_headers(headers: dict[str, str] | None) -> dict[str, str]:
35+
"""Merge tracking headers to the given headers.
36+
37+
Args:
38+
headers: headers to merge tracking headers into.
39+
40+
Returns:
41+
A dictionary of HTTP headers with tracking headers merged.
42+
"""
43+
new_headers = (headers or {}).copy()
44+
for key, tracking_header_value in get_tracking_headers().items():
45+
custom_value = new_headers.get(key, None)
46+
if not custom_value:
47+
new_headers[key] = tracking_header_value
48+
continue
49+
50+
# Merge tracking headers with existing headers and avoid duplicates.
51+
value_parts = tracking_header_value.split(" ")
52+
for custom_value_part in custom_value.split(" "):
53+
if custom_value_part not in value_parts:
54+
value_parts.append(custom_value_part)
55+
new_headers[key] = " ".join(value_parts)
56+
return new_headers

tests/unittests/models/test_anthropic_llm.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,31 @@ async def mock_coro():
391391
assert responses[0].content.parts[0].text == "Hello, how can I help you?"
392392

393393

394+
def test_claude_vertex_client_uses_tracking_headers():
395+
"""Tests that Claude vertex client is called with tracking headers."""
396+
with mock.patch.object(
397+
anthropic_llm, "AsyncAnthropicVertex", autospec=True
398+
) as mock_anthropic_vertex:
399+
with mock.patch.dict(
400+
os.environ,
401+
{
402+
"GOOGLE_CLOUD_PROJECT": "test-project",
403+
"GOOGLE_CLOUD_LOCATION": "us-central1",
404+
},
405+
):
406+
instance = Claude(model="claude-3-5-sonnet-v2@20241022")
407+
_ = instance._anthropic_client
408+
mock_anthropic_vertex.assert_called_once()
409+
_, kwargs = mock_anthropic_vertex.call_args
410+
assert "default_headers" in kwargs
411+
assert "x-goog-api-client" in kwargs["default_headers"]
412+
assert "user-agent" in kwargs["default_headers"]
413+
assert (
414+
f"google-adk/{adk_version.__version__}"
415+
in kwargs["default_headers"]["user-agent"]
416+
)
417+
418+
394419
@pytest.mark.asyncio
395420
async def test_generate_content_async_with_max_tokens(
396421
llm_request, generate_content_response, generate_llm_response

tests/unittests/models/test_google_llm.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from google.adk.models.llm_response import LlmResponse
3232
from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME
3333
from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_TAG
34+
from google.adk.utils._google_client_headers import get_tracking_headers
3435
from google.adk.utils.variant_utils import GoogleLLMVariant
3536
from google.genai import types
3637
from google.genai.errors import ClientError
@@ -469,7 +470,7 @@ async def test_generate_content_async_with_custom_headers(
469470
"""Test that tracking headers are updated when custom headers are provided."""
470471
# Add custom headers to the request config
471472
custom_headers = {"custom-header": "custom-value"}
472-
tracking_headers = gemini_llm._tracking_headers()
473+
tracking_headers = get_tracking_headers()
473474
for key in tracking_headers:
474475
custom_headers[key] = "custom " + tracking_headers[key]
475476
llm_request.config.http_options = types.HttpOptions(headers=custom_headers)
@@ -494,7 +495,7 @@ async def mock_coro():
494495
config_arg = call_args.kwargs["config"]
495496

496497
for key, value in config_arg.http_options.headers.items():
497-
tracking_headers = gemini_llm._tracking_headers()
498+
tracking_headers = get_tracking_headers()
498499
if key in tracking_headers:
499500
assert value == tracking_headers[key] + " custom"
500501
else:
@@ -545,7 +546,7 @@ async def mock_coro():
545546
config_arg = call_args.kwargs["config"]
546547

547548
expected_headers = custom_headers.copy()
548-
expected_headers.update(gemini_llm._tracking_headers())
549+
expected_headers.update(get_tracking_headers())
549550
assert config_arg.http_options.headers == expected_headers
550551

551552
assert len(responses) == 2
@@ -599,7 +600,7 @@ async def mock_coro():
599600
assert final_config.http_options is not None
600601
assert (
601602
final_config.http_options.headers["x-goog-api-client"]
602-
== gemini_llm._tracking_headers()["x-goog-api-client"]
603+
== get_tracking_headers()["x-goog-api-client"]
603604
)
604605

605606
assert len(responses) == 2 if stream else 1
@@ -633,7 +634,7 @@ def test_live_api_client_properties(gemini_llm):
633634
assert http_options.api_version == "v1beta1"
634635

635636
# Check that tracking headers are included
636-
tracking_headers = gemini_llm._tracking_headers()
637+
tracking_headers = get_tracking_headers()
637638
for key, value in tracking_headers.items():
638639
assert key in http_options.headers
639640
assert value in http_options.headers[key]
@@ -671,7 +672,7 @@ async def __aexit__(self, *args):
671672

672673
# Verify that tracking headers were merged with custom headers
673674
expected_headers = custom_headers.copy()
674-
expected_headers.update(gemini_llm._tracking_headers())
675+
expected_headers.update(get_tracking_headers())
675676
assert config_arg.http_options.headers == expected_headers
676677

677678
# Verify that API version was set

tests/unittests/models/test_litellm.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2447,11 +2447,12 @@ def test_model_response_to_chunk(
24472447
async def test_acompletion_additional_args(mock_acompletion, mock_client):
24482448
lite_llm_instance = LiteLlm(
24492449
# valid args
2450-
model="test_model",
2450+
model="vertex_ai/test_model",
24512451
llm_client=mock_client,
24522452
api_key="test_key",
24532453
api_base="some://url",
24542454
api_version="2024-09-12",
2455+
headers={"custom": "header"}, # Add custom header to test merge
24552456
# invalid args (ignored)
24562457
stream=True,
24572458
messages=[{"role": "invalid", "content": "invalid"}],
@@ -2478,13 +2479,43 @@ async def test_acompletion_additional_args(mock_acompletion, mock_client):
24782479

24792480
_, kwargs = mock_acompletion.call_args
24802481

2481-
assert kwargs["model"] == "test_model"
2482+
assert kwargs["model"] == "vertex_ai/test_model"
24822483
assert kwargs["messages"][0]["role"] == "user"
24832484
assert kwargs["messages"][0]["content"] == "Test prompt"
24842485
assert kwargs["tools"][0]["function"]["name"] == "test_function"
24852486
assert "stream" not in kwargs
24862487
assert "llm_client" not in kwargs
24872488
assert kwargs["api_base"] == "some://url"
2489+
assert "headers" in kwargs
2490+
assert kwargs["headers"]["custom"] == "header"
2491+
assert "x-goog-api-client" in kwargs["headers"]
2492+
assert "user-agent" in kwargs["headers"]
2493+
2494+
2495+
@pytest.mark.asyncio
2496+
async def test_acompletion_additional_args_non_vertex(
2497+
mock_acompletion, mock_client
2498+
):
2499+
"""Test that tracking headers are not added for non-Vertex AI models."""
2500+
lite_llm_instance = LiteLlm(
2501+
model="openai/gpt-4o",
2502+
llm_client=mock_client,
2503+
api_key="test_key",
2504+
headers={"custom": "header"},
2505+
)
2506+
2507+
async for _ in lite_llm_instance.generate_content_async(
2508+
LLM_REQUEST_WITH_FUNCTION_DECLARATION
2509+
):
2510+
pass
2511+
2512+
mock_acompletion.assert_called_once()
2513+
_, kwargs = mock_acompletion.call_args
2514+
assert kwargs["model"] == "openai/gpt-4o"
2515+
assert "headers" in kwargs
2516+
assert kwargs["headers"]["custom"] == "header"
2517+
assert "x-goog-api-client" not in kwargs["headers"]
2518+
assert "user-agent" not in kwargs["headers"]
24882519

24892520

24902521
@pytest.mark.asyncio

0 commit comments

Comments
 (0)