diff --git a/src/mcp/client/_memory.py b/src/mcp/client/_memory.py index e6e938673..968855e3d 100644 --- a/src/mcp/client/_memory.py +++ b/src/mcp/client/_memory.py @@ -7,11 +7,10 @@ from types import TracebackType from typing import Any -import anyio - from mcp.client._transport import TransportStreams from mcp.server import Server from mcp.server.mcpserver import MCPServer +from mcp.shared._exception_utils import create_task_group as _create_task_group from mcp.shared.memory import create_client_server_memory_streams @@ -48,7 +47,7 @@ async def _connect(self) -> AsyncIterator[TransportStreams]: client_read, client_write = client_streams server_read, server_write = server_streams - async with anyio.create_task_group() as tg: + async with _create_task_group() as tg: # Start server in background tg.start_soon( lambda: actual_server.run( diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 61026aa0c..92b7e779e 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -12,6 +12,7 @@ from httpx_sse._exceptions import SSEError from mcp import types +from mcp.shared._exception_utils import create_task_group from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client from mcp.shared.message import SessionMessage @@ -60,7 +61,7 @@ async def sse_client( read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - async with anyio.create_task_group() as tg: + async with create_task_group() as tg: try: logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}") async with httpx_client_factory( diff --git a/src/mcp/client/stdio.py b/src/mcp/client/stdio.py index 902dc8576..b5e4f8315 100644 --- a/src/mcp/client/stdio.py +++ b/src/mcp/client/stdio.py @@ -20,6 +20,7 @@ get_windows_executable_command, terminate_windows_process_tree, ) +from mcp.shared._exception_utils import create_task_group from mcp.shared.message import SessionMessage logger = logging.getLogger(__name__) @@ -177,7 +178,7 @@ async def stdin_writer(): except anyio.ClosedResourceError: # pragma: no cover await anyio.lowlevel.checkpoint() - async with anyio.create_task_group() as tg, process: + async with create_task_group() as tg, process: tg.start_soon(stdout_reader) tg.start_soon(stdin_writer) try: diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 9f3dd5e0b..3baca052c 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -16,6 +16,7 @@ from pydantic import ValidationError from mcp.client._transport import TransportStreams +from mcp.shared._exception_utils import create_task_group from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.types import ( @@ -546,7 +547,7 @@ async def streamable_http_client( transport = StreamableHTTPTransport(url) - async with anyio.create_task_group() as tg: + async with create_task_group() as tg: try: logger.debug(f"Connecting to StreamableHTTP endpoint: {url}") diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index 79e75fad1..a9ecbeaa4 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -9,6 +9,7 @@ from websockets.typing import Subprotocol from mcp import types +from mcp.shared._exception_utils import create_task_group from mcp.shared.message import SessionMessage @@ -68,7 +69,7 @@ async def ws_writer(): msg_dict = session_message.message.model_dump(by_alias=True, mode="json", exclude_unset=True) await ws.send(json.dumps(msg_dict)) - async with anyio.create_task_group() as tg: + async with create_task_group() as tg: # Start reader and writer tasks tg.start_soon(ws_reader) tg.start_soon(ws_writer) diff --git a/src/mcp/shared/_exception_utils.py b/src/mcp/shared/_exception_utils.py new file mode 100644 index 000000000..72a101f02 --- /dev/null +++ b/src/mcp/shared/_exception_utils.py @@ -0,0 +1,62 @@ +"""Utilities for collapsing ExceptionGroups from anyio task group cancellations. + +When a task group has one real failure and N cancelled siblings, anyio wraps them +all in a BaseExceptionGroup. This makes it hard for callers to classify the root +cause. These utilities extract the single real error when possible. +""" + +from __future__ import annotations + +import sys +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager + +import anyio +from anyio.abc import TaskGroup + +if sys.version_info < (3, 11): # pragma: lax no cover + from exceptiongroup import BaseExceptionGroup # pragma: lax no cover + + +def collapse_exception_group(exc_group: BaseExceptionGroup[BaseException]) -> BaseException: + """Collapse a single-error exception group into the underlying exception. + + When a task in an anyio task group fails, sibling tasks are cancelled, + producing ``Cancelled`` exceptions. The task group then wraps everything + in a ``BaseExceptionGroup``. If there is exactly one non-cancellation + error, this function returns it directly so callers can handle it without + unwrapping. + + Args: + exc_group: The exception group to collapse. + + Returns: + The single non-cancellation exception if there is exactly one, + otherwise the original exception group unchanged. + """ + cancelled_class = anyio.get_cancelled_exc_class() + real_errors: list[BaseException] = [exc for exc in exc_group.exceptions if not isinstance(exc, cancelled_class)] + + if len(real_errors) == 1: + return real_errors[0] + + return exc_group + + +@asynccontextmanager +async def create_task_group() -> AsyncIterator[TaskGroup]: + """Create an anyio task group that collapses single-error exception groups. + + Drop-in replacement for ``anyio.create_task_group()`` that automatically + unwraps ``BaseExceptionGroup`` when there is exactly one non-cancellation + error. This makes error handling transparent for callers — they receive + the original exception instead of a wrapped group. + """ + try: + async with anyio.create_task_group() as tg: + yield tg + except BaseExceptionGroup as eg: + collapsed = collapse_exception_group(eg) + if collapsed is not eg: + raise collapsed from eg + raise diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index b617d702f..952a9a847 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import sys from collections.abc import Callable from contextlib import AsyncExitStack from types import TracebackType @@ -11,6 +12,10 @@ from pydantic import BaseModel, TypeAdapter from typing_extensions import Self +if sys.version_info < (3, 11): # pragma: lax no cover + from exceptiongroup import BaseExceptionGroup # pragma: lax no cover + +from mcp.shared._exception_utils import collapse_exception_group from mcp.shared.exceptions import MCPError from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.response_router import ResponseRouter @@ -228,7 +233,13 @@ async def __aexit__( # would be very surprising behavior), so make sure to cancel the tasks # in the task group. self._task_group.cancel_scope.cancel() - return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + try: + return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + except BaseExceptionGroup as eg: + collapsed = collapse_exception_group(eg) + if collapsed is not eg: + raise collapsed from eg + raise # pragma: no cover async def send_request( self, diff --git a/tests/shared/test_exception_utils.py b/tests/shared/test_exception_utils.py new file mode 100644 index 000000000..2b4651b67 --- /dev/null +++ b/tests/shared/test_exception_utils.py @@ -0,0 +1,132 @@ +"""Tests for exception group collapsing utilities.""" + +import sys + +import anyio +import pytest + +if sys.version_info < (3, 11): # pragma: lax no cover + from exceptiongroup import BaseExceptionGroup # pragma: lax no cover + +from mcp.shared._exception_utils import collapse_exception_group, create_task_group + + +class TestCollapseExceptionGroup: + """Tests for the collapse_exception_group function.""" + + @pytest.mark.anyio + async def test_single_real_error_with_cancelled(self) -> None: + """A single real error alongside Cancelled exceptions should be extracted.""" + real_error = RuntimeError("connection failed") + cancelled = anyio.get_cancelled_exc_class()() + + group = BaseExceptionGroup("test", [real_error, cancelled]) + result = collapse_exception_group(group) + + assert result is real_error + + @pytest.mark.anyio + async def test_single_real_error_only(self) -> None: + """A single real error without Cancelled should be extracted.""" + real_error = ValueError("bad value") + + group = BaseExceptionGroup("test", [real_error]) + result = collapse_exception_group(group) + + assert result is real_error + + @pytest.mark.anyio + async def test_multiple_real_errors_preserved(self) -> None: + """Multiple non-cancellation errors should keep the group intact.""" + err1 = RuntimeError("first") + err2 = ValueError("second") + + group = BaseExceptionGroup("test", [err1, err2]) + result = collapse_exception_group(group) + + assert result is group + + @pytest.mark.anyio + async def test_all_cancelled_preserved(self) -> None: + """All-cancelled groups should be returned as-is.""" + cancelled_class = anyio.get_cancelled_exc_class() + group = BaseExceptionGroup("test", [cancelled_class(), cancelled_class()]) + result = collapse_exception_group(group) + + assert result is group + + @pytest.mark.anyio + async def test_multiple_cancelled_one_real(self) -> None: + """One real error with multiple Cancelled should extract the real error.""" + cancelled_class = anyio.get_cancelled_exc_class() + real_error = ConnectionError("lost connection") + + group = BaseExceptionGroup("test", [cancelled_class(), real_error, cancelled_class()]) + result = collapse_exception_group(group) + + assert result is real_error + + +class TestCreateTaskGroup: + """Tests for the create_task_group context manager.""" + + @pytest.mark.anyio + async def test_single_failure_unwrapped(self) -> None: + """A single task failure should propagate the original exception, not a group.""" + with pytest.raises(RuntimeError, match="task failed"): + async with create_task_group() as tg: + + async def failing_task() -> None: + raise RuntimeError("task failed") + + async def long_task() -> None: + await anyio.sleep(100) + + tg.start_soon(failing_task) + tg.start_soon(long_task) + + @pytest.mark.anyio + async def test_no_failure_clean_exit(self) -> None: + """Task group with no failures should exit cleanly.""" + results: list[int] = [] + async with create_task_group() as tg: + + async def worker(n: int) -> None: + results.append(n) + + tg.start_soon(worker, 1) + tg.start_soon(worker, 2) + + assert sorted(results) == [1, 2] + + @pytest.mark.anyio + async def test_chained_cause(self) -> None: + """The collapsed exception should chain to the original group via __cause__.""" + with pytest.raises(RuntimeError) as exc_info: + async with create_task_group() as tg: + + async def failing_task() -> None: + raise RuntimeError("root cause") + + async def long_task() -> None: + await anyio.sleep(100) + + tg.start_soon(failing_task) + tg.start_soon(long_task) + + assert isinstance(exc_info.value.__cause__, BaseExceptionGroup) + + @pytest.mark.anyio + async def test_multiple_failures_raises_group(self) -> None: + """Multiple real task failures should raise as a BaseExceptionGroup.""" + with pytest.raises(BaseExceptionGroup): + async with create_task_group() as tg: + + async def fail_a() -> None: + raise RuntimeError("error A") + + async def fail_b() -> None: + raise ValueError("error B") + + tg.start_soon(fail_a) + tg.start_soon(fail_b)