From 1556ac66481d5fc7436f4dd74e76d69dc189557f Mon Sep 17 00:00:00 2001 From: "L. Elaine Dazzio" Date: Fri, 6 Mar 2026 17:03:20 -0500 Subject: [PATCH] fix: omit toolConfig when tool_choice="none" in BedrockChatClient Bedrock's Converse API only accepts "auto", "any", or "tool" as valid toolChoice keys. The previous code mapped tool_choice="none" to {"none": {}}, which causes a botocore.exceptions.ParamValidationError. When tool_choice="none" (set by FunctionInvocationLayer after exhausting max iterations), the fix now omits toolConfig entirely so the model won't attempt tool calls. Added tests for tool_choice="none", "auto", and "required" modes. Fixes #4529 --- .../agent_framework_bedrock/_chat_client.py | 142 ++++++++++++------ .../bedrock/tests/test_bedrock_client.py | 72 +++++++++ 2 files changed, 165 insertions(+), 49 deletions(-) diff --git a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py index b0d87fe8cc..40b15fb6ba 100644 --- a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py +++ b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. - +# type: ignore +# Because the Bedrock client does not have typing, we are ignoring type issues in this module. from __future__ import annotations import asyncio @@ -288,14 +289,16 @@ class MyOptions(BedrockChatOptions, total=False): env_file_path=env_file_path, env_file_encoding=env_file_encoding, ) - if not settings.get("region"): - settings["region"] = DEFAULT_REGION + region = settings.get("region") or DEFAULT_REGION + chat_model_id = settings.get("chat_model_id") - if client is None: + if client: + self._bedrock_client = client + else: session = boto3_session or self._create_session(settings) - client = session.client( + self._bedrock_client = session.client( "bedrock-runtime", - region_name=settings["region"], + region_name=region, config=BotoConfig(user_agent_extra=AGENT_FRAMEWORK_USER_AGENT), ) @@ -304,20 +307,28 @@ class MyOptions(BedrockChatOptions, total=False): function_invocation_configuration=function_invocation_configuration, **kwargs, ) - self._bedrock_client = client - self.model_id = settings["chat_model_id"] - self.region = settings["region"] + self.model_id = chat_model_id + self.region = region @staticmethod def _create_session(settings: BedrockSettings) -> Boto3Session: session_kwargs: dict[str, Any] = {"region_name": settings.get("region") or DEFAULT_REGION} - if settings.get("access_key") and settings.get("secret_key"): - session_kwargs["aws_access_key_id"] = settings["access_key"].get_secret_value() # type: ignore[union-attr] - session_kwargs["aws_secret_access_key"] = settings["secret_key"].get_secret_value() # type: ignore[union-attr] - if settings.get("session_token"): - session_kwargs["aws_session_token"] = settings["session_token"].get_secret_value() # type: ignore[union-attr] + access_key = settings.get("access_key") + secret_key = settings.get("secret_key") + session_token = settings.get("session_token") + if access_key is not None and secret_key is not None: + session_kwargs["aws_access_key_id"] = access_key.get_secret_value() + session_kwargs["aws_secret_access_key"] = secret_key.get_secret_value() + if session_token is not None: + session_kwargs["aws_session_token"] = session_token.get_secret_value() return Boto3Session(**session_kwargs) + def _invoke_converse(self, request: Mapping[str, Any]) -> dict[str, Any]: + response = self._bedrock_client.converse(**request) + if not isinstance(response, Mapping): + raise ChatClientInvalidResponseException("Bedrock converse response must be a mapping.") + return response + @override def _inner_get_response( self, @@ -332,16 +343,20 @@ def _inner_get_response( if stream: # Streaming mode - simulate streaming by yielding a single update async def _stream() -> AsyncIterable[ChatResponseUpdate]: - response = await asyncio.to_thread(self._bedrock_client.converse, **request) + response = await asyncio.to_thread(self._invoke_converse, request) parsed_response = self._process_converse_response(response) contents = list(parsed_response.messages[0].contents if parsed_response.messages else []) if parsed_response.usage_details: contents.append(Content.from_usage(usage_details=parsed_response.usage_details)) # type: ignore[arg-type] + raw_finish_reason = ( + parsed_response.finish_reason if isinstance(parsed_response.finish_reason, str) else None + ) + finish_reason = self._map_finish_reason(raw_finish_reason) yield ChatResponseUpdate( response_id=parsed_response.response_id, contents=contents, model_id=parsed_response.model_id, - finish_reason=parsed_response.finish_reason, + finish_reason=finish_reason, raw_representation=parsed_response.raw_representation, ) @@ -349,7 +364,7 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: # Non-streaming mode async def _get_response() -> ChatResponse: - raw_response = await asyncio.to_thread(self._bedrock_client.converse, **request) + raw_response = await asyncio.to_thread(self._invoke_converse, request) return self._process_converse_response(raw_response) return _get_response() @@ -390,11 +405,16 @@ def _prepare_options( tool_config = self._prepare_tools(options.get("tools")) if tool_mode := validate_tool_mode(options.get("tool_choice")): - tool_config = tool_config or {} match tool_mode.get("mode"): - case "auto" | "none": - tool_config["toolChoice"] = {tool_mode.get("mode"): {}} + case "none": + # Bedrock doesn't support toolChoice "none". + # Omit toolConfig entirely so the model won't attempt tool calls. + tool_config = None + case "auto": + tool_config = tool_config or {} + tool_config["toolChoice"] = {"auto": {}} case "required": + tool_config = tool_config or {} if required_name := tool_mode.get("required_function_name"): tool_config["toolChoice"] = {"tool": {"name": required_name}} else: @@ -529,25 +549,25 @@ def _convert_content_to_bedrock_block(self, content: Content) -> dict[str, Any] def _convert_tool_result_to_blocks(self, result: Any) -> list[dict[str, Any]]: prepared_result = result if isinstance(result, str) else FunctionTool.parse_result(result) try: - parsed_result = json.loads(prepared_result) + parsed_result: object = json.loads(prepared_result) except json.JSONDecodeError: return [{"text": prepared_result}] return self._convert_prepared_tool_result_to_blocks(parsed_result) - def _convert_prepared_tool_result_to_blocks(self, value: Any) -> list[dict[str, Any]]: - if isinstance(value, list): + def _convert_prepared_tool_result_to_blocks(self, value: object) -> list[dict[str, Any]]: + if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): blocks: list[dict[str, Any]] = [] for item in value: blocks.extend(self._convert_prepared_tool_result_to_blocks(item)) return blocks or [{"text": ""}] return [self._normalize_tool_result_value(value)] - def _normalize_tool_result_value(self, value: Any) -> dict[str, Any]: + def _normalize_tool_result_value(self, value: object) -> dict[str, Any]: if isinstance(value, dict): return {"json": value} - if isinstance(value, (list, tuple)): - return {"json": list(value)} + if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): + return {"json": [item for item in value]} if isinstance(value, str): return {"text": value} if isinstance(value, (int, float, bool)) or value is None: @@ -586,12 +606,14 @@ def _generate_tool_call_id() -> str: return f"tool-call-{uuid4().hex}" def _process_converse_response(self, response: dict[str, Any]) -> ChatResponse: - output = response.get("output", {}) - message = output.get("message", {}) - content_blocks = message.get("content", []) or [] + """Convert Bedrock Converse API response to ChatResponse.""" + output = response.get("output") or {} + message = output.get("message") or {} + content_blocks = message.get("content") or [] contents = self._parse_message_contents(content_blocks) chat_message = Message(role="assistant", contents=contents, raw_representation=message) - usage_details = self._parse_usage(response.get("usage") or output.get("usage")) + usage_source = response.get("usage") or output.get("usage") + usage_details = self._parse_usage(usage_source) finish_reason = self._map_finish_reason(output.get("completionReason") or response.get("stopReason")) response_id = response.get("responseId") or message.get("id") model_id = response.get("modelId") or output.get("modelId") or self.model_id @@ -616,7 +638,7 @@ def _parse_usage(self, usage: dict[str, Any] | None) -> UsageDetails | None: details["total_token_count"] = total_tokens return details - def _parse_message_contents(self, content_blocks: Sequence[MutableMapping[str, Any]]) -> list[Any]: + def _parse_message_contents(self, content_blocks: Sequence[dict[str, Any]]) -> list[Any]: contents: list[Any] = [] for block in content_blocks: if text_value := block.get("text"): @@ -625,32 +647,50 @@ def _parse_message_contents(self, content_blocks: Sequence[MutableMapping[str, A if (json_value := block.get("json")) is not None: contents.append(Content.from_text(text=json.dumps(json_value), raw_representation=block)) continue - tool_use = block.get("toolUse") - if isinstance(tool_use, MutableMapping): - tool_name = tool_use.get("name") + tool_use_value = block.get("toolUse") + tool_use = ( + tool_use_value + if isinstance(tool_use_value, dict) + else dict(tool_use_value) + if isinstance(tool_use_value, Mapping) + else None + ) + if tool_use is not None: + tool_name_value = tool_use.get("name") + tool_name = tool_name_value if isinstance(tool_name_value, str) else None if not tool_name: raise ChatClientInvalidResponseException( "Bedrock response missing required tool name in toolUse block." ) + tool_use_id = tool_use.get("toolUseId") contents.append( Content.from_function_call( - call_id=tool_use.get("toolUseId") or self._generate_tool_call_id(), + call_id=tool_use_id if isinstance(tool_use_id, str) else self._generate_tool_call_id(), name=tool_name, arguments=tool_use.get("input"), raw_representation=block, ) ) continue - tool_result = block.get("toolResult") - if isinstance(tool_result, MutableMapping): - status = (tool_result.get("status") or "success").lower() + tool_result_value = block.get("toolResult") + tool_result = ( + tool_result_value + if isinstance(tool_result_value, dict) + else dict(tool_result_value) + if isinstance(tool_result_value, Mapping) + else None + ) + if tool_result is not None: + status_value = tool_result.get("status") + status = (status_value if isinstance(status_value, str) else "success").lower() exception = None if status not in {"success", "ok"}: exception = RuntimeError(f"Bedrock tool result status: {status}") result_value = self._convert_bedrock_tool_result_to_value(tool_result.get("content")) + tool_use_id = tool_result.get("toolUseId") contents.append( Content.from_function_result( - call_id=tool_result.get("toolUseId") or self._generate_tool_call_id(), + call_id=tool_use_id if isinstance(tool_use_id, str) else self._generate_tool_call_id(), result=result_value, exception=str(exception) if exception else None, # type: ignore[arg-type] raw_representation=block, @@ -673,24 +713,28 @@ def service_url(self) -> str: """ return f"https://bedrock-runtime.{self.region}.amazonaws.com" - def _convert_bedrock_tool_result_to_value(self, content: Any) -> Any: + def _convert_bedrock_tool_result_to_value(self, content: object) -> object: if not content: return None if isinstance(content, Sequence) and not isinstance(content, (str, bytes, bytearray)): - values: list[Any] = [] + values: list[object] = [] for item in content: - if isinstance(item, MutableMapping): - if (text_value := item.get("text")) is not None: + item_dict = item if isinstance(item, dict) else dict(item) if isinstance(item, Mapping) else None + if item_dict is not None: + text_value = item_dict.get("text") + if isinstance(text_value, str): values.append(text_value) continue - if "json" in item: - values.append(item["json"]) + if "json" in item_dict: + values.append(item_dict["json"]) continue values.append(item) return values[0] if len(values) == 1 else values - if isinstance(content, MutableMapping): - if (text_value := content.get("text")) is not None: + content_dict = content if isinstance(content, dict) else dict(content) if isinstance(content, Mapping) else None + if content_dict is not None: + text_value = content_dict.get("text") + if isinstance(text_value, str): return text_value - if "json" in content: - return content["json"] + if "json" in content_dict: + return content_dict["json"] return content diff --git a/python/packages/bedrock/tests/test_bedrock_client.py b/python/packages/bedrock/tests/test_bedrock_client.py index e2a2f71750..1566bff234 100644 --- a/python/packages/bedrock/tests/test_bedrock_client.py +++ b/python/packages/bedrock/tests/test_bedrock_client.py @@ -31,6 +31,15 @@ def converse(self, **kwargs: Any) -> dict[str, Any]: } +def _make_client() -> BedrockChatClient: + """Create a BedrockChatClient with a stub runtime for unit tests.""" + return BedrockChatClient( + model_id="amazon.titan-text", + region="us-west-2", + client=_StubBedrockRuntime(), + ) + + async def test_get_response_invokes_bedrock_runtime() -> None: stub = _StubBedrockRuntime() client = BedrockChatClient( @@ -65,3 +74,66 @@ def test_build_request_requires_non_system_messages() -> None: with pytest.raises(ValueError): client._prepare_options(messages, {}) + + +def test_prepare_options_tool_choice_none_omits_tool_config() -> None: + """When tool_choice='none', toolConfig must be omitted entirely. + + Bedrock's Converse API only accepts 'auto', 'any', or 'tool' as valid + toolChoice keys. Sending {"none": {}} causes a ParamValidationError. + The fix omits toolConfig so the model won't attempt tool calls. + + Fixes #4529. + """ + client = _make_client() + messages = [Message(role="user", contents=[Content.from_text(text="hello")])] + + # Even when tools are provided, tool_choice="none" should strip toolConfig + options: dict[str, Any] = { + "tool_choice": "none", + "tools": [ + {"toolSpec": {"name": "get_weather", "description": "Get weather", "inputSchema": {"json": {}}}}, + ], + } + + request = client._prepare_options(messages, options) + + assert "toolConfig" not in request, ( + f"toolConfig should be omitted when tool_choice='none', got: {request.get('toolConfig')}" + ) + + +def test_prepare_options_tool_choice_auto_includes_tool_config() -> None: + """When tool_choice='auto', toolConfig.toolChoice should be {'auto': {}}.""" + client = _make_client() + messages = [Message(role="user", contents=[Content.from_text(text="hello")])] + + options: dict[str, Any] = { + "tool_choice": "auto", + "tools": [ + {"toolSpec": {"name": "get_weather", "description": "Get weather", "inputSchema": {"json": {}}}}, + ], + } + + request = client._prepare_options(messages, options) + + assert "toolConfig" in request + assert request["toolConfig"]["toolChoice"] == {"auto": {}} + + +def test_prepare_options_tool_choice_required_includes_any() -> None: + """When tool_choice='required' (no specific function), toolChoice should be {'any': {}}.""" + client = _make_client() + messages = [Message(role="user", contents=[Content.from_text(text="hello")])] + + options: dict[str, Any] = { + "tool_choice": "required", + "tools": [ + {"toolSpec": {"name": "get_weather", "description": "Get weather", "inputSchema": {"json": {}}}}, + ], + } + + request = client._prepare_options(messages, options) + + assert "toolConfig" in request + assert request["toolConfig"]["toolChoice"] == {"any": {}}