diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 04aed345e..55c5ee729 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -809,7 +809,7 @@ async def _validate_request_headers(self, request: Request, send: Send) -> bool: async def _validate_session(self, request: Request, send: Send) -> bool: """Validate the session ID in the request.""" - if not self.mcp_session_id: # pragma: no cover + if not self.mcp_session_id: # If we're not using session IDs, return True return True @@ -842,7 +842,7 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER) # If no protocol version provided, assume default version - if protocol_version is None: # pragma: no cover + if protocol_version is None: protocol_version = DEFAULT_NEGOTIATED_VERSION # Check if the protocol version is supported diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index c25314eab..8c97ea6cd 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -90,6 +90,9 @@ def __init__( self._session_creation_lock = anyio.Lock() self._server_instances: dict[str, StreamableHTTPServerTransport] = {} + # Track in-flight stateless transports for graceful shutdown + self._stateless_transports: set[StreamableHTTPServerTransport] = set() + # The task group will be set during lifespan self._task_group = None # Thread-safe tracking of run() calls @@ -130,11 +133,28 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]: yield # Let the application run finally: logger.info("StreamableHTTP session manager shutting down") + + # Terminate all active transports before cancelling the task + # group. This closes their in-memory streams, which lets + # EventSourceResponse send a final ``more_body=False`` chunk + # — a clean HTTP close instead of a connection reset. + for transport in list(self._server_instances.values()): + try: + await transport.terminate() + except Exception: # pragma: no cover + logger.debug("Error terminating transport during shutdown", exc_info=True) + for transport in list(self._stateless_transports): + try: + await transport.terminate() + except Exception: # pragma: no cover + logger.debug("Error terminating stateless transport during shutdown", exc_info=True) + # Cancel task group to stop all spawned tasks tg.cancel_scope.cancel() self._task_group = None # Clear any remaining server instances self._server_instances.clear() + self._stateless_transports.clear() async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: """Process ASGI request with proper session handling and transport setup. @@ -161,6 +181,9 @@ async def _handle_stateless_request(self, scope: Scope, receive: Receive, send: security_settings=self.security_settings, ) + # Track for graceful shutdown + self._stateless_transports.add(http_transport) + # Start server in a new task async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED): async with http_transport.connect() as streams: @@ -173,7 +196,7 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA self.app.create_initialization_options(), stateless=True, ) - except Exception: # pragma: no cover + except Exception: # pragma: lax no cover logger.exception("Stateless session crashed") # Assert task group is not None for type checking @@ -181,8 +204,11 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA # Start the server task await self._task_group.start(run_stateless_server) - # Handle the HTTP request and return the response - await http_transport.handle_request(scope, receive, send) + try: + # Handle the HTTP request and return the response + await http_transport.handle_request(scope, receive, send) + finally: + self._stateless_transports.discard(http_transport) # Terminate the transport after the request is handled await http_transport.terminate() diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 47cfbf14a..ad1425d64 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -10,11 +10,11 @@ import pytest from starlette.types import Message -from mcp import Client +from mcp import Client, types from mcp.client.streamable_http import streamable_http_client from mcp.server import Server, ServerRequestContext, streamable_http_manager from mcp.server.streamable_http import MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport -from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager from mcp.types import INVALID_REQUEST, ListToolsResult, PaginatedRequestParams @@ -413,3 +413,165 @@ def test_session_idle_timeout_rejects_non_positive(): def test_session_idle_timeout_rejects_stateless(): with pytest.raises(RuntimeError, match="not supported in stateless"): StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=30, stateless=True) + + +MCP_HEADERS = { + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", +} + +_INITIALIZE_REQUEST = { + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2025-03-26", + "capabilities": {}, + "clientInfo": {"name": "test", "version": "0.1"}, + }, +} + +_INITIALIZED_NOTIFICATION = { + "jsonrpc": "2.0", + "method": "notifications/initialized", +} + +_TOOL_CALL_REQUEST = { + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": {"name": "slow_tool", "arguments": {"message": "hello"}}, +} + + +def _make_slow_tool_server() -> tuple[Server, anyio.Event]: + """Create an MCP server with a tool that blocks forever, returning + the server and an event that fires when the tool starts executing.""" + tool_started = anyio.Event() + + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + tool_started.set() + await anyio.sleep_forever() + return types.CallToolResult( # pragma: no cover + content=[types.TextContent(type="text", text="never reached")] + ) + + async def handle_list_tools( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListToolsResult: # pragma: no cover + return ListToolsResult( + tools=[ + types.Tool( + name="slow_tool", + description="A tool that blocks forever", + input_schema={"type": "object", "properties": {"message": {"type": "string"}}}, + ) + ] + ) + + app = Server("test-graceful-shutdown", on_call_tool=handle_call_tool, on_list_tools=handle_list_tools) + return app, tool_started + + +class SSECloseTracker: + """ASGI middleware that tracks whether SSE responses close cleanly. + + In HTTP, a clean close means sending a final empty chunk (``0\\r\\n\\r\\n``). + At the ASGI protocol level this corresponds to a + ``{"type": "http.response.body", "more_body": False}`` message. + + Without graceful drain, the server task is cancelled but nothing closes + the stateless transport's streams — the SSE response hangs indefinitely + and never sends the final body. A reverse proxy (e.g. nginx) would log + "upstream prematurely closed connection while reading upstream". + """ + + def __init__(self, app: StreamableHTTPASGIApp) -> None: + self.app = app + self.sse_streams_opened = 0 + self.sse_streams_closed_cleanly = 0 + + async def __call__(self, scope: dict[str, Any], receive: Any, send: Any) -> None: + is_sse = False + + async def tracking_send(message: dict[str, Any]) -> None: + nonlocal is_sse + if message["type"] == "http.response.start": + for name, value in message.get("headers", []): + if name == b"content-type" and b"text/event-stream" in value: + is_sse = True + self.sse_streams_opened += 1 + break + elif message["type"] == "http.response.body" and is_sse: + if not message.get("more_body", False): + self.sse_streams_closed_cleanly += 1 + await send(message) + + await self.app(scope, receive, tracking_send) + + +@pytest.mark.anyio +async def test_graceful_shutdown_closes_sse_streams_cleanly(): + """Verify that shutting down the session manager closes in-flight SSE + streams with a proper ``more_body=False`` ASGI message. + + This is the ASGI equivalent of sending the final HTTP chunk — the signal + that reverse proxies like nginx use to distinguish a clean close from a + connection reset ("upstream prematurely closed connection"). + + Without the graceful-drain fix, stateless transports are not tracked by + the session manager. On shutdown nothing calls ``terminate()`` on them, + so SSE responses hang indefinitely and never send the final body. With + the fix, ``run()``'s finally block iterates ``_stateless_transports`` and + terminates each one, closing the underlying memory streams and letting + ``EventSourceResponse`` complete normally. + """ + app, tool_started = _make_slow_tool_server() + manager = StreamableHTTPSessionManager(app=app, stateless=True) + + tracker = SSECloseTracker(StreamableHTTPASGIApp(manager)) + + manager_ready = anyio.Event() + + with anyio.fail_after(10): + async with anyio.create_task_group() as tg: + + async def run_lifespan_and_shutdown() -> None: + async with manager.run(): + manager_ready.set() + with anyio.fail_after(5): + await tool_started.wait() + # manager.run() exits — graceful shutdown runs here + + async def make_requests() -> None: + with anyio.fail_after(5): + await manager_ready.wait() + async with ( + httpx.ASGITransport(tracker, raise_app_exceptions=False) as transport, + httpx.AsyncClient(transport=transport, base_url="http://testserver") as client, + ): + # Initialize + resp = await client.post("/mcp/", json=_INITIALIZE_REQUEST, headers=MCP_HEADERS) + resp.raise_for_status() + + # Send initialized notification + resp = await client.post("/mcp/", json=_INITIALIZED_NOTIFICATION, headers=MCP_HEADERS) + assert resp.status_code == 202 + + # Send slow tool call — returns an SSE stream that blocks + # until shutdown terminates it + await client.post( + "/mcp/", + json=_TOOL_CALL_REQUEST, + headers=MCP_HEADERS, + timeout=httpx.Timeout(10, connect=5), + ) + + tg.start_soon(run_lifespan_and_shutdown) + tg.start_soon(make_requests) + + assert tracker.sse_streams_opened > 0, "Test should have opened at least one SSE stream" + assert tracker.sse_streams_closed_cleanly == tracker.sse_streams_opened, ( + f"All {tracker.sse_streams_opened} SSE stream(s) should have closed with " + f"more_body=False, but only {tracker.sse_streams_closed_cleanly} did" + )