From 83727658ab36b0d2973b469684293d768a284bff Mon Sep 17 00:00:00 2001 From: Tejas Kashinath Date: Fri, 13 Mar 2026 16:48:20 -0400 Subject: [PATCH 1/4] feat(strands-memory): add user-supplied event metadata support to AgentCoreMemorySessionManager Allow users to attach custom key-value metadata to conversation events via a new `default_metadata` config field and per-call `metadata` kwarg. Metadata is merged (per-call > config defaults > internal) and validated against reserved keys and the 15-key API limit. Also refactors the internal message buffer from a raw tuple to a `BufferedMessage` NamedTuple for clarity and extensibility. Closes #149 (Phase 1: Metadata) --- .../memory/integrations/strands/README.md | 41 ++++ .../memory/integrations/strands/config.py | 6 +- .../integrations/strands/session_manager.py | 150 ++++++++++---- .../test_agentcore_memory_session_manager.py | 184 ++++++++++++++++-- .../integrations/test_session_manager.py | 127 +++++++++++- 5 files changed, 454 insertions(+), 54 deletions(-) diff --git a/src/bedrock_agentcore/memory/integrations/strands/README.md b/src/bedrock_agentcore/memory/integrations/strands/README.md index 6186bf3e..5c926507 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/README.md +++ b/src/bedrock_agentcore/memory/integrations/strands/README.md @@ -219,6 +219,7 @@ result = agent_with_tools("/path/to/image.png") - `actor_id`: Unique identifier for the user/actor - `retrieval_config`: Dictionary mapping namespaces to RetrievalConfig objects - `batch_size`: Number of messages to buffer before sending to AgentCore Memory (1-100, default: 1). A value of 1 sends immediately (no batching). +- `default_metadata`: Optional dictionary of key-value metadata to attach to every message event. Maximum 15 total keys per event (including internal keys). Example: `{"location": {"stringValue": "NYC"}}` ### RetrievalConfig Parameters @@ -239,6 +240,46 @@ https://docs.aws.amazon.com/bedrock-agentcore/latest/devguide/memory-strategies. - `/summaries/{actorId}/{sessionId}/`: Session-specific summaries +--- + +## Event Metadata + +You can attach custom key-value metadata to every message event. This is useful for tagging +conversations with contextual information (e.g., location, project, case type) that can later +be used to filter events with `list_events`. + +### Default Metadata (applied to all messages) + +```python +config = AgentCoreMemoryConfig( + memory_id=MEM_ID, + session_id=SESSION_ID, + actor_id=ACTOR_ID, + default_metadata={ + "project": {"stringValue": "atlas"}, + "env": {"stringValue": "production"}, + }, +) +session_manager = AgentCoreMemorySessionManager(config, region_name='us-east-1') +agent = Agent(session_manager=session_manager) +agent("Hello!") # This event will have project=atlas and env=production metadata +``` + +### Per-call Metadata + +You can also pass metadata on individual `create_message` calls. Per-call metadata is merged +with `default_metadata` (per-call values override defaults for the same key): + +```python +session_manager.create_message( + session_id, agent_id, message, + metadata={"priority": {"stringValue": "high"}}, +) +``` + +> **Note:** The API allows a maximum of 15 metadata key-value pairs per event. +> The keys `stateType` and `agentId` are reserved for internal use. + --- ## Message Batching diff --git a/src/bedrock_agentcore/memory/integrations/strands/config.py b/src/bedrock_agentcore/memory/integrations/strands/config.py index 20fbbd8c..3a586512 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/config.py +++ b/src/bedrock_agentcore/memory/integrations/strands/config.py @@ -1,6 +1,6 @@ """Configuration for AgentCore Memory Session Manager.""" -from typing import Dict, Optional +from typing import Any, Dict, Optional from pydantic import BaseModel, Field @@ -38,6 +38,9 @@ class AgentCoreMemoryConfig(BaseModel): Default is "user_context". filter_restored_tool_context: When True, strip historical toolUse/toolResult blocks from restored messages before loading them into Strands runtime memory. Default is False. + default_metadata: Optional default metadata key-value pairs to attach to every message event. + Merged with any per-call metadata. Maximum 15 total keys per event (including internal keys). + Example: {"location": {"stringValue": "NYC"}} """ memory_id: str = Field(min_length=1) @@ -48,3 +51,4 @@ class AgentCoreMemoryConfig(BaseModel): flush_interval_seconds: Optional[float] = Field(default=None, gt=0) context_tag: str = Field(default="user_context", min_length=1) filter_restored_tool_context: bool = Field(default=False) + default_metadata: Optional[Dict[str, Any]] = None diff --git a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py index 68bc05b9..4ff12363 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py +++ b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py @@ -6,7 +6,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime, timedelta, timezone from enum import Enum -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Dict, NamedTuple, Optional import boto3 from botocore.config import Config as BotocoreConfig @@ -23,6 +23,7 @@ from bedrock_agentcore.memory.models.filters import ( EventMetadataFilter, LeftExpression, + MetadataValue, OperatorType, RightExpression, ) @@ -46,6 +47,22 @@ STATE_TYPE_KEY = "stateType" AGENT_ID_KEY = "agentId" +# Maximum metadata key-value pairs per event (API limit) +MAX_METADATA_KEYS = 15 + +# Reserved internal metadata keys that users cannot override +RESERVED_METADATA_KEYS = frozenset({STATE_TYPE_KEY, AGENT_ID_KEY}) + + +class BufferedMessage(NamedTuple): + """A pre-processed message waiting to be flushed to AgentCore Memory.""" + + session_id: str + messages: list[tuple[str, str]] + is_blob: bool + timestamp: datetime + metadata: Optional[Dict[str, MetadataValue]] = None + class StateType(Enum): """State type for distinguishing session and agent metadata in events.""" @@ -129,8 +146,8 @@ def __init__( session = boto_session or boto3.Session(region_name=region_name) self.has_existing_agent = False - # Batching support - stores pre-processed messages: (session_id, messages, is_blob, timestamp) - self._message_buffer: list[tuple[str, list[tuple[str, str]], bool, datetime]] = [] + # Batching support - stores pre-processed messages + self._message_buffer: list[BufferedMessage] = [] self._message_lock = threading.Lock() # Agent state buffering - stores all agent state updates: (session_id, agent) @@ -169,6 +186,54 @@ def __init__( if self.config.flush_interval_seconds: self._start_flush_timer() + def _build_metadata( + self, + internal_metadata: Optional[Dict[str, MetadataValue]] = None, + per_call_metadata: Optional[Dict[str, MetadataValue]] = None, + ) -> Optional[Dict[str, MetadataValue]]: + """Build merged metadata from config defaults, per-call overrides, and internal keys. + + Merge precedence (highest wins): + 1. internal_metadata (stateType, agentId) — always wins + 2. per_call_metadata (passed via **kwargs) + 3. self.config.default_metadata (set at config construction time) + + Args: + internal_metadata: System-reserved metadata (e.g. stateType, agentId). + per_call_metadata: Caller-supplied metadata for a single operation. + + Returns: + Merged metadata dict, or None if empty. + + Raises: + ValueError: If user metadata contains reserved keys or total keys exceed MAX_METADATA_KEYS. + """ + merged: Dict[str, MetadataValue] = {} + + if self.config.default_metadata: + merged.update(self.config.default_metadata) + + if per_call_metadata: + merged.update(per_call_metadata) + + # Validate user-supplied keys before merging internal keys + user_reserved = RESERVED_METADATA_KEYS & merged.keys() + if user_reserved: + raise ValueError( + f"Metadata keys {user_reserved} are reserved for internal use. " + f"Reserved keys: {RESERVED_METADATA_KEYS}" + ) + + if internal_metadata: + merged.update(internal_metadata) + + if len(merged) > MAX_METADATA_KEYS: + raise ValueError( + f"Combined metadata has {len(merged)} keys, exceeding the maximum of {MAX_METADATA_KEYS}." + ) + + return merged or None + # region SessionRepository interface implementation def create_session(self, session: Session, **kwargs: Any) -> Session: """Create a new session in AgentCore Memory. @@ -482,6 +547,9 @@ def create_message( is_blob = self.converter.exceeds_conversational_limit(messages[0]) + # Build merged metadata from config defaults + per-call overrides + merged_metadata = self._build_metadata(per_call_metadata=kwargs.get("metadata")) + # Parse the original timestamp and use it as desired timestamp original_timestamp = datetime.fromisoformat(session_message.created_at.replace("Z", "+00:00")) monotonic_timestamp = self._get_monotonic_timestamp(original_timestamp) @@ -490,7 +558,15 @@ def create_message( # Buffer the pre-processed message should_flush = False with self._message_lock: - self._message_buffer.append((session_id, messages, is_blob, monotonic_timestamp)) + self._message_buffer.append( + BufferedMessage( + session_id=session_id, + messages=messages, + is_blob=is_blob, + timestamp=monotonic_timestamp, + metadata=merged_metadata, + ) + ) should_flush = len(self._message_buffer) >= self.config.batch_size # Flush only messages outside the lock to prevent deadlock @@ -508,17 +584,19 @@ def create_message( session_id=session_id, messages=messages, event_timestamp=monotonic_timestamp, + metadata=merged_metadata, ) else: - event = self.memory_client.gmdp_client.create_event( - memoryId=self.config.memory_id, - actorId=self.config.actor_id, - sessionId=session_id, - payload=[ - {"blob": json.dumps(messages[0])}, - ], - eventTimestamp=monotonic_timestamp, - ) + create_event_kwargs: dict[str, Any] = { + "memoryId": self.config.memory_id, + "actorId": self.config.actor_id, + "sessionId": session_id, + "payload": [{"blob": json.dumps(messages[0])}], + "eventTimestamp": monotonic_timestamp, + } + if merged_metadata: + create_event_kwargs["metadata"] = merged_metadata + event = self.memory_client.gmdp_client.create_event(**create_event_kwargs) logger.debug("Created event: %s for message: %s", event.get("eventId"), session_message.message_id) return event except Exception as e: @@ -790,39 +868,45 @@ def _flush_messages_only(self) -> list[dict[str, Any]]: return [] # Group all messages by session_id, combining conversational and blob messages - # Structure: {session_id: {"payload": [...], "timestamp": latest_timestamp}} + # Structure: {session_id: {"payload": [...], "timestamp": latest_timestamp, "metadata": {...}}} session_groups: dict[str, dict[str, Any]] = {} - for session_id, messages, is_blob, monotonic_timestamp in messages_to_send: - if session_id not in session_groups: - session_groups[session_id] = {"payload": [], "timestamp": monotonic_timestamp} + for buffered_msg in messages_to_send: + sid = buffered_msg.session_id + if sid not in session_groups: + session_groups[sid] = {"payload": [], "timestamp": buffered_msg.timestamp, "metadata": {}} - if is_blob: - # Add blob messages to payload - for msg in messages: - session_groups[session_id]["payload"].append({"blob": json.dumps(msg)}) + if buffered_msg.is_blob: + for msg in buffered_msg.messages: + session_groups[sid]["payload"].append({"blob": json.dumps(msg)}) else: - # Add conversational messages to payload - for text, role in messages: - session_groups[session_id]["payload"].append( + for text, role in buffered_msg.messages: + session_groups[sid]["payload"].append( {"conversational": {"content": {"text": text}, "role": role.upper()}} ) # Use the latest timestamp for the combined event - if monotonic_timestamp > session_groups[session_id]["timestamp"]: - session_groups[session_id]["timestamp"] = monotonic_timestamp + if buffered_msg.timestamp > session_groups[sid]["timestamp"]: + session_groups[sid]["timestamp"] = buffered_msg.timestamp + + # Merge metadata (later entries override earlier for same key) + if buffered_msg.metadata: + session_groups[sid]["metadata"].update(buffered_msg.metadata) results = [] try: # Send one create_event per session_id with all messages (conversational + blob) for session_id, group in session_groups.items(): - event = self.memory_client.gmdp_client.create_event( - memoryId=self.config.memory_id, - actorId=self.config.actor_id, - sessionId=session_id, - payload=group["payload"], - eventTimestamp=group["timestamp"], - ) + create_event_kwargs: dict[str, Any] = { + "memoryId": self.config.memory_id, + "actorId": self.config.actor_id, + "sessionId": session_id, + "payload": group["payload"], + "eventTimestamp": group["timestamp"], + } + if group["metadata"]: + create_event_kwargs["metadata"] = group["metadata"] + event = self.memory_client.gmdp_client.create_event(**create_event_kwargs) results.append(event) logger.debug( "Flushed batched event for session %s with %d messages: %s", diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py index 3d872a6e..898bb68c 100644 --- a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py @@ -18,7 +18,10 @@ AgentCoreMemoryConverter, ) from bedrock_agentcore.memory.integrations.strands.config import AgentCoreMemoryConfig, RetrievalConfig -from bedrock_agentcore.memory.integrations.strands.session_manager import AgentCoreMemorySessionManager +from bedrock_agentcore.memory.integrations.strands.session_manager import ( + AgentCoreMemorySessionManager, + BufferedMessage, +) @pytest.fixture @@ -1690,15 +1693,14 @@ def track_create_event(**kwargs): mock_memory_client.gmdp_client.create_event.side_effect = track_create_event # Directly populate buffer with messages for multiple sessions - # Buffer format: (session_id, messages, is_blob, monotonic_timestamp) base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) batching_session_manager._message_buffer = [ - ("session-A", [("SessionA_Message_0", "user")], False, base_time), - ("session-A", [("SessionA_Message_1", "user")], False, base_time), - ("session-B", [("SessionB_Message_0", "user")], False, base_time), - ("session-B", [("SessionB_Message_1", "user")], False, base_time), - ("session-B", [("SessionB_Message_2", "user")], False, base_time), - ("session-A", [("SessionA_Message_2", "user")], False, base_time), # Non-consecutive + BufferedMessage("session-A", [("SessionA_Message_0", "user")], False, base_time), + BufferedMessage("session-A", [("SessionA_Message_1", "user")], False, base_time), + BufferedMessage("session-B", [("SessionB_Message_0", "user")], False, base_time), + BufferedMessage("session-B", [("SessionB_Message_1", "user")], False, base_time), + BufferedMessage("session-B", [("SessionB_Message_2", "user")], False, base_time), + BufferedMessage("session-A", [("SessionA_Message_2", "user")], False, base_time), # Non-consecutive ] batching_session_manager._flush_messages() @@ -1773,10 +1775,10 @@ def fail_on_second_session(**kwargs): # Directly populate buffer with messages for multiple sessions base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) batching_session_manager._message_buffer = [ - ("session-A", [("SessionA_Message_0", "user")], False, base_time), - ("session-A", [("SessionA_Message_1", "user")], False, base_time), - ("session-B", [("SessionB_Message_0", "user")], False, base_time), - ("session-B", [("SessionB_Message_1", "user")], False, base_time), + BufferedMessage("session-A", [("SessionA_Message_0", "user")], False, base_time), + BufferedMessage("session-A", [("SessionA_Message_1", "user")], False, base_time), + BufferedMessage("session-B", [("SessionB_Message_0", "user")], False, base_time), + BufferedMessage("session-B", [("SessionB_Message_1", "user")], False, base_time), ] assert batching_session_manager.pending_message_count() == 4 @@ -1845,12 +1847,12 @@ def track_create_event(**kwargs): blob_content = {"role": "user", "content": [{"text": "blob_A_" + "x" * (CONVERSATIONAL_MAX_SIZE + 100)}]} batching_session_manager._message_buffer = [ # Session A: 2 conversational messages - ("session-A", [("SessionA_conv_0", "user")], False, base_time), - ("session-A", [("SessionA_conv_1", "user")], False, base_time), + BufferedMessage("session-A", [("SessionA_conv_0", "user")], False, base_time), + BufferedMessage("session-A", [("SessionA_conv_1", "user")], False, base_time), # Session A: 1 blob message - ("session-A", [blob_content], True, base_time), + BufferedMessage("session-A", [blob_content], True, base_time), # Session B: 1 conversational message - ("session-B", [("SessionB_conv_0", "user")], False, base_time), + BufferedMessage("session-B", [("SessionB_conv_0", "user")], False, base_time), ] batching_session_manager._flush_messages() @@ -2541,7 +2543,7 @@ def test_after_invocation_hook_flushes_buffer(self, batching_session_manager, mo # Add messages to buffer with batching_session_manager._message_lock: batching_session_manager._message_buffer.append( - ("test-session", [("user", "test message")], False, batching_session_manager._get_monotonic_timestamp()) + BufferedMessage("test-session", [("user", "test message")], False, batching_session_manager._get_monotonic_timestamp()) ) assert batching_session_manager.pending_message_count() == 1 @@ -2764,7 +2766,7 @@ def test_interval_flush_callback_flushes_when_buffer_has_messages(self): # Add messages to buffer with manager._message_lock: manager._message_buffer.append( - ("test-session", [("user", "test message")], False, manager._get_monotonic_timestamp()) + BufferedMessage("test-session", [("user", "test message")], False, manager._get_monotonic_timestamp()) ) assert manager.pending_message_count() == 1 @@ -2919,7 +2921,7 @@ def test_interval_flush_callback_flushes_when_both_buffers_have_data(self): # Add both messages and agent state to buffers with manager._message_lock: manager._message_buffer.append( - ("test-session", [("user", "test message")], False, manager._get_monotonic_timestamp()) + BufferedMessage("test-session", [("user", "test message")], False, manager._get_monotonic_timestamp()) ) from strands.types.session import SessionAgent @@ -2984,3 +2986,147 @@ def test_config_flush_interval_validation(self): actor_id="test-actor", flush_interval_seconds=-5.0, ) + + +class TestMetadataSupport: + """Tests for user-supplied event metadata on messages.""" + + @pytest.fixture + def config_with_metadata(self): + """Config with default metadata.""" + return AgentCoreMemoryConfig( + memory_id="test-memory-123", + session_id="test-session-456", + actor_id="test-actor-789", + default_metadata={"location": {"stringValue": "NYC"}, "team": {"stringValue": "support"}}, + ) + + @pytest.fixture + def session_manager_with_metadata(self, config_with_metadata, mock_memory_client): + """Session manager with default metadata configured.""" + return _create_session_manager(config_with_metadata, mock_memory_client) + + def test_create_message_with_default_metadata(self, session_manager_with_metadata, mock_memory_client): + """Config-level default_metadata flows to create_event.""" + mock_memory_client.create_event.return_value = {"eventId": "evt_1"} + session_message = SessionMessage.from_message({"role": "user", "content": [{"text": "hello"}]}, 0) + + session_manager_with_metadata.create_message("test-session-456", "agent-1", session_message) + + mock_memory_client.create_event.assert_called_once() + call_kwargs = mock_memory_client.create_event.call_args[1] + assert call_kwargs["metadata"] == {"location": {"stringValue": "NYC"}, "team": {"stringValue": "support"}} + + def test_create_message_with_per_call_metadata(self, session_manager, mock_memory_client): + """Per-call metadata passed via kwargs flows to create_event.""" + mock_memory_client.create_event.return_value = {"eventId": "evt_1"} + session_message = SessionMessage.from_message({"role": "user", "content": [{"text": "hello"}]}, 0) + per_call = {"project": {"stringValue": "alpha"}} + + session_manager.create_message("test-session-456", "agent-1", session_message, metadata=per_call) + + mock_memory_client.create_event.assert_called_once() + call_kwargs = mock_memory_client.create_event.call_args[1] + assert call_kwargs["metadata"] == {"project": {"stringValue": "alpha"}} + + def test_metadata_merging_precedence(self, session_manager_with_metadata, mock_memory_client): + """Per-call metadata overrides config default for the same key.""" + mock_memory_client.create_event.return_value = {"eventId": "evt_1"} + session_message = SessionMessage.from_message({"role": "user", "content": [{"text": "hello"}]}, 0) + per_call = {"location": {"stringValue": "SF"}, "project": {"stringValue": "beta"}} + + session_manager_with_metadata.create_message("test-session-456", "agent-1", session_message, metadata=per_call) + + call_kwargs = mock_memory_client.create_event.call_args[1] + assert call_kwargs["metadata"]["location"] == {"stringValue": "SF"} + assert call_kwargs["metadata"]["team"] == {"stringValue": "support"} + assert call_kwargs["metadata"]["project"] == {"stringValue": "beta"} + + def test_metadata_reserved_keys_rejected(self, session_manager): + """ValueError raised when user metadata contains reserved keys.""" + from bedrock_agentcore.memory.integrations.strands.session_manager import RESERVED_METADATA_KEYS + + session_message = SessionMessage.from_message({"role": "user", "content": [{"text": "hello"}]}, 0) + + for reserved_key in RESERVED_METADATA_KEYS: + with pytest.raises(ValueError, match="reserved"): + session_manager.create_message( + "test-session-456", + "agent-1", + session_message, + metadata={reserved_key: {"stringValue": "bad"}}, + ) + + def test_metadata_max_keys_exceeded(self, session_manager): + """ValueError raised when combined metadata exceeds MAX_METADATA_KEYS.""" + from bedrock_agentcore.memory.integrations.strands.session_manager import MAX_METADATA_KEYS + + session_message = SessionMessage.from_message({"role": "user", "content": [{"text": "hello"}]}, 0) + too_many = {f"key_{i}": {"stringValue": f"val_{i}"} for i in range(MAX_METADATA_KEYS + 1)} + + with pytest.raises(ValueError, match="exceeding the maximum"): + session_manager.create_message("test-session-456", "agent-1", session_message, metadata=too_many) + + def test_create_message_no_metadata_passes_none(self, session_manager, mock_memory_client): + """When no metadata configured and none passed, metadata kwarg is None.""" + mock_memory_client.create_event.return_value = {"eventId": "evt_1"} + session_message = SessionMessage.from_message({"role": "user", "content": [{"text": "hello"}]}, 0) + + session_manager.create_message("test-session-456", "agent-1", session_message) + + call_kwargs = mock_memory_client.create_event.call_args[1] + assert call_kwargs.get("metadata") is None + + def test_batched_messages_include_metadata(self, mock_memory_client): + """Metadata flows through the batching path and appears in the flushed event.""" + from datetime import datetime, timezone + + config = AgentCoreMemoryConfig( + memory_id="test-memory-123", + session_id="test-session-456", + actor_id="test-actor-789", + batch_size=5, + default_metadata={"env": {"stringValue": "staging"}}, + ) + manager = _create_session_manager(config, mock_memory_client) + + mock_memory_client.gmdp_client.create_event.return_value = {"event": {"eventId": "batch_evt_1"}} + + # Buffer messages with metadata + base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + manager._message_buffer = [ + BufferedMessage( + "test-session-456", + [("hello", "user")], + False, + base_time, + metadata={"env": {"stringValue": "staging"}}, + ), + BufferedMessage( + "test-session-456", + [("world", "assistant")], + False, + base_time, + metadata={"env": {"stringValue": "staging"}, "extra": {"stringValue": "val"}}, + ), + ] + + manager._flush_messages_only() + + call_kwargs = mock_memory_client.gmdp_client.create_event.call_args[1] + # Merged metadata: later message's metadata overrides earlier + assert call_kwargs["metadata"]["env"] == {"stringValue": "staging"} + assert call_kwargs["metadata"]["extra"] == {"stringValue": "val"} + + def test_blob_message_with_metadata(self, session_manager_with_metadata, mock_memory_client): + """Blob messages also receive metadata.""" + from bedrock_agentcore.memory.integrations.strands.bedrock_converter import CONVERSATIONAL_MAX_SIZE + + mock_memory_client.gmdp_client.create_event.return_value = {"event": {"eventId": "blob_1"}} + big_text = "x" * (CONVERSATIONAL_MAX_SIZE + 100) + session_message = SessionMessage.from_message({"role": "user", "content": [{"text": big_text}]}, 0) + + session_manager_with_metadata.create_message("test-session-456", "agent-1", session_message) + + call_kwargs = mock_memory_client.gmdp_client.create_event.call_args[1] + assert call_kwargs["metadata"] == {"location": {"stringValue": "NYC"}, "team": {"stringValue": "support"}} diff --git a/tests_integ/memory/integrations/test_session_manager.py b/tests_integ/memory/integrations/test_session_manager.py index 5ab308ba..754f1bcb 100644 --- a/tests_integ/memory/integrations/test_session_manager.py +++ b/tests_integ/memory/integrations/test_session_manager.py @@ -19,7 +19,12 @@ from bedrock_agentcore.memory.integrations.strands.bedrock_converter import AgentCoreMemoryConverter from bedrock_agentcore.memory.integrations.strands.config import AgentCoreMemoryConfig, RetrievalConfig from bedrock_agentcore.memory.integrations.strands.session_manager import AgentCoreMemorySessionManager -from bedrock_agentcore.memory.models.filters import EventMetadataFilter, LeftExpression, OperatorType +from bedrock_agentcore.memory.models.filters import ( + EventMetadataFilter, + LeftExpression, + OperatorType, + RightExpression, +) logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) @@ -375,3 +380,123 @@ def test_agent_multi_turn_with_batching(self, test_memory_stm): assert len(messages) >= 6 # endregion End-to-end agent with batching tests + + # region Event metadata integration tests + + def test_create_message_with_metadata_and_filter(self, test_memory_stm, memory_client): + """Test that user-supplied metadata is persisted and that filters actually exclude non-matching events.""" + session_id = f"test-meta-{uuid.uuid4().hex[:8]}" + actor_id = f"test-actor-{uuid.uuid4().hex[:8]}" + + config = AgentCoreMemoryConfig( + memory_id=test_memory_stm["id"], + session_id=session_id, + actor_id=actor_id, + default_metadata={"project": {"stringValue": "atlas"}, "env": {"stringValue": "test"}}, + ) + session_manager = AgentCoreMemorySessionManager(agentcore_memory_config=config, region_name=REGION) + + agent = Agent(system_prompt="You are a helpful assistant.", session_manager=session_manager) + agent("Hello, remember my project is Atlas") + + # Get ALL events (unfiltered) to know the total count + all_events = memory_client.list_events( + memory_id=test_memory_stm["id"], + actor_id=actor_id, + session_id=session_id, + ) + assert len(all_events) >= 1 # Sanity: events exist + + # Positive filter: query events matching our custom metadata + project_filter = EventMetadataFilter.build_expression( + left_operand=LeftExpression.build("project"), + operator=OperatorType.EQUALS_TO, + right_operand=RightExpression.build("atlas"), + ) + matching_events = memory_client.list_events( + memory_id=test_memory_stm["id"], + actor_id=actor_id, + session_id=session_id, + event_metadata=[project_filter], + ) + assert len(matching_events) >= 1 + + # Verify metadata values round-trip correctly + for event in matching_events: + meta = event.get("metadata", {}) + assert meta.get("project", {}).get("stringValue") == "atlas" + assert meta.get("env", {}).get("stringValue") == "test" + + # Negative filter: query with a value that was never written + wrong_filter = EventMetadataFilter.build_expression( + left_operand=LeftExpression.build("project"), + operator=OperatorType.EQUALS_TO, + right_operand=RightExpression.build("nonexistent_project"), + ) + non_matching_events = memory_client.list_events( + memory_id=test_memory_stm["id"], + actor_id=actor_id, + session_id=session_id, + event_metadata=[wrong_filter], + ) + assert len(non_matching_events) == 0, ( + f"Expected 0 events for nonexistent metadata value, got {len(non_matching_events)}" + ) + + # The positive set should be a strict subset of all events + # (all_events includes state events that don't have user metadata) + assert len(matching_events) < len(all_events) + + def test_metadata_survives_session_resume(self, test_memory_stm, memory_client): + """Events with metadata written by one session manager are filterable from another.""" + session_id = f"test-meta-resume-{uuid.uuid4().hex[:8]}" + actor_id = f"test-actor-{uuid.uuid4().hex[:8]}" + + # First session: write messages with metadata + config1 = AgentCoreMemoryConfig( + memory_id=test_memory_stm["id"], + session_id=session_id, + actor_id=actor_id, + default_metadata={"source": {"stringValue": "session1"}}, + ) + sm1 = AgentCoreMemorySessionManager(agentcore_memory_config=config1, region_name=REGION) + agent1 = Agent(system_prompt="You are a helpful assistant.", session_manager=sm1) + agent1("My favourite colour is blue") + + # Second session: resume without metadata configured + config2 = AgentCoreMemoryConfig( + memory_id=test_memory_stm["id"], + session_id=session_id, + actor_id=actor_id, + ) + sm2 = AgentCoreMemorySessionManager(agentcore_memory_config=config2, region_name=REGION) + + # Positive: filter matches metadata written by session 1 + source_filter = EventMetadataFilter.build_expression( + left_operand=LeftExpression.build("source"), + operator=OperatorType.EQUALS_TO, + right_operand=RightExpression.build("session1"), + ) + matching = sm2.memory_client.list_events( + memory_id=test_memory_stm["id"], + actor_id=actor_id, + session_id=session_id, + event_metadata=[source_filter], + ) + assert len(matching) >= 1 + + # Negative: filter for a different source value returns nothing + wrong_source_filter = EventMetadataFilter.build_expression( + left_operand=LeftExpression.build("source"), + operator=OperatorType.EQUALS_TO, + right_operand=RightExpression.build("session999"), + ) + not_matching = sm2.memory_client.list_events( + memory_id=test_memory_stm["id"], + actor_id=actor_id, + session_id=session_id, + event_metadata=[wrong_source_filter], + ) + assert len(not_matching) == 0 + + # endregion Event metadata integration tests From feefb4dea9df6383b7fba625cc4badfec2a1f8fe Mon Sep 17 00:00:00 2001 From: Tejas Kashinath Date: Fri, 13 Mar 2026 17:23:15 -0400 Subject: [PATCH 2/4] feat(strands-memory): add metadata_provider for dynamic per-invocation metadata MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add `metadata_provider` config field — a callable invoked at each event creation, enabling dynamic metadata like traceId that changes per agent invocation. This solves the Langfuse/user-feedback use case where a static `default_metadata` is insufficient because Strands controls the append_message → create_message call path. Merge precedence: default_metadata < metadata_provider() < per-call kwargs < internal keys. --- .../memory/integrations/strands/README.md | 25 ++++++- .../memory/integrations/strands/config.py | 10 ++- .../integrations/strands/session_manager.py | 8 +- .../test_agentcore_memory_session_manager.py | 69 +++++++++++++++++ .../integrations/test_session_manager.py | 74 +++++++++++++++++++ 5 files changed, 181 insertions(+), 5 deletions(-) diff --git a/src/bedrock_agentcore/memory/integrations/strands/README.md b/src/bedrock_agentcore/memory/integrations/strands/README.md index 5c926507..5cd96929 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/README.md +++ b/src/bedrock_agentcore/memory/integrations/strands/README.md @@ -220,6 +220,7 @@ result = agent_with_tools("/path/to/image.png") - `retrieval_config`: Dictionary mapping namespaces to RetrievalConfig objects - `batch_size`: Number of messages to buffer before sending to AgentCore Memory (1-100, default: 1). A value of 1 sends immediately (no batching). - `default_metadata`: Optional dictionary of key-value metadata to attach to every message event. Maximum 15 total keys per event (including internal keys). Example: `{"location": {"stringValue": "NYC"}}` +- `metadata_provider`: Optional callable returning a metadata dictionary. Called at each event creation for dynamic values (e.g., traceId). Merged after `default_metadata`. ### RetrievalConfig Parameters @@ -265,10 +266,32 @@ agent = Agent(session_manager=session_manager) agent("Hello!") # This event will have project=atlas and env=production metadata ``` +### Dynamic Metadata (metadata_provider) + +For values that change per invocation (e.g., traceId for Langfuse), use `metadata_provider` — +a callable invoked at each event creation: + +```python +from langfuse.decorators import langfuse_context + +def get_trace_metadata(): + return {"traceId": {"stringValue": langfuse_context.get_current_trace_id() or ""}} + +config = AgentCoreMemoryConfig( + memory_id=MEM_ID, + session_id=SESSION_ID, + actor_id=ACTOR_ID, + metadata_provider=get_trace_metadata, +) +session_manager = AgentCoreMemorySessionManager(config, region_name='us-east-1') +agent = Agent(session_manager=session_manager) +agent("Hello!") # Event gets the current traceId automatically +``` + ### Per-call Metadata You can also pass metadata on individual `create_message` calls. Per-call metadata is merged -with `default_metadata` (per-call values override defaults for the same key): +with `default_metadata` and `metadata_provider` (per-call values override both for the same key): ```python session_manager.create_message( diff --git a/src/bedrock_agentcore/memory/integrations/strands/config.py b/src/bedrock_agentcore/memory/integrations/strands/config.py index 3a586512..780ef56a 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/config.py +++ b/src/bedrock_agentcore/memory/integrations/strands/config.py @@ -1,8 +1,8 @@ """Configuration for AgentCore Memory Session Manager.""" -from typing import Any, Dict, Optional +from typing import Any, Callable, Dict, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field class RetrievalConfig(BaseModel): @@ -41,8 +41,13 @@ class AgentCoreMemoryConfig(BaseModel): default_metadata: Optional default metadata key-value pairs to attach to every message event. Merged with any per-call metadata. Maximum 15 total keys per event (including internal keys). Example: {"location": {"stringValue": "NYC"}} + metadata_provider: Optional callable that returns metadata key-value pairs. Called at each + event creation, so it can return dynamic values (e.g. current traceId). The returned + dict is merged after default_metadata but before per-call metadata. """ + model_config = ConfigDict(arbitrary_types_allowed=True) + memory_id: str = Field(min_length=1) session_id: str = Field(min_length=1) actor_id: str = Field(min_length=1) @@ -52,3 +57,4 @@ class AgentCoreMemoryConfig(BaseModel): context_tag: str = Field(default="user_context", min_length=1) filter_restored_tool_context: bool = Field(default=False) default_metadata: Optional[Dict[str, Any]] = None + metadata_provider: Optional[Callable[[], Dict[str, Any]]] = None diff --git a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py index 4ff12363..c9dc45aa 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py +++ b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py @@ -191,12 +191,13 @@ def _build_metadata( internal_metadata: Optional[Dict[str, MetadataValue]] = None, per_call_metadata: Optional[Dict[str, MetadataValue]] = None, ) -> Optional[Dict[str, MetadataValue]]: - """Build merged metadata from config defaults, per-call overrides, and internal keys. + """Build merged metadata from config defaults, provider, per-call overrides, and internal keys. Merge precedence (highest wins): 1. internal_metadata (stateType, agentId) — always wins 2. per_call_metadata (passed via **kwargs) - 3. self.config.default_metadata (set at config construction time) + 3. metadata_provider() (called at event creation time for dynamic values) + 4. self.config.default_metadata (set at config construction time) Args: internal_metadata: System-reserved metadata (e.g. stateType, agentId). @@ -213,6 +214,9 @@ def _build_metadata( if self.config.default_metadata: merged.update(self.config.default_metadata) + if self.config.metadata_provider: + merged.update(self.config.metadata_provider()) + if per_call_metadata: merged.update(per_call_metadata) diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py index 898bb68c..bf5a4fd6 100644 --- a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py @@ -3130,3 +3130,72 @@ def test_blob_message_with_metadata(self, session_manager_with_metadata, mock_me call_kwargs = mock_memory_client.gmdp_client.create_event.call_args[1] assert call_kwargs["metadata"] == {"location": {"stringValue": "NYC"}, "team": {"stringValue": "support"}} + + def test_metadata_provider_called_per_event(self, mock_memory_client): + """metadata_provider is called at each create_message and its values appear in the event.""" + call_count = 0 + + def provider(): + nonlocal call_count + call_count += 1 + return {"traceId": {"stringValue": f"trace-{call_count}"}} + + config = AgentCoreMemoryConfig( + memory_id="test-memory-123", + session_id="test-session-456", + actor_id="test-actor-789", + metadata_provider=provider, + ) + manager = _create_session_manager(config, mock_memory_client) + mock_memory_client.create_event.return_value = {"eventId": "evt_1"} + + msg1 = SessionMessage.from_message({"role": "user", "content": [{"text": "hello"}]}, 0) + manager.create_message("test-session-456", "agent-1", msg1) + + assert call_count == 1 + kwargs1 = mock_memory_client.create_event.call_args[1] + assert kwargs1["metadata"]["traceId"] == {"stringValue": "trace-1"} + + msg2 = SessionMessage.from_message({"role": "user", "content": [{"text": "world"}]}, 0) + manager.create_message("test-session-456", "agent-1", msg2) + + assert call_count == 2 + kwargs2 = mock_memory_client.create_event.call_args[1] + assert kwargs2["metadata"]["traceId"] == {"stringValue": "trace-2"} + + def test_metadata_provider_merged_with_defaults(self, mock_memory_client): + """metadata_provider values override default_metadata for same key, but both appear.""" + config = AgentCoreMemoryConfig( + memory_id="test-memory-123", + session_id="test-session-456", + actor_id="test-actor-789", + default_metadata={"env": {"stringValue": "prod"}, "team": {"stringValue": "support"}}, + metadata_provider=lambda: {"env": {"stringValue": "staging"}, "traceId": {"stringValue": "t-1"}}, + ) + manager = _create_session_manager(config, mock_memory_client) + mock_memory_client.create_event.return_value = {"eventId": "evt_1"} + + msg = SessionMessage.from_message({"role": "user", "content": [{"text": "hello"}]}, 0) + manager.create_message("test-session-456", "agent-1", msg) + + call_kwargs = mock_memory_client.create_event.call_args[1] + # provider overrides default for "env" + assert call_kwargs["metadata"]["env"] == {"stringValue": "staging"} + # default still present + assert call_kwargs["metadata"]["team"] == {"stringValue": "support"} + # provider adds new key + assert call_kwargs["metadata"]["traceId"] == {"stringValue": "t-1"} + + def test_metadata_provider_reserved_keys_rejected(self, mock_memory_client): + """metadata_provider returning reserved keys raises ValueError.""" + config = AgentCoreMemoryConfig( + memory_id="test-memory-123", + session_id="test-session-456", + actor_id="test-actor-789", + metadata_provider=lambda: {"stateType": {"stringValue": "bad"}}, + ) + manager = _create_session_manager(config, mock_memory_client) + + msg = SessionMessage.from_message({"role": "user", "content": [{"text": "hello"}]}, 0) + with pytest.raises(ValueError, match="reserved"): + manager.create_message("test-session-456", "agent-1", msg) diff --git a/tests_integ/memory/integrations/test_session_manager.py b/tests_integ/memory/integrations/test_session_manager.py index 754f1bcb..b0ec5767 100644 --- a/tests_integ/memory/integrations/test_session_manager.py +++ b/tests_integ/memory/integrations/test_session_manager.py @@ -499,4 +499,78 @@ def test_metadata_survives_session_resume(self, test_memory_stm, memory_client): ) assert len(not_matching) == 0 + def test_metadata_provider_attaches_dynamic_trace_id(self, test_memory_stm, memory_client): + """metadata_provider injects a different traceId per invocation and events are filterable by each.""" + session_id = f"test-meta-prov-{uuid.uuid4().hex[:8]}" + actor_id = f"test-actor-{uuid.uuid4().hex[:8]}" + + current_trace = {"traceId": {"stringValue": "trace-AAA"}} + + config = AgentCoreMemoryConfig( + memory_id=test_memory_stm["id"], + session_id=session_id, + actor_id=actor_id, + metadata_provider=lambda: dict(current_trace), + ) + sm = AgentCoreMemorySessionManager(agentcore_memory_config=config, region_name=REGION) + agent = Agent(system_prompt="You are a helpful assistant.", session_manager=sm) + + # First invocation with trace-AAA + agent("Hello from trace AAA") + + # Switch trace for second invocation + current_trace["traceId"] = {"stringValue": "trace-BBB"} + agent("Hello from trace BBB") + + # Filter for trace-AAA — should find events + filter_aaa = EventMetadataFilter.build_expression( + left_operand=LeftExpression.build("traceId"), + operator=OperatorType.EQUALS_TO, + right_operand=RightExpression.build("trace-AAA"), + ) + events_aaa = memory_client.list_events( + memory_id=test_memory_stm["id"], + actor_id=actor_id, + session_id=session_id, + event_metadata=[filter_aaa], + ) + assert len(events_aaa) >= 1 + for e in events_aaa: + assert e.get("metadata", {}).get("traceId", {}).get("stringValue") == "trace-AAA" + + # Filter for trace-BBB — should find events + filter_bbb = EventMetadataFilter.build_expression( + left_operand=LeftExpression.build("traceId"), + operator=OperatorType.EQUALS_TO, + right_operand=RightExpression.build("trace-BBB"), + ) + events_bbb = memory_client.list_events( + memory_id=test_memory_stm["id"], + actor_id=actor_id, + session_id=session_id, + event_metadata=[filter_bbb], + ) + assert len(events_bbb) >= 1 + for e in events_bbb: + assert e.get("metadata", {}).get("traceId", {}).get("stringValue") == "trace-BBB" + + # Negative: nonexistent trace returns nothing + filter_none = EventMetadataFilter.build_expression( + left_operand=LeftExpression.build("traceId"), + operator=OperatorType.EQUALS_TO, + right_operand=RightExpression.build("trace-DOES-NOT-EXIST"), + ) + events_none = memory_client.list_events( + memory_id=test_memory_stm["id"], + actor_id=actor_id, + session_id=session_id, + event_metadata=[filter_none], + ) + assert len(events_none) == 0 + + # The two traces should be disjoint sets + aaa_ids = {e["eventId"] for e in events_aaa} + bbb_ids = {e["eventId"] for e in events_bbb} + assert aaa_ids.isdisjoint(bbb_ids), "trace-AAA and trace-BBB events should not overlap" + # endregion Event metadata integration tests From d00706e35ac0103de481131736d7da793e738723 Mon Sep 17 00:00:00 2001 From: Tejas Kashinath Date: Mon, 16 Mar 2026 14:06:07 -0400 Subject: [PATCH 3/4] fix: address PR review feedback - Auto-normalize plain string metadata values to {"stringValue": ...} so users can write {"project": "atlas"} instead of the verbose form. Applied via pydantic validator on default_metadata and at runtime for metadata_provider return values. - Move inline datetime imports to top of test file (nit from Hweinstock) - Fix lint/format issues that caused CI Lint and Format check to fail - Add tests for normalization in both config and session manager --- .../memory/integrations/strands/README.md | 11 ++-- .../memory/integrations/strands/config.py | 18 +++++- .../integrations/strands/session_manager.py | 11 ++-- .../strands/test_agentcore_memory_config.py | 36 +++++++++++ .../test_agentcore_memory_session_manager.py | 59 +++++++++++++++---- 5 files changed, 110 insertions(+), 25 deletions(-) diff --git a/src/bedrock_agentcore/memory/integrations/strands/README.md b/src/bedrock_agentcore/memory/integrations/strands/README.md index 5cd96929..d7c7b2ff 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/README.md +++ b/src/bedrock_agentcore/memory/integrations/strands/README.md @@ -257,8 +257,8 @@ config = AgentCoreMemoryConfig( session_id=SESSION_ID, actor_id=ACTOR_ID, default_metadata={ - "project": {"stringValue": "atlas"}, - "env": {"stringValue": "production"}, + "project": "atlas", + "env": "production", }, ) session_manager = AgentCoreMemorySessionManager(config, region_name='us-east-1') @@ -266,6 +266,9 @@ agent = Agent(session_manager=session_manager) agent("Hello!") # This event will have project=atlas and env=production metadata ``` +> Plain strings are auto-wrapped to `{"stringValue": "..."}`. The explicit form +> `{"project": {"stringValue": "atlas"}}` also works. + ### Dynamic Metadata (metadata_provider) For values that change per invocation (e.g., traceId for Langfuse), use `metadata_provider` — @@ -275,7 +278,7 @@ a callable invoked at each event creation: from langfuse.decorators import langfuse_context def get_trace_metadata(): - return {"traceId": {"stringValue": langfuse_context.get_current_trace_id() or ""}} + return {"traceId": langfuse_context.get_current_trace_id() or ""} config = AgentCoreMemoryConfig( memory_id=MEM_ID, @@ -296,7 +299,7 @@ with `default_metadata` and `metadata_provider` (per-call values override both f ```python session_manager.create_message( session_id, agent_id, message, - metadata={"priority": {"stringValue": "high"}}, + metadata={"priority": "high"}, ) ``` diff --git a/src/bedrock_agentcore/memory/integrations/strands/config.py b/src/bedrock_agentcore/memory/integrations/strands/config.py index 780ef56a..388fd13d 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/config.py +++ b/src/bedrock_agentcore/memory/integrations/strands/config.py @@ -2,7 +2,12 @@ from typing import Any, Callable, Dict, Optional -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, field_validator + + +def normalize_metadata(raw: Dict[str, Any]) -> Dict[str, Any]: + """Normalize metadata values: plain strings become {"stringValue": value}.""" + return {k: {"stringValue": v} if isinstance(v, str) else v for k, v in raw.items()} class RetrievalConfig(BaseModel): @@ -40,10 +45,12 @@ class AgentCoreMemoryConfig(BaseModel): restored messages before loading them into Strands runtime memory. Default is False. default_metadata: Optional default metadata key-value pairs to attach to every message event. Merged with any per-call metadata. Maximum 15 total keys per event (including internal keys). - Example: {"location": {"stringValue": "NYC"}} + Accepts plain strings (auto-wrapped) or explicit MetadataValue dicts. + Example: {"location": "NYC"} or {"location": {"stringValue": "NYC"}} metadata_provider: Optional callable that returns metadata key-value pairs. Called at each event creation, so it can return dynamic values (e.g. current traceId). The returned dict is merged after default_metadata but before per-call metadata. + Accepts plain strings (auto-wrapped) or explicit MetadataValue dicts. """ model_config = ConfigDict(arbitrary_types_allowed=True) @@ -58,3 +65,10 @@ class AgentCoreMemoryConfig(BaseModel): filter_restored_tool_context: bool = Field(default=False) default_metadata: Optional[Dict[str, Any]] = None metadata_provider: Optional[Callable[[], Dict[str, Any]]] = None + + @field_validator("default_metadata", mode="before") + @classmethod + def _normalize_default_metadata(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + if v is None: + return None + return normalize_metadata(v) diff --git a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py index c9dc45aa..3fd9ced6 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py +++ b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py @@ -29,7 +29,7 @@ ) from .bedrock_converter import AgentCoreMemoryConverter -from .config import AgentCoreMemoryConfig, RetrievalConfig +from .config import AgentCoreMemoryConfig, RetrievalConfig, normalize_metadata from .converters import MemoryConverter if TYPE_CHECKING: @@ -215,7 +215,7 @@ def _build_metadata( merged.update(self.config.default_metadata) if self.config.metadata_provider: - merged.update(self.config.metadata_provider()) + merged.update(normalize_metadata(self.config.metadata_provider())) if per_call_metadata: merged.update(per_call_metadata) @@ -224,17 +224,14 @@ def _build_metadata( user_reserved = RESERVED_METADATA_KEYS & merged.keys() if user_reserved: raise ValueError( - f"Metadata keys {user_reserved} are reserved for internal use. " - f"Reserved keys: {RESERVED_METADATA_KEYS}" + f"Metadata keys {user_reserved} are reserved for internal use. Reserved keys: {RESERVED_METADATA_KEYS}" ) if internal_metadata: merged.update(internal_metadata) if len(merged) > MAX_METADATA_KEYS: - raise ValueError( - f"Combined metadata has {len(merged)} keys, exceeding the maximum of {MAX_METADATA_KEYS}." - ) + raise ValueError(f"Combined metadata has {len(merged)} keys, exceeding the maximum of {MAX_METADATA_KEYS}.") return merged or None diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_config.py b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_config.py index cacef972..97f0bdae 100644 --- a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_config.py +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_config.py @@ -87,3 +87,39 @@ def test_with_retrieval_config(self): memory_id="mem-123", session_id="sess-456", actor_id="actor-789", retrieval_config={"namespace1": retrieval} ) assert config.retrieval_config["namespace1"].top_k == 5 + + def test_default_metadata_plain_strings_normalized(self): + """Plain string values are auto-wrapped to {"stringValue": ...}.""" + config = AgentCoreMemoryConfig( + memory_id="mem-123", + session_id="sess-456", + actor_id="actor-789", + default_metadata={"project": "atlas", "env": "prod"}, + ) + assert config.default_metadata == { + "project": {"stringValue": "atlas"}, + "env": {"stringValue": "prod"}, + } + + def test_default_metadata_explicit_format_unchanged(self): + """Explicit {"stringValue": ...} dicts pass through unchanged.""" + config = AgentCoreMemoryConfig( + memory_id="mem-123", + session_id="sess-456", + actor_id="actor-789", + default_metadata={"project": {"stringValue": "atlas"}}, + ) + assert config.default_metadata == {"project": {"stringValue": "atlas"}} + + def test_default_metadata_mixed_formats(self): + """Mixed plain strings and explicit dicts both work.""" + config = AgentCoreMemoryConfig( + memory_id="mem-123", + session_id="sess-456", + actor_id="actor-789", + default_metadata={"project": "atlas", "env": {"stringValue": "prod"}}, + ) + assert config.default_metadata == { + "project": {"stringValue": "atlas"}, + "env": {"stringValue": "prod"}, + } diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py index bf5a4fd6..e0a6503d 100644 --- a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py @@ -2,6 +2,7 @@ import logging import time +from datetime import datetime, timezone from unittest.mock import Mock, patch import pytest @@ -1680,8 +1681,6 @@ def test__flush_messages_multiple_sessions_grouped_into_separate_api_calls( so all messages go to one session. This test verifies the internal grouping logic by directly manipulating the buffer. """ - from datetime import datetime, timezone - calls_by_session = {} def track_create_event(**kwargs): @@ -1749,8 +1748,6 @@ def track_create_event(**kwargs): # The combined event should use the latest timestamp (12:10:00) assert len(captured_timestamps) == 1 # The timestamp should be the latest one (12:10:00) - from datetime import datetime, timezone - expected_latest = datetime(2024, 1, 1, 12, 10, 0, tzinfo=timezone.utc) # Account for monotonic timestamp adjustment (may add microseconds) assert captured_timestamps[0] >= expected_latest @@ -1762,7 +1759,6 @@ def test__flush_messages_partial_failure_multiple_sessions_preserves_buffer( Note: Tests internal grouping logic by directly manipulating buffer. """ - from datetime import datetime, timezone def fail_on_second_session(**kwargs): session_id = kwargs.get("sessionId") @@ -1830,8 +1826,6 @@ def test_mixed_sessions_with_blobs_and_conversational(self, batching_session_man Note: Tests internal grouping logic by directly manipulating buffer. """ - from datetime import datetime, timezone - calls_by_session = {} def track_create_event(**kwargs): @@ -2543,7 +2537,12 @@ def test_after_invocation_hook_flushes_buffer(self, batching_session_manager, mo # Add messages to buffer with batching_session_manager._message_lock: batching_session_manager._message_buffer.append( - BufferedMessage("test-session", [("user", "test message")], False, batching_session_manager._get_monotonic_timestamp()) + BufferedMessage( + "test-session", + [("user", "test message")], + False, + batching_session_manager._get_monotonic_timestamp(), + ) ) assert batching_session_manager.pending_message_count() == 1 @@ -2766,7 +2765,9 @@ def test_interval_flush_callback_flushes_when_buffer_has_messages(self): # Add messages to buffer with manager._message_lock: manager._message_buffer.append( - BufferedMessage("test-session", [("user", "test message")], False, manager._get_monotonic_timestamp()) + BufferedMessage( + "test-session", [("user", "test message")], False, manager._get_monotonic_timestamp() + ) ) assert manager.pending_message_count() == 1 @@ -2921,7 +2922,9 @@ def test_interval_flush_callback_flushes_when_both_buffers_have_data(self): # Add both messages and agent state to buffers with manager._message_lock: manager._message_buffer.append( - BufferedMessage("test-session", [("user", "test message")], False, manager._get_monotonic_timestamp()) + BufferedMessage( + "test-session", [("user", "test message")], False, manager._get_monotonic_timestamp() + ) ) from strands.types.session import SessionAgent @@ -3079,8 +3082,6 @@ def test_create_message_no_metadata_passes_none(self, session_manager, mock_memo def test_batched_messages_include_metadata(self, mock_memory_client): """Metadata flows through the batching path and appears in the flushed event.""" - from datetime import datetime, timezone - config = AgentCoreMemoryConfig( memory_id="test-memory-123", session_id="test-session-456", @@ -3199,3 +3200,37 @@ def test_metadata_provider_reserved_keys_rejected(self, mock_memory_client): msg = SessionMessage.from_message({"role": "user", "content": [{"text": "hello"}]}, 0) with pytest.raises(ValueError, match="reserved"): manager.create_message("test-session-456", "agent-1", msg) + + def test_metadata_provider_plain_strings_normalized(self, mock_memory_client): + """metadata_provider returning plain strings gets auto-normalized.""" + config = AgentCoreMemoryConfig( + memory_id="test-memory-123", + session_id="test-session-456", + actor_id="test-actor-789", + metadata_provider=lambda: {"traceId": "trace-abc"}, + ) + manager = _create_session_manager(config, mock_memory_client) + mock_memory_client.create_event.return_value = {"eventId": "evt_1"} + + msg = SessionMessage.from_message({"role": "user", "content": [{"text": "hello"}]}, 0) + manager.create_message("test-session-456", "agent-1", msg) + + kwargs = mock_memory_client.create_event.call_args[1] + assert kwargs["metadata"]["traceId"] == {"stringValue": "trace-abc"} + + def test_default_metadata_plain_strings_normalized(self, mock_memory_client): + """default_metadata with plain strings gets auto-normalized at config time.""" + config = AgentCoreMemoryConfig( + memory_id="test-memory-123", + session_id="test-session-456", + actor_id="test-actor-789", + default_metadata={"project": "atlas"}, + ) + manager = _create_session_manager(config, mock_memory_client) + mock_memory_client.create_event.return_value = {"eventId": "evt_1"} + + msg = SessionMessage.from_message({"role": "user", "content": [{"text": "hello"}]}, 0) + manager.create_message("test-session-456", "agent-1", msg) + + kwargs = mock_memory_client.create_event.call_args[1] + assert kwargs["metadata"]["project"] == {"stringValue": "atlas"} From 4fabb957a64b66cc263751cc0698ada9c947bf8a Mon Sep 17 00:00:00 2001 From: Tejas Kashinath Date: Mon, 16 Mar 2026 15:16:53 -0400 Subject: [PATCH 4/4] fix: remove unnecessary model_config from AgentCoreMemoryConfig Pydantic v2 handles Callable natively, so arbitrary_types_allowed is not needed. Removing it avoids any risk of breaking subclasses or downstream validators. --- src/bedrock_agentcore/memory/integrations/strands/config.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/bedrock_agentcore/memory/integrations/strands/config.py b/src/bedrock_agentcore/memory/integrations/strands/config.py index 388fd13d..2a9ee472 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/config.py +++ b/src/bedrock_agentcore/memory/integrations/strands/config.py @@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, Optional -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, Field, field_validator def normalize_metadata(raw: Dict[str, Any]) -> Dict[str, Any]: @@ -53,8 +53,6 @@ class AgentCoreMemoryConfig(BaseModel): Accepts plain strings (auto-wrapped) or explicit MetadataValue dicts. """ - model_config = ConfigDict(arbitrary_types_allowed=True) - memory_id: str = Field(min_length=1) session_id: str = Field(min_length=1) actor_id: str = Field(min_length=1)