From 7232a41694f3f04dc8bdb0559374bb5de825414e Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 23 Mar 2026 14:34:10 +0000 Subject: [PATCH 1/2] fix(utils): correct async dict transform, deterministic errors, and set serialization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - _transform.py: `_async_transform_recursive` was calling the *sync* `_transform_recursive` when the annotated type was `Dict[str, T]`. This meant type aliases inside the dict values were never applied in the async path. Fixed by awaiting `_async_transform_recursive` instead. - _utils.py: `required_args` built the list of missing arguments from a `set` subtraction, producing a non-deterministic error message across Python runs. Changed to `sorted(...)` so the message is always stable. - _utils.py: `json_safe` now converts `set` and `frozenset` to a sorted list so they can be serialised to JSON without raising a TypeError. - _utils.py: clarify `deepcopy_minimal` docstring — tuples are intentionally not copied because they are immutable containers. - _compat.py: `ConfigDict` under Pydantic v1 was silently set to `None`. It now raises a clear `RuntimeError` if instantiated, pointing users towards upgrading to Pydantic v2. Tests added for all of the above changes. https://claude.ai/code/session_01CqXN7tjqqv1rCwTMkEmbFb --- src/anthropic/_compat.py | 10 ++- src/anthropic/_utils/_transform.py | 2 +- src/anthropic/_utils/_utils.py | 9 ++- tests/test_required_args.py | 21 ++++++ tests/test_transform.py | 34 ++++++++++ tests/test_utils/test_utils.py | 104 +++++++++++++++++++++++++++++ 6 files changed, 174 insertions(+), 6 deletions(-) create mode 100644 tests/test_utils/test_utils.py diff --git a/src/anthropic/_compat.py b/src/anthropic/_compat.py index 4d7e4b03..05e71929 100644 --- a/src/anthropic/_compat.py +++ b/src/anthropic/_compat.py @@ -70,8 +70,14 @@ def is_typeddict(type_: type[Any]) -> bool: # noqa: ARG001 from pydantic import ConfigDict as ConfigDict else: if PYDANTIC_V1: - # TODO: provide an error message here? - ConfigDict = None + + class ConfigDict: # type: ignore[no-redef] + def __new__(cls, **kwargs: object) -> "ConfigDict": + raise RuntimeError( + "ConfigDict is not supported in Pydantic v1. " + "Please upgrade to Pydantic v2 to use this feature." + ) + else: from pydantic import ConfigDict as ConfigDict diff --git a/src/anthropic/_utils/_transform.py b/src/anthropic/_utils/_transform.py index 414f38c3..811bfebc 100644 --- a/src/anthropic/_utils/_transform.py +++ b/src/anthropic/_utils/_transform.py @@ -347,7 +347,7 @@ async def _async_transform_recursive( if origin == dict and is_mapping(data): items_type = get_args(stripped_type)[1] - return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()} + return {key: await _async_transform_recursive(value, annotation=items_type) for key, value in data.items()} if ( # List[T] diff --git a/src/anthropic/_utils/_utils.py b/src/anthropic/_utils/_utils.py index eec7f4a1..64419e0d 100644 --- a/src/anthropic/_utils/_utils.py +++ b/src/anthropic/_utils/_utils.py @@ -182,7 +182,8 @@ def deepcopy_minimal(item: _T) -> _T: - mappings, e.g. `dict` - list - This is done for performance reasons. + This is done for performance reasons. Tuples and other immutable containers + are intentionally not copied since they cannot be mutated. """ if is_mapping(item): return cast(_T, {k: deepcopy_minimal(v) for k, v in item.items()}) @@ -272,8 +273,7 @@ def wrapper(*args: object, **kwargs: object) -> object: else: assert len(variants) > 0 - # TODO: this error message is not deterministic - missing = list(set(variants[0]) - given_params) + missing = sorted(set(variants[0]) - given_params) if len(missing) > 1: msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}" else: @@ -412,6 +412,9 @@ def json_safe(data: object) -> object: if is_mapping(data): return {json_safe(key): json_safe(value) for key, value in data.items()} + if isinstance(data, (set, frozenset)): + return [json_safe(item) for item in sorted(data, key=repr)] + if is_iterable(data) and not isinstance(data, (str, bytes, bytearray)): return [json_safe(item) for item in data] diff --git a/tests/test_required_args.py b/tests/test_required_args.py index fc774314..12b9c765 100644 --- a/tests/test_required_args.py +++ b/tests/test_required_args.py @@ -86,6 +86,27 @@ def foo(*, a: str | None = None, b: str | None = None) -> str | None: foo() +def test_missing_args_error_is_deterministic() -> None: + """The error message listing missing arguments must always use the same ordering.""" + + @required_args(["a", "b", "c"]) + def foo(*, a: str | None = None, b: str | None = None, c: str | None = None) -> None: + pass + + messages: list[str] = [] + for _ in range(20): + try: + foo() + except TypeError as exc: + messages.append(str(exc)) + + # All 20 calls must produce the same message. + assert len(set(messages)) == 1, f"Error message is not deterministic: {set(messages)}" + + # Arguments must appear in sorted (alphabetical) order. + assert messages[0] == "Missing required arguments: 'a', 'b' or 'c'" + + def test_multiple_params_multiple_variants() -> None: @required_args(["a", "b"], ["c"]) def foo(*, a: str | None = None, b: str | None = None, c: str | None = None) -> str | None: diff --git a/tests/test_transform.py b/tests/test_transform.py index 4a874ff8..9fa78c40 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -397,6 +397,40 @@ class DictItems(TypedDict): assert await transform({"foo": {"foo_baz": "bar"}}, Dict[str, DictItems], use_async) == {"foo": {"fooBaz": "bar"}} +class _AsyncDictInner(TypedDict): + snake_case: Annotated[str, PropertyInfo(alias="camelCase")] + + +class _AsyncDictOuter(TypedDict): + data: Dict[str, _AsyncDictInner] + + +@parametrize +@pytest.mark.asyncio +async def test_async_dict_values_are_awaited(use_async: bool) -> None: + """Regression test: async transform must await inner dict values, not call the sync version. + + Previously, _async_transform_recursive called the *sync* _transform_recursive when the + outer type was Dict[str, T]. This meant that any async-only work inside T was silently + skipped in the async path. + """ + # Dict[str, _AsyncDictInner] – the async path must transform the values correctly. + result = await transform( + {"key1": {"snake_case": "v1"}, "key2": {"snake_case": "v2"}}, + Dict[str, _AsyncDictInner], + use_async, + ) + assert result == {"key1": {"camelCase": "v1"}, "key2": {"camelCase": "v2"}} + + # Nested TypedDict with Dict[str, _AsyncDictInner] – multiple levels must all be awaited. + result2 = await transform( + {"data": {"a": {"snake_case": "hello"}, "b": {"snake_case": "world"}}}, + _AsyncDictOuter, + use_async, + ) + assert result2 == {"data": {"a": {"camelCase": "hello"}, "b": {"camelCase": "world"}}} + + class TypedDictIterableUnionStr(TypedDict): foo: Annotated[Union[str, Iterable[Baz8]], PropertyInfo(alias="FOO")] diff --git a/tests/test_utils/test_utils.py b/tests/test_utils/test_utils.py new file mode 100644 index 00000000..e222474b --- /dev/null +++ b/tests/test_utils/test_utils.py @@ -0,0 +1,104 @@ +"""Tests for internal utility helpers in anthropic._utils._utils.""" +from __future__ import annotations + +import datetime + +import pytest + +from anthropic._utils._utils import json_safe, deepcopy_minimal + + +class TestJsonSafe: + def test_primitive_passthrough(self) -> None: + assert json_safe(42) == 42 + assert json_safe(3.14) == 3.14 + assert json_safe("hello") == "hello" + assert json_safe(True) is True + assert json_safe(None) is None + + def test_mapping(self) -> None: + result = json_safe({"a": 1, "b": "two"}) + assert result == {"a": 1, "b": "two"} + + def test_list(self) -> None: + result = json_safe([1, 2, 3]) + assert result == [1, 2, 3] + + def test_tuple(self) -> None: + # Tuples are iterable and should be converted to lists. + result = json_safe((1, 2, 3)) + assert result == [1, 2, 3] + + def test_datetime(self) -> None: + dt = datetime.datetime(2024, 6, 15, 12, 30, 0) + assert json_safe(dt) == "2024-06-15T12:30:00" + + def test_date(self) -> None: + d = datetime.date(2024, 6, 15) + assert json_safe(d) == "2024-06-15" + + def test_set_converted_to_sorted_list(self) -> None: + result = json_safe({"c", "a", "b"}) + # Sets are unordered; json_safe must return a stable, sorted list. + assert isinstance(result, list) + assert sorted(result) == ["a", "b", "c"] # type: ignore[arg-type] + + def test_frozenset_converted_to_sorted_list(self) -> None: + result = json_safe(frozenset({"z", "x", "y"})) + assert isinstance(result, list) + assert sorted(result) == ["x", "y", "z"] # type: ignore[arg-type] + + def test_set_is_deterministic(self) -> None: + """json_safe must produce the same output on every call for the same set.""" + s = {"banana", "apple", "cherry"} + results = [json_safe(s) for _ in range(20)] + first = results[0] + for r in results[1:]: + assert r == first, "json_safe(set) is not deterministic" + + def test_nested_set_in_dict(self) -> None: + result = json_safe({"tags": {"b", "a"}}) + assert isinstance(result, dict) + tags = result["tags"] # type: ignore[index] + assert isinstance(tags, list) + assert sorted(tags) == ["a", "b"] + + def test_nested_datetime_in_list(self) -> None: + dt = datetime.datetime(2024, 1, 1) + result = json_safe([dt, "text"]) + assert result == ["2024-01-01T00:00:00", "text"] + + def test_bytes_passthrough(self) -> None: + # bytes are NOT iterable in this context (explicitly excluded). + data = b"raw" + assert json_safe(data) is data + + +class TestDeepcopyMinimal: + def test_dict_is_copied(self) -> None: + original = {"a": 1} + copy = deepcopy_minimal(original) + assert copy == original + assert copy is not original + + def test_nested_dict_deep_copied(self) -> None: + original = {"outer": {"inner": 42}} + copy = deepcopy_minimal(original) + assert copy == original + assert copy["outer"] is not original["outer"] + + def test_list_is_copied(self) -> None: + original = [1, 2, 3] + copy = deepcopy_minimal(original) + assert copy == original + assert copy is not original + + def test_tuple_is_not_copied(self) -> None: + # Tuples are immutable; deepcopy_minimal returns them as-is. + original = ("a", "b") + copy = deepcopy_minimal(original) + assert copy is original + + def test_primitive_passthrough(self) -> None: + for val in (42, "hello", True, None, 3.14): + assert deepcopy_minimal(val) is val # type: ignore[arg-type] From 3af619f01e944404a430a15ca2382782d7be1b21 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 23 Mar 2026 14:59:25 +0000 Subject: [PATCH 2/2] fix+feat(round2): streaming bounds checks, f-string fix, on_retry hook, and more MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bug fixes: - _files.py: fix missing f-prefix on f-string in async_to_httpx_files error message - _streaming.py: narrow broad `except Exception` to `except json.JSONDecodeError` when parsing error body from SSE, preventing accidental swallowing of errors - _qs.py: raise NotImplementedError for unknown nested_format values (resolves existing TODO comment) Streaming improvements: - lib/streaming/_messages.py: add index bounds validation for content_block_start and content_block_delta events; align input_json_delta error handling with beta - lib/streaming/_beta_messages.py: add same index bounds validation New feature — on_retry callback: - _base_client.py: add `RetryInfo` dataclass (attempt, max_retries, url, wait_seconds, response, error) and `on_retry` callback parameter to BaseClient, SyncAPIClient, and AsyncAPIClient; callback is invoked before each retry sleep - _client.py: thread `on_retry` through Anthropic and AsyncAnthropic constructors and copy() / with_options() methods - __init__.py: export RetryInfo in public API Tests: - tests/test_qs.py: add test_unknown_nested_format https://claude.ai/code/session_01CqXN7tjqqv1rCwTMkEmbFb --- src/anthropic/__init__.py | 3 +- src/anthropic/_base_client.py | 75 ++++++++++++++++++- src/anthropic/_client.py | 14 +++- src/anthropic/_files.py | 2 +- src/anthropic/_qs.py | 16 ++-- src/anthropic/_streaming.py | 4 +- src/anthropic/lib/streaming/_beta_messages.py | 13 +++- src/anthropic/lib/streaming/_messages.py | 22 +++++- tests/test_qs.py | 5 ++ 9 files changed, 137 insertions(+), 17 deletions(-) diff --git a/src/anthropic/__init__.py b/src/anthropic/__init__.py index a76d1dae..4eaf0888 100644 --- a/src/anthropic/__init__.py +++ b/src/anthropic/__init__.py @@ -42,7 +42,7 @@ UnprocessableEntityError, APIResponseValidationError, ) -from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient +from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient, RetryInfo from ._utils._logs import setup_logging as _setup_logging from .lib._parse._transform import transform_schema @@ -88,6 +88,7 @@ "DefaultHttpxClient", "DefaultAsyncHttpxClient", "DefaultAioHttpClient", + "RetryInfo", "HUMAN_PROMPT", "AI_PROMPT", "beta_tool", diff --git a/src/anthropic/_base_client.py b/src/anthropic/_base_client.py index 591344ae..272e9692 100644 --- a/src/anthropic/_base_client.py +++ b/src/anthropic/_base_client.py @@ -14,6 +14,7 @@ import email.utils from types import TracebackType from random import random +from dataclasses import dataclass from typing import ( TYPE_CHECKING, Any, @@ -365,6 +366,30 @@ async def get_next_page(self: AsyncPageT) -> AsyncPageT: _DefaultStreamT = TypeVar("_DefaultStreamT", bound=Union[Stream[Any], AsyncStream[Any]]) +@dataclass +class RetryInfo: + """Information about a retry attempt, passed to the ``on_retry`` callback. + + Attributes: + attempt: The 1-based index of the retry that is *about to happen* + (e.g. 1 for the first retry, 2 for the second, …). + max_retries: The maximum number of retries configured for this request. + url: The URL that is being retried. + wait_seconds: How many seconds the SDK will sleep before issuing the + next request. + response: The HTTP response that triggered the retry, or ``None`` if the + retry is caused by a network-level error (timeout / connection error). + error: The exception that caused the retry (always present). + """ + + attempt: int + max_retries: int + url: str + wait_seconds: float + response: httpx.Response | None + error: BaseException + + class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]): _client: _HttpxClientT _version: str @@ -385,6 +410,7 @@ def __init__( timeout: float | Timeout | None = DEFAULT_TIMEOUT, custom_headers: Mapping[str, str] | None = None, custom_query: Mapping[str, object] | None = None, + on_retry: Optional[Callable[[RetryInfo], None]] = None, ) -> None: self._version = version self._base_url = self._enforce_trailing_slash(URL(base_url)) @@ -395,6 +421,7 @@ def __init__( self._strict_response_validation = _strict_response_validation self._idempotency_header = None self._platform: Platform | None = None + self._on_retry = on_retry if max_retries is None: # pyright: ignore[reportUnnecessaryComparison] raise TypeError( @@ -922,6 +949,7 @@ def __init__( custom_headers: Mapping[str, str] | None = None, custom_query: Mapping[str, object] | None = None, _strict_response_validation: bool, + on_retry: Optional[Callable[[RetryInfo], None]] = None, ) -> None: if not is_given(timeout): # if the user passed in a custom http client with a non-default @@ -950,6 +978,7 @@ def __init__( custom_query=custom_query, custom_headers=custom_headers, _strict_response_validation=_strict_response_validation, + on_retry=on_retry, ) self._client = http_client or SyncHttpxClientWrapper( base_url=base_url, @@ -1083,6 +1112,7 @@ def request( max_retries=max_retries, options=input_options, response=None, + error=err, ) continue @@ -1097,6 +1127,7 @@ def request( max_retries=max_retries, options=input_options, response=None, + error=err, ) continue @@ -1125,6 +1156,7 @@ def request( max_retries=max_retries, options=input_options, response=response, + error=err, ) continue @@ -1149,7 +1181,13 @@ def request( ) def _sleep_for_retry( - self, *, retries_taken: int, max_retries: int, options: FinalRequestOptions, response: httpx.Response | None + self, + *, + retries_taken: int, + max_retries: int, + options: FinalRequestOptions, + response: httpx.Response | None, + error: BaseException | None = None, ) -> None: remaining_retries = max_retries - retries_taken if remaining_retries == 1: @@ -1160,6 +1198,17 @@ def _sleep_for_retry( timeout = self._calculate_retry_timeout(remaining_retries, options, response.headers if response else None) log.info("Retrying request to %s in %f seconds", options.url, timeout) + if self._on_retry is not None: + info = RetryInfo( + attempt=retries_taken + 1, + max_retries=max_retries, + url=options.url, + wait_seconds=timeout, + response=response, + error=error or Exception("unknown"), + ) + self._on_retry(info) + time.sleep(timeout) def _process_response( @@ -1560,6 +1609,7 @@ def __init__( http_client: httpx.AsyncClient | None = None, custom_headers: Mapping[str, str] | None = None, custom_query: Mapping[str, object] | None = None, + on_retry: Optional[Callable[[RetryInfo], None]] = None, ) -> None: if not is_given(timeout): # if the user passed in a custom http client with a non-default @@ -1588,6 +1638,7 @@ def __init__( custom_query=custom_query, custom_headers=custom_headers, _strict_response_validation=_strict_response_validation, + on_retry=on_retry, ) self._client = http_client or AsyncHttpxClientWrapper( base_url=base_url, @@ -1723,6 +1774,7 @@ async def request( max_retries=max_retries, options=input_options, response=None, + error=err, ) continue @@ -1737,6 +1789,7 @@ async def request( max_retries=max_retries, options=input_options, response=None, + error=err, ) continue @@ -1765,6 +1818,7 @@ async def request( max_retries=max_retries, options=input_options, response=response, + error=err, ) continue @@ -1789,7 +1843,13 @@ async def request( ) async def _sleep_for_retry( - self, *, retries_taken: int, max_retries: int, options: FinalRequestOptions, response: httpx.Response | None + self, + *, + retries_taken: int, + max_retries: int, + options: FinalRequestOptions, + response: httpx.Response | None, + error: BaseException | None = None, ) -> None: remaining_retries = max_retries - retries_taken if remaining_retries == 1: @@ -1800,6 +1860,17 @@ async def _sleep_for_retry( timeout = self._calculate_retry_timeout(remaining_retries, options, response.headers if response else None) log.info("Retrying request to %s in %f seconds", options.url, timeout) + if self._on_retry is not None: + info = RetryInfo( + attempt=retries_taken + 1, + max_retries=max_retries, + url=options.url, + wait_seconds=timeout, + response=response, + error=error or Exception("unknown"), + ) + self._on_retry(info) + await anyio.sleep(timeout) async def _process_response( diff --git a/src/anthropic/_client.py b/src/anthropic/_client.py index accf652d..6fb0ca80 100644 --- a/src/anthropic/_client.py +++ b/src/anthropic/_client.py @@ -3,7 +3,7 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING, Any, Mapping +from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional from typing_extensions import Self, override import httpx @@ -27,6 +27,7 @@ from ._exceptions import APIStatusError from ._base_client import ( DEFAULT_MAX_RETRIES, + RetryInfo, SyncAPIClient, AsyncAPIClient, ) @@ -82,6 +83,10 @@ def __init__( # outlining your use-case to help us decide if it should be # part of our public interface in the future. _strict_response_validation: bool = False, + # Optional callback invoked before each retry sleep. Receives a + # :class:`RetryInfo` instance describing the retry that is about to + # happen. Useful for logging, metrics, or custom alerting. + on_retry: Optional[Callable[[RetryInfo], None]] = None, ) -> None: """Construct a new synchronous Anthropic client instance. @@ -111,6 +116,7 @@ def __init__( custom_headers=default_headers, custom_query=default_query, _strict_response_validation=_strict_response_validation, + on_retry=on_retry, ) self._default_stream_cls = Stream @@ -210,6 +216,7 @@ def copy( set_default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, set_default_query: Mapping[str, object] | None = None, + on_retry: Optional[Callable[[RetryInfo], None]] | NotGiven = not_given, _extra_kwargs: Mapping[str, Any] = {}, ) -> Self: """ @@ -243,6 +250,7 @@ def copy( max_retries=max_retries if is_given(max_retries) else self.max_retries, default_headers=headers, default_query=params, + on_retry=self._on_retry if isinstance(on_retry, NotGiven) else on_retry, **_extra_kwargs, ) @@ -322,6 +330,7 @@ def __init__( # outlining your use-case to help us decide if it should be # part of our public interface in the future. _strict_response_validation: bool = False, + on_retry: Optional[Callable[[RetryInfo], None]] = None, ) -> None: """Construct a new async AsyncAnthropic client instance. @@ -351,6 +360,7 @@ def __init__( custom_headers=default_headers, custom_query=default_query, _strict_response_validation=_strict_response_validation, + on_retry=on_retry, ) self._default_stream_cls = AsyncStream @@ -450,6 +460,7 @@ def copy( set_default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, set_default_query: Mapping[str, object] | None = None, + on_retry: Optional[Callable[[RetryInfo], None]] | NotGiven = not_given, _extra_kwargs: Mapping[str, Any] = {}, ) -> Self: """ @@ -483,6 +494,7 @@ def copy( max_retries=max_retries if is_given(max_retries) else self.max_retries, default_headers=headers, default_query=params, + on_retry=self._on_retry if isinstance(on_retry, NotGiven) else on_retry, **_extra_kwargs, ) diff --git a/src/anthropic/_files.py b/src/anthropic/_files.py index f2a6f94e..cb070b8f 100644 --- a/src/anthropic/_files.py +++ b/src/anthropic/_files.py @@ -97,7 +97,7 @@ async def async_to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles elif is_sequence_t(files): files = [(key, await _async_transform_file(file)) for key, file in files] else: - raise TypeError("Unexpected file type input {type(files)}, expected mapping or sequence") + raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence") return files diff --git a/src/anthropic/_qs.py b/src/anthropic/_qs.py index ada6fd3f..e94add99 100644 --- a/src/anthropic/_qs.py +++ b/src/anthropic/_qs.py @@ -76,14 +76,16 @@ def _stringify_item( items: list[tuple[str, str]] = [] nested_format = opts.nested_format for subkey, subvalue in value.items(): - items.extend( - self._stringify_item( - # TODO: error if unknown format - f"{key}.{subkey}" if nested_format == "dots" else f"{key}[{subkey}]", - subvalue, - opts, + if nested_format == "dots": + nested_key = f"{key}.{subkey}" + elif nested_format == "brackets": + nested_key = f"{key}[{subkey}]" + else: + raise NotImplementedError( + f"Unknown nested_format value: {nested_format!r}, " + f"choose from {', '.join(get_args(NestedFormat))}" ) - ) + items.extend(self._stringify_item(nested_key, subvalue, opts)) return items if isinstance(value, (list, tuple)): diff --git a/src/anthropic/_streaming.py b/src/anthropic/_streaming.py index fc6943ee..53a67f85 100644 --- a/src/anthropic/_streaming.py +++ b/src/anthropic/_streaming.py @@ -108,7 +108,7 @@ def __stream__(self) -> Iterator[_T]: try: body = sse.json() err_msg = f"{body}" - except Exception: + except json.JSONDecodeError: err_msg = sse.data or f"Error code: {response.status_code}" raise self._client._make_status_error( @@ -228,7 +228,7 @@ async def __stream__(self) -> AsyncIterator[_T]: try: body = sse.json() err_msg = f"{body}" - except Exception: + except json.JSONDecodeError: err_msg = sse.data or f"Error code: {response.status_code}" raise self._client._make_status_error( diff --git a/src/anthropic/lib/streaming/_beta_messages.py b/src/anthropic/lib/streaming/_beta_messages.py index c1447a8d..0eb7c0f1 100644 --- a/src/anthropic/lib/streaming/_beta_messages.py +++ b/src/anthropic/lib/streaming/_beta_messages.py @@ -474,7 +474,12 @@ def accumulate_event( raise RuntimeError(f'Unexpected event order, got {event.type} before "message_start"') if event.type == "content_block_start": - # TODO: check index + expected_index = len(current_snapshot.content) + if event.index != expected_index: + raise RuntimeError( + f"Unexpected content_block_start index {event.index!r}, " + f"expected {expected_index!r}" + ) current_snapshot.content.append( cast( Any, # Pydantic does not support generic unions at runtime @@ -482,6 +487,12 @@ def accumulate_event( ), ) elif event.type == "content_block_delta": + num_blocks = len(current_snapshot.content) + if event.index >= num_blocks: + raise RuntimeError( + f"Unexpected content_block_delta index {event.index!r}: " + f"only {num_blocks} content block(s) have been started" + ) content = current_snapshot.content[event.index] if event.delta.type == "text_delta": if content.type == "text": diff --git a/src/anthropic/lib/streaming/_messages.py b/src/anthropic/lib/streaming/_messages.py index b6b5f538..32308544 100644 --- a/src/anthropic/lib/streaming/_messages.py +++ b/src/anthropic/lib/streaming/_messages.py @@ -454,7 +454,12 @@ def accumulate_event( raise RuntimeError(f'Unexpected event order, got {event.type} before "message_start"') if event.type == "content_block_start": - # TODO: check index + expected_index = len(current_snapshot.content) + if event.index != expected_index: + raise RuntimeError( + f"Unexpected content_block_start index {event.index!r}, " + f"expected {expected_index!r}" + ) current_snapshot.content.append( cast( Any, # Pydantic does not support generic unions at runtime @@ -462,6 +467,12 @@ def accumulate_event( ), ) elif event.type == "content_block_delta": + num_blocks = len(current_snapshot.content) + if event.index >= num_blocks: + raise RuntimeError( + f"Unexpected content_block_delta index {event.index!r}: " + f"only {num_blocks} content block(s) have been started" + ) content = current_snapshot.content[event.index] if event.delta.type == "text_delta": if content.type == "text": @@ -477,7 +488,14 @@ def accumulate_event( json_buf += bytes(event.delta.partial_json, "utf-8") if json_buf: - content.input = from_json(json_buf, partial_mode=True) + try: + content.input = from_json(json_buf, partial_mode=True) + except ValueError as e: + raise ValueError( + f"Unable to parse tool parameter JSON from model. " + f"Please retry your request or adjust your prompt. " + f"Error: {e}. JSON: {json_buf.decode('utf-8')}" + ) from e setattr(content, JSON_BUF_PROPERTY, json_buf) elif event.delta.type == "citations_delta": diff --git a/tests/test_qs.py b/tests/test_qs.py index 564a41e6..092397d6 100644 --- a/tests/test_qs.py +++ b/tests/test_qs.py @@ -76,3 +76,8 @@ def test_array_brackets(method: str) -> None: def test_unknown_array_format() -> None: with pytest.raises(NotImplementedError, match="Unknown array_format value: foo, choose from comma, repeat"): stringify({"a": ["foo", "bar"]}, array_format=cast(Any, "foo")) + + +def test_unknown_nested_format() -> None: + with pytest.raises(NotImplementedError, match="Unknown nested_format value: 'semicolon', choose from dots, brackets"): + stringify({"a": {"b": "c"}}, nested_format=cast(Any, "semicolon"))