Skip to content

Commit 721dd43

Browse files
committed
fix(stdio): bound EOF drain wait
1 parent 5332b0e commit 721dd43

3 files changed

Lines changed: 89 additions & 2 deletions

File tree

src/mcp/server/lowlevel/server.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ async def main():
7272

7373
logger = logging.getLogger(__name__)
7474

75+
DEFAULT_READ_EOF_DRAIN_TIMEOUT_SECONDS = 1.0
76+
7577
LifespanResultT = TypeVar("LifespanResultT", default=Any)
7678

7779

@@ -351,6 +353,9 @@ async def run(
351353
# to drain their responses via the still-open write stream (e.g. stdio
352354
# with bash-redirected stdin).
353355
drain_on_read_close: bool = False,
356+
# Maximum time to wait for in-flight handlers to drain after read EOF.
357+
# None means wait indefinitely.
358+
read_eof_drain_timeout_seconds: float | None = DEFAULT_READ_EOF_DRAIN_TIMEOUT_SECONDS,
354359
):
355360
async with AsyncExitStack() as stack:
356361
lifespan_context = await stack.enter_async_context(self.lifespan(self))
@@ -383,7 +388,14 @@ async def run(
383388
raise_exceptions,
384389
)
385390
finally:
386-
if not drain_on_read_close:
391+
if drain_on_read_close:
392+
if read_eof_drain_timeout_seconds is not None:
393+
with anyio.move_on_after(read_eof_drain_timeout_seconds) as drain_scope:
394+
while session.has_in_flight_requests:
395+
await anyio.sleep(0.01)
396+
if drain_scope.cancelled_caught:
397+
tg.cancel_scope.cancel()
398+
else:
387399
# Transport closed: cancel in-flight handlers. Without this the
388400
# TG join waits for them, and when they eventually try to
389401
# respond they hit a closed write stream (the session's

src/mcp/shared/session.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,11 @@ def __init__(
209209
self._exit_stack.push_async_callback(self._read_stream.aclose)
210210
self._exit_stack.push_async_callback(self._write_stream.aclose)
211211

212+
@property
213+
def has_in_flight_requests(self) -> bool:
214+
"""Whether any received requests are still awaiting a response."""
215+
return bool(self._in_flight)
216+
212217
async def __aenter__(self) -> Self:
213218
self._task_group = anyio.create_task_group()
214219
await self._task_group.__aenter__()

tests/server/test_cancel_handling.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,13 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar
120120
server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10)
121121

122122
async def run_server():
123-
await server.run(server_read, server_write, server.create_initialization_options(), drain_on_read_close=True)
123+
await server.run(
124+
server_read,
125+
server_write,
126+
server.create_initialization_options(),
127+
drain_on_read_close=True,
128+
read_eof_drain_timeout_seconds=None,
129+
)
124130
server_run_returned.set()
125131

126132
init_req = JSONRPCRequest(
@@ -166,6 +172,70 @@ async def run_server():
166172
await server_run_returned.wait()
167173

168174

175+
@pytest.mark.anyio
176+
async def test_server_bounds_drain_on_read_eof_when_handler_never_finishes():
177+
handler_started = anyio.Event()
178+
handler_cancelled = anyio.Event()
179+
server_run_returned = anyio.Event()
180+
181+
async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
182+
handler_started.set()
183+
try:
184+
await anyio.sleep_forever()
185+
finally:
186+
handler_cancelled.set()
187+
raise AssertionError # pragma: no cover
188+
189+
server = Server("test", on_call_tool=handle_call_tool)
190+
191+
to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10)
192+
server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10)
193+
194+
async def run_server():
195+
await server.run(
196+
server_read,
197+
server_write,
198+
server.create_initialization_options(),
199+
drain_on_read_close=True,
200+
read_eof_drain_timeout_seconds=0.05,
201+
)
202+
server_run_returned.set()
203+
204+
init_req = JSONRPCRequest(
205+
jsonrpc="2.0",
206+
id=1,
207+
method="initialize",
208+
params=InitializeRequestParams(
209+
protocol_version=LATEST_PROTOCOL_VERSION,
210+
capabilities=ClientCapabilities(),
211+
client_info=Implementation(name="test", version="1.0"),
212+
).model_dump(by_alias=True, mode="json", exclude_none=True),
213+
)
214+
initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")
215+
call_req = JSONRPCRequest(
216+
jsonrpc="2.0",
217+
id=2,
218+
method="tools/call",
219+
params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"),
220+
)
221+
222+
with anyio.fail_after(2):
223+
async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server:
224+
tg.start_soon(run_server)
225+
226+
await to_server.send(SessionMessage(init_req))
227+
await from_server.receive() # init response
228+
await to_server.send(SessionMessage(initialized))
229+
await to_server.send(SessionMessage(call_req))
230+
231+
await handler_started.wait()
232+
await to_server.aclose()
233+
234+
await server_run_returned.wait()
235+
236+
assert handler_cancelled.is_set()
237+
238+
169239
@pytest.mark.anyio
170240
async def test_server_reraises_handler_cancellation_when_server_is_cancelled():
171241
"""If the server task is cancelled (e.g. KeyboardInterrupt), in-flight

0 commit comments

Comments
 (0)