Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ def _clear_unsupported_count_tokens_cache() -> None:
_UNSUPPORTED_COUNT_TOKENS_MODELS.clear()


def _suppress_task_exception(task: "asyncio.Task[None]") -> None:
"""Consume exception from orphaned stream task to silence 'never retrieved' warning."""
if not task.cancelled():
task.exception()


T = TypeVar("T", bound=BaseModel)

DEFAULT_READ_TIMEOUT = 120
Expand Down Expand Up @@ -898,14 +904,17 @@ def callback(event: StreamEvent | None = None) -> None:
thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt_content, tool_choice)
task = asyncio.create_task(thread)

while True:
event = await queue.get()
if event is None:
break

yield event

await task
try:
while True:
event = await queue.get()
if event is None:
break

yield event
await task
except BaseException:
task.add_done_callback(_suppress_task_exception)
raise

def _stream(
self,
Expand Down
101 changes: 101 additions & 0 deletions tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import copy
import logging
import os
Expand All @@ -20,6 +21,7 @@
DEFAULT_BEDROCK_REGION,
DEFAULT_READ_TIMEOUT,
_clear_unsupported_count_tokens_cache,
_suppress_task_exception,
)
from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException
from strands.types.tools import ToolSpec
Expand Down Expand Up @@ -3495,3 +3497,102 @@ async def test_skip_native_api_when_use_native_token_count_false(self, bedrock_c
bedrock_client.count_tokens.assert_not_called()
assert isinstance(result, int)
assert result >= 0


@pytest.mark.asyncio
async def test_suppress_task_exception(bedrock_client, model, messages):
"""_suppress_task_exception consumes exception from a failed task without re-raising."""

async def fail() -> None:
raise RuntimeError("inner task failure")

task = asyncio.create_task(fail())
await asyncio.sleep(0) # let the task complete with exception

assert task.done()
assert task.exception() is not None

# Calling the helper should not raise — it simply retrieves the exception
_suppress_task_exception(task)


@pytest.mark.asyncio
async def test_suppress_task_exception_skips_cancelled():
"""_suppress_task_exception is a no-op for cancelled tasks."""

async def hang() -> None:
await asyncio.sleep(999)

task = asyncio.create_task(hang())
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task

# Should not raise — cancelled tasks are skipped
_suppress_task_exception(task)


@pytest.mark.asyncio
async def test_stream_break_does_not_leak_task_exception(bedrock_client, model, messages, caplog, alist):
"""Breaking from an async-for on BedrockModel.stream must not leak the inner task's exception."""
caplog.set_level(logging.WARNING, logger="asyncio")

# Mock converse_stream to yield one event then raise — simulates e.g. ReadTimeoutError
# in the boto3 thread *after* the consumer has disconnected.

def stream_with_error():
yield {"messageStart": {"role": "assistant"}}
raise RuntimeError("simulated boto3 timeout after consumer disconnect")

bedrock_client.converse_stream.return_value = {"stream": stream_with_error()}

stream = model.stream(messages)
collected: list = []
async for event in stream:
collected.append(event)
break # disconnect before the generator raises

# Let the event loop process the done-callback and the thread task
await asyncio.sleep(0.01)

# Verify we got the event before breaking
assert len(collected) == 1

# The critical assertion: no "Task exception was never retrieved" warning
assert "Task exception was never retrieved" not in caplog.text
# Also ensure no exception propagates to consumer
assert "exception was never retrieved" not in caplog.text.lower()


@pytest.mark.asyncio
async def test_stream_timeout_cancellation_does_not_leak(
bedrock_client,
model,
messages,
caplog,
):
"""Applying asyncio.wait_for on BedrockModel.stream must not leak the inner task's exception."""
caplog.set_level(logging.WARNING, logger="asyncio")

# Make converse_stream yield slowly so wait_for fires first
import time

def slow_stream():
time.sleep(0.05) # simulate a slow network call
yield {"messageStart": {"role": "assistant"}}
time.sleep(0.05)
raise RuntimeError("boto3 timeout after consumer disconnected")

bedrock_client.converse_stream.return_value = {"stream": slow_stream()}

stream = model.stream(messages)
with pytest.raises(TimeoutError):
# Very short timeout — fires before the slow stream finishes
await asyncio.wait_for(stream.__anext__(), timeout=0.001)

# Let event loop settle
await asyncio.sleep(0.01)

# Critical: no orphaned-task warning
assert "Task exception was never retrieved" not in caplog.text
assert "exception was never retrieved" not in caplog.text.lower()