Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 117 additions & 85 deletions src/anthropic/_utils/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,63 @@ def _maybe_transform_key(key: str, type_: type) -> str:


def _no_transform_needed(annotation: type) -> bool:
return annotation == float or annotation == int
return annotation == float or annotation == int or annotation == str or annotation == bool


# Dispatch kinds for _cached_transform_dispatch
_DISPATCH_TYPEDDICT = 0
_DISPATCH_DICT = 1
_DISPATCH_SEQUENCE = 2
_DISPATCH_UNION = 3
_DISPATCH_OTHER = 4


@lru_cache(maxsize=8096)
def _cached_transform_dispatch(inner_type: type) -> tuple[int, Any]:
"""Precompute the transform dispatch for a type annotation.

Caches the result of strip_annotated_type, get_origin, is_typeddict,
is_list_type, is_iterable_type, is_sequence_type, and is_union_type so
that repeated calls with the same type are O(1) dict lookups instead of
re-running type introspection.

Returns (dispatch_kind, type_arg) where type_arg depends on the kind.
"""
stripped = strip_annotated_type(inner_type)
origin = get_origin(stripped) or stripped

if is_typeddict(stripped):
return (_DISPATCH_TYPEDDICT, stripped)

if origin == dict:
args = get_args(stripped)
return (_DISPATCH_DICT, args[1] if len(args) > 1 else None)

if is_list_type(stripped) or is_iterable_type(stripped) or is_sequence_type(stripped):
item_type = extract_type_arg(stripped, 0)
return (_DISPATCH_SEQUENCE, (item_type, _no_transform_needed(item_type)))

if is_union_type(stripped):
return (_DISPATCH_UNION, get_args(stripped))

return (_DISPATCH_OTHER, None)


@lru_cache(maxsize=8096)
def _get_field_key_map(expected_type: type) -> dict[str, str]:
"""Precompute the key mapping for a TypedDict type.

Returns a dict mapping original field names to their transformed keys
(applying PropertyInfo alias annotations). Only includes entries where
the key actually changes, so a missing key means "use the original name".
"""
annotations = get_type_hints(expected_type, include_extras=True)
key_map: dict[str, str] = {}
for key, type_ in annotations.items():
transformed = _maybe_transform_key(key, type_)
if transformed != key:
key_map[key] = transformed
return key_map


def _transform_recursive(
Expand All @@ -174,63 +230,54 @@ def _transform_recursive(
if inner_type is None:
inner_type = annotation

stripped_type = strip_annotated_type(inner_type)
origin = get_origin(stripped_type) or stripped_type
if is_typeddict(stripped_type) and is_mapping(data):
return _transform_typeddict(data, stripped_type)

if origin == dict and is_mapping(data):
items_type = get_args(stripped_type)[1]
return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}

if (
# List[T]
(is_list_type(stripped_type) and is_list(data))
# Iterable[T]
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
# Sequence[T]
or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str))
):
# dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
# intended as an iterable, so we don't transform it.
if isinstance(data, dict):
return cast(object, data)

inner_type = extract_type_arg(stripped_type, 0)
if _no_transform_needed(inner_type):
# for some types there is no need to transform anything, so we can get a small
# perf boost from skipping that work.
#
# but we still need to convert to a list to ensure the data is json-serializable
if is_list(data):
return data
return list(data)

return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]

if is_union_type(stripped_type):
dispatch, type_arg = _cached_transform_dispatch(inner_type)

if dispatch == _DISPATCH_TYPEDDICT:
if is_mapping(data):
return _transform_typeddict(data, type_arg)

elif dispatch == _DISPATCH_DICT:
if is_mapping(data) and type_arg is not None:
return {key: _transform_recursive(value, annotation=type_arg) for key, value in data.items()}

elif dispatch == _DISPATCH_SEQUENCE:
if not isinstance(data, (str, dict)):
if is_list(data) or is_iterable(data) or is_sequence(data):
item_type, skip_items = type_arg
if skip_items:
# for some types there is no need to transform anything, so we can get a small
# perf boost from skipping that work.
#
# but we still need to convert to a list to ensure the data is json-serializable
if is_list(data):
return data
return list(data)
return [_transform_recursive(d, annotation=annotation, inner_type=item_type) for d in data]

elif dispatch == _DISPATCH_UNION:
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
#
# TODO: there may be edge cases where the same normalized field name will transform to two different names
# in different subtypes.
for subtype in get_args(stripped_type):
for subtype in type_arg:
data = _transform_recursive(data, annotation=annotation, inner_type=subtype)
return data

if isinstance(data, pydantic.BaseModel):
return model_dump(data, exclude_unset=True, mode="json", exclude=getattr(data, "__api_exclude__", None))

unchanged = cast(object, data)
annotated_type = _get_annotated_type(annotation)
if annotated_type is None:
return data
return unchanged

# ignore the first argument as it is the actual type
annotations = get_args(annotated_type)[1:]
for annotation in annotations:
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
return _format_data(data, annotation.format, annotation.format_template)
return _format_data(unchanged, annotation.format, annotation.format_template)

return data
return unchanged


def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
Expand Down Expand Up @@ -266,6 +313,7 @@ def _transform_typeddict(
) -> Mapping[str, object]:
result: dict[str, object] = {}
annotations = get_type_hints(expected_type, include_extras=True)
key_map = _get_field_key_map(expected_type)
for key, value in data.items():
if not is_given(value):
# we don't need to include omitted values here as they'll
Expand All @@ -277,7 +325,7 @@ def _transform_typeddict(
# we do not have a type annotation for this field, leave it as is
result[key] = value
else:
result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_)
result[key_map.get(key, key)] = _transform_recursive(value, annotation=type_)
return result


Expand Down Expand Up @@ -340,63 +388,46 @@ async def _async_transform_recursive(
if inner_type is None:
inner_type = annotation

stripped_type = strip_annotated_type(inner_type)
origin = get_origin(stripped_type) or stripped_type
if is_typeddict(stripped_type) and is_mapping(data):
return await _async_transform_typeddict(data, stripped_type)

if origin == dict and is_mapping(data):
items_type = get_args(stripped_type)[1]
return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}

if (
# List[T]
(is_list_type(stripped_type) and is_list(data))
# Iterable[T]
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
# Sequence[T]
or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str))
):
# dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
# intended as an iterable, so we don't transform it.
if isinstance(data, dict):
return cast(object, data)

inner_type = extract_type_arg(stripped_type, 0)
if _no_transform_needed(inner_type):
# for some types there is no need to transform anything, so we can get a small
# perf boost from skipping that work.
#
# but we still need to convert to a list to ensure the data is json-serializable
if is_list(data):
return data
return list(data)

return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]

if is_union_type(stripped_type):
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
#
# TODO: there may be edge cases where the same normalized field name will transform to two different names
# in different subtypes.
for subtype in get_args(stripped_type):
dispatch, type_arg = _cached_transform_dispatch(inner_type)

if dispatch == _DISPATCH_TYPEDDICT:
if is_mapping(data):
return await _async_transform_typeddict(data, type_arg)

elif dispatch == _DISPATCH_DICT:
if is_mapping(data) and type_arg is not None:
return {key: await _async_transform_recursive(value, annotation=type_arg) for key, value in data.items()}

elif dispatch == _DISPATCH_SEQUENCE:
if not isinstance(data, (str, dict)):
if is_list(data) or is_iterable(data) or is_sequence(data):
item_type, skip_items = type_arg
if skip_items:
if is_list(data):
return data
return list(data)
return [await _async_transform_recursive(d, annotation=annotation, inner_type=item_type) for d in data]

elif dispatch == _DISPATCH_UNION:
for subtype in type_arg:
data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype)
return data

if isinstance(data, pydantic.BaseModel):
return model_dump(data, exclude_unset=True, mode="json")

unchanged = cast(object, data)
annotated_type = _get_annotated_type(annotation)
if annotated_type is None:
return data
return unchanged

# ignore the first argument as it is the actual type
annotations = get_args(annotated_type)[1:]
for annotation in annotations:
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
return await _async_format_data(data, annotation.format, annotation.format_template)
return await _async_format_data(unchanged, annotation.format, annotation.format_template)

return data
return unchanged


async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
Expand Down Expand Up @@ -432,6 +463,7 @@ async def _async_transform_typeddict(
) -> Mapping[str, object]:
result: dict[str, object] = {}
annotations = get_type_hints(expected_type, include_extras=True)
key_map = _get_field_key_map(expected_type)
for key, value in data.items():
if not is_given(value):
# we don't need to include omitted values here as they'll
Expand All @@ -443,7 +475,7 @@ async def _async_transform_typeddict(
# we do not have a type annotation for this field, leave it as is
result[key] = value
else:
result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
result[key_map.get(key, key)] = await _async_transform_recursive(value, annotation=type_)
return result


Expand Down
76 changes: 39 additions & 37 deletions src/anthropic/resources/beta/messages/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -3387,44 +3387,46 @@ def stream(

merged_output_config = _merge_output_configs(output_config, transformed_output_format)

request = self._post(
"/v1/messages?beta=true",
body=maybe_transform(
{
"max_tokens": max_tokens,
"messages": messages,
"model": model,
"cache_control": cache_control,
"metadata": metadata,
"output_config": merged_output_config,
"output_format": omit,
"container": container,
"context_management": context_management,
"inference_geo": inference_geo,
"mcp_servers": mcp_servers,
"service_tier": service_tier,
"speed": speed,
"stop_sequences": stop_sequences,
"system": system,
"temperature": temperature,
"thinking": thinking,
"top_k": top_k,
"top_p": top_p,
"tools": tools,
"tool_choice": tool_choice,
"stream": True,
},
message_create_params.MessageCreateParams,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=BetaMessage,
stream=True,
stream_cls=AsyncStream[BetaRawMessageStreamEvent],
)
async def make_request() -> AsyncStream[BetaRawMessageStreamEvent]:
return await self._post(
"/v1/messages?beta=true",
body=await async_maybe_transform(
{
"max_tokens": max_tokens,
"messages": messages,
"model": model,
"cache_control": cache_control,
"metadata": metadata,
"output_config": merged_output_config,
"output_format": omit,
"container": container,
"context_management": context_management,
"inference_geo": inference_geo,
"mcp_servers": mcp_servers,
"service_tier": service_tier,
"speed": speed,
"stop_sequences": stop_sequences,
"system": system,
"temperature": temperature,
"thinking": thinking,
"top_k": top_k,
"top_p": top_p,
"tools": tools,
"tool_choice": tool_choice,
"stream": True,
},
message_create_params.MessageCreateParams,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=BetaMessage,
stream=True,
stream_cls=AsyncStream[BetaRawMessageStreamEvent],
)

return BetaAsyncMessageStreamManager(
request,
make_request(),
output_format=NOT_GIVEN if is_dict(output_format) else cast(ResponseFormatT, output_format),
)

Expand Down
Loading