Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/anthropic/lib/bedrock/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,27 @@ def _infer_region() -> str:


class BaseBedrockClient(BaseClient[_HttpxClientT, _DefaultStreamT]):
@override
def _should_retry(self, response: httpx.Response) -> bool:
if super()._should_retry(response):
return True

if response.status_code == 400:
error_type = response.headers.get("x-amzn-errortype", "")
if any(
exc in error_type
for exc in (
"ThrottlingException",
"TooManyRequestsException",
"ModelTimeoutException",
"ServiceUnavailableException",
)
):
log.debug("Retrying due to Bedrock transient error: %s", error_type)
return True

return False

@override
def _make_status_error(
self,
Expand Down
93 changes: 92 additions & 1 deletion tests/lib/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
from respx import MockRouter

from anthropic import AnthropicBedrock, AsyncAnthropicBedrock
from anthropic import BadRequestError, AnthropicBedrock, AsyncAnthropicBedrock

sync_client = AnthropicBedrock(
aws_region="us-east-1",
Expand Down Expand Up @@ -195,3 +195,94 @@ def test_region_infer_from_specified_profile(
client = AnthropicBedrock()

assert client.aws_region == next(profile for profile in profiles if profile["name"] == aws_profile)["region"]


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
@pytest.mark.respx()
def test_retries_on_bedrock_throttling_error(respx_mock: MockRouter) -> None:
respx_mock.post(re.compile(r"https://bedrock-runtime\.us-east-1\.amazonaws\.com/model/.*/invoke")).mock(
side_effect=[
httpx.Response(
400,
json={"message": "Too many requests, please wait before trying again."},
headers={"x-amzn-errortype": "ThrottlingException", "retry-after-ms": "10"},
),
httpx.Response(200, json={"foo": "bar"}),
]
)

sync_client.messages.create(
max_tokens=1024,
messages=[
{
"role": "user",
"content": "Say hello there!",
}
],
model="anthropic.claude-3-5-sonnet-20241022-v2:0",
)

calls = cast("list[MockRequestCall]", respx_mock.calls)

assert len(calls) == 2


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
@pytest.mark.respx()
@pytest.mark.asyncio()
async def test_retries_on_bedrock_throttling_error_async(respx_mock: MockRouter) -> None:
respx_mock.post(re.compile(r"https://bedrock-runtime\.us-east-1\.amazonaws\.com/model/.*/invoke")).mock(
side_effect=[
httpx.Response(
400,
json={"message": "Too many requests, please wait before trying again."},
headers={"x-amzn-errortype": "ThrottlingException", "retry-after-ms": "10"},
),
httpx.Response(200, json={"foo": "bar"}),
]
)

await async_client.messages.create(
max_tokens=1024,
messages=[
{
"role": "user",
"content": "Say hello there!",
}
],
model="anthropic.claude-3-5-sonnet-20241022-v2:0",
)

calls = cast("list[MockRequestCall]", respx_mock.calls)

assert len(calls) == 2


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
@pytest.mark.respx()
def test_no_retry_on_bedrock_validation_error(respx_mock: MockRouter) -> None:
respx_mock.post(re.compile(r"https://bedrock-runtime\.us-east-1\.amazonaws\.com/model/.*/invoke")).mock(
side_effect=[
httpx.Response(
400,
json={"message": "Invalid input"},
headers={"x-amzn-errortype": "ValidationException"},
),
]
)

with pytest.raises(BadRequestError):
sync_client.messages.create(
max_tokens=1024,
messages=[
{
"role": "user",
"content": "Say hello there!",
}
],
model="anthropic.claude-3-5-sonnet-20241022-v2:0",
)

calls = cast("list[MockRequestCall]", respx_mock.calls)

assert len(calls) == 1