diff --git a/src/anthropic/lib/bedrock/_client.py b/src/anthropic/lib/bedrock/_client.py index 013d2702..f53a8aa4 100644 --- a/src/anthropic/lib/bedrock/_client.py +++ b/src/anthropic/lib/bedrock/_client.py @@ -91,6 +91,27 @@ def _infer_region() -> str: class BaseBedrockClient(BaseClient[_HttpxClientT, _DefaultStreamT]): + @override + def _should_retry(self, response: httpx.Response) -> bool: + if super()._should_retry(response): + return True + + if response.status_code == 400: + error_type = response.headers.get("x-amzn-errortype", "") + if any( + exc in error_type + for exc in ( + "ThrottlingException", + "TooManyRequestsException", + "ModelTimeoutException", + "ServiceUnavailableException", + ) + ): + log.debug("Retrying due to Bedrock transient error: %s", error_type) + return True + + return False + @override def _make_status_error( self, diff --git a/tests/lib/test_bedrock.py b/tests/lib/test_bedrock.py index fe62da43..ea7bb565 100644 --- a/tests/lib/test_bedrock.py +++ b/tests/lib/test_bedrock.py @@ -8,7 +8,7 @@ import pytest from respx import MockRouter -from anthropic import AnthropicBedrock, AsyncAnthropicBedrock +from anthropic import BadRequestError, AnthropicBedrock, AsyncAnthropicBedrock sync_client = AnthropicBedrock( aws_region="us-east-1", @@ -195,3 +195,94 @@ def test_region_infer_from_specified_profile( client = AnthropicBedrock() assert client.aws_region == next(profile for profile in profiles if profile["name"] == aws_profile)["region"] + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +@pytest.mark.respx() +def test_retries_on_bedrock_throttling_error(respx_mock: MockRouter) -> None: + respx_mock.post(re.compile(r"https://bedrock-runtime\.us-east-1\.amazonaws\.com/model/.*/invoke")).mock( + side_effect=[ + httpx.Response( + 400, + json={"message": "Too many requests, please wait before trying again."}, + headers={"x-amzn-errortype": "ThrottlingException", "retry-after-ms": "10"}, + ), + httpx.Response(200, json={"foo": "bar"}), + ] + ) + + sync_client.messages.create( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "Say hello there!", + } + ], + model="anthropic.claude-3-5-sonnet-20241022-v2:0", + ) + + calls = cast("list[MockRequestCall]", respx_mock.calls) + + assert len(calls) == 2 + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +@pytest.mark.respx() +@pytest.mark.asyncio() +async def test_retries_on_bedrock_throttling_error_async(respx_mock: MockRouter) -> None: + respx_mock.post(re.compile(r"https://bedrock-runtime\.us-east-1\.amazonaws\.com/model/.*/invoke")).mock( + side_effect=[ + httpx.Response( + 400, + json={"message": "Too many requests, please wait before trying again."}, + headers={"x-amzn-errortype": "ThrottlingException", "retry-after-ms": "10"}, + ), + httpx.Response(200, json={"foo": "bar"}), + ] + ) + + await async_client.messages.create( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "Say hello there!", + } + ], + model="anthropic.claude-3-5-sonnet-20241022-v2:0", + ) + + calls = cast("list[MockRequestCall]", respx_mock.calls) + + assert len(calls) == 2 + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +@pytest.mark.respx() +def test_no_retry_on_bedrock_validation_error(respx_mock: MockRouter) -> None: + respx_mock.post(re.compile(r"https://bedrock-runtime\.us-east-1\.amazonaws\.com/model/.*/invoke")).mock( + side_effect=[ + httpx.Response( + 400, + json={"message": "Invalid input"}, + headers={"x-amzn-errortype": "ValidationException"}, + ), + ] + ) + + with pytest.raises(BadRequestError): + sync_client.messages.create( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "Say hello there!", + } + ], + model="anthropic.claude-3-5-sonnet-20241022-v2:0", + ) + + calls = cast("list[MockRequestCall]", respx_mock.calls) + + assert len(calls) == 1