forked from google/adk-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmcp_tool.py
More file actions
395 lines (340 loc) · 12.7 KB
/
mcp_tool.py
File metadata and controls
395 lines (340 loc) · 12.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import base64
import inspect
import logging
import sys
from typing import Any
from typing import Callable
from typing import Dict
from typing import Optional
from typing import Union
import warnings
from fastapi.openapi.models import APIKeyIn
from google.genai.types import FunctionDeclaration
from typing_extensions import override
from ...agents.readonly_context import ReadonlyContext
from .._gemini_schema_util import _to_gemini_schema
from .mcp_session_manager import MCPSessionManager
from .mcp_session_manager import retry_on_closed_resource
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
# their Python version to 3.10 if it fails.
try:
from mcp.types import Tool as McpBaseTool
except ImportError as e:
if sys.version_info < (3, 10):
raise ImportError(
"MCP Tool requires Python 3.10 or above. Please upgrade your Python"
" version."
) from e
else:
raise e
from ...auth.auth_credential import AuthCredential
from ...auth.auth_schemes import AuthScheme
from ...auth.auth_tool import AuthConfig
from ..base_authenticated_tool import BaseAuthenticatedTool
# import
from ..tool_context import ToolContext
logger = logging.getLogger("google_adk." + __name__)
class McpTool(BaseAuthenticatedTool):
"""Turns an MCP Tool into an ADK Tool.
Internally, the tool initializes from a MCP Tool, and uses the MCP Session to
call the tool.
Note: For API key authentication, only header-based API keys are supported.
Query and cookie-based API keys will result in authentication errors.
"""
def __init__(
self,
*,
mcp_tool: McpBaseTool,
mcp_session_manager: MCPSessionManager,
auth_scheme: Optional[AuthScheme] = None,
auth_credential: Optional[AuthCredential] = None,
require_confirmation: Union[bool, Callable[..., bool]] = False,
header_provider: Optional[
Callable[[ReadonlyContext], Dict[str, str]]
] = None,
):
"""Initializes an McpTool.
This tool wraps an MCP Tool interface and uses a session manager to
communicate with the MCP server.
Args:
mcp_tool: The MCP tool to wrap.
mcp_session_manager: The MCP session manager to use for communication.
auth_scheme: The authentication scheme to use.
auth_credential: The authentication credential to use.
require_confirmation: Whether this tool requires confirmation. A boolean
or a callable that takes the function's arguments and returns a
boolean. If the callable returns True, the tool will require
confirmation from the user.
Raises:
ValueError: If mcp_tool or mcp_session_manager is None.
"""
super().__init__(
name=mcp_tool.name,
description=mcp_tool.description if mcp_tool.description else "",
auth_config=AuthConfig(
auth_scheme=auth_scheme, raw_auth_credential=auth_credential
)
if auth_scheme
else None,
)
self._mcp_tool = mcp_tool
self._mcp_session_manager = mcp_session_manager
self._require_confirmation = require_confirmation
self._header_provider = header_provider
@override
def _get_declaration(self) -> FunctionDeclaration:
"""Gets the function declaration for the tool.
Returns:
FunctionDeclaration: The Gemini function declaration for the tool.
"""
input_schema = self._mcp_tool.inputSchema
parameters = _to_gemini_schema(input_schema)
function_decl = FunctionDeclaration(
name=self.name,
description=self.description,
parameters=parameters,
)
return function_decl
@property
def raw_mcp_tool(self) -> McpBaseTool:
"""Returns the raw MCP tool."""
return self._mcp_tool
async def _invoke_callable(
self, target: Callable[..., Any], args_to_call: dict[str, Any]
) -> Any:
"""Invokes a callable, handling both sync and async cases."""
# Functions are callable objects, but not all callable objects are functions
# checking coroutine function is not enough. We also need to check whether
# Callable's __call__ function is a coroutine funciton
is_async = inspect.iscoroutinefunction(target) or (
hasattr(target, "__call__")
and inspect.iscoroutinefunction(target.__call__)
)
if is_async:
return await target(**args_to_call)
else:
return target(**args_to_call)
@override
async def run_async(
self, *, args: dict[str, Any], tool_context: ToolContext
) -> Any:
if isinstance(self._require_confirmation, Callable):
require_confirmation = await self._invoke_callable(
self._require_confirmation, args
)
else:
require_confirmation = bool(self._require_confirmation)
if require_confirmation:
if not tool_context.tool_confirmation:
args_to_show = args.copy()
if "tool_context" in args_to_show:
args_to_show.pop("tool_context")
tool_context.request_confirmation(
hint=(
f"Please approve or reject the tool call {self.name}() by"
" responding with a FunctionResponse with an expected"
" ToolConfirmation payload."
),
)
return {
"error": (
"This tool call requires confirmation, please approve or"
" reject."
)
}
elif not tool_context.tool_confirmation.confirmed:
return {"error": "This tool call is rejected."}
return await super().run_async(args=args, tool_context=tool_context)
@retry_on_closed_resource
@override
async def _run_async_impl(
self, *, args, tool_context: ToolContext, credential: AuthCredential
) -> Dict[str, Any]:
"""Runs the tool asynchronously.
Args:
args: The arguments as a dict to pass to the tool.
tool_context: The tool context of the current invocation.
Returns:
Any: The response from the tool.
"""
# Extract headers from credential for session pooling
auth_headers = await self._get_headers(tool_context, credential)
dynamic_headers = None
if self._header_provider:
dynamic_headers = self._header_provider(
ReadonlyContext(tool_context._invocation_context)
)
headers: Dict[str, str] = {}
if auth_headers:
headers.update(auth_headers)
if dynamic_headers:
headers.update(dynamic_headers)
final_headers = headers if headers else None
# Get the session from the session manager
session = await self._mcp_session_manager.create_session(
headers=final_headers
)
# Transform arguments to match MCP schema
transformed_args = self._transform_args_to_mcp_format(
args, self._mcp_tool.inputSchema
)
response = await session.call_tool(self._mcp_tool.name, arguments=transformed_args)
return response.model_dump(exclude_none=True, mode="json")
async def _get_headers(
self, tool_context: ToolContext, credential: AuthCredential
) -> Optional[dict[str, str]]:
"""Extracts authentication headers from credentials.
Args:
tool_context: The tool context of the current invocation.
credential: The authentication credential to process.
Returns:
Dictionary of headers to add to the request, or None if no auth.
Raises:
ValueError: If API key authentication is configured for non-header location.
"""
headers: Optional[dict[str, str]] = None
if credential:
if credential.oauth2:
headers = {"Authorization": f"Bearer {credential.oauth2.access_token}"}
elif credential.http:
# Handle HTTP authentication schemes
if (
credential.http.scheme.lower() == "bearer"
and credential.http.credentials.token
):
headers = {
"Authorization": f"Bearer {credential.http.credentials.token}"
}
elif credential.http.scheme.lower() == "basic":
# Handle basic auth
if (
credential.http.credentials.username
and credential.http.credentials.password
):
credentials = f"{credential.http.credentials.username}:{credential.http.credentials.password}"
encoded_credentials = base64.b64encode(
credentials.encode()
).decode()
headers = {"Authorization": f"Basic {encoded_credentials}"}
elif credential.http.credentials.token:
# Handle other HTTP schemes with token
headers = {
"Authorization": (
f"{credential.http.scheme} {credential.http.credentials.token}"
)
}
elif credential.api_key:
if (
not self._credentials_manager
or not self._credentials_manager._auth_config
):
error_msg = (
"Cannot find corresponding auth scheme for API key credential"
f" {credential}"
)
logger.error(error_msg)
raise ValueError(error_msg)
elif (
self._credentials_manager._auth_config.auth_scheme.in_
!= APIKeyIn.header
):
error_msg = (
"McpTool only supports header-based API key authentication."
" Configured location:"
f" {self._credentials_manager._auth_config.auth_scheme.in_}"
)
logger.error(error_msg)
raise ValueError(error_msg)
else:
headers = {
self._credentials_manager._auth_config.auth_scheme.name: (
credential.api_key
)
}
elif credential.service_account:
# Service accounts should be exchanged for access tokens before reaching this point
logger.warning(
"Service account credentials should be exchanged before MCP"
" session creation"
)
return headers
def _transform_args_to_mcp_format(
self, args: Dict[str, Any], mcp_schema: Dict[str, Any]
) -> Dict[str, Any]:
"""Transform arguments to match MCP schema.
Handles cases where model output simplifies array-of-objects to
array-of-primitives for schemas with single-property objects.
Args:
args: Tool arguments from model output.
mcp_schema: MCP tool input schema.
Returns:
Transformed arguments matching schema, or original if no transformation needed.
"""
if not args or not mcp_schema:
return args
properties = mcp_schema.get("properties", {})
if not properties:
return args
transformed = {}
for key, value in args.items():
if key in properties:
transformed[key] = self._transform_value_to_schema(value, properties[key])
else:
transformed[key] = value
return transformed
def _transform_value_to_schema(
self, value: Any, schema: Dict[str, Any]
) -> Any:
"""Transform value to match schema.
Args:
value: Value to transform.
schema: JSON schema for the value.
Returns:
Transformed value or original if no transformation needed.
"""
if value is None or not schema:
return value
schema_type = schema.get("type")
if schema_type == "array" and isinstance(value, list) and value:
items_schema = schema.get("items")
if not items_schema or items_schema.get("type") != "object":
return value
if not isinstance(value[0], dict):
if not all(not isinstance(item, dict) for item in value):
logger.warning(
"Mixed types in array for MCP tool %s", self.name
)
return value
item_properties = items_schema.get("properties", {})
if len(item_properties) == 1:
property_name = next(iter(item_properties))
logger.debug(
"Transforming array for MCP tool %s with property '%s'",
self.name,
property_name,
)
return [{property_name: item} for item in value]
return value
class MCPTool(McpTool):
"""Deprecated name, use `McpTool` instead."""
def __init__(self, *args, **kwargs):
warnings.warn(
"MCPTool class is deprecated, use `McpTool` instead.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)