diff --git a/src/basic_memory/mcp/async_client.py b/src/basic_memory/mcp/async_client.py index 76580f05..421e658b 100644 --- a/src/basic_memory/mcp/async_client.py +++ b/src/basic_memory/mcp/async_client.py @@ -128,11 +128,14 @@ async def get_cloud_control_plane_client( yield client -# Optional factory override for dependency injection -_client_factory: Optional[Callable[[], AbstractAsyncContextManager[AsyncClient]]] = None +# Optional factory override for dependency injection. +# The factory accepts an optional workspace keyword argument so that MCP tools +# can route individual requests to a different workspace than the one set at +# connection time. See basic-memory-cloud main.py tenant_asgi_client_factory. +_client_factory: Optional[Callable[..., AbstractAsyncContextManager[AsyncClient]]] = None -def set_client_factory(factory: Callable[[], AbstractAsyncContextManager[AsyncClient]]) -> None: +def set_client_factory(factory: Callable[..., AbstractAsyncContextManager[AsyncClient]]) -> None: """Override the default client factory (for cloud app, testing, etc).""" global _client_factory _client_factory = factory @@ -173,7 +176,7 @@ async def get_client( 4. Local ASGI transport by default. """ if _client_factory: - async with _client_factory() as client: + async with _client_factory(workspace=workspace) as client: yield client return diff --git a/src/basic_memory/mcp/project_context.py b/src/basic_memory/mcp/project_context.py index c29fc152..2c2f58df 100644 --- a/src/basic_memory/mcp/project_context.py +++ b/src/basic_memory/mcp/project_context.py @@ -622,10 +622,10 @@ async def get_project_client( # Step 1b: Factory injection (in-process cloud server) # Trigger: set_client_factory() was called (e.g., by cloud MCP server) - # Why: the transport layer already resolved workspace and tenant context; - # attempting cloud workspace resolution here would call the production - # control-plane API with no valid credentials and fail with 401 - # Outcome: use the factory client directly, skip workspace resolution + # Why: the factory's transport layer handles auth and tenant resolution; + # we pass workspace through so the transport can route to the correct + # workspace when the tool specifies one different from the connection default + # Outcome: factory client with optional workspace override via inner request headers if is_factory_mode(): route_mode = "factory" with telemetry.scope( @@ -635,7 +635,7 @@ async def get_project_client( workspace_id=workspace, ): logger.debug("Using injected client factory for project routing") - async with get_client() as client: + async with get_client(workspace=workspace) as client: active_project = await get_active_project(client, resolved_project, context) yield client, active_project return diff --git a/tests/mcp/test_async_client_modes.py b/tests/mcp/test_async_client_modes.py index 5a0d65f9..ba63f5ba 100644 --- a/tests/mcp/test_async_client_modes.py +++ b/tests/mcp/test_async_client_modes.py @@ -28,7 +28,7 @@ async def test_get_client_uses_injected_factory(monkeypatch): seen = {"used": False} @asynccontextmanager - async def factory(): + async def factory(workspace=None): seen["used"] = True async with httpx.AsyncClient(base_url="https://example.test") as client: yield client @@ -185,7 +185,7 @@ async def test_get_client_factory_overrides_per_project_routing(config_manager): config_manager.save_config(cfg) @asynccontextmanager - async def factory(): + async def factory(workspace=None): async with httpx.AsyncClient(base_url="https://factory.test") as client: yield client diff --git a/tests/mcp/test_project_context.py b/tests/mcp/test_project_context.py index 15c1dac1..94a5c94f 100644 --- a/tests/mcp/test_project_context.py +++ b/tests/mcp/test_project_context.py @@ -709,7 +709,7 @@ async def test_factory_mode_skips_workspace_resolution(self, config_manager, mon # Set up a factory (simulates what cloud MCP server does) @asynccontextmanager - async def fake_factory(): + async def fake_factory(workspace=None): from httpx import ASGITransport, AsyncClient from basic_memory.api.app import app as fastapi_app