|
8 | 8 | from typing import TYPE_CHECKING, Any |
9 | 9 |
|
10 | 10 | import anyio |
| 11 | +from anyio.abc import CancelScope |
11 | 12 | from mcp.types import ( |
12 | 13 | CallToolRequest, |
13 | 14 | CallToolRequestParams, |
@@ -113,6 +114,15 @@ def __init__( |
113 | 114 | float(os.environ.get("CLAUDE_CODE_STREAM_CLOSE_TIMEOUT", "60000")) / 1000.0 |
114 | 115 | ) # Convert ms to seconds |
115 | 116 |
|
| 117 | + # Cancel scope for the reader task - can be cancelled from any task context |
| 118 | + # This fixes the RuntimeError when async generator cleanup happens in a different task |
| 119 | + self._reader_cancel_scope: CancelScope | None = None |
| 120 | + self._reader_task_started = anyio.Event() |
| 121 | + |
| 122 | + # Track whether we entered the task group in this task |
| 123 | + # Used to determine if we can safely call __aexit__() |
| 124 | + self._tg_entered_in_current_task = False |
| 125 | + |
116 | 126 | async def initialize(self) -> dict[str, Any] | None: |
117 | 127 | """Initialize control protocol if in streaming mode. |
118 | 128 |
|
@@ -158,11 +168,33 @@ async def initialize(self) -> dict[str, Any] | None: |
158 | 168 | return response |
159 | 169 |
|
160 | 170 | async def start(self) -> None: |
161 | | - """Start reading messages from transport.""" |
| 171 | + """Start reading messages from transport. |
| 172 | +
|
| 173 | + This method starts background tasks for reading messages. The task lifecycle |
| 174 | + is managed using a CancelScope that can be safely cancelled from any async |
| 175 | + task context, avoiding the RuntimeError that occurs when task group |
| 176 | + __aexit__() is called from a different task than __aenter__(). |
| 177 | + """ |
162 | 178 | if self._tg is None: |
| 179 | + # Create a task group for spawning background tasks |
163 | 180 | self._tg = anyio.create_task_group() |
164 | 181 | await self._tg.__aenter__() |
165 | | - self._tg.start_soon(self._read_messages) |
| 182 | + self._tg_entered_in_current_task = True |
| 183 | + |
| 184 | + # Start the reader with its own cancel scope that can be cancelled safely |
| 185 | + self._tg.start_soon(self._read_messages_with_cancel_scope) |
| 186 | + |
| 187 | + async def _read_messages_with_cancel_scope(self) -> None: |
| 188 | + """Wrapper for _read_messages that sets up a cancellable scope. |
| 189 | +
|
| 190 | + This wrapper creates a CancelScope that can be cancelled from any task |
| 191 | + context, solving the issue where async generator cleanup happens in a |
| 192 | + different task than where the task group was entered. |
| 193 | + """ |
| 194 | + self._reader_cancel_scope = anyio.CancelScope() |
| 195 | + self._reader_task_started.set() |
| 196 | + with self._reader_cancel_scope: |
| 197 | + await self._read_messages() |
166 | 198 |
|
167 | 199 | async def _read_messages(self) -> None: |
168 | 200 | """Read messages from transport and route them.""" |
@@ -604,15 +636,66 @@ async def receive_messages(self) -> AsyncIterator[dict[str, Any]]: |
604 | 636 | yield message |
605 | 637 |
|
606 | 638 | async def close(self) -> None: |
607 | | - """Close the query and transport.""" |
| 639 | + """Close the query and transport. |
| 640 | +
|
| 641 | + This method safely cleans up resources, handling the case where cleanup |
| 642 | + happens in a different async task context than where start() was called. |
| 643 | + This commonly occurs during async generator cleanup (e.g., when breaking |
| 644 | + out of an `async for` loop or when asyncio.run() shuts down). |
| 645 | +
|
| 646 | + The fix uses two mechanisms: |
| 647 | + 1. A CancelScope for the reader task that can be cancelled from any context |
| 648 | + 2. Suppressing the RuntimeError that occurs when task group __aexit__() |
| 649 | + is called from a different task than __aenter__() |
| 650 | + """ |
| 651 | + if self._closed: |
| 652 | + return |
608 | 653 | self._closed = True |
609 | | - if self._tg: |
| 654 | + |
| 655 | + # Cancel the reader task via its cancel scope (safe from any task context) |
| 656 | + if self._reader_cancel_scope is not None: |
| 657 | + self._reader_cancel_scope.cancel() |
| 658 | + |
| 659 | + # Handle task group cleanup |
| 660 | + if self._tg is not None: |
| 661 | + # Always cancel the task group's scope to stop any running tasks |
610 | 662 | self._tg.cancel_scope.cancel() |
611 | | - # Wait for task group to complete cancellation |
612 | | - with suppress(anyio.get_cancelled_exc_class()): |
613 | | - await self._tg.__aexit__(None, None, None) |
| 663 | + |
| 664 | + # Try to properly exit the task group, but handle the case where |
| 665 | + # we're in a different task context than where __aenter__() was called |
| 666 | + try: |
| 667 | + with suppress(anyio.get_cancelled_exc_class()): |
| 668 | + await self._tg.__aexit__(None, None, None) |
| 669 | + except RuntimeError as e: |
| 670 | + # Handle "Attempted to exit cancel scope in a different task" |
| 671 | + # This happens during async generator cleanup when Python's GC |
| 672 | + # runs the finally block in a different task context. |
| 673 | + if "different task" in str(e): |
| 674 | + logger.debug( |
| 675 | + "Task group cleanup skipped due to cross-task context " |
| 676 | + "(this is expected during async generator cleanup)" |
| 677 | + ) |
| 678 | + else: |
| 679 | + raise |
| 680 | + finally: |
| 681 | + self._tg = None |
| 682 | + self._tg_entered_in_current_task = False |
| 683 | + |
614 | 684 | await self.transport.close() |
615 | 685 |
|
| 686 | + # Make Query an async context manager |
| 687 | + async def __aenter__(self) -> "Query": |
| 688 | + """Enter async context - starts reading messages.""" |
| 689 | + await self.start() |
| 690 | + return self |
| 691 | + |
| 692 | + async def __aexit__( |
| 693 | + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any |
| 694 | + ) -> bool: |
| 695 | + """Exit async context - closes the query.""" |
| 696 | + await self.close() |
| 697 | + return False |
| 698 | + |
616 | 699 | # Make Query an async iterator |
617 | 700 | def __aiter__(self) -> AsyncIterator[dict[str, Any]]: |
618 | 701 | """Return async iterator for messages.""" |
|
0 commit comments