Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/strands/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,34 @@
T = TypeVar("T", bound=BaseModel)


def _serialize_anthropic_event(event: Any) -> dict[str, Any]:
"""Manually serializes Anthropic events to bypass Pydantic warnings on inner blocks.

If an inner block has a .model_dump() method, it is called explicitly during serialization.
"""
if type(event).__name__ in ("Mock", "AsyncMock", "MagicMock"):
return event.model_dump()

if hasattr(event, "model_fields"):
result = {}
for key in getattr(event, "model_fields", {}):
val = getattr(event, key, None)
if val is None:
continue
if key == "message":
result["message"] = {"stop_reason": getattr(val, "stop_reason", None)}
elif hasattr(val, "model_dump"):
result[key] = val.model_dump()
else:
result[key] = val
return result

if hasattr(event, "model_dump"):
return event.model_dump()

return dict(event)


class AnthropicModel(Model):
"""Anthropic model provider implementation."""

Expand Down Expand Up @@ -407,7 +435,9 @@ async def stream(
logger.debug("got response from model")
async for event in stream:
if event.type in AnthropicModel.EVENT_TYPES:
yield self.format_chunk(event.model_dump())
event_dict = _serialize_anthropic_event(event)

yield self.format_chunk(event_dict)

usage = event.message.usage # type: ignore
yield self.format_chunk({"type": "metadata", "usage": usage.model_dump()})
Expand Down
9 changes: 7 additions & 2 deletions src/strands/types/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
- Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html
"""

from typing import Literal
from typing import Annotated, Any, Literal

from pydantic import PlainSerializer
from typing_extensions import NotRequired, TypedDict

from .citations import CitationsContentBlock
Expand Down Expand Up @@ -177,6 +178,10 @@ class ContentBlockStop(TypedDict):
"""


def _serialize_blocks(blocks: list[Any]) -> list[Any]:
return [b.model_dump() if hasattr(b, "model_dump") else b for b in blocks]


class Message(TypedDict):
"""A message in a conversation with the agent.

Expand All @@ -185,7 +190,7 @@ class Message(TypedDict):
role: The role of the message sender.
"""

content: list[ContentBlock]
content: Annotated[list[ContentBlock], PlainSerializer(_serialize_blocks, return_type=list[Any])]
role: Role


Expand Down