From bd427fdf429dd0da0f6f8bcc4184a084d2ca6253 Mon Sep 17 00:00:00 2001 From: Piyush Date: Thu, 5 Mar 2026 12:16:01 +0530 Subject: [PATCH 1/2] test: add unit test for MCP sampling callback support --- .../mcp_tool/test_mcp_sampling_callback.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 tests/unittests/tools/mcp_tool/test_mcp_sampling_callback.py diff --git a/tests/unittests/tools/mcp_tool/test_mcp_sampling_callback.py b/tests/unittests/tools/mcp_tool/test_mcp_sampling_callback.py new file mode 100644 index 0000000000..604a6c2018 --- /dev/null +++ b/tests/unittests/tools/mcp_tool/test_mcp_sampling_callback.py @@ -0,0 +1,45 @@ +import pytest +from fastmcp.client.sampling import SamplingMessage +from google.adk.tools.mcp_tool.mcp_toolset import McpToolset +# from google.adk.tools.mcp_tool.mcp_toolset import StreamableHTTPConnectionParams +from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams + + +@pytest.mark.asyncio +async def test_sampling_callback_invoked(): + + called = {"value": False} + + async def mock_sampling_handler(messages, params=None, context=None): + called["value"] = True + + assert isinstance(messages, list) + assert messages[0].role == "user" + + return { + "model": "test-model", + "role": "assistant", + "content": {"type": "text", "text": "sampling response"}, + "stopReason": "endTurn", + } + + toolset = McpToolset( + connection_params=StreamableHTTPConnectionParams( + url="http://localhost:9999", + timeout=10, + ), + sampling_callback=mock_sampling_handler, + ) + + messages = [ + SamplingMessage( + role="user", + content={"type": "text", "text": "hello"}, + ) + ] + + result = await toolset._sampling_callback(messages) + + assert called["value"] is True + assert result["role"] == "assistant" + assert result["content"]["text"] == "sampling response" \ No newline at end of file From cc510eb1cdda2d4acc0d9004edc360b6e8f4187b Mon Sep 17 00:00:00 2001 From: Piyush Date: Thu, 5 Mar 2026 12:44:58 +0530 Subject: [PATCH 2/2] feat(mcp): add sampling callback support for MCP sessions --- src/google/adk/tools/mcp_tool/mcp_session_manager.py | 7 +++++++ src/google/adk/tools/mcp_tool/mcp_toolset.py | 7 +++++++ src/google/adk/tools/mcp_tool/session_context.py | 8 ++++++++ 3 files changed, 22 insertions(+) diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index f4339f8678..cb8a2185ca 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -195,6 +195,8 @@ def __init__( StreamableHTTPConnectionParams, ], errlog: TextIO = sys.stderr, + sampling_callback: Optional[Any] = None, + sampling_capabilities: Optional[Any] = None, ): """Initializes the MCP session manager. @@ -205,6 +207,9 @@ def __init__( errlog: (Optional) TextIO stream for error logging. Use only for initializing a local stdio MCP session. """ + self._sampling_callback = sampling_callback + self._sampling_capabilities = sampling_capabilities + if isinstance(connection_params, StdioServerParameters): # So far timeout is not configurable. Given MCP is still evolving, we # would expect stdio_client to evolve to accept timeout parameter like @@ -475,6 +480,8 @@ async def create_session( timeout=timeout_in_seconds, sse_read_timeout=sse_read_timeout_in_seconds, is_stdio=is_stdio, + sampling_callback=self._sampling_callback, + sampling_capabilities=self._sampling_capabilities, ) ), timeout=timeout_in_seconds, diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index fb4e992dfd..c515beb78c 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -114,6 +114,8 @@ def __init__( Union[ProgressFnT, ProgressCallbackFactory] ] = None, use_mcp_resources: Optional[bool] = False, + sampling_callback: Optional[SamplingFnT] = None, + sampling_capabilities: Optional[Any] = None, ): """Initializes the McpToolset. @@ -154,6 +156,9 @@ def __init__( super().__init__(tool_filter=tool_filter, tool_name_prefix=tool_name_prefix) + self._sampling_callback = sampling_callback + self._sampling_capabilities = sampling_capabilities + if not connection_params: raise ValueError("Missing connection params in McpToolset.") @@ -166,6 +171,8 @@ def __init__( self._mcp_session_manager = MCPSessionManager( connection_params=self._connection_params, errlog=self._errlog, + sampling_callback=self._sampling_callback, + sampling_capabilities=self._sampling_capabilities, ) self._auth_scheme = auth_scheme self._auth_credential = auth_credential diff --git a/src/google/adk/tools/mcp_tool/session_context.py b/src/google/adk/tools/mcp_tool/session_context.py index ca637d0489..7ca007d220 100644 --- a/src/google/adk/tools/mcp_tool/session_context.py +++ b/src/google/adk/tools/mcp_tool/session_context.py @@ -54,6 +54,8 @@ def __init__( timeout: Optional[float], sse_read_timeout: Optional[float], is_stdio: bool = False, + sampling_callback: Optional[Any] = None, + sampling_capabilities: Optional[Any] = None, ): """ Args: @@ -73,6 +75,8 @@ def __init__( self._close_event = asyncio.Event() self._task: Optional[asyncio.Task] = None self._task_lock = asyncio.Lock() + self._sampling_callback = sampling_callback + self._sampling_capabilities = sampling_capabilities @property def session(self) -> Optional[ClientSession]: @@ -165,6 +169,8 @@ async def _run(self): read_timeout_seconds=timedelta(seconds=self._timeout) if self._timeout is not None else None, + sampling_callback=self._sampling_callback, + sampling_capabilities=self._sampling_capabilities, ) ) else: @@ -176,6 +182,8 @@ async def _run(self): read_timeout_seconds=timedelta(seconds=self._sse_read_timeout) if self._sse_read_timeout is not None else None, + sampling_callback=self._sampling_callback, + sampling_capabilities=self._sampling_capabilities, ) ) await asyncio.wait_for(session.initialize(), timeout=self._timeout)