diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index bab4031ed..d3467cbc1 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -1064,15 +1064,36 @@ async def structured_output( """ tool_spec = convert_pydantic_to_tool_spec(output_model) - response = self.stream( - messages=prompt, - tool_specs=[tool_spec], - system_prompt=system_prompt, - tool_choice=cast(ToolChoice, {"any": {}}), - **kwargs, - ) - async for event in streaming.process_stream(response): - yield event + tool_choice: ToolChoice = cast(ToolChoice, {"any": {}}) + try: + response = self.stream( + messages=prompt, + tool_specs=[tool_spec], + system_prompt=system_prompt, + tool_choice=tool_choice, + **kwargs, + ) + async for event in streaming.process_stream(response): + yield event + except ClientError as e: + error_message = str(e) + if "toolChoice.any" not in error_message and "toolChoice" not in error_message: + raise + + logger.debug( + "model_id=<%s> | toolChoice.any not supported, falling back to toolChoice.auto", + self.config.get("model_id"), + ) + tool_choice = cast(ToolChoice, {"auto": {}}) + response = self.stream( + messages=prompt, + tool_specs=[tool_spec], + system_prompt=system_prompt, + tool_choice=tool_choice, + **kwargs, + ) + async for event in streaming.process_stream(response): + yield event stop_reason, messages, _, _ = event["stop"] diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 89c4df70d..f3d73d71f 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -1419,7 +1419,45 @@ async def test_structured_output(bedrock_client, model, test_output_model_cls, a assert tru_output == exp_output -@pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)") +@pytest.mark.asyncio +async def test_structured_output_fallback_tool_choice_auto(bedrock_client, model, test_output_model_cls, alist): + """When toolChoice.any is not supported, structured_output falls back to toolChoice.auto.""" + messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] + + success_stream = { + "stream": [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "TestOutputModel"}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"name": "John", "age": 30}'}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + ] + } + + error_response = { + "Error": { + "Code": "ValidationException", + "Message": "This model doesn't support the toolConfig.toolChoice.any field. " + "Remove toolConfig.toolChoice.any and try again", + } + } + + bedrock_client.converse_stream.side_effect = [ + ClientError(error_response, "ConverseStream"), + success_stream, + ] + + stream = model.structured_output(test_output_model_cls, messages) + events = await alist(stream) + + assert events[-1] == {"output": test_output_model_cls(name="John", age=30)} + assert bedrock_client.converse_stream.call_count == 2 + + first_call = bedrock_client.converse_stream.call_args_list[0] + assert first_call.kwargs["toolConfig"]["toolChoice"] == {"any": {}} + + second_call = bedrock_client.converse_stream.call_args_list[1] + assert second_call.kwargs["toolConfig"]["toolChoice"] == {"auto": {}} @pytest.mark.asyncio async def test_add_note_on_client_error(bedrock_client, model, alist, messages): """Test that add_note is called on ClientError with region and model ID information."""