Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions temporalio/contrib/strands/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ class ChatWorkflow:

## MCP

`StrandsPlugin(mcp_clients=...)` takes a mapping of `name → MCPClient factory`, mirroring the `models=` pattern. The plugin registers a per-server `{name}-call-tool` activity and connects at worker startup to enumerate tools. Workflow-side, `TemporalMCPClient(server="name")` is a pure handle: it references the server by name and carries the per-call activity options.
`StrandsPlugin(mcp_clients=...)` takes a mapping of `name → MCPClient factory`, mirroring the `models=` pattern. The plugin registers per-server `{name}-call-tool` and `{name}-list-tools` activities. Workflow-side, `TemporalMCPClient(server="name")` is a pure handle: it references the server by name, discovers tools by running `{name}-list-tools`, and carries the per-call activity options.

```python
from mcp import StdioServerParameters, stdio_client
Expand Down Expand Up @@ -412,9 +412,15 @@ Worker(
)
```

Each factory returns a fully configured `MCPClient`, so you can pass options like `tool_filters`, `prefix`, `elicitation_callback`, or `tasks_config` to it. The plugin connects to each MCP server once at worker startup to enumerate tools. The schema is frozen for the worker's lifetime; restart workers to pick up MCP-server changes. If a server is unavailable at startup, the worker fails to start.
Each factory returns a fully configured `MCPClient`, so you can pass options like `tool_filters`, `prefix`, `elicitation_callback`, or `tasks_config` to it.

To amortize connection setup, the `{name}-call-tool` activity keeps a worker-process MCP connection open between calls and reuses it. The connection is disconnected after it sits idle for `mcp_connection_idle_timeout` (default 5 minutes); the timer resets on every reuse:
By default, `TemporalMCPClient` re-lists the server's tools (via `{name}-list-tools`) on every agent turn, so an MCP server that is restarted mid-workflow — with tools added, removed, or renamed — is picked up. To list the tools just once at the beginning of the workflow and reuse that schema for the workflow's lifetime (one fewer activity per turn), set `cache_tools=True`:

```python
echo = TemporalMCPClient(server="echo", cache_tools=True, start_to_close_timeout=timedelta(seconds=30))
```

To amortize connection setup, the `{name}-call-tool` and `{name}-list-tools` activities share a worker-process MCP connection that is opened lazily and reused across calls. The connection is disconnected after it sits idle for `mcp_connection_idle_timeout` (default 5 minutes); the timer resets on every reuse:

```python
StrandsPlugin(
Expand Down
20 changes: 11 additions & 9 deletions temporalio/contrib/strands/_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
from ._temporal_mcp_client import (
_evict_connection,
build_call_tool_activity,
clear_cache,
populate_cache,
build_list_tools_activity,
)


Expand All @@ -31,10 +30,11 @@ class StrandsPlugin(SimplePlugin):
on first use, then cached for the worker's lifetime. Use the same name in
``TemporalAgent(model=...)`` inside the workflow.

When ``mcp_clients`` is supplied, registers a per-server
``{server}-call-tool`` activity for each entry and, at worker startup,
connects to each MCP server to cache its tool list. Workflow-side
``TemporalMCPClient(server="...").load_tools()`` reads from the cache.
When ``mcp_clients`` is supplied, registers per-server
``{server}-call-tool`` and ``{server}-list-tools`` activities for each
entry. Workflow-side ``TemporalMCPClient(server="...")`` discovers tools by
running ``{server}-list-tools``; whether it lists once per workflow or once
per agent turn is controlled by its ``cache_tools`` option.

``mcp_connection_idle_timeout`` controls how long a worker-process MCP
connection is kept open between ``call-tool`` activities before it is
Expand Down Expand Up @@ -69,17 +69,19 @@ def __init__(
server, client_factory, mcp_connection_idle_timeout
)
)
activities.append(
build_list_tools_activity(
server, client_factory, mcp_connection_idle_timeout
)
)

@asynccontextmanager
async def run_context() -> AsyncGenerator[None, None]:
for server, client_factory in mcp_clients.items():
await populate_cache(server, client_factory)
try:
yield
finally:
for server in mcp_clients:
await _evict_connection(server)
clear_cache(server)

super().__init__(
"aws.StrandsPlugin",
Expand Down
47 changes: 47 additions & 0 deletions temporalio/contrib/strands/_temporal_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from typing import Any

from strands import Agent
from strands.hooks import BeforeModelCallEvent, HookCallback

from temporalio.common import Priority, RetryPolicy
from temporalio.workflow import ActivityCancellationType, VersioningIntent

from ._temporal_mcp_client import TemporalMCPClient
from ._temporal_model import TemporalModel

_SNAPSHOT_DISABLED = (
Expand Down Expand Up @@ -76,6 +78,51 @@ def __init__(
)
super().__init__(model=temporal_model, **agent_kwargs)

# Strands invokes ToolProvider.load_tools() once at construction on a
# separate run_async thread that has no workflow runtime, so a
# TemporalMCPClient cannot list its tools there. Instead refresh from a
# BeforeModelCallEvent hook, which runs on the workflow loop just before
# the registry is read each turn. cache_tools=True lists once (guarded
# by _fetched); cache_tools=False re-lists every turn.
for provider in self.tool_registry._tool_providers:
if isinstance(provider, TemporalMCPClient):
self.hooks.add_callback(
BeforeModelCallEvent, self._make_mcp_refresh_hook(provider)
)

def _make_mcp_refresh_hook(
self, provider: TemporalMCPClient
) -> HookCallback[BeforeModelCallEvent]:
async def hook(event: BeforeModelCallEvent) -> None:
if provider._cache_tools and provider._fetched:
return
old_names = {tool.tool_name for tool in provider._tools}
await provider._refresh()
self._reconcile_mcp_tools(event, provider, old_names)

return hook

def _reconcile_mcp_tools(
self,
event: BeforeModelCallEvent,
provider: TemporalMCPClient,
old_names: set[str],
) -> None:
reg = event.agent.tool_registry
new = {tool.tool_name: tool for tool in provider._tools}
# Tools the server dropped or renamed since the last listing. There is
# no public unregister, so remove them from the registry directly.
for name in old_names - set(new):
reg.registry.pop(name, None)
reg.dynamic_tools.pop(name, None)
# replace() swaps an existing tool in place (no hot-reload guard);
# register_tool() adds a newly-discovered one.
for name, tool in new.items():
if name in reg.registry:
reg.replace(tool)
else:
reg.register_tool(tool)

def take_snapshot(self, *_args: Any, **_kwargs: Any) -> Any:
"""Disabled; Temporal's event history is the source of truth."""
raise NotImplementedError(_SNAPSHOT_DISABLED)
Expand Down
158 changes: 100 additions & 58 deletions temporalio/contrib/strands/_temporal_mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from strands.tools.mcp.mcp_types import MCPToolResult
from strands.types.tools import AgentTool

from temporalio import activity
from temporalio import activity, workflow
from temporalio.common import Priority, RetryPolicy
from temporalio.workflow import ActivityCancellationType, VersioningIntent

Expand All @@ -33,20 +33,21 @@ class _CallToolArgs:
tool_use_id: str = ""


# Server name -> cached tool list. Populated by ``_populate_cache`` at worker
# startup and read by ``TemporalMCPClient.load_tools()`` inside the workflow
# sandbox. ``temporalio`` is in the SDK's default sandbox passthrough, so this
# dict is shared between worker process and workflow execution.
_TOOL_CACHE: dict[str, list[_MCPToolInfo]] = {}


class TemporalMCPClient(ToolProvider):
"""Workflow-side handle to an MCP server registered on the worker.

The transport factory and tool discovery live worker-side via
``StrandsPlugin(mcp_clients={"server": lambda: ...})``. This handle only
carries the server name (which selects the registered factory) and the
per-call activity options.
The transport factory lives worker-side via
``StrandsPlugin(mcp_clients={"server": lambda: ...})``. This handle carries
the server name (which selects the registered factory) and the per-call
activity options. Tool discovery runs as the ``{server}-list-tools``
activity, dispatched from inside the workflow by ``TemporalAgent`` before
each model call.

``cache_tools`` controls how often that listing happens. When ``False``
(the default) the tools are re-listed on every agent turn, so an MCP server
restarted mid-workflow (with tools added, removed, or renamed) is picked up.
When ``True`` the tools are listed once at the beginning of the workflow and
reused for its lifetime.

Construct once at module level and pass to ``TemporalAgent(tools=[...])``
inside the workflow. Multiple handles may reference the same server name
Expand All @@ -57,6 +58,7 @@ def __init__(
self,
server: str,
*,
cache_tools: bool = False,
task_queue: str | None = None,
schedule_to_close_timeout: timedelta | None = None,
schedule_to_start_timeout: timedelta | None = None,
Expand All @@ -70,6 +72,9 @@ def __init__(
) -> None:
"""Configure the server name and activity options."""
self._server = server
self._cache_tools = cache_tools
self._tools: list[AgentTool] = []
self._fetched = False
self._options: dict[str, Any] = {
"task_queue": task_queue,
"schedule_to_close_timeout": schedule_to_close_timeout,
Expand All @@ -89,11 +94,33 @@ def server(self) -> str:
return self._server

async def load_tools(self, **_kwargs: Any) -> Sequence[AgentTool]:
"""Return TemporalMCPTool wrappers for tools cached at worker startup."""
"""Return the tools fetched by the most recent ``_refresh``.

This must stay free of any ``workflow`` API: Strands invokes it once at
``Agent`` construction on a separate ``run_async`` thread that has no
workflow runtime. ``TemporalAgent`` populates the tools by calling
``_refresh`` from a ``BeforeModelCallEvent`` hook before the registry is
first read.
"""
return list(self._tools)

async def _refresh(self) -> None:
"""List the server's tools via the ``{server}-list-tools`` activity.

Runs on the workflow event loop (dispatched from ``TemporalAgent``'s
hook), so the activity result is recorded in history and replay-safe.
"""
from ._temporal_mcp_tool import TemporalMCPTool

infos = _TOOL_CACHE.get(self._server, [])
return [TemporalMCPTool(self._server, info, self._options) for info in infos]
infos: list[_MCPToolInfo] = await workflow.execute_activity(
f"{self._server}-list-tools",
result_type=list[_MCPToolInfo],
**self._options,
)
self._tools = [
TemporalMCPTool(self._server, info, self._options) for info in infos
]
self._fetched = True

def add_consumer(self, consumer_id: Any, **_kwargs: Any) -> None:
"""No-op; consumer tracking is handled by the underlying MCP client."""
Expand All @@ -104,45 +131,37 @@ def remove_consumer(self, consumer_id: Any, **_kwargs: Any) -> None:
return None


# Use MCP sessions directly instead of MCPClient's background-thread helpers.
# Those helpers route calls through cross-loop futures that are unreliable on
# Python 3.10 when invoked from Temporal's async worker/activity event loops.
async def _list_mcp_tools(client: MCPClient) -> Sequence[Tool]:
async with client._transport_callable() as (read_stream, write_stream, *_):
async with ClientSession(
read_stream,
write_stream,
elicitation_callback=client._elicitation_callback,
) as session:
await session.initialize()
tools: list[Tool] = []
pagination_token = None
while True:
page = await session.list_tools(
params=PaginatedRequestParams(cursor=pagination_token)
if pagination_token is not None
else None
)
tools.extend(page.tools)
pagination_token = page.nextCursor
if pagination_token is None:
return tools


def _agent_tool_for_filtering(client: MCPClient, tool: Tool) -> MCPAgentTool:
if client._prefix:
return MCPAgentTool(tool, client, name_override=f"{client._prefix}_{tool.name}")
return MCPAgentTool(tool, client)


async def populate_cache(server: str, client_factory: Callable[[], MCPClient]) -> None:
"""Connect to the MCP server, list tools, fill ``_TOOL_CACHE``."""
client = client_factory()
# Use the MCP session directly instead of MCPClient's background-thread
# helpers. Those helpers route calls through cross-loop futures that are
# unreliable on Python 3.10 when invoked from Temporal's async worker/activity
# event loops.
async def _paginate_list_tools(session: ClientSession) -> list[Tool]:
tools: list[Tool] = []
pagination_token = None
while True:
page = await session.list_tools(
params=PaginatedRequestParams(cursor=pagination_token)
if pagination_token is not None
else None
)
tools.extend(page.tools)
pagination_token = page.nextCursor
if pagination_token is None:
return tools


def _tool_infos(client: MCPClient, tools: Sequence[Tool]) -> list[_MCPToolInfo]:
"""Apply the client's tool filters and project to serializable records."""
infos: list[_MCPToolInfo] = []
for tool in await _list_mcp_tools(client):
for tool in tools:
if client._prefix:
agent_tool = MCPAgentTool(
tool, client, name_override=f"{client._prefix}_{tool.name}"
)
else:
agent_tool = MCPAgentTool(tool, client)
if not client._should_include_tool_with_filters(
_agent_tool_for_filtering(client, tool),
client._tool_filters,
agent_tool, client._tool_filters
):
continue
infos.append(
Expand All @@ -153,12 +172,7 @@ async def populate_cache(server: str, client_factory: Callable[[], MCPClient]) -
output_schema=tool.outputSchema,
)
)
_TOOL_CACHE[server] = infos


def clear_cache(server: str) -> None:
"""Drop the cached tool list for ``server``."""
_TOOL_CACHE.pop(server, None)
return infos


# Default for how long an idle MCP connection stays open before it is
Expand Down Expand Up @@ -324,3 +338,31 @@ async def call_tool(args: _CallToolArgs) -> MCPToolResult:
record.release()

return call_tool


def build_list_tools_activity(
server: str,
client_factory: Callable[[], MCPClient],
idle_timeout: timedelta | None = None,
) -> Callable:
"""Return the per-server ``{server}-list-tools`` activity for registration.

Lists the server's tools (applying the client's tool filters) and reuses
the same lazily-opened, idle-evicted worker-process MCP session as
``{server}-call-tool``.
"""
idle = idle_timeout if idle_timeout is not None else _MCP_CONNECTION_IDLE

@activity.defn(name=f"{server}-list-tools")
async def list_tools() -> list[_MCPToolInfo]:
client, session, record = await get_connection(server, client_factory, idle)
try:
return _tool_infos(client, await _paginate_list_tools(session))
except Exception:
# The session may be broken; drop it so the next call reconnects.
await _evict_connection(server)
raise
finally:
record.release()

return list_tools
Loading
Loading