Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions src/anthropic/lib/streaming/_beta_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,25 @@
)
from ..._streaming import Stream, AsyncStream
from ...types.beta import BetaRawMessageStreamEvent
from ...types.beta.beta_raw_message_start_event import BetaRawMessageStartEvent
from ...types.beta.beta_raw_message_delta_event import BetaRawMessageDeltaEvent
from ...types.beta.beta_raw_message_stop_event import BetaRawMessageStopEvent
from ...types.beta.beta_raw_content_block_start_event import BetaRawContentBlockStartEvent
from ...types.beta.beta_raw_content_block_delta_event import BetaRawContentBlockDeltaEvent
from ...types.beta.beta_raw_content_block_stop_event import BetaRawContentBlockStopEvent
from ..._utils._utils import is_given
from .._parse._response import ResponseFormatT, parse_text
from ...types.beta.parsed_beta_message import ParsedBetaMessage, ParsedBetaContentBlock

_BETA_RAW_EVENT_TYPE_MAP: dict[str, type[BaseModel]] = {
"message_start": BetaRawMessageStartEvent,
"message_delta": BetaRawMessageDeltaEvent,
"message_stop": BetaRawMessageStopEvent,
"content_block_start": BetaRawContentBlockStartEvent,
"content_block_delta": BetaRawContentBlockDeltaEvent,
"content_block_stop": BetaRawContentBlockStopEvent,
}


class BetaMessageStream(Generic[ResponseFormatT]):
text_stream: Iterator[str]
Expand Down Expand Up @@ -461,9 +476,20 @@ def accumulate_event(
),
)
if not isinstance(cast(Any, event), BaseModel):
raise TypeError(
f"Unexpected event runtime type, after deserialising twice - {event} - {builtins.type(event)}"
)
# Union discriminator deserialization silently returned the raw dict in some
# environments (e.g. older pydantic versions). Fall back to a direct type-map
# lookup using the 'type' field so that well-formed events (including
# content_block_delta) are always promoted to the correct BaseModel. See #941.
raw = cast(Any, event)
if isinstance(raw, dict):
event_type = raw.get("type")
target_cls = _BETA_RAW_EVENT_TYPE_MAP.get(event_type) if isinstance(event_type, str) else None
if target_cls is not None:
event = cast(BetaRawMessageStreamEvent, target_cls.model_construct(**raw))
if not isinstance(cast(Any, event), BaseModel):
raise TypeError(
f"Unexpected event runtime type, after deserialising twice - {event} - {builtins.type(event)}"
)

if current_snapshot is None:
if event.type == "message_start":
Expand Down
28 changes: 27 additions & 1 deletion src/anthropic/lib/streaming/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
ParsedContentBlockStopEvent,
)
from ...types import RawMessageStreamEvent
from ...types.raw_message_start_event import RawMessageStartEvent
from ...types.raw_message_delta_event import RawMessageDeltaEvent
from ...types.raw_message_stop_event import RawMessageStopEvent
from ...types.raw_content_block_start_event import RawContentBlockStartEvent
from ...types.raw_content_block_delta_event import RawContentBlockDeltaEvent
from ...types.raw_content_block_stop_event import RawContentBlockStopEvent
from ..._types import NOT_GIVEN, NotGiven
from ..._utils import consume_sync_iterator, consume_async_iterator
from ..._models import build, construct_type, construct_type_unchecked
Expand All @@ -29,6 +35,15 @@
from .._parse._response import ResponseFormatT, parse_text
from ...types.parsed_message import ParsedMessage, ParsedContentBlock

_RAW_EVENT_TYPE_MAP: dict[str, type[BaseModel]] = {
"message_start": RawMessageStartEvent,
"message_delta": RawMessageDeltaEvent,
"message_stop": RawMessageStopEvent,
"content_block_start": RawContentBlockStartEvent,
"content_block_delta": RawContentBlockDeltaEvent,
"content_block_stop": RawContentBlockStopEvent,
}


class MessageStream(Generic[ResponseFormatT]):
text_stream: Iterator[str]
Expand Down Expand Up @@ -445,7 +460,18 @@ def accumulate_event(
),
)
if not isinstance(cast(Any, event), BaseModel):
raise TypeError(f"Unexpected event runtime type, after deserialising twice - {event} - {type(event)}")
# Union discriminator deserialization silently returned the raw dict in some
# environments (e.g. older pydantic versions). Fall back to a direct type-map
# lookup using the 'type' field so that well-formed events (including
# content_block_delta) are always promoted to the correct BaseModel. See #941.
raw = cast(Any, event)
if isinstance(raw, dict):
event_type = raw.get("type")
target_cls = _RAW_EVENT_TYPE_MAP.get(event_type) if isinstance(event_type, str) else None
if target_cls is not None:
event = cast(RawMessageStreamEvent, target_cls.model_construct(**raw))
if not isinstance(cast(Any, event), BaseModel):
raise TypeError(f"Unexpected event runtime type, after deserialising twice - {event} - {type(event)}")

if current_snapshot is None:
if event.type == "message_start":
Expand Down
46 changes: 45 additions & 1 deletion tests/lib/streaming/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from anthropic.lib.streaming import ParsedMessageStreamEvent
from anthropic.types.message import Message
from anthropic.resources.messages import DEPRECATED_MODELS
from anthropic.lib.streaming._messages import TRACKS_TOOL_INPUT
from anthropic.lib.streaming._messages import TRACKS_TOOL_INPUT, accumulate_event
from anthropic.types.raw_message_stream_event import RawMessageStreamEvent
from anthropic._models import construct_type

from .helpers import get_response, to_async_iter

Expand Down Expand Up @@ -336,3 +338,45 @@ def test_tracks_tool_input_type_alias_is_up_to_date() -> None:
f"ContentBlock type {block_type.__name__} has an input property, "
f"but is not included in TRACKS_TOOL_INPUT. You probably need to update the TRACKS_TOOL_INPUT type alias."
)


def test_accumulate_event_handles_raw_dict_content_block_delta() -> None:
"""Regression test for #941.

When the union discriminator deserialization silently returns a raw dict
(e.g. in older pydantic versions or specific environments), accumulate_event
must fall back to the direct type-map lookup and successfully process
content_block_delta events without raising TypeError.
"""
# Build a valid message snapshot via normal deserialization
msg_start = construct_type(
type_=RawMessageStreamEvent,
value={
"type": "message_start",
"message": {
"id": "msg_test941",
"type": "message",
"role": "assistant",
"content": [],
"model": "claude-3-5-sonnet-20241022",
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": 5, "output_tokens": 0},
},
},
)
snapshot = accumulate_event(event=msg_start, current_snapshot=None)

cbs = construct_type(
type_=RawMessageStreamEvent,
value={"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}},
)
snapshot = accumulate_event(event=cbs, current_snapshot=snapshot)

# Simulate the bug: pass a raw dict directly instead of a typed BaseModel
# (as if _process_response_data / construct_type silently returned the dict)
raw_delta: Any = {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "hello"}}
snapshot = accumulate_event(event=raw_delta, current_snapshot=snapshot)

assert snapshot.content[0].type == "text"
assert snapshot.content[0].text == "hello"