Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 93 additions & 49 deletions python/packages/bedrock/agent_framework_bedrock/_chat_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.

# type: ignore
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The header # type: ignore does not disable type checking for the module in mypy/pyright (mypy uses # mypy: ignore-errors, pyright uses # pyright: ignore). As written, this comment is misleading and may not achieve the intended suppression; consider removing it or switching to the correct per-tool directive / targeted ignores.

Suggested change
# type: ignore
# mypy: ignore-errors
# pyright: ignore

Copilot uses AI. Check for mistakes.
# Because the Bedrock client does not have typing, we are ignoring type issues in this module.
from __future__ import annotations

import asyncio
Expand Down Expand Up @@ -288,14 +289,16 @@ class MyOptions(BedrockChatOptions, total=False):
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
if not settings.get("region"):
settings["region"] = DEFAULT_REGION
region = settings.get("region") or DEFAULT_REGION
chat_model_id = settings.get("chat_model_id")

if client is None:
if client:
self._bedrock_client = client
else:
session = boto3_session or self._create_session(settings)
Comment on lines +295 to 298
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if client: relies on truthiness; a valid injected client/stub that defines __bool__/__len__ as falsey would be ignored and a boto3 client would be created instead. Prefer an explicit if client is not None: check (consistent with other modules in this repo).

Suggested change
if client:
self._bedrock_client = client
else:
session = boto3_session or self._create_session(settings)
if client is not None:
self._bedrock_client = client
else:
if boto3_session is not None:
session = boto3_session
else:
session = self._create_session(settings)

Copilot uses AI. Check for mistakes.
client = session.client(
self._bedrock_client = session.client(
"bedrock-runtime",
region_name=settings["region"],
region_name=region,
config=BotoConfig(user_agent_extra=AGENT_FRAMEWORK_USER_AGENT),
)

Expand All @@ -304,20 +307,28 @@ class MyOptions(BedrockChatOptions, total=False):
function_invocation_configuration=function_invocation_configuration,
**kwargs,
)
self._bedrock_client = client
self.model_id = settings["chat_model_id"]
self.region = settings["region"]
self.model_id = chat_model_id
self.region = region

@staticmethod
def _create_session(settings: BedrockSettings) -> Boto3Session:
session_kwargs: dict[str, Any] = {"region_name": settings.get("region") or DEFAULT_REGION}
if settings.get("access_key") and settings.get("secret_key"):
session_kwargs["aws_access_key_id"] = settings["access_key"].get_secret_value() # type: ignore[union-attr]
session_kwargs["aws_secret_access_key"] = settings["secret_key"].get_secret_value() # type: ignore[union-attr]
if settings.get("session_token"):
session_kwargs["aws_session_token"] = settings["session_token"].get_secret_value() # type: ignore[union-attr]
access_key = settings.get("access_key")
secret_key = settings.get("secret_key")
session_token = settings.get("session_token")
if access_key is not None and secret_key is not None:
session_kwargs["aws_access_key_id"] = access_key.get_secret_value()
session_kwargs["aws_secret_access_key"] = secret_key.get_secret_value()
if session_token is not None:
session_kwargs["aws_session_token"] = session_token.get_secret_value()
return Boto3Session(**session_kwargs)

def _invoke_converse(self, request: Mapping[str, Any]) -> dict[str, Any]:
response = self._bedrock_client.converse(**request)
if not isinstance(response, Mapping):
raise ChatClientInvalidResponseException("Bedrock converse response must be a mapping.")
return response

@override
def _inner_get_response(
self,
Expand All @@ -332,24 +343,28 @@ def _inner_get_response(
if stream:
# Streaming mode - simulate streaming by yielding a single update
async def _stream() -> AsyncIterable[ChatResponseUpdate]:
response = await asyncio.to_thread(self._bedrock_client.converse, **request)
response = await asyncio.to_thread(self._invoke_converse, request)
parsed_response = self._process_converse_response(response)
contents = list(parsed_response.messages[0].contents if parsed_response.messages else [])
if parsed_response.usage_details:
contents.append(Content.from_usage(usage_details=parsed_response.usage_details)) # type: ignore[arg-type]
raw_finish_reason = (
parsed_response.finish_reason if isinstance(parsed_response.finish_reason, str) else None
)
finish_reason = self._map_finish_reason(raw_finish_reason)
yield ChatResponseUpdate(
response_id=parsed_response.response_id,
contents=contents,
model_id=parsed_response.model_id,
finish_reason=parsed_response.finish_reason,
finish_reason=finish_reason,
raw_representation=parsed_response.raw_representation,
Comment on lines +351 to 360
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the streaming path, parsed_response.finish_reason is already normalized by _process_converse_response (via _map_finish_reason). Mapping it again with _map_finish_reason turns values like "stop" into None, so streamed ChatResponseUpdate.finish_reason will be incorrect/missing. Use parsed_response.finish_reason directly (or capture the raw Bedrock completion reason before mapping).

Copilot uses AI. Check for mistakes.
)

return self._build_response_stream(_stream())

# Non-streaming mode
async def _get_response() -> ChatResponse:
raw_response = await asyncio.to_thread(self._bedrock_client.converse, **request)
raw_response = await asyncio.to_thread(self._invoke_converse, request)
return self._process_converse_response(raw_response)

return _get_response()
Expand Down Expand Up @@ -390,11 +405,16 @@ def _prepare_options(

tool_config = self._prepare_tools(options.get("tools"))
if tool_mode := validate_tool_mode(options.get("tool_choice")):
tool_config = tool_config or {}
match tool_mode.get("mode"):
case "auto" | "none":
tool_config["toolChoice"] = {tool_mode.get("mode"): {}}
case "none":
# Bedrock doesn't support toolChoice "none".
# Omit toolConfig entirely so the model won't attempt tool calls.
tool_config = None
case "auto":
tool_config = tool_config or {}
tool_config["toolChoice"] = {"auto": {}}
case "required":
tool_config = tool_config or {}
if required_name := tool_mode.get("required_function_name"):
tool_config["toolChoice"] = {"tool": {"name": required_name}}
else:
Expand Down Expand Up @@ -529,25 +549,25 @@ def _convert_content_to_bedrock_block(self, content: Content) -> dict[str, Any]
def _convert_tool_result_to_blocks(self, result: Any) -> list[dict[str, Any]]:
prepared_result = result if isinstance(result, str) else FunctionTool.parse_result(result)
try:
parsed_result = json.loads(prepared_result)
parsed_result: object = json.loads(prepared_result)
except json.JSONDecodeError:
return [{"text": prepared_result}]

return self._convert_prepared_tool_result_to_blocks(parsed_result)

def _convert_prepared_tool_result_to_blocks(self, value: Any) -> list[dict[str, Any]]:
if isinstance(value, list):
def _convert_prepared_tool_result_to_blocks(self, value: object) -> list[dict[str, Any]]:
if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
blocks: list[dict[str, Any]] = []
for item in value:
blocks.extend(self._convert_prepared_tool_result_to_blocks(item))
return blocks or [{"text": ""}]
return [self._normalize_tool_result_value(value)]

def _normalize_tool_result_value(self, value: Any) -> dict[str, Any]:
def _normalize_tool_result_value(self, value: object) -> dict[str, Any]:
if isinstance(value, dict):
return {"json": value}
if isinstance(value, (list, tuple)):
return {"json": list(value)}
if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
return {"json": [item for item in value]}
if isinstance(value, str):
return {"text": value}
if isinstance(value, (int, float, bool)) or value is None:
Expand Down Expand Up @@ -586,12 +606,14 @@ def _generate_tool_call_id() -> str:
return f"tool-call-{uuid4().hex}"

def _process_converse_response(self, response: dict[str, Any]) -> ChatResponse:
output = response.get("output", {})
message = output.get("message", {})
content_blocks = message.get("content", []) or []
"""Convert Bedrock Converse API response to ChatResponse."""
output = response.get("output") or {}
message = output.get("message") or {}
content_blocks = message.get("content") or []
contents = self._parse_message_contents(content_blocks)
chat_message = Message(role="assistant", contents=contents, raw_representation=message)
usage_details = self._parse_usage(response.get("usage") or output.get("usage"))
usage_source = response.get("usage") or output.get("usage")
usage_details = self._parse_usage(usage_source)
finish_reason = self._map_finish_reason(output.get("completionReason") or response.get("stopReason"))
response_id = response.get("responseId") or message.get("id")
model_id = response.get("modelId") or output.get("modelId") or self.model_id
Expand All @@ -616,7 +638,7 @@ def _parse_usage(self, usage: dict[str, Any] | None) -> UsageDetails | None:
details["total_token_count"] = total_tokens
return details

def _parse_message_contents(self, content_blocks: Sequence[MutableMapping[str, Any]]) -> list[Any]:
def _parse_message_contents(self, content_blocks: Sequence[dict[str, Any]]) -> list[Any]:
contents: list[Any] = []
for block in content_blocks:
if text_value := block.get("text"):
Expand All @@ -625,32 +647,50 @@ def _parse_message_contents(self, content_blocks: Sequence[MutableMapping[str, A
if (json_value := block.get("json")) is not None:
contents.append(Content.from_text(text=json.dumps(json_value), raw_representation=block))
continue
tool_use = block.get("toolUse")
if isinstance(tool_use, MutableMapping):
tool_name = tool_use.get("name")
tool_use_value = block.get("toolUse")
tool_use = (
tool_use_value
if isinstance(tool_use_value, dict)
else dict(tool_use_value)
if isinstance(tool_use_value, Mapping)
else None
)
if tool_use is not None:
tool_name_value = tool_use.get("name")
tool_name = tool_name_value if isinstance(tool_name_value, str) else None
if not tool_name:
raise ChatClientInvalidResponseException(
"Bedrock response missing required tool name in toolUse block."
)
tool_use_id = tool_use.get("toolUseId")
contents.append(
Content.from_function_call(
call_id=tool_use.get("toolUseId") or self._generate_tool_call_id(),
call_id=tool_use_id if isinstance(tool_use_id, str) else self._generate_tool_call_id(),
name=tool_name,
arguments=tool_use.get("input"),
raw_representation=block,
)
)
continue
tool_result = block.get("toolResult")
if isinstance(tool_result, MutableMapping):
status = (tool_result.get("status") or "success").lower()
tool_result_value = block.get("toolResult")
tool_result = (
tool_result_value
if isinstance(tool_result_value, dict)
else dict(tool_result_value)
if isinstance(tool_result_value, Mapping)
else None
)
if tool_result is not None:
status_value = tool_result.get("status")
status = (status_value if isinstance(status_value, str) else "success").lower()
exception = None
if status not in {"success", "ok"}:
exception = RuntimeError(f"Bedrock tool result status: {status}")
result_value = self._convert_bedrock_tool_result_to_value(tool_result.get("content"))
tool_use_id = tool_result.get("toolUseId")
contents.append(
Content.from_function_result(
call_id=tool_result.get("toolUseId") or self._generate_tool_call_id(),
call_id=tool_use_id if isinstance(tool_use_id, str) else self._generate_tool_call_id(),
result=result_value,
exception=str(exception) if exception else None, # type: ignore[arg-type]
raw_representation=block,
Expand All @@ -673,24 +713,28 @@ def service_url(self) -> str:
"""
return f"https://bedrock-runtime.{self.region}.amazonaws.com"

def _convert_bedrock_tool_result_to_value(self, content: Any) -> Any:
def _convert_bedrock_tool_result_to_value(self, content: object) -> object:
if not content:
return None
if isinstance(content, Sequence) and not isinstance(content, (str, bytes, bytearray)):
values: list[Any] = []
values: list[object] = []
for item in content:
if isinstance(item, MutableMapping):
if (text_value := item.get("text")) is not None:
item_dict = item if isinstance(item, dict) else dict(item) if isinstance(item, Mapping) else None
if item_dict is not None:
text_value = item_dict.get("text")
if isinstance(text_value, str):
values.append(text_value)
continue
if "json" in item:
values.append(item["json"])
if "json" in item_dict:
values.append(item_dict["json"])
continue
values.append(item)
return values[0] if len(values) == 1 else values
if isinstance(content, MutableMapping):
if (text_value := content.get("text")) is not None:
content_dict = content if isinstance(content, dict) else dict(content) if isinstance(content, Mapping) else None
if content_dict is not None:
text_value = content_dict.get("text")
if isinstance(text_value, str):
return text_value
if "json" in content:
return content["json"]
if "json" in content_dict:
return content_dict["json"]
return content
72 changes: 72 additions & 0 deletions python/packages/bedrock/tests/test_bedrock_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@ def converse(self, **kwargs: Any) -> dict[str, Any]:
}


def _make_client() -> BedrockChatClient:
"""Create a BedrockChatClient with a stub runtime for unit tests."""
return BedrockChatClient(
model_id="amazon.titan-text",
region="us-west-2",
client=_StubBedrockRuntime(),
)


async def test_get_response_invokes_bedrock_runtime() -> None:
stub = _StubBedrockRuntime()
client = BedrockChatClient(
Expand Down Expand Up @@ -65,3 +74,66 @@ def test_build_request_requires_non_system_messages() -> None:

with pytest.raises(ValueError):
client._prepare_options(messages, {})


def test_prepare_options_tool_choice_none_omits_tool_config() -> None:
"""When tool_choice='none', toolConfig must be omitted entirely.

Bedrock's Converse API only accepts 'auto', 'any', or 'tool' as valid
toolChoice keys. Sending {"none": {}} causes a ParamValidationError.
The fix omits toolConfig so the model won't attempt tool calls.

Fixes #4529.
"""
client = _make_client()
messages = [Message(role="user", contents=[Content.from_text(text="hello")])]

# Even when tools are provided, tool_choice="none" should strip toolConfig
options: dict[str, Any] = {
"tool_choice": "none",
"tools": [
{"toolSpec": {"name": "get_weather", "description": "Get weather", "inputSchema": {"json": {}}}},
],
}

request = client._prepare_options(messages, options)

assert "toolConfig" not in request, (
f"toolConfig should be omitted when tool_choice='none', got: {request.get('toolConfig')}"
)


def test_prepare_options_tool_choice_auto_includes_tool_config() -> None:
"""When tool_choice='auto', toolConfig.toolChoice should be {'auto': {}}."""
client = _make_client()
messages = [Message(role="user", contents=[Content.from_text(text="hello")])]

options: dict[str, Any] = {
"tool_choice": "auto",
"tools": [
{"toolSpec": {"name": "get_weather", "description": "Get weather", "inputSchema": {"json": {}}}},
],
}

request = client._prepare_options(messages, options)

assert "toolConfig" in request
assert request["toolConfig"]["toolChoice"] == {"auto": {}}


def test_prepare_options_tool_choice_required_includes_any() -> None:
"""When tool_choice='required' (no specific function), toolChoice should be {'any': {}}."""
client = _make_client()
messages = [Message(role="user", contents=[Content.from_text(text="hello")])]

options: dict[str, Any] = {
"tool_choice": "required",
"tools": [
{"toolSpec": {"name": "get_weather", "description": "Get weather", "inputSchema": {"json": {}}}},
],
}

request = client._prepare_options(messages, options)

assert "toolConfig" in request
assert request["toolConfig"]["toolChoice"] == {"any": {}}
Loading