diff --git a/src/anthropic/lib/streaming/_beta_messages.py b/src/anthropic/lib/streaming/_beta_messages.py index c1447a8d..13dfd65e 100644 --- a/src/anthropic/lib/streaming/_beta_messages.py +++ b/src/anthropic/lib/streaming/_beta_messages.py @@ -28,10 +28,25 @@ ) from ..._streaming import Stream, AsyncStream from ...types.beta import BetaRawMessageStreamEvent +from ...types.beta.beta_raw_message_start_event import BetaRawMessageStartEvent +from ...types.beta.beta_raw_message_delta_event import BetaRawMessageDeltaEvent +from ...types.beta.beta_raw_message_stop_event import BetaRawMessageStopEvent +from ...types.beta.beta_raw_content_block_start_event import BetaRawContentBlockStartEvent +from ...types.beta.beta_raw_content_block_delta_event import BetaRawContentBlockDeltaEvent +from ...types.beta.beta_raw_content_block_stop_event import BetaRawContentBlockStopEvent from ..._utils._utils import is_given from .._parse._response import ResponseFormatT, parse_text from ...types.beta.parsed_beta_message import ParsedBetaMessage, ParsedBetaContentBlock +_BETA_RAW_EVENT_TYPE_MAP: dict[str, type[BaseModel]] = { + "message_start": BetaRawMessageStartEvent, + "message_delta": BetaRawMessageDeltaEvent, + "message_stop": BetaRawMessageStopEvent, + "content_block_start": BetaRawContentBlockStartEvent, + "content_block_delta": BetaRawContentBlockDeltaEvent, + "content_block_stop": BetaRawContentBlockStopEvent, +} + class BetaMessageStream(Generic[ResponseFormatT]): text_stream: Iterator[str] @@ -461,9 +476,20 @@ def accumulate_event( ), ) if not isinstance(cast(Any, event), BaseModel): - raise TypeError( - f"Unexpected event runtime type, after deserialising twice - {event} - {builtins.type(event)}" - ) + # Union discriminator deserialization silently returned the raw dict in some + # environments (e.g. older pydantic versions). Fall back to a direct type-map + # lookup using the 'type' field so that well-formed events (including + # content_block_delta) are always promoted to the correct BaseModel. See #941. + raw = cast(Any, event) + if isinstance(raw, dict): + event_type = raw.get("type") + target_cls = _BETA_RAW_EVENT_TYPE_MAP.get(event_type) if isinstance(event_type, str) else None + if target_cls is not None: + event = cast(BetaRawMessageStreamEvent, target_cls.model_construct(**raw)) + if not isinstance(cast(Any, event), BaseModel): + raise TypeError( + f"Unexpected event runtime type, after deserialising twice - {event} - {builtins.type(event)}" + ) if current_snapshot is None: if event.type == "message_start": diff --git a/src/anthropic/lib/streaming/_messages.py b/src/anthropic/lib/streaming/_messages.py index b6b5f538..0f34531a 100644 --- a/src/anthropic/lib/streaming/_messages.py +++ b/src/anthropic/lib/streaming/_messages.py @@ -21,6 +21,12 @@ ParsedContentBlockStopEvent, ) from ...types import RawMessageStreamEvent +from ...types.raw_message_start_event import RawMessageStartEvent +from ...types.raw_message_delta_event import RawMessageDeltaEvent +from ...types.raw_message_stop_event import RawMessageStopEvent +from ...types.raw_content_block_start_event import RawContentBlockStartEvent +from ...types.raw_content_block_delta_event import RawContentBlockDeltaEvent +from ...types.raw_content_block_stop_event import RawContentBlockStopEvent from ..._types import NOT_GIVEN, NotGiven from ..._utils import consume_sync_iterator, consume_async_iterator from ..._models import build, construct_type, construct_type_unchecked @@ -29,6 +35,15 @@ from .._parse._response import ResponseFormatT, parse_text from ...types.parsed_message import ParsedMessage, ParsedContentBlock +_RAW_EVENT_TYPE_MAP: dict[str, type[BaseModel]] = { + "message_start": RawMessageStartEvent, + "message_delta": RawMessageDeltaEvent, + "message_stop": RawMessageStopEvent, + "content_block_start": RawContentBlockStartEvent, + "content_block_delta": RawContentBlockDeltaEvent, + "content_block_stop": RawContentBlockStopEvent, +} + class MessageStream(Generic[ResponseFormatT]): text_stream: Iterator[str] @@ -445,7 +460,18 @@ def accumulate_event( ), ) if not isinstance(cast(Any, event), BaseModel): - raise TypeError(f"Unexpected event runtime type, after deserialising twice - {event} - {type(event)}") + # Union discriminator deserialization silently returned the raw dict in some + # environments (e.g. older pydantic versions). Fall back to a direct type-map + # lookup using the 'type' field so that well-formed events (including + # content_block_delta) are always promoted to the correct BaseModel. See #941. + raw = cast(Any, event) + if isinstance(raw, dict): + event_type = raw.get("type") + target_cls = _RAW_EVENT_TYPE_MAP.get(event_type) if isinstance(event_type, str) else None + if target_cls is not None: + event = cast(RawMessageStreamEvent, target_cls.model_construct(**raw)) + if not isinstance(cast(Any, event), BaseModel): + raise TypeError(f"Unexpected event runtime type, after deserialising twice - {event} - {type(event)}") if current_snapshot is None: if event.type == "message_start": diff --git a/tests/lib/streaming/test_messages.py b/tests/lib/streaming/test_messages.py index d3c959dd..74db6380 100644 --- a/tests/lib/streaming/test_messages.py +++ b/tests/lib/streaming/test_messages.py @@ -13,7 +13,9 @@ from anthropic.lib.streaming import ParsedMessageStreamEvent from anthropic.types.message import Message from anthropic.resources.messages import DEPRECATED_MODELS -from anthropic.lib.streaming._messages import TRACKS_TOOL_INPUT +from anthropic.lib.streaming._messages import TRACKS_TOOL_INPUT, accumulate_event +from anthropic.types.raw_message_stream_event import RawMessageStreamEvent +from anthropic._models import construct_type from .helpers import get_response, to_async_iter @@ -336,3 +338,45 @@ def test_tracks_tool_input_type_alias_is_up_to_date() -> None: f"ContentBlock type {block_type.__name__} has an input property, " f"but is not included in TRACKS_TOOL_INPUT. You probably need to update the TRACKS_TOOL_INPUT type alias." ) + + +def test_accumulate_event_handles_raw_dict_content_block_delta() -> None: + """Regression test for #941. + + When the union discriminator deserialization silently returns a raw dict + (e.g. in older pydantic versions or specific environments), accumulate_event + must fall back to the direct type-map lookup and successfully process + content_block_delta events without raising TypeError. + """ + # Build a valid message snapshot via normal deserialization + msg_start = construct_type( + type_=RawMessageStreamEvent, + value={ + "type": "message_start", + "message": { + "id": "msg_test941", + "type": "message", + "role": "assistant", + "content": [], + "model": "claude-3-5-sonnet-20241022", + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": 5, "output_tokens": 0}, + }, + }, + ) + snapshot = accumulate_event(event=msg_start, current_snapshot=None) + + cbs = construct_type( + type_=RawMessageStreamEvent, + value={"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}}, + ) + snapshot = accumulate_event(event=cbs, current_snapshot=snapshot) + + # Simulate the bug: pass a raw dict directly instead of a typed BaseModel + # (as if _process_response_data / construct_type silently returned the dict) + raw_delta: Any = {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "hello"}} + snapshot = accumulate_event(event=raw_delta, current_snapshot=snapshot) + + assert snapshot.content[0].type == "text" + assert snapshot.content[0].text == "hello"