From 07f3373032e2dea8ade64843230c5db7fb993e1c Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Mon, 1 Jun 2026 15:17:57 -0700 Subject: [PATCH 1/2] contrib/strands: add cache_tools toggle to TemporalMCPClient Replace worker-startup tool discovery with a per-server {server}-list-tools activity executed from inside the workflow. TemporalMCPClient.cache_tools (default True) lists tools once at the start of the workflow; cache_tools=False re-lists on every agent turn so a mid-workflow MCP server restart is picked up. Strands calls load_tools() once at agent construction on a separate run_async thread with no workflow runtime, so the activity is dispatched from a BeforeModelCallEvent hook (which runs on the workflow loop before the registry is read each turn) that reconciles added/removed/renamed tools. --- temporalio/contrib/strands/README.md | 12 +- temporalio/contrib/strands/_plugin.py | 20 ++- temporalio/contrib/strands/_temporal_agent.py | 47 ++++++ .../contrib/strands/_temporal_mcp_client.py | 158 +++++++++++------- tests/contrib/strands/test_mcp.py | 83 ++++++++- 5 files changed, 244 insertions(+), 76 deletions(-) diff --git a/temporalio/contrib/strands/README.md b/temporalio/contrib/strands/README.md index fc9c6d74f..3f2c49fba 100644 --- a/temporalio/contrib/strands/README.md +++ b/temporalio/contrib/strands/README.md @@ -380,7 +380,7 @@ class ChatWorkflow: ## MCP -`StrandsPlugin(mcp_clients=...)` takes a mapping of `name → MCPClient factory`, mirroring the `models=` pattern. The plugin registers a per-server `{name}-call-tool` activity and connects at worker startup to enumerate tools. Workflow-side, `TemporalMCPClient(server="name")` is a pure handle: it references the server by name and carries the per-call activity options. +`StrandsPlugin(mcp_clients=...)` takes a mapping of `name → MCPClient factory`, mirroring the `models=` pattern. The plugin registers per-server `{name}-call-tool` and `{name}-list-tools` activities. Workflow-side, `TemporalMCPClient(server="name")` is a pure handle: it references the server by name, discovers tools by running `{name}-list-tools`, and carries the per-call activity options. ```python from mcp import StdioServerParameters, stdio_client @@ -412,9 +412,15 @@ Worker( ) ``` -Each factory returns a fully configured `MCPClient`, so you can pass options like `tool_filters`, `prefix`, `elicitation_callback`, or `tasks_config` to it. The plugin connects to each MCP server once at worker startup to enumerate tools. The schema is frozen for the worker's lifetime; restart workers to pick up MCP-server changes. If a server is unavailable at startup, the worker fails to start. +Each factory returns a fully configured `MCPClient`, so you can pass options like `tool_filters`, `prefix`, `elicitation_callback`, or `tasks_config` to it. -To amortize connection setup, the `{name}-call-tool` activity keeps a worker-process MCP connection open between calls and reuses it. The connection is disconnected after it sits idle for `mcp_connection_idle_timeout` (default 5 minutes); the timer resets on every reuse: +By default, `TemporalMCPClient` lists the server's tools once at the beginning of the workflow (via `{name}-list-tools`) and reuses that schema for the workflow's lifetime. To pick up an MCP server that is restarted mid-workflow — with tools added, removed, or renamed — set `cache_tools=False`, and the tools are re-listed on every agent turn instead: + +```python +echo = TemporalMCPClient(server="echo", cache_tools=False, start_to_close_timeout=timedelta(seconds=30)) +``` + +To amortize connection setup, the `{name}-call-tool` and `{name}-list-tools` activities share a worker-process MCP connection that is opened lazily and reused across calls. The connection is disconnected after it sits idle for `mcp_connection_idle_timeout` (default 5 minutes); the timer resets on every reuse: ```python StrandsPlugin( diff --git a/temporalio/contrib/strands/_plugin.py b/temporalio/contrib/strands/_plugin.py index b6f7db2ff..0f1972666 100644 --- a/temporalio/contrib/strands/_plugin.py +++ b/temporalio/contrib/strands/_plugin.py @@ -17,8 +17,7 @@ from ._temporal_mcp_client import ( _evict_connection, build_call_tool_activity, - clear_cache, - populate_cache, + build_list_tools_activity, ) @@ -31,10 +30,11 @@ class StrandsPlugin(SimplePlugin): on first use, then cached for the worker's lifetime. Use the same name in ``TemporalAgent(model=...)`` inside the workflow. - When ``mcp_clients`` is supplied, registers a per-server - ``{server}-call-tool`` activity for each entry and, at worker startup, - connects to each MCP server to cache its tool list. Workflow-side - ``TemporalMCPClient(server="...").load_tools()`` reads from the cache. + When ``mcp_clients`` is supplied, registers per-server + ``{server}-call-tool`` and ``{server}-list-tools`` activities for each + entry. Workflow-side ``TemporalMCPClient(server="...")`` discovers tools by + running ``{server}-list-tools``; whether it lists once per workflow or once + per agent turn is controlled by its ``cache_tools`` option. ``mcp_connection_idle_timeout`` controls how long a worker-process MCP connection is kept open between ``call-tool`` activities before it is @@ -69,17 +69,19 @@ def __init__( server, client_factory, mcp_connection_idle_timeout ) ) + activities.append( + build_list_tools_activity( + server, client_factory, mcp_connection_idle_timeout + ) + ) @asynccontextmanager async def run_context() -> AsyncGenerator[None, None]: - for server, client_factory in mcp_clients.items(): - await populate_cache(server, client_factory) try: yield finally: for server in mcp_clients: await _evict_connection(server) - clear_cache(server) super().__init__( "aws.StrandsPlugin", diff --git a/temporalio/contrib/strands/_temporal_agent.py b/temporalio/contrib/strands/_temporal_agent.py index 9bc1beb31..c2f9f14c7 100644 --- a/temporalio/contrib/strands/_temporal_agent.py +++ b/temporalio/contrib/strands/_temporal_agent.py @@ -2,10 +2,12 @@ from typing import Any from strands import Agent +from strands.hooks import BeforeModelCallEvent, HookCallback from temporalio.common import Priority, RetryPolicy from temporalio.workflow import ActivityCancellationType, VersioningIntent +from ._temporal_mcp_client import TemporalMCPClient from ._temporal_model import TemporalModel _SNAPSHOT_DISABLED = ( @@ -76,6 +78,51 @@ def __init__( ) super().__init__(model=temporal_model, **agent_kwargs) + # Strands invokes ToolProvider.load_tools() once at construction on a + # separate run_async thread that has no workflow runtime, so a + # TemporalMCPClient cannot list its tools there. Instead refresh from a + # BeforeModelCallEvent hook, which runs on the workflow loop just before + # the registry is read each turn. cache_tools=True lists once (guarded + # by _fetched); cache_tools=False re-lists every turn. + for provider in self.tool_registry._tool_providers: + if isinstance(provider, TemporalMCPClient): + self.hooks.add_callback( + BeforeModelCallEvent, self._make_mcp_refresh_hook(provider) + ) + + def _make_mcp_refresh_hook( + self, provider: TemporalMCPClient + ) -> HookCallback[BeforeModelCallEvent]: + async def hook(event: BeforeModelCallEvent) -> None: + if provider._cache_tools and provider._fetched: + return + old_names = {tool.tool_name for tool in provider._tools} + await provider._refresh() + self._reconcile_mcp_tools(event, provider, old_names) + + return hook + + def _reconcile_mcp_tools( + self, + event: BeforeModelCallEvent, + provider: TemporalMCPClient, + old_names: set[str], + ) -> None: + reg = event.agent.tool_registry + new = {tool.tool_name: tool for tool in provider._tools} + # Tools the server dropped or renamed since the last listing. There is + # no public unregister, so remove them from the registry directly. + for name in old_names - set(new): + reg.registry.pop(name, None) + reg.dynamic_tools.pop(name, None) + # replace() swaps an existing tool in place (no hot-reload guard); + # register_tool() adds a newly-discovered one. + for name, tool in new.items(): + if name in reg.registry: + reg.replace(tool) + else: + reg.register_tool(tool) + def take_snapshot(self, *_args: Any, **_kwargs: Any) -> Any: """Disabled; Temporal's event history is the source of truth.""" raise NotImplementedError(_SNAPSHOT_DISABLED) diff --git a/temporalio/contrib/strands/_temporal_mcp_client.py b/temporalio/contrib/strands/_temporal_mcp_client.py index 71e1f2f7c..00605ae2c 100644 --- a/temporalio/contrib/strands/_temporal_mcp_client.py +++ b/temporalio/contrib/strands/_temporal_mcp_client.py @@ -13,7 +13,7 @@ from strands.tools.mcp.mcp_types import MCPToolResult from strands.types.tools import AgentTool -from temporalio import activity +from temporalio import activity, workflow from temporalio.common import Priority, RetryPolicy from temporalio.workflow import ActivityCancellationType, VersioningIntent @@ -33,20 +33,21 @@ class _CallToolArgs: tool_use_id: str = "" -# Server name -> cached tool list. Populated by ``_populate_cache`` at worker -# startup and read by ``TemporalMCPClient.load_tools()`` inside the workflow -# sandbox. ``temporalio`` is in the SDK's default sandbox passthrough, so this -# dict is shared between worker process and workflow execution. -_TOOL_CACHE: dict[str, list[_MCPToolInfo]] = {} - - class TemporalMCPClient(ToolProvider): """Workflow-side handle to an MCP server registered on the worker. - The transport factory and tool discovery live worker-side via - ``StrandsPlugin(mcp_clients={"server": lambda: ...})``. This handle only - carries the server name (which selects the registered factory) and the - per-call activity options. + The transport factory lives worker-side via + ``StrandsPlugin(mcp_clients={"server": lambda: ...})``. This handle carries + the server name (which selects the registered factory) and the per-call + activity options. Tool discovery runs as the ``{server}-list-tools`` + activity, dispatched from inside the workflow by ``TemporalAgent`` before + each model call. + + ``cache_tools`` controls how often that listing happens. When ``True`` (the + default) the tools are listed once at the beginning of the workflow and + reused for its lifetime. When ``False`` the tools are re-listed on every + agent turn, so an MCP server restarted mid-workflow (with tools added, + removed, or renamed) is picked up. Construct once at module level and pass to ``TemporalAgent(tools=[...])`` inside the workflow. Multiple handles may reference the same server name @@ -57,6 +58,7 @@ def __init__( self, server: str, *, + cache_tools: bool = True, task_queue: str | None = None, schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, @@ -70,6 +72,9 @@ def __init__( ) -> None: """Configure the server name and activity options.""" self._server = server + self._cache_tools = cache_tools + self._tools: list[AgentTool] = [] + self._fetched = False self._options: dict[str, Any] = { "task_queue": task_queue, "schedule_to_close_timeout": schedule_to_close_timeout, @@ -89,11 +94,33 @@ def server(self) -> str: return self._server async def load_tools(self, **_kwargs: Any) -> Sequence[AgentTool]: - """Return TemporalMCPTool wrappers for tools cached at worker startup.""" + """Return the tools fetched by the most recent ``_refresh``. + + This must stay free of any ``workflow`` API: Strands invokes it once at + ``Agent`` construction on a separate ``run_async`` thread that has no + workflow runtime. ``TemporalAgent`` populates the tools by calling + ``_refresh`` from a ``BeforeModelCallEvent`` hook before the registry is + first read. + """ + return list(self._tools) + + async def _refresh(self) -> None: + """List the server's tools via the ``{server}-list-tools`` activity. + + Runs on the workflow event loop (dispatched from ``TemporalAgent``'s + hook), so the activity result is recorded in history and replay-safe. + """ from ._temporal_mcp_tool import TemporalMCPTool - infos = _TOOL_CACHE.get(self._server, []) - return [TemporalMCPTool(self._server, info, self._options) for info in infos] + infos: list[_MCPToolInfo] = await workflow.execute_activity( + f"{self._server}-list-tools", + result_type=list[_MCPToolInfo], + **self._options, + ) + self._tools = [ + TemporalMCPTool(self._server, info, self._options) for info in infos + ] + self._fetched = True def add_consumer(self, consumer_id: Any, **_kwargs: Any) -> None: """No-op; consumer tracking is handled by the underlying MCP client.""" @@ -104,45 +131,37 @@ def remove_consumer(self, consumer_id: Any, **_kwargs: Any) -> None: return None -# Use MCP sessions directly instead of MCPClient's background-thread helpers. -# Those helpers route calls through cross-loop futures that are unreliable on -# Python 3.10 when invoked from Temporal's async worker/activity event loops. -async def _list_mcp_tools(client: MCPClient) -> Sequence[Tool]: - async with client._transport_callable() as (read_stream, write_stream, *_): - async with ClientSession( - read_stream, - write_stream, - elicitation_callback=client._elicitation_callback, - ) as session: - await session.initialize() - tools: list[Tool] = [] - pagination_token = None - while True: - page = await session.list_tools( - params=PaginatedRequestParams(cursor=pagination_token) - if pagination_token is not None - else None - ) - tools.extend(page.tools) - pagination_token = page.nextCursor - if pagination_token is None: - return tools - - -def _agent_tool_for_filtering(client: MCPClient, tool: Tool) -> MCPAgentTool: - if client._prefix: - return MCPAgentTool(tool, client, name_override=f"{client._prefix}_{tool.name}") - return MCPAgentTool(tool, client) - - -async def populate_cache(server: str, client_factory: Callable[[], MCPClient]) -> None: - """Connect to the MCP server, list tools, fill ``_TOOL_CACHE``.""" - client = client_factory() +# Use the MCP session directly instead of MCPClient's background-thread +# helpers. Those helpers route calls through cross-loop futures that are +# unreliable on Python 3.10 when invoked from Temporal's async worker/activity +# event loops. +async def _paginate_list_tools(session: ClientSession) -> list[Tool]: + tools: list[Tool] = [] + pagination_token = None + while True: + page = await session.list_tools( + params=PaginatedRequestParams(cursor=pagination_token) + if pagination_token is not None + else None + ) + tools.extend(page.tools) + pagination_token = page.nextCursor + if pagination_token is None: + return tools + + +def _tool_infos(client: MCPClient, tools: Sequence[Tool]) -> list[_MCPToolInfo]: + """Apply the client's tool filters and project to serializable records.""" infos: list[_MCPToolInfo] = [] - for tool in await _list_mcp_tools(client): + for tool in tools: + if client._prefix: + agent_tool = MCPAgentTool( + tool, client, name_override=f"{client._prefix}_{tool.name}" + ) + else: + agent_tool = MCPAgentTool(tool, client) if not client._should_include_tool_with_filters( - _agent_tool_for_filtering(client, tool), - client._tool_filters, + agent_tool, client._tool_filters ): continue infos.append( @@ -153,12 +172,7 @@ async def populate_cache(server: str, client_factory: Callable[[], MCPClient]) - output_schema=tool.outputSchema, ) ) - _TOOL_CACHE[server] = infos - - -def clear_cache(server: str) -> None: - """Drop the cached tool list for ``server``.""" - _TOOL_CACHE.pop(server, None) + return infos # Default for how long an idle MCP connection stays open before it is @@ -324,3 +338,31 @@ async def call_tool(args: _CallToolArgs) -> MCPToolResult: record.release() return call_tool + + +def build_list_tools_activity( + server: str, + client_factory: Callable[[], MCPClient], + idle_timeout: timedelta | None = None, +) -> Callable: + """Return the per-server ``{server}-list-tools`` activity for registration. + + Lists the server's tools (applying the client's tool filters) and reuses + the same lazily-opened, idle-evicted worker-process MCP session as + ``{server}-call-tool``. + """ + idle = idle_timeout if idle_timeout is not None else _MCP_CONNECTION_IDLE + + @activity.defn(name=f"{server}-list-tools") + async def list_tools() -> list[_MCPToolInfo]: + client, session, record = await get_connection(server, client_factory, idle) + try: + return _tool_infos(client, await _paginate_list_tools(session)) + except Exception: + # The session may be broken; drop it so the next call reconnects. + await _evict_connection(server) + raise + finally: + record.release() + + return list_tools diff --git a/tests/contrib/strands/test_mcp.py b/tests/contrib/strands/test_mcp.py index bde857022..b65683c4b 100644 --- a/tests/contrib/strands/test_mcp.py +++ b/tests/contrib/strands/test_mcp.py @@ -90,6 +90,7 @@ async def test_mcp(client: Client): history = await handle.fetch_history() assert get_activities(history) == [ + "echo-list-tools", "invoke_model", "echo-call-tool", "invoke_model", @@ -123,9 +124,9 @@ async def run(self, prompt: str) -> str: async def test_mcp_reuses_connection(client: Client): """Successive MCP tool calls reuse one cached worker-side connection.""" task_queue = "test_mcp_reuses_connection" - # Count how often the worker opens a connection. With caching this is one - # startup-discovery connection plus one cached call connection serving both - # tool calls (2); reconnecting per call would make it 3. + # Count how often the worker opens a connection. One lazily-opened + # connection serves the list-tools discovery and both tool calls (1); + # reconnecting per call would make it more. factory_calls = [0] def counting_factory() -> MCPClient: @@ -163,10 +164,11 @@ def counting_factory() -> MCPClient: # The worker context has exited, so its run_context finally evicted the # cached connection. assert "echo_cached" not in _temporal_mcp_client._CONNECTIONS - assert factory_calls[0] == 2 + assert factory_calls[0] == 1 history = await handle.fetch_history() assert get_activities(history) == [ + "echo_cached-list-tools", "invoke_model", "echo_cached-call-tool", "invoke_model", @@ -239,8 +241,11 @@ def counting_factory() -> MCPClient: ) assert await handle.result() == "Done!\n" - # The call opened a second connection (startup discovery was the first). - assert factory_calls[0] == 2 + # A connection was opened lazily (on the first list-tools/call-tool). + # How many times depends on whether the short idle timer fires between + # activities, so this only asserts that at least one was opened; the + # eviction-while-alive behavior is what the polling loop below checks. + assert factory_calls[0] >= 1 # Still inside the worker context: the short idle timer evicts the # cached call connection on its own. Asserting eviction here -- with the @@ -250,3 +255,69 @@ def counting_factory() -> MCPClient: break await asyncio.sleep(0.1) assert "echo_idle" not in _temporal_mcp_client._CONNECTIONS + + +@workflow.defn +class MCPNoCacheWorkflow: + def __init__(self) -> None: + echo = TemporalMCPClient( + server="echo_nocache", + cache_tools=False, + start_to_close_timeout=timedelta(seconds=30), + ) + self.agent = TemporalAgent( + model="mock", + start_to_close_timeout=timedelta(seconds=30), + tools=[echo], + ) + + @workflow.run + async def run(self, prompt: str) -> str: + result = await self.agent.invoke_async(prompt) + return str(result) + + +async def test_mcp_lists_tools_each_turn_when_uncached(client: Client): + """With cache_tools=False the tool list is re-fetched on every model call.""" + task_queue = "test_mcp_lists_tools_each_turn_when_uncached" + plugin = StrandsPlugin( + models={ + "mock": lambda: MockModel( + [ + {"name": "echo", "input": {"message": "one"}}, + {"name": "echo", "input": {"message": "two"}}, + "Done!", + ] + ) + }, + mcp_clients={"echo_nocache": _echo_client_factory}, + ) + + async with Worker( + client, + task_queue=task_queue, + workflows=[MCPNoCacheWorkflow], + plugins=[plugin], + max_cached_workflows=0, + ): + handle = await client.start_workflow( + MCPNoCacheWorkflow.run, + "echo twice", + id=f"test_mcp_lists_tools_each_turn_when_uncached_{uuid4()}", + task_queue=task_queue, + ) + assert await handle.result() == "Done!\n" + + history = await handle.fetch_history() + activities = get_activities(history) + # One list-tools per model call -- the tools are re-listed every turn rather + # than once for the workflow. + assert activities.count("echo_nocache-list-tools") == activities.count( + "invoke_model" + ) + assert activities.count("echo_nocache-list-tools") == 3 + + await Replayer( + workflows=[MCPNoCacheWorkflow], + plugins=[plugin], + ).replay_workflow(history) From b91554334981f59b21ad38a8dfca3272269baa67 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Mon, 1 Jun 2026 15:22:50 -0700 Subject: [PATCH 2/2] contrib/strands: default cache_tools to False --- temporalio/contrib/strands/README.md | 4 ++-- temporalio/contrib/strands/_temporal_mcp_client.py | 12 ++++++------ tests/contrib/strands/test_mcp.py | 3 +++ 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/temporalio/contrib/strands/README.md b/temporalio/contrib/strands/README.md index 3f2c49fba..126f4bd95 100644 --- a/temporalio/contrib/strands/README.md +++ b/temporalio/contrib/strands/README.md @@ -414,10 +414,10 @@ Worker( Each factory returns a fully configured `MCPClient`, so you can pass options like `tool_filters`, `prefix`, `elicitation_callback`, or `tasks_config` to it. -By default, `TemporalMCPClient` lists the server's tools once at the beginning of the workflow (via `{name}-list-tools`) and reuses that schema for the workflow's lifetime. To pick up an MCP server that is restarted mid-workflow — with tools added, removed, or renamed — set `cache_tools=False`, and the tools are re-listed on every agent turn instead: +By default, `TemporalMCPClient` re-lists the server's tools (via `{name}-list-tools`) on every agent turn, so an MCP server that is restarted mid-workflow — with tools added, removed, or renamed — is picked up. To list the tools just once at the beginning of the workflow and reuse that schema for the workflow's lifetime (one fewer activity per turn), set `cache_tools=True`: ```python -echo = TemporalMCPClient(server="echo", cache_tools=False, start_to_close_timeout=timedelta(seconds=30)) +echo = TemporalMCPClient(server="echo", cache_tools=True, start_to_close_timeout=timedelta(seconds=30)) ``` To amortize connection setup, the `{name}-call-tool` and `{name}-list-tools` activities share a worker-process MCP connection that is opened lazily and reused across calls. The connection is disconnected after it sits idle for `mcp_connection_idle_timeout` (default 5 minutes); the timer resets on every reuse: diff --git a/temporalio/contrib/strands/_temporal_mcp_client.py b/temporalio/contrib/strands/_temporal_mcp_client.py index 00605ae2c..bb096956e 100644 --- a/temporalio/contrib/strands/_temporal_mcp_client.py +++ b/temporalio/contrib/strands/_temporal_mcp_client.py @@ -43,11 +43,11 @@ class TemporalMCPClient(ToolProvider): activity, dispatched from inside the workflow by ``TemporalAgent`` before each model call. - ``cache_tools`` controls how often that listing happens. When ``True`` (the - default) the tools are listed once at the beginning of the workflow and - reused for its lifetime. When ``False`` the tools are re-listed on every - agent turn, so an MCP server restarted mid-workflow (with tools added, - removed, or renamed) is picked up. + ``cache_tools`` controls how often that listing happens. When ``False`` + (the default) the tools are re-listed on every agent turn, so an MCP server + restarted mid-workflow (with tools added, removed, or renamed) is picked up. + When ``True`` the tools are listed once at the beginning of the workflow and + reused for its lifetime. Construct once at module level and pass to ``TemporalAgent(tools=[...])`` inside the workflow. Multiple handles may reference the same server name @@ -58,7 +58,7 @@ def __init__( self, server: str, *, - cache_tools: bool = True, + cache_tools: bool = False, task_queue: str | None = None, schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, diff --git a/tests/contrib/strands/test_mcp.py b/tests/contrib/strands/test_mcp.py index b65683c4b..0f989cd83 100644 --- a/tests/contrib/strands/test_mcp.py +++ b/tests/contrib/strands/test_mcp.py @@ -36,6 +36,7 @@ class MCPWorkflow: def __init__(self) -> None: echo = TemporalMCPClient( server="echo", + cache_tools=True, start_to_close_timeout=timedelta(seconds=30), ) self.agent = TemporalAgent( @@ -107,6 +108,7 @@ class MCPReuseWorkflow: def __init__(self) -> None: echo = TemporalMCPClient( server="echo_cached", + cache_tools=True, start_to_close_timeout=timedelta(seconds=30), ) self.agent = TemporalAgent( @@ -187,6 +189,7 @@ class MCPIdleWorkflow: def __init__(self) -> None: echo = TemporalMCPClient( server="echo_idle", + cache_tools=True, start_to_close_timeout=timedelta(seconds=30), ) self.agent = TemporalAgent(