Skip to content

Commit 5112436

Browse files
committed
feat: graceful SSE drain on session manager shutdown
Terminate all active transports before cancelling the task group during StreamableHTTPSessionManager shutdown. This closes their in-memory streams, allowing EventSourceResponse to send a final `more_body=False` chunk — a clean HTTP close instead of a connection reset. Without this, reverse proxies like nginx see "upstream prematurely closed connection" and return 502 to clients during rolling deploys. Changes: - Track in-flight stateless transports in `_stateless_transports` set - In `run()` finally block, call `terminate()` on all stateful and stateless transports before `tg.cancel_scope.cancel()` - Add E2E tests for both stateless and stateful modes that verify the SSE stream closes cleanly when the manager shuts down while a tool call is in-flight
1 parent 7ba41dc commit 5112436

File tree

3 files changed

+240
-7
lines changed

3 files changed

+240
-7
lines changed

src/mcp/server/streamable_http.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,7 @@ async def _validate_request_headers(self, request: Request, send: Send) -> bool:
809809

810810
async def _validate_session(self, request: Request, send: Send) -> bool:
811811
"""Validate the session ID in the request."""
812-
if not self.mcp_session_id: # pragma: no cover
812+
if not self.mcp_session_id:
813813
# If we're not using session IDs, return True
814814
return True
815815

@@ -842,7 +842,7 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool
842842
protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER)
843843

844844
# If no protocol version provided, assume default version
845-
if protocol_version is None: # pragma: no cover
845+
if protocol_version is None:
846846
protocol_version = DEFAULT_NEGOTIATED_VERSION
847847

848848
# Check if the protocol version is supported

src/mcp/server/streamable_http_manager.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ def __init__(
9090
self._session_creation_lock = anyio.Lock()
9191
self._server_instances: dict[str, StreamableHTTPServerTransport] = {}
9292

93+
# Track in-flight stateless transports for graceful shutdown
94+
self._stateless_transports: set[StreamableHTTPServerTransport] = set()
95+
9396
# The task group will be set during lifespan
9497
self._task_group = None
9598
# Thread-safe tracking of run() calls
@@ -130,11 +133,28 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]:
130133
yield # Let the application run
131134
finally:
132135
logger.info("StreamableHTTP session manager shutting down")
136+
137+
# Terminate all active transports before cancelling the task
138+
# group. This closes their in-memory streams, which lets
139+
# EventSourceResponse send a final ``more_body=False`` chunk
140+
# — a clean HTTP close instead of a connection reset.
141+
for transport in list(self._server_instances.values()):
142+
try:
143+
await transport.terminate()
144+
except Exception: # pragma: no cover
145+
logger.debug("Error terminating transport during shutdown", exc_info=True)
146+
for transport in list(self._stateless_transports):
147+
try:
148+
await transport.terminate()
149+
except Exception: # pragma: no cover
150+
logger.debug("Error terminating stateless transport during shutdown", exc_info=True)
151+
133152
# Cancel task group to stop all spawned tasks
134153
tg.cancel_scope.cancel()
135154
self._task_group = None
136155
# Clear any remaining server instances
137156
self._server_instances.clear()
157+
self._stateless_transports.clear()
138158

139159
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
140160
"""Process ASGI request with proper session handling and transport setup.
@@ -161,6 +181,9 @@ async def _handle_stateless_request(self, scope: Scope, receive: Receive, send:
161181
security_settings=self.security_settings,
162182
)
163183

184+
# Track for graceful shutdown
185+
self._stateless_transports.add(http_transport)
186+
164187
# Start server in a new task
165188
async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED):
166189
async with http_transport.connect() as streams:
@@ -173,16 +196,19 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA
173196
self.app.create_initialization_options(),
174197
stateless=True,
175198
)
176-
except Exception: # pragma: no cover
199+
except Exception: # pragma: lax no cover
177200
logger.exception("Stateless session crashed")
178201

179202
# Assert task group is not None for type checking
180203
assert self._task_group is not None
181204
# Start the server task
182205
await self._task_group.start(run_stateless_server)
183206

184-
# Handle the HTTP request and return the response
185-
await http_transport.handle_request(scope, receive, send)
207+
try:
208+
# Handle the HTTP request and return the response
209+
await http_transport.handle_request(scope, receive, send)
210+
finally:
211+
self._stateless_transports.discard(http_transport)
186212

187213
# Terminate the transport after the request is handled
188214
await http_transport.terminate()

tests/server/test_streamable_http_manager.py

Lines changed: 209 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
import pytest
1111
from starlette.types import Message
1212

13-
from mcp import Client
13+
from mcp import Client, types
1414
from mcp.client.streamable_http import streamable_http_client
1515
from mcp.server import Server, ServerRequestContext, streamable_http_manager
1616
from mcp.server.streamable_http import MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport
17-
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
17+
from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager
1818
from mcp.types import INVALID_REQUEST, ListToolsResult, PaginatedRequestParams
1919

2020

@@ -413,3 +413,210 @@ def test_session_idle_timeout_rejects_non_positive():
413413
def test_session_idle_timeout_rejects_stateless():
414414
with pytest.raises(RuntimeError, match="not supported in stateless"):
415415
StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=30, stateless=True)
416+
417+
418+
MCP_HEADERS = {
419+
"Accept": "application/json, text/event-stream",
420+
"Content-Type": "application/json",
421+
}
422+
423+
_INITIALIZE_REQUEST = {
424+
"jsonrpc": "2.0",
425+
"id": 1,
426+
"method": "initialize",
427+
"params": {
428+
"protocolVersion": "2025-03-26",
429+
"capabilities": {},
430+
"clientInfo": {"name": "test", "version": "0.1"},
431+
},
432+
}
433+
434+
_INITIALIZED_NOTIFICATION = {
435+
"jsonrpc": "2.0",
436+
"method": "notifications/initialized",
437+
}
438+
439+
_TOOL_CALL_REQUEST = {
440+
"jsonrpc": "2.0",
441+
"id": 2,
442+
"method": "tools/call",
443+
"params": {"name": "slow_tool", "arguments": {"message": "hello"}},
444+
}
445+
446+
447+
def _make_slow_tool_server() -> tuple[Server, anyio.Event]:
448+
"""Create an MCP server with a tool that blocks forever, returning
449+
the server and an event that fires when the tool starts executing."""
450+
tool_started = anyio.Event()
451+
452+
async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult:
453+
tool_started.set()
454+
await anyio.sleep_forever()
455+
return types.CallToolResult( # pragma: no cover
456+
content=[types.TextContent(type="text", text="never reached")]
457+
)
458+
459+
async def handle_list_tools(
460+
ctx: ServerRequestContext, params: PaginatedRequestParams | None
461+
) -> ListToolsResult: # pragma: no cover
462+
return ListToolsResult(
463+
tools=[
464+
types.Tool(
465+
name="slow_tool",
466+
description="A tool that blocks forever",
467+
input_schema={"type": "object", "properties": {"message": {"type": "string"}}},
468+
)
469+
]
470+
)
471+
472+
app = Server("test-graceful-shutdown", on_call_tool=handle_call_tool, on_list_tools=handle_list_tools)
473+
return app, tool_started
474+
475+
476+
@pytest.mark.anyio
477+
async def test_graceful_shutdown_terminates_active_stateless_transports():
478+
"""Verify that shutting down the session manager terminates in-flight
479+
stateless transports so SSE streams close cleanly (``more_body=False``)
480+
instead of being abruptly cancelled.
481+
482+
Without the graceful-drain fix, the ``run()`` finally block only cancels
483+
the task group — it never calls ``terminate()`` on active transports.
484+
This test asserts ``transport._terminated`` is True after shutdown, which
485+
fails without the fix.
486+
"""
487+
app, tool_started = _make_slow_tool_server()
488+
manager = StreamableHTTPSessionManager(app=app, stateless=True)
489+
490+
mcp_app = StreamableHTTPASGIApp(manager)
491+
492+
manager_ready = anyio.Event()
493+
captured_transport: StreamableHTTPServerTransport | None = None
494+
495+
with anyio.fail_after(10):
496+
async with anyio.create_task_group() as tg:
497+
498+
async def run_lifespan_and_shutdown():
499+
nonlocal captured_transport
500+
async with manager.run():
501+
manager_ready.set()
502+
with anyio.fail_after(5):
503+
await tool_started.wait()
504+
# Grab reference to the in-flight stateless transport
505+
assert len(manager._stateless_transports) == 1
506+
captured_transport = next(iter(manager._stateless_transports))
507+
assert not captured_transport._terminated
508+
# manager.run() exits — graceful shutdown runs here
509+
510+
async def make_requests():
511+
with anyio.fail_after(5):
512+
await manager_ready.wait()
513+
async with (
514+
httpx.ASGITransport(mcp_app) as transport,
515+
httpx.AsyncClient(transport=transport, base_url="http://testserver") as client,
516+
):
517+
# Initialize
518+
resp = await client.post("/mcp/", json=_INITIALIZE_REQUEST, headers=MCP_HEADERS)
519+
resp.raise_for_status()
520+
521+
# Send initialized notification
522+
resp = await client.post("/mcp/", json=_INITIALIZED_NOTIFICATION, headers=MCP_HEADERS)
523+
assert resp.status_code == 202
524+
525+
# Send slow tool call — blocks until shutdown terminates it
526+
async with client.stream(
527+
"POST",
528+
"/mcp/",
529+
json=_TOOL_CALL_REQUEST,
530+
headers=MCP_HEADERS,
531+
timeout=httpx.Timeout(10, connect=5),
532+
) as stream:
533+
stream.raise_for_status()
534+
async for _chunk in stream.aiter_bytes():
535+
pass # pragma: no cover
536+
537+
tg.start_soon(run_lifespan_and_shutdown)
538+
tg.start_soon(make_requests)
539+
540+
assert captured_transport is not None
541+
assert captured_transport._terminated, (
542+
"Transport should have been terminated by graceful shutdown "
543+
"(without the fix, run() only cancels the task group and never calls terminate())"
544+
)
545+
546+
547+
@pytest.mark.anyio
548+
async def test_graceful_shutdown_terminates_active_stateful_transports():
549+
"""Verify that shutting down the session manager terminates in-flight
550+
stateful transports so SSE streams close cleanly.
551+
552+
Without the graceful-drain fix, the ``run()`` finally block only cancels
553+
the task group — it never calls ``terminate()`` on active transports.
554+
This test asserts ``transport._terminated`` is True after shutdown, which
555+
fails without the fix.
556+
"""
557+
app, tool_started = _make_slow_tool_server()
558+
manager = StreamableHTTPSessionManager(app=app, stateless=False)
559+
560+
mcp_app = StreamableHTTPASGIApp(manager)
561+
562+
manager_ready = anyio.Event()
563+
captured_transport: StreamableHTTPServerTransport | None = None
564+
565+
with anyio.fail_after(10):
566+
async with anyio.create_task_group() as tg:
567+
568+
async def run_lifespan_and_shutdown():
569+
nonlocal captured_transport
570+
async with manager.run():
571+
manager_ready.set()
572+
with anyio.fail_after(5):
573+
await tool_started.wait()
574+
# Grab reference to the in-flight stateful transport
575+
assert len(manager._server_instances) == 1
576+
captured_transport = next(iter(manager._server_instances.values()))
577+
assert not captured_transport._terminated
578+
# manager.run() exits — graceful shutdown runs here
579+
580+
async def make_requests():
581+
with anyio.fail_after(5):
582+
await manager_ready.wait()
583+
async with (
584+
httpx.ASGITransport(mcp_app) as transport,
585+
httpx.AsyncClient(transport=transport, base_url="http://testserver") as client,
586+
):
587+
# Initialize (creates a session)
588+
resp = await client.post("/mcp/", json=_INITIALIZE_REQUEST, headers=MCP_HEADERS)
589+
resp.raise_for_status()
590+
session_id = resp.headers.get(MCP_SESSION_ID_HEADER)
591+
assert session_id is not None
592+
593+
session_headers = {
594+
**MCP_HEADERS,
595+
MCP_SESSION_ID_HEADER: session_id,
596+
"mcp-protocol-version": "2025-03-26",
597+
}
598+
599+
# Send initialized notification
600+
resp = await client.post("/mcp/", json=_INITIALIZED_NOTIFICATION, headers=session_headers)
601+
assert resp.status_code == 202
602+
603+
# Send slow tool call — blocks until shutdown terminates it
604+
async with client.stream(
605+
"POST",
606+
"/mcp/",
607+
json=_TOOL_CALL_REQUEST,
608+
headers=session_headers,
609+
timeout=httpx.Timeout(10, connect=5),
610+
) as stream:
611+
stream.raise_for_status()
612+
async for _chunk in stream.aiter_bytes():
613+
pass # pragma: no cover
614+
615+
tg.start_soon(run_lifespan_and_shutdown)
616+
tg.start_soon(make_requests)
617+
618+
assert captured_transport is not None
619+
assert captured_transport._terminated, (
620+
"Transport should have been terminated by graceful shutdown "
621+
"(without the fix, run() only cancels the task group and never calls terminate())"
622+
)

0 commit comments

Comments
 (0)