Skip to content

Commit 467aa0c

Browse files
committed
fix(adk): resolve OAuth persistence, ID token extraction, and auth flow crashes
This PR addresses several authentication bugs within the `toolbox-adk` wrapper that were preventing successful OIDC flows, causing repeated login prompts, and crashing during tool execution. `adk-python` natively drops the `id_token` during OAuth2 exchange (`update_credential_with_tokens`), which is strictly required by the MCP Toolbox for `USER_IDENTITY` authentication. #### Solution Added a temporary monkey patch to hook into the `adk-python` credential update process and explicitly preserve `tokens.get("id_token")`. > [!NOTE] > Tracked by `TODO`s to remove once upstream PR google/adk-python#4402 is merged. Tools repeatedly prompted users to authenticate because the exchanged credentials were not being appropriately persisted across sessions (`exchanged_auth_credential` was left unpopulated/None). Explicitly assigned the newly fetched user credentials to `auth_config_adk.exchanged_auth_credential` and synced them to storage using `tool_context._invocation_context.credential_service.save_credential()`. Passing an empty list of scopes triggered an `AttributeError` deep inside the `adk-python` `auth_handler.py` due to a falsy chain evaluation on empty dictionaries. Added default OIDC fallback scopes (`["openid", "profile", "email"]`) for `USER_IDENTITY` flows when no explicit scopes are provided. If a tool required the same authentication service across multiple parameters, the `add_auth_token_getter` logic would attempt to register it twice, causing a `ValueError: already registered`. Collected `needed_services` into a deduplicated `set` and added a protective check before registration (`if not hasattr(...) or s not in ...`) to ensure idempotency. - [x] Verified that submitting empty scopes successfully defaults to `openid profile email` and resolves without crashing `adk-python`. - [x] Verified that the `id_token` is successfully propagated bounds and retrieved via `getattr(creds.oauth2, "id_token", ...)` during tool execution. - [x] Verified that subsequent tool executions load the saved credentials seamlessly from the `CredentialService` without triggering continuous Google OAuth consent screens.
1 parent 850d006 commit 467aa0c

3 files changed

Lines changed: 164 additions & 58 deletions

File tree

packages/toolbox-adk/src/toolbox_adk/tool.py

Lines changed: 104 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,22 @@
3636
from .client import USER_TOKEN_CONTEXT_VAR
3737
from .credentials import CredentialConfig, CredentialType
3838

39+
# --- Monkey Patch ADK OAuth2 Exchange to Retain ID Tokens ---
40+
# TODO(id_token): Remove this monkey patch once the PR https://github.com/google/adk-python/pull/4402 is merged.
41+
# Google's ID Token is required by MCP Toolbox but ADK's `update_credential_with_tokens` natively drops the `id_token`.
42+
import google.adk.auth.oauth2_credential_util as oauth2_credential_util
43+
import google.adk.auth.exchanger.oauth2_credential_exchanger as oauth2_credential_exchanger
44+
_orig_update_cred = oauth2_credential_util.update_credential_with_tokens
45+
46+
def _patched_update_credential_with_tokens(auth_credential, tokens):
47+
_orig_update_cred(auth_credential, tokens)
48+
if tokens and "id_token" in tokens and auth_credential and auth_credential.oauth2:
49+
# Pydantic's `extra="allow"` config preserves this dynamically set attribute
50+
setattr(auth_credential.oauth2, "id_token", tokens["id_token"])
51+
52+
oauth2_credential_util.update_credential_with_tokens = _patched_update_credential_with_tokens
53+
oauth2_credential_exchanger.update_credential_with_tokens = _patched_update_credential_with_tokens
54+
# -------------------------------------------------------------
3955

4056
class ToolboxTool(BaseTool):
4157
"""
@@ -133,56 +149,98 @@ async def run_async(
133149
reset_token = None
134150

135151
if self._auth_config and self._auth_config.type == CredentialType.USER_IDENTITY:
136-
if not self._auth_config.client_id or not self._auth_config.client_secret:
137-
raise ValueError("USER_IDENTITY requires client_id and client_secret")
138-
139-
# Construct ADK AuthConfig
140-
scopes = self._auth_config.scopes or [
141-
"https://www.googleapis.com/auth/cloud-platform"
142-
]
143-
scope_dict = {s: "" for s in scopes}
144-
145-
auth_config_adk = AuthConfig(
146-
auth_scheme=OAuth2(
147-
flows=OAuthFlows(
148-
authorizationCode=OAuthFlowAuthorizationCode(
149-
authorizationUrl="https://accounts.google.com/o/oauth2/auth",
150-
tokenUrl="https://oauth2.googleapis.com/token",
151-
scopes=scope_dict,
152+
requires_auth = (
153+
len(self._core_tool._required_authn_params) > 0
154+
or len(self._core_tool._required_authz_tokens) > 0
155+
)
156+
157+
if requires_auth:
158+
if not self._auth_config.client_id or not self._auth_config.client_secret:
159+
raise ValueError("USER_IDENTITY requires client_id and client_secret")
160+
161+
# Construct ADK AuthConfig
162+
scopes = self._auth_config.scopes or ["openid", "profile", "email"]
163+
scope_dict = {s: "" for s in scopes}
164+
165+
auth_config_adk = AuthConfig(
166+
auth_scheme=OAuth2(
167+
flows=OAuthFlows(
168+
authorizationCode=OAuthFlowAuthorizationCode(
169+
authorizationUrl="https://accounts.google.com/o/oauth2/auth",
170+
tokenUrl="https://oauth2.googleapis.com/token",
171+
scopes=scope_dict,
172+
)
152173
)
153-
)
154-
),
155-
raw_auth_credential=AuthCredential(
156-
auth_type=AuthCredentialTypes.OAUTH2,
157-
oauth2=OAuth2Auth(
158-
client_id=self._auth_config.client_id,
159-
client_secret=self._auth_config.client_secret,
160174
),
161-
),
162-
)
175+
raw_auth_credential=AuthCredential(
176+
auth_type=AuthCredentialTypes.OAUTH2,
177+
oauth2=OAuth2Auth(
178+
client_id=self._auth_config.client_id,
179+
client_secret=self._auth_config.client_secret,
180+
),
181+
),
182+
)
163183

164-
# Check if we already have credentials from a previous exchange
165-
try:
166-
# get_auth_response returns AuthCredential if found
167-
creds = tool_context.get_auth_response(auth_config_adk)
168-
if creds and creds.oauth2 and creds.oauth2.access_token:
169-
reset_token = USER_TOKEN_CONTEXT_VAR.set(creds.oauth2.access_token)
170-
else:
171-
# Request credentials and pause execution
184+
# Check if we already have credentials from a previous exchange
185+
try:
186+
# Try to load credential from credential service first (persists across sessions)
187+
creds = None
188+
try:
189+
if tool_context._invocation_context.credential_service:
190+
creds = await tool_context._invocation_context.credential_service.load_credential(
191+
auth_config=auth_config_adk,
192+
callback_context=tool_context
193+
)
194+
except ValueError:
195+
# Credential service might not be initialized
196+
pass
197+
198+
if not creds:
199+
# Fallback to session state (get_auth_response returns AuthCredential if found)
200+
creds = tool_context.get_auth_response(auth_config_adk)
201+
202+
if creds and creds.oauth2 and creds.oauth2.access_token:
203+
reset_token = USER_TOKEN_CONTEXT_VAR.set(creds.oauth2.access_token)
204+
205+
# Bind the token to the underlying core_tool so it constructs headers properly
206+
needed_services = set()
207+
for requested_service in (list(self._core_tool._required_authn_params.values()) + list(self._core_tool._required_authz_tokens)):
208+
if isinstance(requested_service, list):
209+
needed_services.update(requested_service)
210+
else:
211+
needed_services.add(requested_service)
212+
213+
for s in needed_services:
214+
# Only add if not already registered (prevents ValueError on duplicate params or subsequent runs)
215+
if not hasattr(self._core_tool, '_auth_token_getters') or s not in self._core_tool._auth_token_getters:
216+
# TODO(id_token): Uncomment this line and remove the `getattr` fallback below once PR https://github.com/google/adk-python/pull/4402 is merged.
217+
# self._core_tool = self._core_tool.add_auth_token_getter(s, lambda t=creds.oauth2.id_token or creds.oauth2.access_token: t)
218+
self._core_tool = self._core_tool.add_auth_token_getter(s, lambda t=getattr(creds.oauth2, "id_token", creds.oauth2.access_token): t)
219+
# Once we use it from get_auth_response, save it to the auth service for future use
220+
try:
221+
if tool_context._invocation_context.credential_service:
222+
auth_config_adk.exchanged_auth_credential = creds
223+
await tool_context._invocation_context.credential_service.save_credential(
224+
auth_config=auth_config_adk,
225+
callback_context=tool_context
226+
)
227+
except Exception as e:
228+
logging.debug(f"Failed to save credential to service: {e}")
229+
else:
230+
tool_context.request_credential(auth_config_adk)
231+
return {"error": f"OAuth2 Credentials required for {self.name}. A consent link has been generated for the user. Do NOT attempt to run this tool again until the user confirms they have logged in."}
232+
except Exception as e:
233+
if "credential" in str(e).lower() or isinstance(e, ValueError):
234+
raise e
235+
236+
logging.warning(
237+
f"Unexpected error in get_auth_response during User Identity (OAuth2) retrieval: {e}. "
238+
"Falling back to request_credential.",
239+
exc_info=True
240+
)
241+
# Fallback to request logic
172242
tool_context.request_credential(auth_config_adk)
173-
return None
174-
except Exception as e:
175-
if "credential" in str(e).lower() or isinstance(e, ValueError):
176-
raise e
177-
178-
logging.warning(
179-
f"Unexpected error in get_auth_response during User Identity (OAuth2) retrieval: {e}. "
180-
"Falling back to request_credential.",
181-
exc_info=True
182-
)
183-
# Fallback to request logic
184-
tool_context.request_credential(auth_config_adk)
185-
return None
243+
return {"error": f"OAuth2 Credentials required for {self.name}. A consent link has been generated for the user. Do NOT attempt to run this tool again until the user confirms they have logged in."}
186244

187245
result: Optional[Any] = None
188246
error: Optional[Exception] = None

packages/toolbox-adk/tests/integration/test_integration.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,30 +203,41 @@ async def test_3lo_flow_simulation(self):
203203
# The wrapper should catch the missing creds and request them.
204204
assert result_first is None, "Tool should return None sig for auth requirement"
205205
mock_ctx_first.request_credential.assert_called_once()
206-
206+
207207
# Inspect the requested config
208208
auth_config = mock_ctx_first.request_credential.call_args[0][0]
209209
assert auth_config.raw_auth_credential.oauth2.client_id == "test-client-id"
210+
# Verify the default fallback scopes were assigned correctly to avoid upstream crashes
211+
assert auth_config.auth_scheme.flows.authorizationCode.scopes == {"openid": "", "profile": "", "email": ""}
210212

211213
mock_ctx_second = MagicMock()
212214

213215
# Simulate "Auth Response Found"
214216
mock_creds = AuthCredential(
215217
auth_type=AuthCredentialTypes.OAUTH2,
216-
oauth2=OAuth2Auth(access_token="fake-access-token"),
218+
oauth2=OAuth2Auth(access_token="fake-access-token", id_token="fake-id-token"),
217219
)
218220
mock_ctx_second.get_auth_response.return_value = mock_creds
219221

222+
# Setup the credential service mock to verify credential persistence across sessions
223+
mock_cred_service = AsyncMock()
224+
mock_ctx_second._invocation_context = MagicMock()
225+
mock_ctx_second._invocation_context.credential_service = mock_cred_service
226+
220227
print("Running tool second time (expecting success or server error)...")
221228

222229
try:
223230
result_second = await tool.run_async({"num_rows": "1"}, mock_ctx_second)
224231
assert result_second is not None
232+
# Verify that the tool saved the credentials to the storage service backends locally
233+
mock_cred_service.save_credential.assert_called_once()
225234
except Exception as e:
226235
mock_ctx_second.request_credential.assert_not_called()
227236
err_msg = str(e).lower()
228237
assert any(x in err_msg for x in ["401", "403", "unauthorized", "forbidden"]), f"Caught UNEXPECTED exception: {type(e).__name__}: {e}"
229238
print(f"Caught expected server exception with fake token: {e}")
239+
# Verify that the tool AT LEAST triggered save_credential before failing via core_tool inner HTTP req
240+
mock_cred_service.save_credential.assert_called_once()
230241

231242
finally:
232243
await toolset.close()

packages/toolbox-adk/tests/unit/test_tool.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ async def test_3lo_missing_client_secret(self):
154154
core_tool = AsyncMock()
155155
core_tool.__name__ = "mock_tool"
156156
core_tool.__doc__ = "mock doc"
157+
core_tool._required_authn_params = {"mock_param": "mock_service"}
158+
core_tool._required_authz_tokens = []
157159
auth_config = CredentialConfig(type=CredentialType.USER_IDENTITY)
158160
# Missing client_id/secret
159161

@@ -173,6 +175,8 @@ async def test_3lo_request_credential_when_missing(self):
173175
core_tool.__doc__ = "mock"
174176
core_tool.__name__ = "mock_tool"
175177
core_tool.__doc__ = "mock doc"
178+
core_tool._required_authn_params = {"mock_param": "mock_service"}
179+
core_tool._required_authz_tokens = []
176180

177181
auth_config = CredentialConfig(
178182
type=CredentialType.USER_IDENTITY, client_id="cid", client_secret="csec"
@@ -186,8 +190,8 @@ async def test_3lo_request_credential_when_missing(self):
186190

187191
result = await tool.run_async({}, ctx)
188192

189-
# Verify result is None (signal pause)
190-
assert result is None
193+
# Verify result is error/stop
194+
assert isinstance(result, dict) and "error" in result
191195
# Verify request_credential was called
192196
ctx.request_credential.assert_called_once()
193197
# Verify core tool was NOT called
@@ -197,10 +201,12 @@ async def test_3lo_request_credential_when_missing(self):
197201
async def test_3lo_uses_existing_credential(self):
198202
# Test that if creds exist, they are used and injected
199203
core_tool = AsyncMock(return_value="success")
200-
core_tool.__name__ = "mock"
201-
core_tool.__doc__ = "mock"
202204
core_tool.__name__ = "mock_tool"
203205
core_tool.__doc__ = "mock doc"
206+
# Setup overlapping needed services to test deduplication
207+
core_tool._required_authn_params = {"mock_param": "mock_service", "another_param": "mock_service"}
208+
core_tool._required_authz_tokens = ["mock_service"]
209+
core_tool.add_auth_token_getter = MagicMock(return_value=core_tool)
204210

205211
auth_config = CredentialConfig(
206212
type=CredentialType.USER_IDENTITY, client_id="cid", client_secret="csec"
@@ -209,11 +215,19 @@ async def test_3lo_uses_existing_credential(self):
209215
tool = ToolboxTool(core_tool, auth_config=auth_config)
210216

211217
ctx = MagicMock()
212-
# Mock get_auth_response returning valid creds
218+
# Mock get_auth_response returning valid creds with both access & id tokens
213219
mock_creds = MagicMock()
214-
mock_creds.oauth2.access_token = "valid_token"
220+
mock_creds.oauth2.access_token = "valid_access_token"
221+
mock_creds.oauth2.id_token = "valid_id_token"
215222
ctx.get_auth_response.return_value = mock_creds
216223

224+
# Set up invocation context and credential service mock to verify saving and avoid await errors
225+
mock_cred_service = MagicMock()
226+
mock_cred_service.load_credential = AsyncMock(return_value=None)
227+
mock_cred_service.save_credential = AsyncMock(return_value=None)
228+
ctx._invocation_context = MagicMock()
229+
ctx._invocation_context.credential_service = mock_cred_service
230+
217231
result = await tool.run_async({}, ctx)
218232

219233
# Verify result is success
@@ -222,22 +236,43 @@ async def test_3lo_uses_existing_credential(self):
222236
ctx.request_credential.assert_not_called()
223237
# Verify core tool WAS called
224238
core_tool.assert_called_once()
239+
240+
# Verify deduplication: add_auth_token_getter should only be called ONCE for "mock_service"
241+
core_tool.add_auth_token_getter.assert_called_once()
242+
call_args_getter = core_tool.add_auth_token_getter.call_args[0]
243+
assert call_args_getter[0] == "mock_service"
244+
# Evaluate the getter lambda to ensure it prefers id_token
245+
token_getter_lambda = call_args_getter[1]
246+
assert token_getter_lambda() == "valid_id_token"
247+
248+
# Verify save_credential was called with the exchanged credential
249+
mock_cred_service.save_credential.assert_called_once()
250+
call_args = mock_cred_service.save_credential.call_args[1]
251+
assert call_args["auth_config"].exchanged_auth_credential == mock_creds
252+
253+
# Verify safe scope fallback to ["openid", "profile", "email"] when scopes is None
254+
assert call_args["auth_config"].auth_scheme.flows.authorizationCode.scopes == {"openid": "", "profile": "", "email": ""}
225255

226256
@pytest.mark.asyncio
227257
async def test_3lo_exception_reraise(self):
228258
# Test that specific credential errors are re-raised
229259
core_tool = AsyncMock()
230-
core_tool.__name__ = "mock"
231-
core_tool.__doc__ = "mock"
232260
core_tool.__name__ = "mock_tool"
233261
core_tool.__doc__ = "mock doc"
262+
core_tool._required_authn_params = {"mock_param": "mock_service"}
263+
core_tool._required_authz_tokens = []
234264

235265
auth_config = CredentialConfig(
236266
type=CredentialType.USER_IDENTITY, client_id="cid", client_secret="csec"
237267
)
238268
tool = ToolboxTool(core_tool, auth_config=auth_config)
239269
ctx = MagicMock()
240270

271+
mock_cred_service = MagicMock()
272+
mock_cred_service.load_credential = AsyncMock(return_value=None)
273+
ctx._invocation_context = MagicMock()
274+
ctx._invocation_context.credential_service = mock_cred_service
275+
241276
# Mock get_auth_response raising ValueError
242277
ctx.get_auth_response.side_effect = ValueError("Invalid Credential")
243278

@@ -252,6 +287,8 @@ async def test_3lo_exception_fallback(self):
252287
core_tool.__doc__ = "mock"
253288
core_tool.__name__ = "mock_tool"
254289
core_tool.__doc__ = "mock doc"
290+
core_tool._required_authn_params = {"mock_param": "mock_service"}
291+
core_tool._required_authz_tokens = []
255292

256293
auth_config = CredentialConfig(
257294
type=CredentialType.USER_IDENTITY, client_id="cid", client_secret="csec"
@@ -264,8 +301,8 @@ async def test_3lo_exception_fallback(self):
264301

265302
result = await tool.run_async({}, ctx)
266303

267-
# Should catch RuntimeError, call request_credential, and return None
268-
assert result is None
304+
# Should catch RuntimeError, call request_credential, and return error map
305+
assert isinstance(result, dict) and "error" in result
269306
ctx.request_credential.assert_called_once()
270307

271308
def test_param_type_to_schema_type(self):

0 commit comments

Comments
 (0)