From 349661c14f0ea4a3b2995db77fa9c04e14e5ce83 Mon Sep 17 00:00:00 2001 From: Varun Sharma Date: Sat, 7 Mar 2026 15:46:20 +0530 Subject: [PATCH 1/5] fix: collapse single-exception ExceptionGroups from task groups Replace all 16 anyio.create_task_group() calls with create_mcp_task_group() which automatically unwraps BaseExceptionGroups containing a single exception. This allows callers to catch specific error types (e.g. except ConnectionError) instead of having to handle ExceptionGroup wrapping. - Add src/mcp/shared/_task_group.py with collapse_exception_group() utility and _CollapsingTaskGroup wrapper class - Update all client transports (sse, stdio, websocket, streamable_http, memory) - Update all server transports (sse, stdio, websocket, streamable_http) - Update shared session, session_group, lowlevel server, task_support, task_result_handler, and streamable_http_manager - Add builtins config for BaseExceptionGroup/ExceptionGroup in ruff - Add 12 comprehensive tests covering collapse logic and integration Closes #2114 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyproject.toml | 1 + src/mcp/client/_memory.py | 5 +- src/mcp/client/session_group.py | 4 +- src/mcp/client/sse.py | 3 +- src/mcp/client/stdio.py | 3 +- src/mcp/client/streamable_http.py | 3 +- src/mcp/client/websocket.py | 3 +- .../experimental/task_result_handler.py | 5 +- src/mcp/server/experimental/task_support.py | 4 +- src/mcp/server/lowlevel/server.py | 3 +- src/mcp/server/sse.py | 3 +- src/mcp/server/stdio.py | 3 +- src/mcp/server/streamable_http.py | 5 +- src/mcp/server/streamable_http_manager.py | 3 +- src/mcp/server/websocket.py | 3 +- src/mcp/shared/_task_group.py | 89 ++++++++++ src/mcp/shared/session.py | 3 +- tests/shared/test_task_group.py | 158 ++++++++++++++++++ 18 files changed, 279 insertions(+), 22 deletions(-) create mode 100644 src/mcp/shared/_task_group.py create mode 100644 tests/shared/test_task_group.py diff --git a/pyproject.toml b/pyproject.toml index 737839a23..2766017c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,6 +126,7 @@ executionEnvironments = [ line-length = 120 target-version = "py310" extend-exclude = ["README.md", "README.v2.md"] +builtins = ["BaseExceptionGroup", "ExceptionGroup"] [tool.ruff.lint] select = [ diff --git a/src/mcp/client/_memory.py b/src/mcp/client/_memory.py index e6e938673..49e0849a8 100644 --- a/src/mcp/client/_memory.py +++ b/src/mcp/client/_memory.py @@ -7,11 +7,10 @@ from types import TracebackType from typing import Any -import anyio - from mcp.client._transport import TransportStreams from mcp.server import Server from mcp.server.mcpserver import MCPServer +from mcp.shared._task_group import create_mcp_task_group from mcp.shared.memory import create_client_server_memory_streams @@ -48,7 +47,7 @@ async def _connect(self) -> AsyncIterator[TransportStreams]: client_read, client_write = client_streams server_read, server_write = server_streams - async with anyio.create_task_group() as tg: + async with create_mcp_task_group() as tg: # Start server in background tg.start_soon( lambda: actual_server.run( diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 961021264..59ddae954 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -13,7 +13,6 @@ from types import TracebackType from typing import Any, TypeAlias -import anyio import httpx from pydantic import BaseModel, Field from typing_extensions import Self @@ -25,6 +24,7 @@ from mcp.client.stdio import StdioServerParameters from mcp.client.streamable_http import streamable_http_client from mcp.shared._httpx_utils import create_mcp_http_client +from mcp.shared._task_group import create_mcp_task_group from mcp.shared.exceptions import MCPError from mcp.shared.session import ProgressFnT @@ -166,7 +166,7 @@ async def __aexit__( await self._exit_stack.aclose() # Concurrently close session stacks. - async with anyio.create_task_group() as tg: + async with create_mcp_task_group() as tg: for exit_stack in self._session_exit_stacks.values(): tg.start_soon(exit_stack.aclose) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 61026aa0c..23f591b79 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -13,6 +13,7 @@ from mcp import types from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client +from mcp.shared._task_group import create_mcp_task_group from mcp.shared.message import SessionMessage logger = logging.getLogger(__name__) @@ -60,7 +61,7 @@ async def sse_client( read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - async with anyio.create_task_group() as tg: + async with create_mcp_task_group() as tg: try: logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}") async with httpx_client_factory( diff --git a/src/mcp/client/stdio.py b/src/mcp/client/stdio.py index 902dc8576..913c3327e 100644 --- a/src/mcp/client/stdio.py +++ b/src/mcp/client/stdio.py @@ -20,6 +20,7 @@ get_windows_executable_command, terminate_windows_process_tree, ) +from mcp.shared._task_group import create_mcp_task_group from mcp.shared.message import SessionMessage logger = logging.getLogger(__name__) @@ -177,7 +178,7 @@ async def stdin_writer(): except anyio.ClosedResourceError: # pragma: no cover await anyio.lowlevel.checkpoint() - async with anyio.create_task_group() as tg, process: + async with create_mcp_task_group() as tg, process: tg.start_soon(stdout_reader) tg.start_soon(stdin_writer) try: diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 9f3dd5e0b..a2c0464a4 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -17,6 +17,7 @@ from mcp.client._transport import TransportStreams from mcp.shared._httpx_utils import create_mcp_http_client +from mcp.shared._task_group import create_mcp_task_group from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.types import ( INTERNAL_ERROR, @@ -546,7 +547,7 @@ async def streamable_http_client( transport = StreamableHTTPTransport(url) - async with anyio.create_task_group() as tg: + async with create_mcp_task_group() as tg: try: logger.debug(f"Connecting to StreamableHTTP endpoint: {url}") diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index 79e75fad1..ea6d265b8 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -9,6 +9,7 @@ from websockets.typing import Subprotocol from mcp import types +from mcp.shared._task_group import create_mcp_task_group from mcp.shared.message import SessionMessage @@ -68,7 +69,7 @@ async def ws_writer(): msg_dict = session_message.message.model_dump(by_alias=True, mode="json", exclude_unset=True) await ws.send(json.dumps(msg_dict)) - async with anyio.create_task_group() as tg: + async with create_mcp_task_group() as tg: # Start reader and writer tasks tg.start_soon(ws_reader) tg.start_soon(ws_writer) diff --git a/src/mcp/server/experimental/task_result_handler.py b/src/mcp/server/experimental/task_result_handler.py index b2268bc1c..c804fd345 100644 --- a/src/mcp/server/experimental/task_result_handler.py +++ b/src/mcp/server/experimental/task_result_handler.py @@ -12,9 +12,8 @@ import logging from typing import Any -import anyio - from mcp.server.session import ServerSession +from mcp.shared._task_group import create_mcp_task_group from mcp.shared.exceptions import MCPError from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY, is_terminal from mcp.shared.experimental.tasks.message_queue import TaskMessageQueue @@ -162,7 +161,7 @@ async def _wait_for_task_update(self, task_id: str) -> None: Races between store update and queue message - first one wins. """ - async with anyio.create_task_group() as tg: + async with create_mcp_task_group() as tg: async def wait_for_store() -> None: try: diff --git a/src/mcp/server/experimental/task_support.py b/src/mcp/server/experimental/task_support.py index b54219504..3acde331e 100644 --- a/src/mcp/server/experimental/task_support.py +++ b/src/mcp/server/experimental/task_support.py @@ -8,11 +8,11 @@ from contextlib import asynccontextmanager from dataclasses import dataclass, field -import anyio from anyio.abc import TaskGroup from mcp.server.experimental.task_result_handler import TaskResultHandler from mcp.server.session import ServerSession +from mcp.shared._task_group import create_mcp_task_group from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, TaskMessageQueue from mcp.shared.experimental.tasks.store import TaskStore @@ -79,7 +79,7 @@ async def run(self) -> AsyncIterator[None]: # Task group is now available ... """ - async with anyio.create_task_group() as tg: + async with create_mcp_task_group() as tg: self._task_group = tg try: yield diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 1c84c8610..b78f54dab 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -65,6 +65,7 @@ async def main(): from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings +from mcp.shared._task_group import create_mcp_task_group from mcp.shared.exceptions import MCPError from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder @@ -386,7 +387,7 @@ async def run( task_support.configure_session(session) await stack.enter_async_context(task_support.run()) - async with anyio.create_task_group() as tg: + async with create_mcp_task_group() as tg: async for message in session.incoming_messages: logger.debug("Received message: %s", message) diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 9dcee67f7..d4865b8a8 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -55,6 +55,7 @@ async def handle_sse(request): TransportSecurityMiddleware, TransportSecuritySettings, ) +from mcp.shared._task_group import create_mcp_task_group from mcp.shared.message import ServerMessageMetadata, SessionMessage logger = logging.getLogger(__name__) @@ -174,7 +175,7 @@ async def sse_writer(): } ) - async with anyio.create_task_group() as tg: + async with create_mcp_task_group() as tg: async def response_wrapper(scope: Scope, receive: Receive, send: Send): """The EventSourceResponse returning signals a client close / disconnect. diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index e526bab56..3288cf90d 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -26,6 +26,7 @@ async def run_server(): from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import types +from mcp.shared._task_group import create_mcp_task_group from mcp.shared.message import SessionMessage @@ -77,7 +78,7 @@ async def stdout_writer(): except anyio.ClosedResourceError: # pragma: no cover await anyio.lowlevel.checkpoint() - async with anyio.create_task_group() as tg: + async with create_mcp_task_group() as tg: tg.start_soon(stdin_reader) tg.start_soon(stdout_writer) yield read_stream, write_stream diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 04aed345e..42407df98 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -25,6 +25,7 @@ from starlette.types import Receive, Scope, Send from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings +from mcp.shared._task_group import create_mcp_task_group from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.types import ( @@ -614,7 +615,7 @@ async def sse_writer(): # pragma: lax no cover # Start the SSE response (this will send headers immediately) try: # First send the response to establish the SSE connection - async with anyio.create_task_group() as tg: + async with create_mcp_task_group() as tg: tg.start_soon(response, scope, receive, send) # Then send the message to be processed by the server session_message = self._create_session_message(message, request, request_id, protocol_version) @@ -970,7 +971,7 @@ async def connect( self._write_stream = write_stream # Start a task group for message routing - async with anyio.create_task_group() as tg: + async with create_mcp_task_group() as tg: # Create a message router that distributes messages to request streams async def message_router(): try: diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index c25314eab..5e1c2ceca 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -21,6 +21,7 @@ StreamableHTTPServerTransport, ) from mcp.server.transport_security import TransportSecuritySettings +from mcp.shared._task_group import create_mcp_task_group from mcp.types import INVALID_REQUEST, ErrorData, JSONRPCError if TYPE_CHECKING: @@ -122,7 +123,7 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]: ) self._has_started = True - async with anyio.create_task_group() as tg: + async with create_mcp_task_group() as tg: # Store the task group for later use self._task_group = tg logger.info("StreamableHTTP session manager started") diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index 3e675da5f..e45032432 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -7,6 +7,7 @@ from starlette.websockets import WebSocket from mcp import types +from mcp.shared._task_group import create_mcp_task_group from mcp.shared.message import SessionMessage @@ -52,7 +53,7 @@ async def ws_writer(): except anyio.ClosedResourceError: await websocket.close() - async with anyio.create_task_group() as tg: + async with create_mcp_task_group() as tg: tg.start_soon(ws_reader) tg.start_soon(ws_writer) yield (read_stream, write_stream) diff --git a/src/mcp/shared/_task_group.py b/src/mcp/shared/_task_group.py new file mode 100644 index 000000000..6924abb41 --- /dev/null +++ b/src/mcp/shared/_task_group.py @@ -0,0 +1,89 @@ +"""Task group wrapper that collapses single-exception ExceptionGroups. + +When an anyio task group contains tasks and one fails, the exception is +always wrapped in an ExceptionGroup — even if there is only one real +exception. This makes it impossible for callers to catch specific error +types with ``except SomeError:``. + +This module provides a drop-in replacement for ``anyio.create_task_group()`` +that automatically unwraps single-exception groups so callers receive the +original exception directly. +""" + +from __future__ import annotations + +from types import TracebackType + +import anyio +from anyio.abc import TaskGroup + + +def collapse_exception_group(exc: BaseExceptionGroup) -> BaseException: # type: ignore[type-arg] + """Unwrap nested single-exception BaseExceptionGroups. + + If the group (and any nested groups) each contain exactly one exception, + return the innermost real exception. Otherwise return *exc* unchanged. + """ + while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1: + exc = exc.exceptions[0] # type: ignore[assignment] + return exc + + +class _CollapsingTaskGroup: + """A thin wrapper around an anyio ``TaskGroup`` that collapses exceptions. + + On ``__aexit__``, if the task group raises a ``BaseExceptionGroup`` that + contains only a single exception, that inner exception is re-raised + directly so callers can ``except`` it by its concrete type. + + The wrapper delegates ``start_soon``, ``start``, and ``cancel_scope`` to + the underlying task group. + """ + + def __init__(self) -> None: + self._task_group: TaskGroup | None = None + + def _tg(self) -> TaskGroup: + if self._task_group is None: + raise RuntimeError("Task group has not been entered") + return self._task_group + + @property + def cancel_scope(self) -> anyio.CancelScope: + return self._tg().cancel_scope + + def start_soon(self, *args: object, **kwargs: object) -> None: + self._tg().start_soon(*args, **kwargs) # type: ignore[arg-type] + + async def start(self, *args: object, **kwargs: object) -> object: + return await self._tg().start(*args, **kwargs) # type: ignore[arg-type] + + async def __aenter__(self) -> _CollapsingTaskGroup: + self._task_group = anyio.create_task_group() + await self._task_group.__aenter__() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + try: + return await self._tg().__aexit__(exc_type, exc_val, exc_tb) + except BaseExceptionGroup as eg: + collapsed = collapse_exception_group(eg) + if collapsed is not eg: + raise collapsed from eg + raise + + +def create_mcp_task_group() -> _CollapsingTaskGroup: + """Create an anyio task group that collapses single-exception groups. + + Use this as a drop-in replacement for ``anyio.create_task_group()``:: + + async with create_mcp_task_group() as tg: + tg.start_soon(some_task) + """ + return _CollapsingTaskGroup() diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index b617d702f..2eb8e5118 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -11,6 +11,7 @@ from pydantic import BaseModel, TypeAdapter from typing_extensions import Self +from mcp.shared._task_group import create_mcp_task_group from mcp.shared.exceptions import MCPError from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.response_router import ResponseRouter @@ -212,7 +213,7 @@ def add_response_router(self, router: ResponseRouter) -> None: self._response_routers.append(router) async def __aenter__(self) -> Self: - self._task_group = anyio.create_task_group() + self._task_group = create_mcp_task_group() await self._task_group.__aenter__() self._task_group.start_soon(self._receive_loop) return self diff --git a/tests/shared/test_task_group.py b/tests/shared/test_task_group.py new file mode 100644 index 000000000..0defe3305 --- /dev/null +++ b/tests/shared/test_task_group.py @@ -0,0 +1,158 @@ +"""Tests for mcp.shared._task_group — collapsing ExceptionGroup wrapper.""" + +import anyio +import pytest + +from mcp.shared._task_group import ( + _CollapsingTaskGroup, + collapse_exception_group, + create_mcp_task_group, +) + +# --------------------------------------------------------------------------- +# collapse_exception_group unit tests +# --------------------------------------------------------------------------- + + +def test_collapse_single_exception() -> None: + """A group containing one exception is unwrapped.""" + inner = ConnectionError("boom") + group = ExceptionGroup("g", [inner]) + assert collapse_exception_group(group) is inner + + +def test_collapse_nested_single() -> None: + """Recursively unwraps nested single-exception groups.""" + inner = ValueError("deep") + group = ExceptionGroup("outer", [ExceptionGroup("inner", [inner])]) + assert collapse_exception_group(group) is inner + + +def test_collapse_multiple_exceptions_unchanged() -> None: + """Groups with >1 exception are returned unchanged.""" + exc_a = TypeError("a") + exc_b = RuntimeError("b") + group = ExceptionGroup("g", [exc_a, exc_b]) + assert collapse_exception_group(group) is group + + +def test_collapse_base_exception_group() -> None: + """Works with BaseExceptionGroup (e.g. containing KeyboardInterrupt).""" + inner = KeyboardInterrupt() + group = BaseExceptionGroup("g", [inner]) + assert collapse_exception_group(group) is inner + + +# --------------------------------------------------------------------------- +# _CollapsingTaskGroup integration tests +# --------------------------------------------------------------------------- + + +@pytest.mark.anyio +async def test_single_task_failure_is_unwrapped() -> None: + """A single failing task raises its exception directly, not wrapped.""" + with pytest.raises(ConnectionError, match="server down"): + async with create_mcp_task_group() as tg: + + async def failing() -> None: + raise ConnectionError("server down") + + tg.start_soon(failing) + + +@pytest.mark.anyio +async def test_single_task_failure_with_cancelled_sibling() -> None: + """When one task fails and another is cancelled, the real error surfaces.""" + with pytest.raises(ConnectionError, match="oops"): + async with create_mcp_task_group() as tg: + + async def failing() -> None: + raise ConnectionError("oops") + + async def long_running() -> None: + await anyio.sleep(999) + + tg.start_soon(failing) + tg.start_soon(long_running) + + +@pytest.mark.anyio +async def test_multiple_failures_stay_grouped() -> None: + """When multiple tasks fail, an ExceptionGroup is raised.""" + with pytest.raises(BaseExceptionGroup): + async with create_mcp_task_group() as tg: + ready = anyio.Event() + + async def fail_a() -> None: + await ready.wait() + raise ConnectionError("a") + + async def fail_b() -> None: + ready.set() + raise ValueError("b") + + tg.start_soon(fail_a) + tg.start_soon(fail_b) + + +@pytest.mark.anyio +async def test_no_failure_passes_cleanly() -> None: + """Normal execution does not raise.""" + results: list[int] = [] + async with create_mcp_task_group() as tg: + + async def worker(n: int) -> None: + results.append(n) + + tg.start_soon(worker, 1) + tg.start_soon(worker, 2) + + assert sorted(results) == [1, 2] + + +@pytest.mark.anyio +async def test_cancel_scope_is_delegated() -> None: + """cancel_scope is accessible and works.""" + async with create_mcp_task_group() as tg: + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_start_delegates_to_task_group() -> None: + """start() delegates to the underlying task group.""" + + async def task_with_status(*, task_status: anyio.abc.TaskStatus[str] = anyio.TASK_STATUS_IGNORED) -> None: + task_status.started("ready") + await anyio.sleep(999) + + async with create_mcp_task_group() as tg: + result = await tg.start(task_with_status) + assert result == "ready" + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_task_group_not_entered_raises() -> None: + """Accessing methods before __aenter__ raises RuntimeError.""" + ctg = _CollapsingTaskGroup() + with pytest.raises(RuntimeError, match="not been entered"): + ctg.cancel_scope + with pytest.raises(RuntimeError, match="not been entered"): + ctg.start_soon(lambda: None) + + +@pytest.mark.anyio +async def test_collapsed_exception_preserves_cause_chain() -> None: + """The collapsed exception has the original ExceptionGroup as __cause__.""" + try: + async with create_mcp_task_group() as tg: + + async def failing() -> None: + raise RuntimeError("root cause") + + tg.start_soon(failing) + except RuntimeError as exc: + assert isinstance(exc.__cause__, BaseExceptionGroup) + assert str(exc) == "root cause" + else: + pytest.fail("Expected RuntimeError") From 4943417b65ea5fae8874e3d3d1d1a00221c116be Mon Sep 17 00:00:00 2001 From: Varun Sharma Date: Sat, 7 Mar 2026 16:06:33 +0530 Subject: [PATCH 2/5] fix: add Python 3.10 compat for BaseExceptionGroup import On Python < 3.11, BaseExceptionGroup and ExceptionGroup are not builtins. Import them from the exceptiongroup backport package conditionally. Also fix coverage miss on unreachable pytest.fail line and revert unnecessary builtins ruff config. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyproject.toml | 1 - src/mcp/shared/_task_group.py | 4 ++++ tests/shared/test_task_group.py | 17 ++++++++++------- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2766017c4..737839a23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,7 +126,6 @@ executionEnvironments = [ line-length = 120 target-version = "py310" extend-exclude = ["README.md", "README.v2.md"] -builtins = ["BaseExceptionGroup", "ExceptionGroup"] [tool.ruff.lint] select = [ diff --git a/src/mcp/shared/_task_group.py b/src/mcp/shared/_task_group.py index 6924abb41..e4d1ea20a 100644 --- a/src/mcp/shared/_task_group.py +++ b/src/mcp/shared/_task_group.py @@ -12,11 +12,15 @@ from __future__ import annotations +import sys from types import TracebackType import anyio from anyio.abc import TaskGroup +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup + def collapse_exception_group(exc: BaseExceptionGroup) -> BaseException: # type: ignore[type-arg] """Unwrap nested single-exception BaseExceptionGroups. diff --git a/tests/shared/test_task_group.py b/tests/shared/test_task_group.py index 0defe3305..ad169bef4 100644 --- a/tests/shared/test_task_group.py +++ b/tests/shared/test_task_group.py @@ -1,7 +1,10 @@ """Tests for mcp.shared._task_group — collapsing ExceptionGroup wrapper.""" +import sys + import anyio import pytest +from anyio.abc import TaskStatus from mcp.shared._task_group import ( _CollapsingTaskGroup, @@ -9,6 +12,9 @@ create_mcp_task_group, ) +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup, ExceptionGroup + # --------------------------------------------------------------------------- # collapse_exception_group unit tests # --------------------------------------------------------------------------- @@ -121,7 +127,7 @@ async def test_cancel_scope_is_delegated() -> None: async def test_start_delegates_to_task_group() -> None: """start() delegates to the underlying task group.""" - async def task_with_status(*, task_status: anyio.abc.TaskStatus[str] = anyio.TASK_STATUS_IGNORED) -> None: + async def task_with_status(*, task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED) -> None: task_status.started("ready") await anyio.sleep(999) @@ -144,15 +150,12 @@ async def test_task_group_not_entered_raises() -> None: @pytest.mark.anyio async def test_collapsed_exception_preserves_cause_chain() -> None: """The collapsed exception has the original ExceptionGroup as __cause__.""" - try: + with pytest.raises(RuntimeError, match="root cause") as exc_info: async with create_mcp_task_group() as tg: async def failing() -> None: raise RuntimeError("root cause") tg.start_soon(failing) - except RuntimeError as exc: - assert isinstance(exc.__cause__, BaseExceptionGroup) - assert str(exc) == "root cause" - else: - pytest.fail("Expected RuntimeError") + + assert isinstance(exc_info.value.__cause__, BaseExceptionGroup) From dc630221b7c9661be0e419be4d61987a98ec1a47 Mon Sep 17 00:00:00 2001 From: Varun Sharma Date: Sat, 7 Mar 2026 16:14:30 +0530 Subject: [PATCH 3/5] fix: resolve pyright and coverage CI failures - Add pragma no cover to version-conditional imports (Python 3.10 compat) - Add type: ignore for pyright false positives (isinstance in loop, duck-type assignment) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/mcp/server/experimental/task_support.py | 2 +- src/mcp/shared/_task_group.py | 4 ++-- tests/shared/test_task_group.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mcp/server/experimental/task_support.py b/src/mcp/server/experimental/task_support.py index 3acde331e..a71d7f774 100644 --- a/src/mcp/server/experimental/task_support.py +++ b/src/mcp/server/experimental/task_support.py @@ -80,7 +80,7 @@ async def run(self) -> AsyncIterator[None]: ... """ async with create_mcp_task_group() as tg: - self._task_group = tg + self._task_group = tg # type: ignore[assignment] try: yield finally: diff --git a/src/mcp/shared/_task_group.py b/src/mcp/shared/_task_group.py index e4d1ea20a..f2b0f03ac 100644 --- a/src/mcp/shared/_task_group.py +++ b/src/mcp/shared/_task_group.py @@ -18,7 +18,7 @@ import anyio from anyio.abc import TaskGroup -if sys.version_info < (3, 11): +if sys.version_info < (3, 11): # pragma: no cover from exceptiongroup import BaseExceptionGroup @@ -28,7 +28,7 @@ def collapse_exception_group(exc: BaseExceptionGroup) -> BaseException: # type: If the group (and any nested groups) each contain exactly one exception, return the innermost real exception. Otherwise return *exc* unchanged. """ - while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1: + while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1: # type: ignore[reportUnnecessaryIsInstance] exc = exc.exceptions[0] # type: ignore[assignment] return exc diff --git a/tests/shared/test_task_group.py b/tests/shared/test_task_group.py index ad169bef4..e6095cd50 100644 --- a/tests/shared/test_task_group.py +++ b/tests/shared/test_task_group.py @@ -12,7 +12,7 @@ create_mcp_task_group, ) -if sys.version_info < (3, 11): +if sys.version_info < (3, 11): # pragma: no cover from exceptiongroup import BaseExceptionGroup, ExceptionGroup # --------------------------------------------------------------------------- From d549d59dad50ce01e11a0c7c71545fd66e752ad6 Mon Sep 17 00:00:00 2001 From: Varun Sharma Date: Sat, 7 Mar 2026 16:21:06 +0530 Subject: [PATCH 4/5] fix: use pragma lax no cover for version-conditional imports, add regression tests - Replace 'pragma: no cover' with 'pragma: lax no cover' (matches repo convention for lines that may/may not be covered per environment) - Add test_start_failure_is_unwrapped for start() error path - Add test_issue_2114_except_specific_error_type regression test Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Reported-by: maxisbey --- src/mcp/shared/_task_group.py | 4 +-- tests/shared/test_task_group.py | 49 +++++++++++++++++++++++++++++++-- 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/src/mcp/shared/_task_group.py b/src/mcp/shared/_task_group.py index f2b0f03ac..5f53a8e47 100644 --- a/src/mcp/shared/_task_group.py +++ b/src/mcp/shared/_task_group.py @@ -18,8 +18,8 @@ import anyio from anyio.abc import TaskGroup -if sys.version_info < (3, 11): # pragma: no cover - from exceptiongroup import BaseExceptionGroup +if sys.version_info < (3, 11): # pragma: lax no cover + from exceptiongroup import BaseExceptionGroup # pragma: lax no cover def collapse_exception_group(exc: BaseExceptionGroup) -> BaseException: # type: ignore[type-arg] diff --git a/tests/shared/test_task_group.py b/tests/shared/test_task_group.py index e6095cd50..00eac4beb 100644 --- a/tests/shared/test_task_group.py +++ b/tests/shared/test_task_group.py @@ -12,8 +12,8 @@ create_mcp_task_group, ) -if sys.version_info < (3, 11): # pragma: no cover - from exceptiongroup import BaseExceptionGroup, ExceptionGroup +if sys.version_info < (3, 11): # pragma: lax no cover + from exceptiongroup import BaseExceptionGroup, ExceptionGroup # pragma: lax no cover # --------------------------------------------------------------------------- # collapse_exception_group unit tests @@ -159,3 +159,48 @@ async def failing() -> None: tg.start_soon(failing) assert isinstance(exc_info.value.__cause__, BaseExceptionGroup) + + +@pytest.mark.anyio +async def test_start_failure_is_unwrapped() -> None: + """An exception from start() is also unwrapped.""" + + async def fail_on_start(*, task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED) -> None: + raise ConnectionError("startup failed") + + with pytest.raises(ConnectionError, match="startup failed"): + async with create_mcp_task_group() as tg: + await tg.start(fail_on_start) + + +# --------------------------------------------------------------------------- +# Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/2114 +# --------------------------------------------------------------------------- + + +@pytest.mark.anyio +async def test_issue_2114_except_specific_error_type() -> None: + """Callers can catch specific exception types without ExceptionGroup wrapping. + + Before the fix, anyio task groups always wrapped exceptions in + ExceptionGroup, making ``except ConnectionError:`` impossible. + """ + caught: BaseException | None = None + try: + async with create_mcp_task_group() as tg: + + async def background() -> None: + await anyio.sleep(999) + + async def connect() -> None: + raise ConnectionError("connection refused") + + tg.start_soon(background) + tg.start_soon(connect) + + except ConnectionError as exc: + caught = exc + + assert caught is not None + assert str(caught) == "connection refused" + assert isinstance(caught.__cause__, BaseExceptionGroup) From 851f41ad3a87cb86d423f721284d71b2790f617b Mon Sep 17 00:00:00 2001 From: Varun Sharma Date: Sat, 7 Mar 2026 16:30:19 +0530 Subject: [PATCH 5/5] test: add start() not-entered guard to completeness test Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/shared/test_task_group.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/shared/test_task_group.py b/tests/shared/test_task_group.py index 00eac4beb..493af562f 100644 --- a/tests/shared/test_task_group.py +++ b/tests/shared/test_task_group.py @@ -145,6 +145,8 @@ async def test_task_group_not_entered_raises() -> None: ctg.cancel_scope with pytest.raises(RuntimeError, match="not been entered"): ctg.start_soon(lambda: None) + with pytest.raises(RuntimeError, match="not been entered"): + await ctg.start(lambda: None) @pytest.mark.anyio