Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,15 +1064,36 @@ async def structured_output(
"""
tool_spec = convert_pydantic_to_tool_spec(output_model)

response = self.stream(
messages=prompt,
tool_specs=[tool_spec],
system_prompt=system_prompt,
tool_choice=cast(ToolChoice, {"any": {}}),
**kwargs,
)
async for event in streaming.process_stream(response):
yield event
tool_choice: ToolChoice = cast(ToolChoice, {"any": {}})
try:
response = self.stream(
messages=prompt,
tool_specs=[tool_spec],
system_prompt=system_prompt,
tool_choice=tool_choice,
**kwargs,
)
async for event in streaming.process_stream(response):
yield event
except ClientError as e:
error_message = str(e)
if "toolChoice.any" not in error_message and "toolChoice" not in error_message:
raise

logger.debug(
"model_id=<%s> | toolChoice.any not supported, falling back to toolChoice.auto",
self.config.get("model_id"),
)
tool_choice = cast(ToolChoice, {"auto": {}})
response = self.stream(
messages=prompt,
tool_specs=[tool_spec],
system_prompt=system_prompt,
tool_choice=tool_choice,
**kwargs,
)
async for event in streaming.process_stream(response):
yield event

stop_reason, messages, _, _ = event["stop"]

Expand Down
40 changes: 39 additions & 1 deletion tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -1419,7 +1419,45 @@ async def test_structured_output(bedrock_client, model, test_output_model_cls, a
assert tru_output == exp_output


@pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)")
@pytest.mark.asyncio
async def test_structured_output_fallback_tool_choice_auto(bedrock_client, model, test_output_model_cls, alist):
"""When toolChoice.any is not supported, structured_output falls back to toolChoice.auto."""
messages = [{"role": "user", "content": [{"text": "Generate a person"}]}]

success_stream = {
"stream": [
{"messageStart": {"role": "assistant"}},
{"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "TestOutputModel"}}}},
{"contentBlockDelta": {"delta": {"toolUse": {"input": '{"name": "John", "age": 30}'}}}},
{"contentBlockStop": {}},
{"messageStop": {"stopReason": "tool_use"}},
]
}

error_response = {
"Error": {
"Code": "ValidationException",
"Message": "This model doesn't support the toolConfig.toolChoice.any field. "
"Remove toolConfig.toolChoice.any and try again",
}
}

bedrock_client.converse_stream.side_effect = [
ClientError(error_response, "ConverseStream"),
success_stream,
]

stream = model.structured_output(test_output_model_cls, messages)
events = await alist(stream)

assert events[-1] == {"output": test_output_model_cls(name="John", age=30)}
assert bedrock_client.converse_stream.call_count == 2

first_call = bedrock_client.converse_stream.call_args_list[0]
assert first_call.kwargs["toolConfig"]["toolChoice"] == {"any": {}}

second_call = bedrock_client.converse_stream.call_args_list[1]
assert second_call.kwargs["toolConfig"]["toolChoice"] == {"auto": {}}
@pytest.mark.asyncio
async def test_add_note_on_client_error(bedrock_client, model, alist, messages):
"""Test that add_note is called on ClientError with region and model ID information."""
Expand Down