Skip to content

Commit 4d43209

Browse files
committed
fix(adk): resolve OAuth persistence, ID token extraction, and auth flow crashes
This PR addresses several authentication bugs within the `toolbox-adk` wrapper that were preventing successful OIDC flows, causing repeated login prompts, and crashing during tool execution. `adk-python` natively drops the `id_token` during OAuth2 exchange (`update_credential_with_tokens`), which is strictly required by the MCP Toolbox for `USER_IDENTITY` authentication. #### Solution Added a temporary monkey patch to hook into the `adk-python` credential update process and explicitly preserve `tokens.get("id_token")`. > [!NOTE] > Tracked by `TODO`s to remove once upstream PR google/adk-python#4402 is merged. Tools repeatedly prompted users to authenticate because the exchanged credentials were not being appropriately persisted across sessions (`exchanged_auth_credential` was left unpopulated/None). Explicitly assigned the newly fetched user credentials to `auth_config_adk.exchanged_auth_credential` and synced them to storage using `tool_context._invocation_context.credential_service.save_credential()`. Passing an empty list of scopes triggered an `AttributeError` deep inside the `adk-python` `auth_handler.py` due to a falsy chain evaluation on empty dictionaries. Added default OIDC fallback scopes (`["openid", "profile", "email"]`) for `USER_IDENTITY` flows when no explicit scopes are provided. If a tool required the same authentication service across multiple parameters, the `add_auth_token_getter` logic would attempt to register it twice, causing a `ValueError: already registered`. Collected `needed_services` into a deduplicated `set` and added a protective check before registration (`if not hasattr(...) or s not in ...`) to ensure idempotency. - [x] Verified that submitting empty scopes successfully defaults to `openid profile email` and resolves without crashing `adk-python`. - [x] Verified that the `id_token` is successfully propagated bounds and retrieved via `getattr(creds.oauth2, "id_token", ...)` during tool execution. - [x] Verified that subsequent tool executions load the saved credentials seamlessly from the `CredentialService` without triggering continuous Google OAuth consent screens.
1 parent 4f11229 commit 4d43209

3 files changed

Lines changed: 164 additions & 58 deletions

File tree

packages/toolbox-adk/src/toolbox_adk/tool.py

Lines changed: 104 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,22 @@
3636
from .client import USER_TOKEN_CONTEXT_VAR
3737
from .credentials import CredentialConfig, CredentialType
3838

39+
# --- Monkey Patch ADK OAuth2 Exchange to Retain ID Tokens ---
40+
# TODO(id_token): Remove this monkey patch once the PR https://github.com/google/adk-python/pull/4402 is merged.
41+
# Google's ID Token is required by MCP Toolbox but ADK's `update_credential_with_tokens` natively drops the `id_token`.
42+
import google.adk.auth.oauth2_credential_util as oauth2_credential_util
43+
import google.adk.auth.exchanger.oauth2_credential_exchanger as oauth2_credential_exchanger
44+
_orig_update_cred = oauth2_credential_util.update_credential_with_tokens
45+
46+
def _patched_update_credential_with_tokens(auth_credential, tokens):
47+
_orig_update_cred(auth_credential, tokens)
48+
if tokens and "id_token" in tokens and auth_credential and auth_credential.oauth2:
49+
# Pydantic's `extra="allow"` config preserves this dynamically set attribute
50+
setattr(auth_credential.oauth2, "id_token", tokens["id_token"])
51+
52+
oauth2_credential_util.update_credential_with_tokens = _patched_update_credential_with_tokens
53+
oauth2_credential_exchanger.update_credential_with_tokens = _patched_update_credential_with_tokens
54+
# -------------------------------------------------------------
3955

4056
class ToolboxTool(BaseTool):
4157
"""
@@ -135,56 +151,98 @@ async def run_async(
135151
reset_token = None
136152

137153
if self._auth_config and self._auth_config.type == CredentialType.USER_IDENTITY:
138-
if not self._auth_config.client_id or not self._auth_config.client_secret:
139-
raise ValueError("USER_IDENTITY requires client_id and client_secret")
140-
141-
# Construct ADK AuthConfig
142-
scopes = self._auth_config.scopes or [
143-
"https://www.googleapis.com/auth/cloud-platform"
144-
]
145-
scope_dict = {s: "" for s in scopes}
146-
147-
auth_config_adk = AuthConfig(
148-
auth_scheme=OAuth2(
149-
flows=OAuthFlows(
150-
authorizationCode=OAuthFlowAuthorizationCode(
151-
authorizationUrl="https://accounts.google.com/o/oauth2/auth",
152-
tokenUrl="https://oauth2.googleapis.com/token",
153-
scopes=scope_dict,
154+
requires_auth = (
155+
len(self._core_tool._required_authn_params) > 0
156+
or len(self._core_tool._required_authz_tokens) > 0
157+
)
158+
159+
if requires_auth:
160+
if not self._auth_config.client_id or not self._auth_config.client_secret:
161+
raise ValueError("USER_IDENTITY requires client_id and client_secret")
162+
163+
# Construct ADK AuthConfig
164+
scopes = self._auth_config.scopes or ["openid", "profile", "email"]
165+
scope_dict = {s: "" for s in scopes}
166+
167+
auth_config_adk = AuthConfig(
168+
auth_scheme=OAuth2(
169+
flows=OAuthFlows(
170+
authorizationCode=OAuthFlowAuthorizationCode(
171+
authorizationUrl="https://accounts.google.com/o/oauth2/auth",
172+
tokenUrl="https://oauth2.googleapis.com/token",
173+
scopes=scope_dict,
174+
)
154175
)
155-
)
156-
),
157-
raw_auth_credential=AuthCredential(
158-
auth_type=AuthCredentialTypes.OAUTH2,
159-
oauth2=OAuth2Auth(
160-
client_id=self._auth_config.client_id,
161-
client_secret=self._auth_config.client_secret,
162176
),
163-
),
164-
)
177+
raw_auth_credential=AuthCredential(
178+
auth_type=AuthCredentialTypes.OAUTH2,
179+
oauth2=OAuth2Auth(
180+
client_id=self._auth_config.client_id,
181+
client_secret=self._auth_config.client_secret,
182+
),
183+
),
184+
)
165185

166-
# Check if we already have credentials from a previous exchange
167-
try:
168-
# get_auth_response returns AuthCredential if found
169-
creds = tool_context.get_auth_response(auth_config_adk)
170-
if creds and creds.oauth2 and creds.oauth2.access_token:
171-
reset_token = USER_TOKEN_CONTEXT_VAR.set(creds.oauth2.access_token)
172-
else:
173-
# Request credentials and pause execution
186+
# Check if we already have credentials from a previous exchange
187+
try:
188+
# Try to load credential from credential service first (persists across sessions)
189+
creds = None
190+
try:
191+
if tool_context._invocation_context.credential_service:
192+
creds = await tool_context._invocation_context.credential_service.load_credential(
193+
auth_config=auth_config_adk,
194+
callback_context=tool_context
195+
)
196+
except ValueError:
197+
# Credential service might not be initialized
198+
pass
199+
200+
if not creds:
201+
# Fallback to session state (get_auth_response returns AuthCredential if found)
202+
creds = tool_context.get_auth_response(auth_config_adk)
203+
204+
if creds and creds.oauth2 and creds.oauth2.access_token:
205+
reset_token = USER_TOKEN_CONTEXT_VAR.set(creds.oauth2.access_token)
206+
207+
# Bind the token to the underlying core_tool so it constructs headers properly
208+
needed_services = set()
209+
for requested_service in (list(self._core_tool._required_authn_params.values()) + list(self._core_tool._required_authz_tokens)):
210+
if isinstance(requested_service, list):
211+
needed_services.update(requested_service)
212+
else:
213+
needed_services.add(requested_service)
214+
215+
for s in needed_services:
216+
# Only add if not already registered (prevents ValueError on duplicate params or subsequent runs)
217+
if not hasattr(self._core_tool, '_auth_token_getters') or s not in self._core_tool._auth_token_getters:
218+
# TODO(id_token): Uncomment this line and remove the `getattr` fallback below once PR https://github.com/google/adk-python/pull/4402 is merged.
219+
# self._core_tool = self._core_tool.add_auth_token_getter(s, lambda t=creds.oauth2.id_token or creds.oauth2.access_token: t)
220+
self._core_tool = self._core_tool.add_auth_token_getter(s, lambda t=getattr(creds.oauth2, "id_token", creds.oauth2.access_token): t)
221+
# Once we use it from get_auth_response, save it to the auth service for future use
222+
try:
223+
if tool_context._invocation_context.credential_service:
224+
auth_config_adk.exchanged_auth_credential = creds
225+
await tool_context._invocation_context.credential_service.save_credential(
226+
auth_config=auth_config_adk,
227+
callback_context=tool_context
228+
)
229+
except Exception as e:
230+
logging.debug(f"Failed to save credential to service: {e}")
231+
else:
232+
tool_context.request_credential(auth_config_adk)
233+
return {"error": f"OAuth2 Credentials required for {self.name}. A consent link has been generated for the user. Do NOT attempt to run this tool again until the user confirms they have logged in."}
234+
except Exception as e:
235+
if "credential" in str(e).lower() or isinstance(e, ValueError):
236+
raise e
237+
238+
logging.warning(
239+
f"Unexpected error in get_auth_response during User Identity (OAuth2) retrieval: {e}. "
240+
"Falling back to request_credential.",
241+
exc_info=True
242+
)
243+
# Fallback to request logic
174244
tool_context.request_credential(auth_config_adk)
175-
return None
176-
except Exception as e:
177-
if "credential" in str(e).lower() or isinstance(e, ValueError):
178-
raise e
179-
180-
logging.warning(
181-
f"Unexpected error in get_auth_response during User Identity (OAuth2) retrieval: {e}. "
182-
"Falling back to request_credential.",
183-
exc_info=True
184-
)
185-
# Fallback to request logic
186-
tool_context.request_credential(auth_config_adk)
187-
return None
245+
return {"error": f"OAuth2 Credentials required for {self.name}. A consent link has been generated for the user. Do NOT attempt to run this tool again until the user confirms they have logged in."}
188246

189247
result: Optional[Any] = None
190248
error: Optional[Exception] = None

packages/toolbox-adk/tests/integration/test_integration.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,30 +203,41 @@ async def test_3lo_flow_simulation(self):
203203
# The wrapper should catch the missing creds and request them.
204204
assert result_first is None, "Tool should return None sig for auth requirement"
205205
mock_ctx_first.request_credential.assert_called_once()
206-
206+
207207
# Inspect the requested config
208208
auth_config = mock_ctx_first.request_credential.call_args[0][0]
209209
assert auth_config.raw_auth_credential.oauth2.client_id == "test-client-id"
210+
# Verify the default fallback scopes were assigned correctly to avoid upstream crashes
211+
assert auth_config.auth_scheme.flows.authorizationCode.scopes == {"openid": "", "profile": "", "email": ""}
210212

211213
mock_ctx_second = MagicMock()
212214

213215
# Simulate "Auth Response Found"
214216
mock_creds = AuthCredential(
215217
auth_type=AuthCredentialTypes.OAUTH2,
216-
oauth2=OAuth2Auth(access_token="fake-access-token"),
218+
oauth2=OAuth2Auth(access_token="fake-access-token", id_token="fake-id-token"),
217219
)
218220
mock_ctx_second.get_auth_response.return_value = mock_creds
219221

222+
# Setup the credential service mock to verify credential persistence across sessions
223+
mock_cred_service = AsyncMock()
224+
mock_ctx_second._invocation_context = MagicMock()
225+
mock_ctx_second._invocation_context.credential_service = mock_cred_service
226+
220227
print("Running tool second time (expecting success or server error)...")
221228

222229
try:
223230
result_second = await tool.run_async({"num_rows": "1"}, mock_ctx_second)
224231
assert result_second is not None
232+
# Verify that the tool saved the credentials to the storage service backends locally
233+
mock_cred_service.save_credential.assert_called_once()
225234
except Exception as e:
226235
mock_ctx_second.request_credential.assert_not_called()
227236
err_msg = str(e).lower()
228237
assert any(x in err_msg for x in ["401", "403", "unauthorized", "forbidden"]), f"Caught UNEXPECTED exception: {type(e).__name__}: {e}"
229238
print(f"Caught expected server exception with fake token: {e}")
239+
# Verify that the tool AT LEAST triggered save_credential before failing via core_tool inner HTTP req
240+
mock_cred_service.save_credential.assert_called_once()
230241

231242
finally:
232243
await toolset.close()

packages/toolbox-adk/tests/unit/test_tool.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ async def test_3lo_missing_client_secret(self):
155155
core_tool = AsyncMock()
156156
core_tool.__name__ = "mock_tool"
157157
core_tool.__doc__ = "mock doc"
158+
core_tool._required_authn_params = {"mock_param": "mock_service"}
159+
core_tool._required_authz_tokens = []
158160
auth_config = CredentialConfig(type=CredentialType.USER_IDENTITY)
159161
# Missing client_id/secret
160162

@@ -174,6 +176,8 @@ async def test_3lo_request_credential_when_missing(self):
174176
core_tool.__doc__ = "mock"
175177
core_tool.__name__ = "mock_tool"
176178
core_tool.__doc__ = "mock doc"
179+
core_tool._required_authn_params = {"mock_param": "mock_service"}
180+
core_tool._required_authz_tokens = []
177181

178182
auth_config = CredentialConfig(
179183
type=CredentialType.USER_IDENTITY, client_id="cid", client_secret="csec"
@@ -187,8 +191,8 @@ async def test_3lo_request_credential_when_missing(self):
187191

188192
result = await tool.run_async({}, ctx)
189193

190-
# Verify result is None (signal pause)
191-
assert result is None
194+
# Verify result is error/stop
195+
assert isinstance(result, dict) and "error" in result
192196
# Verify request_credential was called
193197
ctx.request_credential.assert_called_once()
194198
# Verify core tool was NOT called
@@ -198,10 +202,12 @@ async def test_3lo_request_credential_when_missing(self):
198202
async def test_3lo_uses_existing_credential(self):
199203
# Test that if creds exist, they are used and injected
200204
core_tool = AsyncMock(return_value="success")
201-
core_tool.__name__ = "mock"
202-
core_tool.__doc__ = "mock"
203205
core_tool.__name__ = "mock_tool"
204206
core_tool.__doc__ = "mock doc"
207+
# Setup overlapping needed services to test deduplication
208+
core_tool._required_authn_params = {"mock_param": "mock_service", "another_param": "mock_service"}
209+
core_tool._required_authz_tokens = ["mock_service"]
210+
core_tool.add_auth_token_getter = MagicMock(return_value=core_tool)
205211

206212
auth_config = CredentialConfig(
207213
type=CredentialType.USER_IDENTITY, client_id="cid", client_secret="csec"
@@ -210,11 +216,19 @@ async def test_3lo_uses_existing_credential(self):
210216
tool = ToolboxTool(core_tool, auth_config=auth_config)
211217

212218
ctx = MagicMock()
213-
# Mock get_auth_response returning valid creds
219+
# Mock get_auth_response returning valid creds with both access & id tokens
214220
mock_creds = MagicMock()
215-
mock_creds.oauth2.access_token = "valid_token"
221+
mock_creds.oauth2.access_token = "valid_access_token"
222+
mock_creds.oauth2.id_token = "valid_id_token"
216223
ctx.get_auth_response.return_value = mock_creds
217224

225+
# Set up invocation context and credential service mock to verify saving and avoid await errors
226+
mock_cred_service = MagicMock()
227+
mock_cred_service.load_credential = AsyncMock(return_value=None)
228+
mock_cred_service.save_credential = AsyncMock(return_value=None)
229+
ctx._invocation_context = MagicMock()
230+
ctx._invocation_context.credential_service = mock_cred_service
231+
218232
result = await tool.run_async({}, ctx)
219233

220234
# Verify result is success
@@ -223,22 +237,43 @@ async def test_3lo_uses_existing_credential(self):
223237
ctx.request_credential.assert_not_called()
224238
# Verify core tool WAS called
225239
core_tool.assert_called_once()
240+
241+
# Verify deduplication: add_auth_token_getter should only be called ONCE for "mock_service"
242+
core_tool.add_auth_token_getter.assert_called_once()
243+
call_args_getter = core_tool.add_auth_token_getter.call_args[0]
244+
assert call_args_getter[0] == "mock_service"
245+
# Evaluate the getter lambda to ensure it prefers id_token
246+
token_getter_lambda = call_args_getter[1]
247+
assert token_getter_lambda() == "valid_id_token"
248+
249+
# Verify save_credential was called with the exchanged credential
250+
mock_cred_service.save_credential.assert_called_once()
251+
call_args = mock_cred_service.save_credential.call_args[1]
252+
assert call_args["auth_config"].exchanged_auth_credential == mock_creds
253+
254+
# Verify safe scope fallback to ["openid", "profile", "email"] when scopes is None
255+
assert call_args["auth_config"].auth_scheme.flows.authorizationCode.scopes == {"openid": "", "profile": "", "email": ""}
226256

227257
@pytest.mark.asyncio
228258
async def test_3lo_exception_reraise(self):
229259
# Test that specific credential errors are re-raised
230260
core_tool = AsyncMock()
231-
core_tool.__name__ = "mock"
232-
core_tool.__doc__ = "mock"
233261
core_tool.__name__ = "mock_tool"
234262
core_tool.__doc__ = "mock doc"
263+
core_tool._required_authn_params = {"mock_param": "mock_service"}
264+
core_tool._required_authz_tokens = []
235265

236266
auth_config = CredentialConfig(
237267
type=CredentialType.USER_IDENTITY, client_id="cid", client_secret="csec"
238268
)
239269
tool = ToolboxTool(core_tool, auth_config=auth_config)
240270
ctx = MagicMock()
241271

272+
mock_cred_service = MagicMock()
273+
mock_cred_service.load_credential = AsyncMock(return_value=None)
274+
ctx._invocation_context = MagicMock()
275+
ctx._invocation_context.credential_service = mock_cred_service
276+
242277
# Mock get_auth_response raising ValueError
243278
ctx.get_auth_response.side_effect = ValueError("Invalid Credential")
244279

@@ -253,6 +288,8 @@ async def test_3lo_exception_fallback(self):
253288
core_tool.__doc__ = "mock"
254289
core_tool.__name__ = "mock_tool"
255290
core_tool.__doc__ = "mock doc"
291+
core_tool._required_authn_params = {"mock_param": "mock_service"}
292+
core_tool._required_authz_tokens = []
256293

257294
auth_config = CredentialConfig(
258295
type=CredentialType.USER_IDENTITY, client_id="cid", client_secret="csec"
@@ -265,8 +302,8 @@ async def test_3lo_exception_fallback(self):
265302

266303
result = await tool.run_async({}, ctx)
267304

268-
# Should catch RuntimeError, call request_credential, and return None
269-
assert result is None
305+
# Should catch RuntimeError, call request_credential, and return error map
306+
assert isinstance(result, dict) and "error" in result
270307
ctx.request_credential.assert_called_once()
271308

272309
def test_param_type_to_schema_type(self):

0 commit comments

Comments
 (0)