diff --git a/src/anthropic/_utils/_transform.py b/src/anthropic/_utils/_transform.py index 414f38c3..b9dc2b13 100644 --- a/src/anthropic/_utils/_transform.py +++ b/src/anthropic/_utils/_transform.py @@ -148,7 +148,63 @@ def _maybe_transform_key(key: str, type_: type) -> str: def _no_transform_needed(annotation: type) -> bool: - return annotation == float or annotation == int + return annotation == float or annotation == int or annotation == str or annotation == bool + + +# Dispatch kinds for _cached_transform_dispatch +_DISPATCH_TYPEDDICT = 0 +_DISPATCH_DICT = 1 +_DISPATCH_SEQUENCE = 2 +_DISPATCH_UNION = 3 +_DISPATCH_OTHER = 4 + + +@lru_cache(maxsize=8096) +def _cached_transform_dispatch(inner_type: type) -> tuple[int, Any]: + """Precompute the transform dispatch for a type annotation. + + Caches the result of strip_annotated_type, get_origin, is_typeddict, + is_list_type, is_iterable_type, is_sequence_type, and is_union_type so + that repeated calls with the same type are O(1) dict lookups instead of + re-running type introspection. + + Returns (dispatch_kind, type_arg) where type_arg depends on the kind. + """ + stripped = strip_annotated_type(inner_type) + origin = get_origin(stripped) or stripped + + if is_typeddict(stripped): + return (_DISPATCH_TYPEDDICT, stripped) + + if origin == dict: + args = get_args(stripped) + return (_DISPATCH_DICT, args[1] if len(args) > 1 else None) + + if is_list_type(stripped) or is_iterable_type(stripped) or is_sequence_type(stripped): + item_type = extract_type_arg(stripped, 0) + return (_DISPATCH_SEQUENCE, (item_type, _no_transform_needed(item_type))) + + if is_union_type(stripped): + return (_DISPATCH_UNION, get_args(stripped)) + + return (_DISPATCH_OTHER, None) + + +@lru_cache(maxsize=8096) +def _get_field_key_map(expected_type: type) -> dict[str, str]: + """Precompute the key mapping for a TypedDict type. + + Returns a dict mapping original field names to their transformed keys + (applying PropertyInfo alias annotations). Only includes entries where + the key actually changes, so a missing key means "use the original name". + """ + annotations = get_type_hints(expected_type, include_extras=True) + key_map: dict[str, str] = {} + for key, type_ in annotations.items(): + transformed = _maybe_transform_key(key, type_) + if transformed != key: + key_map[key] = transformed + return key_map def _transform_recursive( @@ -174,46 +230,36 @@ def _transform_recursive( if inner_type is None: inner_type = annotation - stripped_type = strip_annotated_type(inner_type) - origin = get_origin(stripped_type) or stripped_type - if is_typeddict(stripped_type) and is_mapping(data): - return _transform_typeddict(data, stripped_type) - - if origin == dict and is_mapping(data): - items_type = get_args(stripped_type)[1] - return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()} - - if ( - # List[T] - (is_list_type(stripped_type) and is_list(data)) - # Iterable[T] - or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) - # Sequence[T] - or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str)) - ): - # dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually - # intended as an iterable, so we don't transform it. - if isinstance(data, dict): - return cast(object, data) - - inner_type = extract_type_arg(stripped_type, 0) - if _no_transform_needed(inner_type): - # for some types there is no need to transform anything, so we can get a small - # perf boost from skipping that work. - # - # but we still need to convert to a list to ensure the data is json-serializable - if is_list(data): - return data - return list(data) - - return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] - - if is_union_type(stripped_type): + dispatch, type_arg = _cached_transform_dispatch(inner_type) + + if dispatch == _DISPATCH_TYPEDDICT: + if is_mapping(data): + return _transform_typeddict(data, type_arg) + + elif dispatch == _DISPATCH_DICT: + if is_mapping(data) and type_arg is not None: + return {key: _transform_recursive(value, annotation=type_arg) for key, value in data.items()} + + elif dispatch == _DISPATCH_SEQUENCE: + if not isinstance(data, (str, dict)): + if is_list(data) or is_iterable(data) or is_sequence(data): + item_type, skip_items = type_arg + if skip_items: + # for some types there is no need to transform anything, so we can get a small + # perf boost from skipping that work. + # + # but we still need to convert to a list to ensure the data is json-serializable + if is_list(data): + return data + return list(data) + return [_transform_recursive(d, annotation=annotation, inner_type=item_type) for d in data] + + elif dispatch == _DISPATCH_UNION: # For union types we run the transformation against all subtypes to ensure that everything is transformed. # # TODO: there may be edge cases where the same normalized field name will transform to two different names # in different subtypes. - for subtype in get_args(stripped_type): + for subtype in type_arg: data = _transform_recursive(data, annotation=annotation, inner_type=subtype) return data @@ -266,6 +312,7 @@ def _transform_typeddict( ) -> Mapping[str, object]: result: dict[str, object] = {} annotations = get_type_hints(expected_type, include_extras=True) + key_map = _get_field_key_map(expected_type) for key, value in data.items(): if not is_given(value): # we don't need to include omitted values here as they'll @@ -277,7 +324,7 @@ def _transform_typeddict( # we do not have a type annotation for this field, leave it as is result[key] = value else: - result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_) + result[key_map.get(key, key)] = _transform_recursive(value, annotation=type_) return result @@ -340,46 +387,28 @@ async def _async_transform_recursive( if inner_type is None: inner_type = annotation - stripped_type = strip_annotated_type(inner_type) - origin = get_origin(stripped_type) or stripped_type - if is_typeddict(stripped_type) and is_mapping(data): - return await _async_transform_typeddict(data, stripped_type) - - if origin == dict and is_mapping(data): - items_type = get_args(stripped_type)[1] - return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()} - - if ( - # List[T] - (is_list_type(stripped_type) and is_list(data)) - # Iterable[T] - or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) - # Sequence[T] - or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str)) - ): - # dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually - # intended as an iterable, so we don't transform it. - if isinstance(data, dict): - return cast(object, data) - - inner_type = extract_type_arg(stripped_type, 0) - if _no_transform_needed(inner_type): - # for some types there is no need to transform anything, so we can get a small - # perf boost from skipping that work. - # - # but we still need to convert to a list to ensure the data is json-serializable - if is_list(data): - return data - return list(data) - - return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] - - if is_union_type(stripped_type): - # For union types we run the transformation against all subtypes to ensure that everything is transformed. - # - # TODO: there may be edge cases where the same normalized field name will transform to two different names - # in different subtypes. - for subtype in get_args(stripped_type): + dispatch, type_arg = _cached_transform_dispatch(inner_type) + + if dispatch == _DISPATCH_TYPEDDICT: + if is_mapping(data): + return await _async_transform_typeddict(data, type_arg) + + elif dispatch == _DISPATCH_DICT: + if is_mapping(data) and type_arg is not None: + return {key: await _async_transform_recursive(value, annotation=type_arg) for key, value in data.items()} + + elif dispatch == _DISPATCH_SEQUENCE: + if not isinstance(data, (str, dict)): + if is_list(data) or is_iterable(data) or is_sequence(data): + item_type, skip_items = type_arg + if skip_items: + if is_list(data): + return data + return list(data) + return [await _async_transform_recursive(d, annotation=annotation, inner_type=item_type) for d in data] + + elif dispatch == _DISPATCH_UNION: + for subtype in type_arg: data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype) return data @@ -432,6 +461,7 @@ async def _async_transform_typeddict( ) -> Mapping[str, object]: result: dict[str, object] = {} annotations = get_type_hints(expected_type, include_extras=True) + key_map = _get_field_key_map(expected_type) for key, value in data.items(): if not is_given(value): # we don't need to include omitted values here as they'll @@ -443,7 +473,7 @@ async def _async_transform_typeddict( # we do not have a type annotation for this field, leave it as is result[key] = value else: - result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_) + result[key_map.get(key, key)] = await _async_transform_recursive(value, annotation=type_) return result diff --git a/tests/test_transform.py b/tests/test_transform.py index 4a874ff8..d13f1e11 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -1,6 +1,7 @@ from __future__ import annotations import io +import time import pathlib from typing import Any, Dict, List, Union, TypeVar, Iterable, Optional, cast from datetime import date, datetime @@ -458,3 +459,153 @@ async def test_strips_notgiven(use_async: bool) -> None: async def test_strips_omit(use_async: bool) -> None: assert await transform({"foo_bar": "bar"}, Foo1, use_async) == {"fooBar": "bar"} assert await transform({"foo_bar": omit}, Foo1, use_async) == {} + + +# --- Dispatch cache and performance tests --- + + +class NoAnnotationDict(TypedDict): + """TypedDict with no PropertyInfo annotations — the common case for Messages API types.""" + + role: str + content: str + + +class NestedNoAnnotation(TypedDict): + messages: List[NoAnnotationDict] + model: str + + +class MixedAnnotations(TypedDict, total=False): + name: Annotated[str, PropertyInfo(alias="fullName")] + age: int + tags: List[str] + + +@parametrize +@pytest.mark.asyncio +async def test_no_annotation_typeddict_passthrough(use_async: bool) -> None: + """TypedDicts with no PropertyInfo should still transform correctly (no key renames).""" + result = await transform({"role": "user", "content": "hello"}, NoAnnotationDict, use_async) + assert result == {"role": "user", "content": "hello"} + + +@parametrize +@pytest.mark.asyncio +async def test_nested_no_annotation(use_async: bool) -> None: + """Nested TypedDicts with no PropertyInfo should preserve structure.""" + data = { + "messages": [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ], + "model": "claude-3", + } + result = await transform(data, NestedNoAnnotation, use_async) + assert result == data + + +@parametrize +@pytest.mark.asyncio +async def test_mixed_annotations_still_work(use_async: bool) -> None: + """Types with PropertyInfo annotations continue to work correctly with caching.""" + result = await transform({"name": "Alice", "age": 30, "tags": ["a", "b"]}, MixedAnnotations, use_async) + assert result == {"fullName": "Alice", "age": 30, "tags": ["a", "b"]} + + +@parametrize +@pytest.mark.asyncio +async def test_str_list_skip_optimization(use_async: bool) -> None: + """Lists of str should be returned as-is (expanded _no_transform_needed).""" + data = ["a", "b", "c"] + result = await transform(data, List[str], use_async) + assert result is data + + +@parametrize +@pytest.mark.asyncio +async def test_bool_list_skip_optimization(use_async: bool) -> None: + """Lists of bool should be returned as-is (expanded _no_transform_needed).""" + data = [True, False, True] + result = await transform(data, List[bool], use_async) + assert result is data + + +@parametrize +@pytest.mark.asyncio +async def test_cached_dispatch_consistency(use_async: bool) -> None: + """Verify that repeated transforms with the same type produce identical results.""" + data = {"name": "Bob", "age": 25} + for _ in range(100): + result = await transform(data, MixedAnnotations, use_async) + assert result == {"fullName": "Bob", "age": 25} + + +@pytest.mark.asyncio +async def test_large_message_list_performance() -> None: + """Verify that transforming large message lists completes in reasonable time. + + This simulates the real-world scenario from issue #1195 where large + conversation histories cause excessive CPU usage in _transform_recursive. + Uses a relative comparison instead of a wall-clock threshold to avoid + flaky failures in slow CI environments. + """ + messages = [{"role": "user" if i % 2 == 0 else "assistant", "content": f"Message {i}"} for i in range(10000)] + data = {"messages": messages, "model": "claude-3"} + + # Measure sync transform + start = time.monotonic() + result = _transform(data, NestedNoAnnotation) + sync_elapsed = time.monotonic() - start + + assert len(result["messages"]) == 10000 + assert result["messages"][0] == {"role": "user", "content": "Message 0"} + + # Measure async transform + start = time.monotonic() + async_result = await _async_transform(data, NestedNoAnnotation) + async_elapsed = time.monotonic() - start + + assert async_result == result + + # Relative check: async should not be more than 10x slower than sync. + # This avoids brittle wall-clock thresholds while still catching regressions. + assert async_elapsed < sync_elapsed * 10, ( + f"Async transform ({async_elapsed:.3f}s) was >10x slower than " + f"sync ({sync_elapsed:.3f}s)" + ) + + +@pytest.mark.asyncio +async def test_dispatch_cache_hit() -> None: + """Verify the dispatch cache is populated after first use.""" + from anthropic._utils._transform import _cached_transform_dispatch + + # Clear cache to get a clean baseline + _cached_transform_dispatch.cache_clear() + + # First call populates the cache + _transform({"role": "user", "content": "hi"}, NoAnnotationDict) + info = _cached_transform_dispatch.cache_info() + assert info.misses > 0, "Expected cache misses on first call" + + # Second call should hit the cache + misses_before = info.misses + _transform({"role": "user", "content": "hello"}, NoAnnotationDict) + info2 = _cached_transform_dispatch.cache_info() + assert info2.hits > 0, "Expected cache hits on second call" + assert info2.misses == misses_before, "Expected no new cache misses on second call" + + +@pytest.mark.asyncio +async def test_field_key_map_cache() -> None: + """Verify that _get_field_key_map caches key transformations.""" + from anthropic._utils._transform import _get_field_key_map + + # No aliases + key_map = _get_field_key_map(NoAnnotationDict) + assert key_map == {}, "Expected empty key map for TypedDict with no aliases" + + # With alias + key_map = _get_field_key_map(MixedAnnotations) + assert key_map == {"name": "fullName"}, "Expected alias mapping for 'name' → 'fullName'"