-
Notifications
You must be signed in to change notification settings - Fork 104
feat(strands-memory): add event metadata support to AgentCoreMemorySessionManager #339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8372765
feefb4d
d00706e
4fabb95
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,13 @@ | ||
| """Configuration for AgentCore Memory Session Manager.""" | ||
|
|
||
| from typing import Dict, Optional | ||
| from typing import Any, Callable, Dict, Optional | ||
|
|
||
| from pydantic import BaseModel, Field | ||
| from pydantic import BaseModel, Field, field_validator | ||
|
|
||
|
|
||
| def normalize_metadata(raw: Dict[str, Any]) -> Dict[str, Any]: | ||
| """Normalize metadata values: plain strings become {"stringValue": value}.""" | ||
| return {k: {"stringValue": v} if isinstance(v, str) else v for k, v in raw.items()} | ||
|
|
||
|
|
||
| class RetrievalConfig(BaseModel): | ||
|
|
@@ -38,6 +43,14 @@ class AgentCoreMemoryConfig(BaseModel): | |
| Default is "user_context". | ||
| filter_restored_tool_context: When True, strip historical toolUse/toolResult blocks from | ||
| restored messages before loading them into Strands runtime memory. Default is False. | ||
| default_metadata: Optional default metadata key-value pairs to attach to every message event. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this actually solve the customer's ask?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are we sure this is the interface the customer is looking for? Could we ask them to send an example code block of the support they want?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it solves the ask. Per-turn metadata works out of the box with The customer is talking about STM events, not LTM records (those are extracted async and don't carry event metadata).
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes — the customer's use case is tagging events with a |
||
| Merged with any per-call metadata. Maximum 15 total keys per event (including internal keys). | ||
| Accepts plain strings (auto-wrapped) or explicit MetadataValue dicts. | ||
| Example: {"location": "NYC"} or {"location": {"stringValue": "NYC"}} | ||
| metadata_provider: Optional callable that returns metadata key-value pairs. Called at each | ||
| event creation, so it can return dynamic values (e.g. current traceId). The returned | ||
| dict is merged after default_metadata but before per-call metadata. | ||
| Accepts plain strings (auto-wrapped) or explicit MetadataValue dicts. | ||
| """ | ||
|
|
||
| memory_id: str = Field(min_length=1) | ||
|
|
@@ -48,3 +61,12 @@ class AgentCoreMemoryConfig(BaseModel): | |
| flush_interval_seconds: Optional[float] = Field(default=None, gt=0) | ||
| context_tag: str = Field(default="user_context", min_length=1) | ||
| filter_restored_tool_context: bool = Field(default=False) | ||
| default_metadata: Optional[Dict[str, Any]] = None | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason we need
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tried switching to |
||
| metadata_provider: Optional[Callable[[], Dict[str, Any]]] = None | ||
|
|
||
| @field_validator("default_metadata", mode="before") | ||
| @classmethod | ||
| def _normalize_default_metadata(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: | ||
| if v is None: | ||
| return None | ||
| return normalize_metadata(v) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,7 +6,7 @@ | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | ||
| from datetime import datetime, timedelta, timezone | ||
| from enum import Enum | ||
| from typing import TYPE_CHECKING, Any, Optional | ||
| from typing import TYPE_CHECKING, Any, Dict, NamedTuple, Optional | ||
|
|
||
| import boto3 | ||
| from botocore.config import Config as BotocoreConfig | ||
|
|
@@ -23,12 +23,13 @@ | |
| from bedrock_agentcore.memory.models.filters import ( | ||
| EventMetadataFilter, | ||
| LeftExpression, | ||
| MetadataValue, | ||
| OperatorType, | ||
| RightExpression, | ||
| ) | ||
|
|
||
| from .bedrock_converter import AgentCoreMemoryConverter | ||
| from .config import AgentCoreMemoryConfig, RetrievalConfig | ||
| from .config import AgentCoreMemoryConfig, RetrievalConfig, normalize_metadata | ||
| from .converters import MemoryConverter | ||
|
|
||
| if TYPE_CHECKING: | ||
|
|
@@ -46,6 +47,22 @@ | |
| STATE_TYPE_KEY = "stateType" | ||
| AGENT_ID_KEY = "agentId" | ||
|
|
||
| # Maximum metadata key-value pairs per event (API limit) | ||
| MAX_METADATA_KEYS = 15 | ||
|
|
||
| # Reserved internal metadata keys that users cannot override | ||
| RESERVED_METADATA_KEYS = frozenset({STATE_TYPE_KEY, AGENT_ID_KEY}) | ||
|
|
||
|
|
||
| class BufferedMessage(NamedTuple): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice, I agree with the decision to add some structure here. |
||
| """A pre-processed message waiting to be flushed to AgentCore Memory.""" | ||
|
|
||
| session_id: str | ||
| messages: list[tuple[str, str]] | ||
| is_blob: bool | ||
| timestamp: datetime | ||
| metadata: Optional[Dict[str, MetadataValue]] = None | ||
|
|
||
|
|
||
| class StateType(Enum): | ||
| """State type for distinguishing session and agent metadata in events.""" | ||
|
|
@@ -129,8 +146,8 @@ def __init__( | |
| session = boto_session or boto3.Session(region_name=region_name) | ||
| self.has_existing_agent = False | ||
|
|
||
| # Batching support - stores pre-processed messages: (session_id, messages, is_blob, timestamp) | ||
| self._message_buffer: list[tuple[str, list[tuple[str, str]], bool, datetime]] = [] | ||
| # Batching support - stores pre-processed messages | ||
| self._message_buffer: list[BufferedMessage] = [] | ||
| self._message_lock = threading.Lock() | ||
|
|
||
| # Agent state buffering - stores all agent state updates: (session_id, agent) | ||
|
|
@@ -169,6 +186,55 @@ def __init__( | |
| if self.config.flush_interval_seconds: | ||
| self._start_flush_timer() | ||
|
|
||
| def _build_metadata( | ||
| self, | ||
| internal_metadata: Optional[Dict[str, MetadataValue]] = None, | ||
| per_call_metadata: Optional[Dict[str, MetadataValue]] = None, | ||
| ) -> Optional[Dict[str, MetadataValue]]: | ||
| """Build merged metadata from config defaults, provider, per-call overrides, and internal keys. | ||
| Merge precedence (highest wins): | ||
tejaskash marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| 1. internal_metadata (stateType, agentId) — always wins | ||
| 2. per_call_metadata (passed via **kwargs) | ||
| 3. metadata_provider() (called at event creation time for dynamic values) | ||
| 4. self.config.default_metadata (set at config construction time) | ||
| Args: | ||
| internal_metadata: System-reserved metadata (e.g. stateType, agentId). | ||
| per_call_metadata: Caller-supplied metadata for a single operation. | ||
| Returns: | ||
| Merged metadata dict, or None if empty. | ||
| Raises: | ||
| ValueError: If user metadata contains reserved keys or total keys exceed MAX_METADATA_KEYS. | ||
| """ | ||
| merged: Dict[str, MetadataValue] = {} | ||
|
|
||
| if self.config.default_metadata: | ||
| merged.update(self.config.default_metadata) | ||
|
|
||
| if self.config.metadata_provider: | ||
| merged.update(normalize_metadata(self.config.metadata_provider())) | ||
|
|
||
| if per_call_metadata: | ||
| merged.update(per_call_metadata) | ||
|
|
||
| # Validate user-supplied keys before merging internal keys | ||
| user_reserved = RESERVED_METADATA_KEYS & merged.keys() | ||
| if user_reserved: | ||
| raise ValueError( | ||
| f"Metadata keys {user_reserved} are reserved for internal use. Reserved keys: {RESERVED_METADATA_KEYS}" | ||
| ) | ||
|
|
||
| if internal_metadata: | ||
| merged.update(internal_metadata) | ||
|
|
||
| if len(merged) > MAX_METADATA_KEYS: | ||
| raise ValueError(f"Combined metadata has {len(merged)} keys, exceeding the maximum of {MAX_METADATA_KEYS}.") | ||
|
|
||
| return merged or None | ||
|
|
||
| # region SessionRepository interface implementation | ||
| def create_session(self, session: Session, **kwargs: Any) -> Session: | ||
| """Create a new session in AgentCore Memory. | ||
|
|
@@ -482,6 +548,9 @@ def create_message( | |
|
|
||
| is_blob = self.converter.exceeds_conversational_limit(messages[0]) | ||
|
|
||
| # Build merged metadata from config defaults + per-call overrides | ||
| merged_metadata = self._build_metadata(per_call_metadata=kwargs.get("metadata")) | ||
|
|
||
| # Parse the original timestamp and use it as desired timestamp | ||
| original_timestamp = datetime.fromisoformat(session_message.created_at.replace("Z", "+00:00")) | ||
| monotonic_timestamp = self._get_monotonic_timestamp(original_timestamp) | ||
|
|
@@ -490,7 +559,15 @@ def create_message( | |
| # Buffer the pre-processed message | ||
| should_flush = False | ||
| with self._message_lock: | ||
| self._message_buffer.append((session_id, messages, is_blob, monotonic_timestamp)) | ||
| self._message_buffer.append( | ||
| BufferedMessage( | ||
| session_id=session_id, | ||
| messages=messages, | ||
| is_blob=is_blob, | ||
| timestamp=monotonic_timestamp, | ||
| metadata=merged_metadata, | ||
| ) | ||
| ) | ||
| should_flush = len(self._message_buffer) >= self.config.batch_size | ||
|
|
||
| # Flush only messages outside the lock to prevent deadlock | ||
|
|
@@ -508,17 +585,19 @@ def create_message( | |
| session_id=session_id, | ||
| messages=messages, | ||
| event_timestamp=monotonic_timestamp, | ||
| metadata=merged_metadata, | ||
| ) | ||
| else: | ||
| event = self.memory_client.gmdp_client.create_event( | ||
| memoryId=self.config.memory_id, | ||
| actorId=self.config.actor_id, | ||
| sessionId=session_id, | ||
| payload=[ | ||
| {"blob": json.dumps(messages[0])}, | ||
| ], | ||
| eventTimestamp=monotonic_timestamp, | ||
| ) | ||
| create_event_kwargs: dict[str, Any] = { | ||
| "memoryId": self.config.memory_id, | ||
| "actorId": self.config.actor_id, | ||
| "sessionId": session_id, | ||
| "payload": [{"blob": json.dumps(messages[0])}], | ||
| "eventTimestamp": monotonic_timestamp, | ||
| } | ||
| if merged_metadata: | ||
| create_event_kwargs["metadata"] = merged_metadata | ||
| event = self.memory_client.gmdp_client.create_event(**create_event_kwargs) | ||
| logger.debug("Created event: %s for message: %s", event.get("eventId"), session_message.message_id) | ||
| return event | ||
| except Exception as e: | ||
|
|
@@ -790,39 +869,45 @@ def _flush_messages_only(self) -> list[dict[str, Any]]: | |
| return [] | ||
|
|
||
| # Group all messages by session_id, combining conversational and blob messages | ||
| # Structure: {session_id: {"payload": [...], "timestamp": latest_timestamp}} | ||
| # Structure: {session_id: {"payload": [...], "timestamp": latest_timestamp, "metadata": {...}}} | ||
| session_groups: dict[str, dict[str, Any]] = {} | ||
|
|
||
| for session_id, messages, is_blob, monotonic_timestamp in messages_to_send: | ||
| if session_id not in session_groups: | ||
| session_groups[session_id] = {"payload": [], "timestamp": monotonic_timestamp} | ||
| for buffered_msg in messages_to_send: | ||
| sid = buffered_msg.session_id | ||
| if sid not in session_groups: | ||
| session_groups[sid] = {"payload": [], "timestamp": buffered_msg.timestamp, "metadata": {}} | ||
|
|
||
| if is_blob: | ||
| # Add blob messages to payload | ||
| for msg in messages: | ||
| session_groups[session_id]["payload"].append({"blob": json.dumps(msg)}) | ||
| if buffered_msg.is_blob: | ||
| for msg in buffered_msg.messages: | ||
| session_groups[sid]["payload"].append({"blob": json.dumps(msg)}) | ||
| else: | ||
| # Add conversational messages to payload | ||
| for text, role in messages: | ||
| session_groups[session_id]["payload"].append( | ||
| for text, role in buffered_msg.messages: | ||
| session_groups[sid]["payload"].append( | ||
| {"conversational": {"content": {"text": text}, "role": role.upper()}} | ||
| ) | ||
|
|
||
| # Use the latest timestamp for the combined event | ||
| if monotonic_timestamp > session_groups[session_id]["timestamp"]: | ||
| session_groups[session_id]["timestamp"] = monotonic_timestamp | ||
| if buffered_msg.timestamp > session_groups[sid]["timestamp"]: | ||
| session_groups[sid]["timestamp"] = buffered_msg.timestamp | ||
|
|
||
| # Merge metadata (later entries override earlier for same key) | ||
| if buffered_msg.metadata: | ||
| session_groups[sid]["metadata"].update(buffered_msg.metadata) | ||
|
|
||
| results = [] | ||
| try: | ||
| # Send one create_event per session_id with all messages (conversational + blob) | ||
| for session_id, group in session_groups.items(): | ||
| event = self.memory_client.gmdp_client.create_event( | ||
| memoryId=self.config.memory_id, | ||
| actorId=self.config.actor_id, | ||
| sessionId=session_id, | ||
| payload=group["payload"], | ||
| eventTimestamp=group["timestamp"], | ||
| ) | ||
| create_event_kwargs: dict[str, Any] = { | ||
| "memoryId": self.config.memory_id, | ||
| "actorId": self.config.actor_id, | ||
| "sessionId": session_id, | ||
| "payload": group["payload"], | ||
| "eventTimestamp": group["timestamp"], | ||
| } | ||
| if group["metadata"]: | ||
| create_event_kwargs["metadata"] = group["metadata"] | ||
| event = self.memory_client.gmdp_client.create_event(**create_event_kwargs) | ||
| results.append(event) | ||
| logger.debug( | ||
| "Flushed batched event for session %s with %d messages: %s", | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens when we flush messages in a batch and the metadata is different on each message? Does all the metadata get merged?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When messages in a batch have different metadata, the metadata dicts are merged (later message's keys override earlier ones for the same key). So with
batch_size > 1, the last value for each key wins in the combined event. This is documented in the batching tradeoff — withbatch_size=1(the default) each turn gets its own event with its own metadata, so no merging occurs.