Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/anthropic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
UnprocessableEntityError,
APIResponseValidationError,
)
from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient
from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient, RetryInfo
from ._utils._logs import setup_logging as _setup_logging
from .lib._parse._transform import transform_schema

Expand Down Expand Up @@ -88,6 +88,7 @@
"DefaultHttpxClient",
"DefaultAsyncHttpxClient",
"DefaultAioHttpClient",
"RetryInfo",
"HUMAN_PROMPT",
"AI_PROMPT",
"beta_tool",
Expand Down
75 changes: 73 additions & 2 deletions src/anthropic/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import email.utils
from types import TracebackType
from random import random
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -365,6 +366,30 @@ async def get_next_page(self: AsyncPageT) -> AsyncPageT:
_DefaultStreamT = TypeVar("_DefaultStreamT", bound=Union[Stream[Any], AsyncStream[Any]])


@dataclass
class RetryInfo:
"""Information about a retry attempt, passed to the ``on_retry`` callback.

Attributes:
attempt: The 1-based index of the retry that is *about to happen*
(e.g. 1 for the first retry, 2 for the second, …).
max_retries: The maximum number of retries configured for this request.
url: The URL that is being retried.
wait_seconds: How many seconds the SDK will sleep before issuing the
next request.
response: The HTTP response that triggered the retry, or ``None`` if the
retry is caused by a network-level error (timeout / connection error).
error: The exception that caused the retry (always present).
"""

attempt: int
max_retries: int
url: str
wait_seconds: float
response: httpx.Response | None
error: BaseException


class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
_client: _HttpxClientT
_version: str
Expand All @@ -385,6 +410,7 @@ def __init__(
timeout: float | Timeout | None = DEFAULT_TIMEOUT,
custom_headers: Mapping[str, str] | None = None,
custom_query: Mapping[str, object] | None = None,
on_retry: Optional[Callable[[RetryInfo], None]] = None,
) -> None:
self._version = version
self._base_url = self._enforce_trailing_slash(URL(base_url))
Expand All @@ -395,6 +421,7 @@ def __init__(
self._strict_response_validation = _strict_response_validation
self._idempotency_header = None
self._platform: Platform | None = None
self._on_retry = on_retry

if max_retries is None: # pyright: ignore[reportUnnecessaryComparison]
raise TypeError(
Expand Down Expand Up @@ -922,6 +949,7 @@ def __init__(
custom_headers: Mapping[str, str] | None = None,
custom_query: Mapping[str, object] | None = None,
_strict_response_validation: bool,
on_retry: Optional[Callable[[RetryInfo], None]] = None,
) -> None:
if not is_given(timeout):
# if the user passed in a custom http client with a non-default
Expand Down Expand Up @@ -950,6 +978,7 @@ def __init__(
custom_query=custom_query,
custom_headers=custom_headers,
_strict_response_validation=_strict_response_validation,
on_retry=on_retry,
)
self._client = http_client or SyncHttpxClientWrapper(
base_url=base_url,
Expand Down Expand Up @@ -1083,6 +1112,7 @@ def request(
max_retries=max_retries,
options=input_options,
response=None,
error=err,
)
continue

Expand All @@ -1097,6 +1127,7 @@ def request(
max_retries=max_retries,
options=input_options,
response=None,
error=err,
)
continue

Expand Down Expand Up @@ -1125,6 +1156,7 @@ def request(
max_retries=max_retries,
options=input_options,
response=response,
error=err,
)
continue

Expand All @@ -1149,7 +1181,13 @@ def request(
)

def _sleep_for_retry(
self, *, retries_taken: int, max_retries: int, options: FinalRequestOptions, response: httpx.Response | None
self,
*,
retries_taken: int,
max_retries: int,
options: FinalRequestOptions,
response: httpx.Response | None,
error: BaseException | None = None,
) -> None:
remaining_retries = max_retries - retries_taken
if remaining_retries == 1:
Expand All @@ -1160,6 +1198,17 @@ def _sleep_for_retry(
timeout = self._calculate_retry_timeout(remaining_retries, options, response.headers if response else None)
log.info("Retrying request to %s in %f seconds", options.url, timeout)

if self._on_retry is not None:
info = RetryInfo(
attempt=retries_taken + 1,
max_retries=max_retries,
url=options.url,
wait_seconds=timeout,
response=response,
error=error or Exception("unknown"),
)
self._on_retry(info)

time.sleep(timeout)

def _process_response(
Expand Down Expand Up @@ -1560,6 +1609,7 @@ def __init__(
http_client: httpx.AsyncClient | None = None,
custom_headers: Mapping[str, str] | None = None,
custom_query: Mapping[str, object] | None = None,
on_retry: Optional[Callable[[RetryInfo], None]] = None,
) -> None:
if not is_given(timeout):
# if the user passed in a custom http client with a non-default
Expand Down Expand Up @@ -1588,6 +1638,7 @@ def __init__(
custom_query=custom_query,
custom_headers=custom_headers,
_strict_response_validation=_strict_response_validation,
on_retry=on_retry,
)
self._client = http_client or AsyncHttpxClientWrapper(
base_url=base_url,
Expand Down Expand Up @@ -1723,6 +1774,7 @@ async def request(
max_retries=max_retries,
options=input_options,
response=None,
error=err,
)
continue

Expand All @@ -1737,6 +1789,7 @@ async def request(
max_retries=max_retries,
options=input_options,
response=None,
error=err,
)
continue

Expand Down Expand Up @@ -1765,6 +1818,7 @@ async def request(
max_retries=max_retries,
options=input_options,
response=response,
error=err,
)
continue

Expand All @@ -1789,7 +1843,13 @@ async def request(
)

async def _sleep_for_retry(
self, *, retries_taken: int, max_retries: int, options: FinalRequestOptions, response: httpx.Response | None
self,
*,
retries_taken: int,
max_retries: int,
options: FinalRequestOptions,
response: httpx.Response | None,
error: BaseException | None = None,
) -> None:
remaining_retries = max_retries - retries_taken
if remaining_retries == 1:
Expand All @@ -1800,6 +1860,17 @@ async def _sleep_for_retry(
timeout = self._calculate_retry_timeout(remaining_retries, options, response.headers if response else None)
log.info("Retrying request to %s in %f seconds", options.url, timeout)

if self._on_retry is not None:
info = RetryInfo(
attempt=retries_taken + 1,
max_retries=max_retries,
url=options.url,
wait_seconds=timeout,
response=response,
error=error or Exception("unknown"),
)
self._on_retry(info)

await anyio.sleep(timeout)

async def _process_response(
Expand Down
14 changes: 13 additions & 1 deletion src/anthropic/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Any, Mapping
from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional
from typing_extensions import Self, override

import httpx
Expand All @@ -27,6 +27,7 @@
from ._exceptions import APIStatusError
from ._base_client import (
DEFAULT_MAX_RETRIES,
RetryInfo,
SyncAPIClient,
AsyncAPIClient,
)
Expand Down Expand Up @@ -82,6 +83,10 @@ def __init__(
# outlining your use-case to help us decide if it should be
# part of our public interface in the future.
_strict_response_validation: bool = False,
# Optional callback invoked before each retry sleep. Receives a
# :class:`RetryInfo` instance describing the retry that is about to
# happen. Useful for logging, metrics, or custom alerting.
on_retry: Optional[Callable[[RetryInfo], None]] = None,
) -> None:
"""Construct a new synchronous Anthropic client instance.

Expand Down Expand Up @@ -111,6 +116,7 @@ def __init__(
custom_headers=default_headers,
custom_query=default_query,
_strict_response_validation=_strict_response_validation,
on_retry=on_retry,
)

self._default_stream_cls = Stream
Expand Down Expand Up @@ -210,6 +216,7 @@ def copy(
set_default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
set_default_query: Mapping[str, object] | None = None,
on_retry: Optional[Callable[[RetryInfo], None]] | NotGiven = not_given,
_extra_kwargs: Mapping[str, Any] = {},
) -> Self:
"""
Expand Down Expand Up @@ -243,6 +250,7 @@ def copy(
max_retries=max_retries if is_given(max_retries) else self.max_retries,
default_headers=headers,
default_query=params,
on_retry=self._on_retry if isinstance(on_retry, NotGiven) else on_retry,
**_extra_kwargs,
)

Expand Down Expand Up @@ -322,6 +330,7 @@ def __init__(
# outlining your use-case to help us decide if it should be
# part of our public interface in the future.
_strict_response_validation: bool = False,
on_retry: Optional[Callable[[RetryInfo], None]] = None,
) -> None:
"""Construct a new async AsyncAnthropic client instance.

Expand Down Expand Up @@ -351,6 +360,7 @@ def __init__(
custom_headers=default_headers,
custom_query=default_query,
_strict_response_validation=_strict_response_validation,
on_retry=on_retry,
)

self._default_stream_cls = AsyncStream
Expand Down Expand Up @@ -450,6 +460,7 @@ def copy(
set_default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
set_default_query: Mapping[str, object] | None = None,
on_retry: Optional[Callable[[RetryInfo], None]] | NotGiven = not_given,
_extra_kwargs: Mapping[str, Any] = {},
) -> Self:
"""
Expand Down Expand Up @@ -483,6 +494,7 @@ def copy(
max_retries=max_retries if is_given(max_retries) else self.max_retries,
default_headers=headers,
default_query=params,
on_retry=self._on_retry if isinstance(on_retry, NotGiven) else on_retry,
**_extra_kwargs,
)

Expand Down
10 changes: 8 additions & 2 deletions src/anthropic/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,14 @@ def is_typeddict(type_: type[Any]) -> bool: # noqa: ARG001
from pydantic import ConfigDict as ConfigDict
else:
if PYDANTIC_V1:
# TODO: provide an error message here?
ConfigDict = None

class ConfigDict: # type: ignore[no-redef]
def __new__(cls, **kwargs: object) -> "ConfigDict":
raise RuntimeError(
"ConfigDict is not supported in Pydantic v1. "
"Please upgrade to Pydantic v2 to use this feature."
)

else:
from pydantic import ConfigDict as ConfigDict

Expand Down
2 changes: 1 addition & 1 deletion src/anthropic/_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ async def async_to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles
elif is_sequence_t(files):
files = [(key, await _async_transform_file(file)) for key, file in files]
else:
raise TypeError("Unexpected file type input {type(files)}, expected mapping or sequence")
raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence")

return files

Expand Down
16 changes: 9 additions & 7 deletions src/anthropic/_qs.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,16 @@ def _stringify_item(
items: list[tuple[str, str]] = []
nested_format = opts.nested_format
for subkey, subvalue in value.items():
items.extend(
self._stringify_item(
# TODO: error if unknown format
f"{key}.{subkey}" if nested_format == "dots" else f"{key}[{subkey}]",
subvalue,
opts,
if nested_format == "dots":
nested_key = f"{key}.{subkey}"
elif nested_format == "brackets":
nested_key = f"{key}[{subkey}]"
else:
raise NotImplementedError(
f"Unknown nested_format value: {nested_format!r}, "
f"choose from {', '.join(get_args(NestedFormat))}"
)
)
items.extend(self._stringify_item(nested_key, subvalue, opts))
return items

if isinstance(value, (list, tuple)):
Expand Down
4 changes: 2 additions & 2 deletions src/anthropic/_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __stream__(self) -> Iterator[_T]:
try:
body = sse.json()
err_msg = f"{body}"
except Exception:
except json.JSONDecodeError:
err_msg = sse.data or f"Error code: {response.status_code}"

raise self._client._make_status_error(
Expand Down Expand Up @@ -228,7 +228,7 @@ async def __stream__(self) -> AsyncIterator[_T]:
try:
body = sse.json()
err_msg = f"{body}"
except Exception:
except json.JSONDecodeError:
err_msg = sse.data or f"Error code: {response.status_code}"

raise self._client._make_status_error(
Expand Down
2 changes: 1 addition & 1 deletion src/anthropic/_utils/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ async def _async_transform_recursive(

if origin == dict and is_mapping(data):
items_type = get_args(stripped_type)[1]
return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
return {key: await _async_transform_recursive(value, annotation=items_type) for key, value in data.items()}

if (
# List[T]
Expand Down
Loading