Skip to content

Commit ae157bd

Browse files
MarkDaoustcopybara-github
authored andcommitted
chore: support credentials in interactions
PiperOrigin-RevId: 875370538
1 parent dd52cc2 commit ae157bd

10 files changed

Lines changed: 736 additions & 179 deletions

File tree

google/genai/_api_client.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,21 @@ def __init__(
702702
)
703703
self._http_options.api_version = 'v1beta1'
704704
else: # Implicit initialization or missing arguments.
705-
if not self.api_key:
705+
if env_api_key and api_key:
706+
# Explicit api_key takes precedence over implicit api_key.
707+
logger.info(
708+
'The client initializer api_key argument takes '
709+
'precedence over the API key from the environment variable.'
710+
)
711+
if credentials:
712+
if env_api_key:
713+
logger.info(
714+
'The user `credentials` argument will take precedence over the'
715+
' api key from the environment variables.'
716+
)
717+
self.api_key = None
718+
719+
if not self.api_key and not credentials:
706720
raise ValueError(
707721
'No API key was provided. Please pass a valid API key. Learn how to'
708722
' create an API key at'
@@ -1175,20 +1189,21 @@ def _request_once(
11751189
stream: bool = False,
11761190
) -> HttpResponse:
11771191
data: Optional[Union[str, bytes]] = None
1178-
# If using proj/location, fetch ADC
1179-
if self.vertexai and (self.project or self.location):
1192+
1193+
uses_vertex_creds = self.vertexai and (self.project or self.location)
1194+
uses_mldev_creds = not self.vertexai and self._credentials
1195+
if (uses_vertex_creds or uses_mldev_creds):
11801196
http_request.headers['Authorization'] = f'Bearer {self._access_token()}'
11811197
if self._credentials and self._credentials.quota_project_id:
11821198
http_request.headers['x-goog-user-project'] = (
11831199
self._credentials.quota_project_id
11841200
)
1185-
data = json.dumps(http_request.data) if http_request.data else None
1186-
else:
1187-
if http_request.data:
1188-
if not isinstance(http_request.data, bytes):
1189-
data = json.dumps(http_request.data) if http_request.data else None
1190-
else:
1191-
data = http_request.data
1201+
1202+
if http_request.data:
1203+
if not isinstance(http_request.data, bytes):
1204+
data = json.dumps(http_request.data)
1205+
else:
1206+
data = http_request.data
11921207

11931208
if stream:
11941209
httpx_request = self._httpx_client.build_request(
@@ -1241,22 +1256,22 @@ async def _async_request_once(
12411256
) -> HttpResponse:
12421257
data: Optional[Union[str, bytes]] = None
12431258

1244-
# If using proj/location, fetch ADC
1245-
if self.vertexai and (self.project or self.location):
1259+
uses_vertex_creds = self.vertexai and (self.project or self.location)
1260+
uses_mldev_creds = not self.vertexai and self._credentials
1261+
if (uses_vertex_creds or uses_mldev_creds):
12461262
http_request.headers['Authorization'] = (
12471263
f'Bearer {await self._async_access_token()}'
12481264
)
12491265
if self._credentials and self._credentials.quota_project_id:
12501266
http_request.headers['x-goog-user-project'] = (
12511267
self._credentials.quota_project_id
12521268
)
1253-
data = json.dumps(http_request.data) if http_request.data else None
1254-
else:
1255-
if http_request.data:
1256-
if not isinstance(http_request.data, bytes):
1257-
data = json.dumps(http_request.data) if http_request.data else None
1258-
else:
1259-
data = http_request.data
1269+
1270+
if http_request.data:
1271+
if not isinstance(http_request.data, bytes):
1272+
data = json.dumps(http_request.data)
1273+
else:
1274+
data = http_request.data
12601275

12611276
if stream:
12621277
if self._use_aiohttp():

google/genai/_extra_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,20 @@
1616
"""Extra utils depending on types that are shared between sync and async modules."""
1717

1818
import asyncio
19+
from collections.abc import Callable, MutableMapping
1920
import inspect
2021
import io
2122
import logging
2223
import sys
2324
import typing
24-
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
25+
from typing import Any, Optional, Union, get_args, get_origin
2526
import mimetypes
2627
import os
2728
import pydantic
2829

30+
import google.auth.transport.requests
31+
32+
2933
from . import _common
3034
from . import _mcp_utils
3135
from . import _transformers as t
@@ -677,3 +681,18 @@ def prepare_resumable_upload(
677681
http_options.headers = {}
678682
http_options.headers['X-Goog-Upload-File-Name'] = os.path.basename(file)
679683
return http_options, size_bytes, mime_type
684+
685+
686+
async def _maybe_update_and_insert_auth_token(
687+
headers:MutableMapping[str, str],
688+
creds: google.auth.credentials.Credentials) -> None:
689+
# Refresh credentials to ensure token is valid
690+
if not (creds.token and creds.valid):
691+
try:
692+
auth_req = google.auth.transport.requests.Request() # type: ignore[no-untyped-call]
693+
await asyncio.to_thread(creds.refresh, auth_req)
694+
except Exception as e:
695+
raise ConnectionError(f"Failed to refresh credentials") from e
696+
697+
if not headers.get('Authorization'):
698+
headers['Authorization'] = f'Bearer {creds.token}'

google/genai/_interactions/_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:
178178

179179
@override
180180
def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
181-
if not self.client_adapter or not self.client_adapter.is_vertex_ai():
181+
if not self.client_adapter:
182182
return options
183183

184184
headers = options.headers or {}
@@ -400,7 +400,7 @@ def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:
400400

401401
@override
402402
async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
403-
if not self.client_adapter or not self.client_adapter.is_vertex_ai():
403+
if not self.client_adapter:
404404
return options
405405

406406
headers = options.headers or {}

google/genai/client.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -174,14 +174,12 @@ def _nextgen_client(self) -> AsyncGeminiNextGenAPIClient:
174174
# uSDk expects ms, nextgen uses a httpx Timeout -> expects seconds.
175175
timeout=http_opts.timeout / 1000 if http_opts.timeout else None,
176176
max_retries=max_retries,
177-
client_adapter=AsyncGeminiNextGenAPIClientAdapter(self._api_client)
177+
client_adapter=AsyncGeminiNextGenAPIClientAdapter(self._api_client),
178178
)
179179

180-
client = self._nextgen_client_instance
181-
if self._api_client.vertexai:
182-
client._is_vertex = True
183-
client._vertex_project = self._api_client.project
184-
client._vertex_location = self._api_client.location
180+
self._nextgen_client_instance._is_vertex = self._api_client.vertexai or False
181+
self._nextgen_client_instance._vertex_project = self._api_client.project
182+
self._nextgen_client_instance._vertex_location = self._api_client.location
185183

186184
return self._nextgen_client_instance
187185

@@ -525,11 +523,9 @@ def _nextgen_client(self) -> GeminiNextGenAPIClient:
525523
client_adapter=GeminiNextGenAPIClientAdapter(self._api_client),
526524
)
527525

528-
client = self._nextgen_client_instance
529-
if self._api_client.vertexai:
530-
client._is_vertex = True
531-
client._vertex_project = self._api_client.project
532-
client._vertex_location = self._api_client.location
526+
self._nextgen_client_instance._is_vertex = self._api_client.vertexai or False
527+
self._nextgen_client_instance._vertex_project = self._api_client.project
528+
self._nextgen_client_instance._vertex_location = self._api_client.location
533529

534530
return self._nextgen_client_instance
535531

0 commit comments

Comments
 (0)