From 412f15238871ef0812beaf68d9f230d5f7e2f9b9 Mon Sep 17 00:00:00 2001 From: AnnasMazhar Date: Sat, 9 May 2026 23:45:47 +0100 Subject: [PATCH] fix(bedrock): consume orphaned task exception on stream cancellation When a consumer cancels, breaks from, or times out on BedrockModel.stream, the internal asyncio.Task wrapping asyncio.to_thread is never awaited. If boto3 eventually raises, asyncio emits 'Task exception was never retrieved'. Add a done-callback on the unhappy path that retrieves the exception, silencing the warning without interrupting the background thread. Resolves: #2266 --- src/strands/models/bedrock.py | 25 ++++--- tests/strands/models/test_bedrock.py | 101 +++++++++++++++++++++++++++ 2 files changed, 118 insertions(+), 8 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index c74a63a3b..4a086b566 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -64,6 +64,12 @@ def _clear_unsupported_count_tokens_cache() -> None: _UNSUPPORTED_COUNT_TOKENS_MODELS.clear() +def _suppress_task_exception(task: "asyncio.Task[None]") -> None: + """Consume exception from orphaned stream task to silence 'never retrieved' warning.""" + if not task.cancelled(): + task.exception() + + T = TypeVar("T", bound=BaseModel) DEFAULT_READ_TIMEOUT = 120 @@ -898,14 +904,17 @@ def callback(event: StreamEvent | None = None) -> None: thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt_content, tool_choice) task = asyncio.create_task(thread) - while True: - event = await queue.get() - if event is None: - break - - yield event - - await task + try: + while True: + event = await queue.get() + if event is None: + break + + yield event + await task + except BaseException: + task.add_done_callback(_suppress_task_exception) + raise def _stream( self, diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 2f1f7d1f1..81e999003 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -1,3 +1,4 @@ +import asyncio import copy import logging import os @@ -20,6 +21,7 @@ DEFAULT_BEDROCK_REGION, DEFAULT_READ_TIMEOUT, _clear_unsupported_count_tokens_cache, + _suppress_task_exception, ) from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException from strands.types.tools import ToolSpec @@ -3495,3 +3497,102 @@ async def test_skip_native_api_when_use_native_token_count_false(self, bedrock_c bedrock_client.count_tokens.assert_not_called() assert isinstance(result, int) assert result >= 0 + + +@pytest.mark.asyncio +async def test_suppress_task_exception(bedrock_client, model, messages): + """_suppress_task_exception consumes exception from a failed task without re-raising.""" + + async def fail() -> None: + raise RuntimeError("inner task failure") + + task = asyncio.create_task(fail()) + await asyncio.sleep(0) # let the task complete with exception + + assert task.done() + assert task.exception() is not None + + # Calling the helper should not raise — it simply retrieves the exception + _suppress_task_exception(task) + + +@pytest.mark.asyncio +async def test_suppress_task_exception_skips_cancelled(): + """_suppress_task_exception is a no-op for cancelled tasks.""" + + async def hang() -> None: + await asyncio.sleep(999) + + task = asyncio.create_task(hang()) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + # Should not raise — cancelled tasks are skipped + _suppress_task_exception(task) + + +@pytest.mark.asyncio +async def test_stream_break_does_not_leak_task_exception(bedrock_client, model, messages, caplog, alist): + """Breaking from an async-for on BedrockModel.stream must not leak the inner task's exception.""" + caplog.set_level(logging.WARNING, logger="asyncio") + + # Mock converse_stream to yield one event then raise — simulates e.g. ReadTimeoutError + # in the boto3 thread *after* the consumer has disconnected. + + def stream_with_error(): + yield {"messageStart": {"role": "assistant"}} + raise RuntimeError("simulated boto3 timeout after consumer disconnect") + + bedrock_client.converse_stream.return_value = {"stream": stream_with_error()} + + stream = model.stream(messages) + collected: list = [] + async for event in stream: + collected.append(event) + break # disconnect before the generator raises + + # Let the event loop process the done-callback and the thread task + await asyncio.sleep(0.01) + + # Verify we got the event before breaking + assert len(collected) == 1 + + # The critical assertion: no "Task exception was never retrieved" warning + assert "Task exception was never retrieved" not in caplog.text + # Also ensure no exception propagates to consumer + assert "exception was never retrieved" not in caplog.text.lower() + + +@pytest.mark.asyncio +async def test_stream_timeout_cancellation_does_not_leak( + bedrock_client, + model, + messages, + caplog, +): + """Applying asyncio.wait_for on BedrockModel.stream must not leak the inner task's exception.""" + caplog.set_level(logging.WARNING, logger="asyncio") + + # Make converse_stream yield slowly so wait_for fires first + import time + + def slow_stream(): + time.sleep(0.05) # simulate a slow network call + yield {"messageStart": {"role": "assistant"}} + time.sleep(0.05) + raise RuntimeError("boto3 timeout after consumer disconnected") + + bedrock_client.converse_stream.return_value = {"stream": slow_stream()} + + stream = model.stream(messages) + with pytest.raises(TimeoutError): + # Very short timeout — fires before the slow stream finishes + await asyncio.wait_for(stream.__anext__(), timeout=0.001) + + # Let event loop settle + await asyncio.sleep(0.01) + + # Critical: no orphaned-task warning + assert "Task exception was never retrieved" not in caplog.text + assert "exception was never retrieved" not in caplog.text.lower()