diff --git a/src/google/adk/models/apigee_llm.py b/src/google/adk/models/apigee_llm.py index fc4928cb25..4b6a2f1d2e 100644 --- a/src/google/adk/models/apigee_llm.py +++ b/src/google/adk/models/apigee_llm.py @@ -29,6 +29,7 @@ from typing import Optional from typing import TYPE_CHECKING +import google.auth from google.adk import version as adk_version from google.genai import types import httpx @@ -52,6 +53,11 @@ _PROJECT_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_PROJECT' _LOCATION_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_LOCATION' +_APIGEE_SCOPES = [ + 'https://www.googleapis.com/auth/cloud-platform', + 'https://www.googleapis.com/auth/userinfo.email', +] + _CUSTOM_METADATA_FIELDS = ( 'id', 'created', @@ -232,6 +238,8 @@ def api_client(self) -> Client: **kwargs_for_http_options, ) + credentials, _ = google.auth.default(scopes=_APIGEE_SCOPES) + kwargs_for_client = {} kwargs_for_client['vertexai'] = self._isvertexai if self._isvertexai: @@ -239,6 +247,7 @@ def api_client(self) -> Client: kwargs_for_client['location'] = self._location return Client( + credentials=credentials, http_options=http_options, **kwargs_for_client, ) diff --git a/tests/unittests/models/test_apigee_llm.py b/tests/unittests/models/test_apigee_llm.py index c57bc9fcb8..7a959829e8 100644 --- a/tests/unittests/models/test_apigee_llm.py +++ b/tests/unittests/models/test_apigee_llm.py @@ -18,6 +18,7 @@ from unittest import mock from unittest.mock import AsyncMock +from google.adk.models.apigee_llm import _APIGEE_SCOPES from google.adk.models.apigee_llm import ApigeeLlm from google.adk.models.apigee_llm import CompletionsHTTPClient from google.adk.models.llm_request import LlmRequest @@ -627,3 +628,42 @@ async def test_api_key_injection_openai(model): ) client = apigee_llm._completions_http_client assert client._headers['Authorization'] == 'Bearer sk-test-key' + + +@pytest.mark.asyncio +@mock.patch('google.genai.Client') +@mock.patch('google.adk.models.apigee_llm.google.auth.default') +async def test_api_client_requests_userinfo_email_scope( + mock_auth_default, mock_client_constructor, llm_request +): + """Tests that api_client requests userinfo.email scope for Apigee Gateway tokeninfo.""" + mock_credentials = mock.Mock() + mock_auth_default.return_value = (mock_credentials, 'test-project') + + mock_client_instance = mock.Mock() + mock_client_instance.aio.models.generate_content = AsyncMock( + return_value=types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=Content( + parts=[Part.from_text(text='Test response')], + role='model', + ) + ) + ] + ) + ) + mock_client_constructor.return_value = mock_client_instance + + apigee_llm = ApigeeLlm( + model=APIGEE_GEMINI_MODEL_ID, + proxy_url=PROXY_URL, + ) + _ = [resp async for resp in apigee_llm.generate_content_async(llm_request)] + + mock_auth_default.assert_called_once_with(scopes=_APIGEE_SCOPES) + assert 'https://www.googleapis.com/auth/userinfo.email' in _APIGEE_SCOPES + assert 'https://www.googleapis.com/auth/cloud-platform' in _APIGEE_SCOPES + + _, kwargs = mock_client_constructor.call_args + assert kwargs['credentials'] is mock_credentials