diff --git a/src/bedrock_agentcore/memory/integrations/strands/README.md b/src/bedrock_agentcore/memory/integrations/strands/README.md index 6186bf3..d7c7b2f 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/README.md +++ b/src/bedrock_agentcore/memory/integrations/strands/README.md @@ -219,6 +219,8 @@ result = agent_with_tools("/path/to/image.png") - `actor_id`: Unique identifier for the user/actor - `retrieval_config`: Dictionary mapping namespaces to RetrievalConfig objects - `batch_size`: Number of messages to buffer before sending to AgentCore Memory (1-100, default: 1). A value of 1 sends immediately (no batching). +- `default_metadata`: Optional dictionary of key-value metadata to attach to every message event. Maximum 15 total keys per event (including internal keys). Example: `{"location": {"stringValue": "NYC"}}` +- `metadata_provider`: Optional callable returning a metadata dictionary. Called at each event creation for dynamic values (e.g., traceId). Merged after `default_metadata`. ### RetrievalConfig Parameters @@ -239,6 +241,71 @@ https://docs.aws.amazon.com/bedrock-agentcore/latest/devguide/memory-strategies. - `/summaries/{actorId}/{sessionId}/`: Session-specific summaries +--- + +## Event Metadata + +You can attach custom key-value metadata to every message event. This is useful for tagging +conversations with contextual information (e.g., location, project, case type) that can later +be used to filter events with `list_events`. + +### Default Metadata (applied to all messages) + +```python +config = AgentCoreMemoryConfig( + memory_id=MEM_ID, + session_id=SESSION_ID, + actor_id=ACTOR_ID, + default_metadata={ + "project": "atlas", + "env": "production", + }, +) +session_manager = AgentCoreMemorySessionManager(config, region_name='us-east-1') +agent = Agent(session_manager=session_manager) +agent("Hello!") # This event will have project=atlas and env=production metadata +``` + +> Plain strings are auto-wrapped to `{"stringValue": "..."}`. The explicit form +> `{"project": {"stringValue": "atlas"}}` also works. + +### Dynamic Metadata (metadata_provider) + +For values that change per invocation (e.g., traceId for Langfuse), use `metadata_provider` — +a callable invoked at each event creation: + +```python +from langfuse.decorators import langfuse_context + +def get_trace_metadata(): + return {"traceId": langfuse_context.get_current_trace_id() or ""} + +config = AgentCoreMemoryConfig( + memory_id=MEM_ID, + session_id=SESSION_ID, + actor_id=ACTOR_ID, + metadata_provider=get_trace_metadata, +) +session_manager = AgentCoreMemorySessionManager(config, region_name='us-east-1') +agent = Agent(session_manager=session_manager) +agent("Hello!") # Event gets the current traceId automatically +``` + +### Per-call Metadata + +You can also pass metadata on individual `create_message` calls. Per-call metadata is merged +with `default_metadata` and `metadata_provider` (per-call values override both for the same key): + +```python +session_manager.create_message( + session_id, agent_id, message, + metadata={"priority": "high"}, +) +``` + +> **Note:** The API allows a maximum of 15 metadata key-value pairs per event. +> The keys `stateType` and `agentId` are reserved for internal use. + --- ## Message Batching diff --git a/src/bedrock_agentcore/memory/integrations/strands/config.py b/src/bedrock_agentcore/memory/integrations/strands/config.py index 20fbbd8..2a9ee47 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/config.py +++ b/src/bedrock_agentcore/memory/integrations/strands/config.py @@ -1,8 +1,13 @@ """Configuration for AgentCore Memory Session Manager.""" -from typing import Dict, Optional +from typing import Any, Callable, Dict, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator + + +def normalize_metadata(raw: Dict[str, Any]) -> Dict[str, Any]: + """Normalize metadata values: plain strings become {"stringValue": value}.""" + return {k: {"stringValue": v} if isinstance(v, str) else v for k, v in raw.items()} class RetrievalConfig(BaseModel): @@ -38,6 +43,14 @@ class AgentCoreMemoryConfig(BaseModel): Default is "user_context". filter_restored_tool_context: When True, strip historical toolUse/toolResult blocks from restored messages before loading them into Strands runtime memory. Default is False. + default_metadata: Optional default metadata key-value pairs to attach to every message event. + Merged with any per-call metadata. Maximum 15 total keys per event (including internal keys). + Accepts plain strings (auto-wrapped) or explicit MetadataValue dicts. + Example: {"location": "NYC"} or {"location": {"stringValue": "NYC"}} + metadata_provider: Optional callable that returns metadata key-value pairs. Called at each + event creation, so it can return dynamic values (e.g. current traceId). The returned + dict is merged after default_metadata but before per-call metadata. + Accepts plain strings (auto-wrapped) or explicit MetadataValue dicts. """ memory_id: str = Field(min_length=1) @@ -48,3 +61,12 @@ class AgentCoreMemoryConfig(BaseModel): flush_interval_seconds: Optional[float] = Field(default=None, gt=0) context_tag: str = Field(default="user_context", min_length=1) filter_restored_tool_context: bool = Field(default=False) + default_metadata: Optional[Dict[str, Any]] = None + metadata_provider: Optional[Callable[[], Dict[str, Any]]] = None + + @field_validator("default_metadata", mode="before") + @classmethod + def _normalize_default_metadata(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + if v is None: + return None + return normalize_metadata(v) diff --git a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py index 68bc05b..3fd9ced 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py +++ b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py @@ -6,7 +6,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime, timedelta, timezone from enum import Enum -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Dict, NamedTuple, Optional import boto3 from botocore.config import Config as BotocoreConfig @@ -23,12 +23,13 @@ from bedrock_agentcore.memory.models.filters import ( EventMetadataFilter, LeftExpression, + MetadataValue, OperatorType, RightExpression, ) from .bedrock_converter import AgentCoreMemoryConverter -from .config import AgentCoreMemoryConfig, RetrievalConfig +from .config import AgentCoreMemoryConfig, RetrievalConfig, normalize_metadata from .converters import MemoryConverter if TYPE_CHECKING: @@ -46,6 +47,22 @@ STATE_TYPE_KEY = "stateType" AGENT_ID_KEY = "agentId" +# Maximum metadata key-value pairs per event (API limit) +MAX_METADATA_KEYS = 15 + +# Reserved internal metadata keys that users cannot override +RESERVED_METADATA_KEYS = frozenset({STATE_TYPE_KEY, AGENT_ID_KEY}) + + +class BufferedMessage(NamedTuple): + """A pre-processed message waiting to be flushed to AgentCore Memory.""" + + session_id: str + messages: list[tuple[str, str]] + is_blob: bool + timestamp: datetime + metadata: Optional[Dict[str, MetadataValue]] = None + class StateType(Enum): """State type for distinguishing session and agent metadata in events.""" @@ -129,8 +146,8 @@ def __init__( session = boto_session or boto3.Session(region_name=region_name) self.has_existing_agent = False - # Batching support - stores pre-processed messages: (session_id, messages, is_blob, timestamp) - self._message_buffer: list[tuple[str, list[tuple[str, str]], bool, datetime]] = [] + # Batching support - stores pre-processed messages + self._message_buffer: list[BufferedMessage] = [] self._message_lock = threading.Lock() # Agent state buffering - stores all agent state updates: (session_id, agent) @@ -169,6 +186,55 @@ def __init__( if self.config.flush_interval_seconds: self._start_flush_timer() + def _build_metadata( + self, + internal_metadata: Optional[Dict[str, MetadataValue]] = None, + per_call_metadata: Optional[Dict[str, MetadataValue]] = None, + ) -> Optional[Dict[str, MetadataValue]]: + """Build merged metadata from config defaults, provider, per-call overrides, and internal keys. + + Merge precedence (highest wins): + 1. internal_metadata (stateType, agentId) — always wins + 2. per_call_metadata (passed via **kwargs) + 3. metadata_provider() (called at event creation time for dynamic values) + 4. self.config.default_metadata (set at config construction time) + + Args: + internal_metadata: System-reserved metadata (e.g. stateType, agentId). + per_call_metadata: Caller-supplied metadata for a single operation. + + Returns: + Merged metadata dict, or None if empty. + + Raises: + ValueError: If user metadata contains reserved keys or total keys exceed MAX_METADATA_KEYS. + """ + merged: Dict[str, MetadataValue] = {} + + if self.config.default_metadata: + merged.update(self.config.default_metadata) + + if self.config.metadata_provider: + merged.update(normalize_metadata(self.config.metadata_provider())) + + if per_call_metadata: + merged.update(per_call_metadata) + + # Validate user-supplied keys before merging internal keys + user_reserved = RESERVED_METADATA_KEYS & merged.keys() + if user_reserved: + raise ValueError( + f"Metadata keys {user_reserved} are reserved for internal use. Reserved keys: {RESERVED_METADATA_KEYS}" + ) + + if internal_metadata: + merged.update(internal_metadata) + + if len(merged) > MAX_METADATA_KEYS: + raise ValueError(f"Combined metadata has {len(merged)} keys, exceeding the maximum of {MAX_METADATA_KEYS}.") + + return merged or None + # region SessionRepository interface implementation def create_session(self, session: Session, **kwargs: Any) -> Session: """Create a new session in AgentCore Memory. @@ -482,6 +548,9 @@ def create_message( is_blob = self.converter.exceeds_conversational_limit(messages[0]) + # Build merged metadata from config defaults + per-call overrides + merged_metadata = self._build_metadata(per_call_metadata=kwargs.get("metadata")) + # Parse the original timestamp and use it as desired timestamp original_timestamp = datetime.fromisoformat(session_message.created_at.replace("Z", "+00:00")) monotonic_timestamp = self._get_monotonic_timestamp(original_timestamp) @@ -490,7 +559,15 @@ def create_message( # Buffer the pre-processed message should_flush = False with self._message_lock: - self._message_buffer.append((session_id, messages, is_blob, monotonic_timestamp)) + self._message_buffer.append( + BufferedMessage( + session_id=session_id, + messages=messages, + is_blob=is_blob, + timestamp=monotonic_timestamp, + metadata=merged_metadata, + ) + ) should_flush = len(self._message_buffer) >= self.config.batch_size # Flush only messages outside the lock to prevent deadlock @@ -508,17 +585,19 @@ def create_message( session_id=session_id, messages=messages, event_timestamp=monotonic_timestamp, + metadata=merged_metadata, ) else: - event = self.memory_client.gmdp_client.create_event( - memoryId=self.config.memory_id, - actorId=self.config.actor_id, - sessionId=session_id, - payload=[ - {"blob": json.dumps(messages[0])}, - ], - eventTimestamp=monotonic_timestamp, - ) + create_event_kwargs: dict[str, Any] = { + "memoryId": self.config.memory_id, + "actorId": self.config.actor_id, + "sessionId": session_id, + "payload": [{"blob": json.dumps(messages[0])}], + "eventTimestamp": monotonic_timestamp, + } + if merged_metadata: + create_event_kwargs["metadata"] = merged_metadata + event = self.memory_client.gmdp_client.create_event(**create_event_kwargs) logger.debug("Created event: %s for message: %s", event.get("eventId"), session_message.message_id) return event except Exception as e: @@ -790,39 +869,45 @@ def _flush_messages_only(self) -> list[dict[str, Any]]: return [] # Group all messages by session_id, combining conversational and blob messages - # Structure: {session_id: {"payload": [...], "timestamp": latest_timestamp}} + # Structure: {session_id: {"payload": [...], "timestamp": latest_timestamp, "metadata": {...}}} session_groups: dict[str, dict[str, Any]] = {} - for session_id, messages, is_blob, monotonic_timestamp in messages_to_send: - if session_id not in session_groups: - session_groups[session_id] = {"payload": [], "timestamp": monotonic_timestamp} + for buffered_msg in messages_to_send: + sid = buffered_msg.session_id + if sid not in session_groups: + session_groups[sid] = {"payload": [], "timestamp": buffered_msg.timestamp, "metadata": {}} - if is_blob: - # Add blob messages to payload - for msg in messages: - session_groups[session_id]["payload"].append({"blob": json.dumps(msg)}) + if buffered_msg.is_blob: + for msg in buffered_msg.messages: + session_groups[sid]["payload"].append({"blob": json.dumps(msg)}) else: - # Add conversational messages to payload - for text, role in messages: - session_groups[session_id]["payload"].append( + for text, role in buffered_msg.messages: + session_groups[sid]["payload"].append( {"conversational": {"content": {"text": text}, "role": role.upper()}} ) # Use the latest timestamp for the combined event - if monotonic_timestamp > session_groups[session_id]["timestamp"]: - session_groups[session_id]["timestamp"] = monotonic_timestamp + if buffered_msg.timestamp > session_groups[sid]["timestamp"]: + session_groups[sid]["timestamp"] = buffered_msg.timestamp + + # Merge metadata (later entries override earlier for same key) + if buffered_msg.metadata: + session_groups[sid]["metadata"].update(buffered_msg.metadata) results = [] try: # Send one create_event per session_id with all messages (conversational + blob) for session_id, group in session_groups.items(): - event = self.memory_client.gmdp_client.create_event( - memoryId=self.config.memory_id, - actorId=self.config.actor_id, - sessionId=session_id, - payload=group["payload"], - eventTimestamp=group["timestamp"], - ) + create_event_kwargs: dict[str, Any] = { + "memoryId": self.config.memory_id, + "actorId": self.config.actor_id, + "sessionId": session_id, + "payload": group["payload"], + "eventTimestamp": group["timestamp"], + } + if group["metadata"]: + create_event_kwargs["metadata"] = group["metadata"] + event = self.memory_client.gmdp_client.create_event(**create_event_kwargs) results.append(event) logger.debug( "Flushed batched event for session %s with %d messages: %s", diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_config.py b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_config.py index cacef97..97f0bda 100644 --- a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_config.py +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_config.py @@ -87,3 +87,39 @@ def test_with_retrieval_config(self): memory_id="mem-123", session_id="sess-456", actor_id="actor-789", retrieval_config={"namespace1": retrieval} ) assert config.retrieval_config["namespace1"].top_k == 5 + + def test_default_metadata_plain_strings_normalized(self): + """Plain string values are auto-wrapped to {"stringValue": ...}.""" + config = AgentCoreMemoryConfig( + memory_id="mem-123", + session_id="sess-456", + actor_id="actor-789", + default_metadata={"project": "atlas", "env": "prod"}, + ) + assert config.default_metadata == { + "project": {"stringValue": "atlas"}, + "env": {"stringValue": "prod"}, + } + + def test_default_metadata_explicit_format_unchanged(self): + """Explicit {"stringValue": ...} dicts pass through unchanged.""" + config = AgentCoreMemoryConfig( + memory_id="mem-123", + session_id="sess-456", + actor_id="actor-789", + default_metadata={"project": {"stringValue": "atlas"}}, + ) + assert config.default_metadata == {"project": {"stringValue": "atlas"}} + + def test_default_metadata_mixed_formats(self): + """Mixed plain strings and explicit dicts both work.""" + config = AgentCoreMemoryConfig( + memory_id="mem-123", + session_id="sess-456", + actor_id="actor-789", + default_metadata={"project": "atlas", "env": {"stringValue": "prod"}}, + ) + assert config.default_metadata == { + "project": {"stringValue": "atlas"}, + "env": {"stringValue": "prod"}, + } diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py index 3d872a6..e0a6503 100644 --- a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py @@ -2,6 +2,7 @@ import logging import time +from datetime import datetime, timezone from unittest.mock import Mock, patch import pytest @@ -18,7 +19,10 @@ AgentCoreMemoryConverter, ) from bedrock_agentcore.memory.integrations.strands.config import AgentCoreMemoryConfig, RetrievalConfig -from bedrock_agentcore.memory.integrations.strands.session_manager import AgentCoreMemorySessionManager +from bedrock_agentcore.memory.integrations.strands.session_manager import ( + AgentCoreMemorySessionManager, + BufferedMessage, +) @pytest.fixture @@ -1677,8 +1681,6 @@ def test__flush_messages_multiple_sessions_grouped_into_separate_api_calls( so all messages go to one session. This test verifies the internal grouping logic by directly manipulating the buffer. """ - from datetime import datetime, timezone - calls_by_session = {} def track_create_event(**kwargs): @@ -1690,15 +1692,14 @@ def track_create_event(**kwargs): mock_memory_client.gmdp_client.create_event.side_effect = track_create_event # Directly populate buffer with messages for multiple sessions - # Buffer format: (session_id, messages, is_blob, monotonic_timestamp) base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) batching_session_manager._message_buffer = [ - ("session-A", [("SessionA_Message_0", "user")], False, base_time), - ("session-A", [("SessionA_Message_1", "user")], False, base_time), - ("session-B", [("SessionB_Message_0", "user")], False, base_time), - ("session-B", [("SessionB_Message_1", "user")], False, base_time), - ("session-B", [("SessionB_Message_2", "user")], False, base_time), - ("session-A", [("SessionA_Message_2", "user")], False, base_time), # Non-consecutive + BufferedMessage("session-A", [("SessionA_Message_0", "user")], False, base_time), + BufferedMessage("session-A", [("SessionA_Message_1", "user")], False, base_time), + BufferedMessage("session-B", [("SessionB_Message_0", "user")], False, base_time), + BufferedMessage("session-B", [("SessionB_Message_1", "user")], False, base_time), + BufferedMessage("session-B", [("SessionB_Message_2", "user")], False, base_time), + BufferedMessage("session-A", [("SessionA_Message_2", "user")], False, base_time), # Non-consecutive ] batching_session_manager._flush_messages() @@ -1747,8 +1748,6 @@ def track_create_event(**kwargs): # The combined event should use the latest timestamp (12:10:00) assert len(captured_timestamps) == 1 # The timestamp should be the latest one (12:10:00) - from datetime import datetime, timezone - expected_latest = datetime(2024, 1, 1, 12, 10, 0, tzinfo=timezone.utc) # Account for monotonic timestamp adjustment (may add microseconds) assert captured_timestamps[0] >= expected_latest @@ -1760,7 +1759,6 @@ def test__flush_messages_partial_failure_multiple_sessions_preserves_buffer( Note: Tests internal grouping logic by directly manipulating buffer. """ - from datetime import datetime, timezone def fail_on_second_session(**kwargs): session_id = kwargs.get("sessionId") @@ -1773,10 +1771,10 @@ def fail_on_second_session(**kwargs): # Directly populate buffer with messages for multiple sessions base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) batching_session_manager._message_buffer = [ - ("session-A", [("SessionA_Message_0", "user")], False, base_time), - ("session-A", [("SessionA_Message_1", "user")], False, base_time), - ("session-B", [("SessionB_Message_0", "user")], False, base_time), - ("session-B", [("SessionB_Message_1", "user")], False, base_time), + BufferedMessage("session-A", [("SessionA_Message_0", "user")], False, base_time), + BufferedMessage("session-A", [("SessionA_Message_1", "user")], False, base_time), + BufferedMessage("session-B", [("SessionB_Message_0", "user")], False, base_time), + BufferedMessage("session-B", [("SessionB_Message_1", "user")], False, base_time), ] assert batching_session_manager.pending_message_count() == 4 @@ -1828,8 +1826,6 @@ def test_mixed_sessions_with_blobs_and_conversational(self, batching_session_man Note: Tests internal grouping logic by directly manipulating buffer. """ - from datetime import datetime, timezone - calls_by_session = {} def track_create_event(**kwargs): @@ -1845,12 +1841,12 @@ def track_create_event(**kwargs): blob_content = {"role": "user", "content": [{"text": "blob_A_" + "x" * (CONVERSATIONAL_MAX_SIZE + 100)}]} batching_session_manager._message_buffer = [ # Session A: 2 conversational messages - ("session-A", [("SessionA_conv_0", "user")], False, base_time), - ("session-A", [("SessionA_conv_1", "user")], False, base_time), + BufferedMessage("session-A", [("SessionA_conv_0", "user")], False, base_time), + BufferedMessage("session-A", [("SessionA_conv_1", "user")], False, base_time), # Session A: 1 blob message - ("session-A", [blob_content], True, base_time), + BufferedMessage("session-A", [blob_content], True, base_time), # Session B: 1 conversational message - ("session-B", [("SessionB_conv_0", "user")], False, base_time), + BufferedMessage("session-B", [("SessionB_conv_0", "user")], False, base_time), ] batching_session_manager._flush_messages() @@ -2541,7 +2537,12 @@ def test_after_invocation_hook_flushes_buffer(self, batching_session_manager, mo # Add messages to buffer with batching_session_manager._message_lock: batching_session_manager._message_buffer.append( - ("test-session", [("user", "test message")], False, batching_session_manager._get_monotonic_timestamp()) + BufferedMessage( + "test-session", + [("user", "test message")], + False, + batching_session_manager._get_monotonic_timestamp(), + ) ) assert batching_session_manager.pending_message_count() == 1 @@ -2764,7 +2765,9 @@ def test_interval_flush_callback_flushes_when_buffer_has_messages(self): # Add messages to buffer with manager._message_lock: manager._message_buffer.append( - ("test-session", [("user", "test message")], False, manager._get_monotonic_timestamp()) + BufferedMessage( + "test-session", [("user", "test message")], False, manager._get_monotonic_timestamp() + ) ) assert manager.pending_message_count() == 1 @@ -2919,7 +2922,9 @@ def test_interval_flush_callback_flushes_when_both_buffers_have_data(self): # Add both messages and agent state to buffers with manager._message_lock: manager._message_buffer.append( - ("test-session", [("user", "test message")], False, manager._get_monotonic_timestamp()) + BufferedMessage( + "test-session", [("user", "test message")], False, manager._get_monotonic_timestamp() + ) ) from strands.types.session import SessionAgent @@ -2984,3 +2989,248 @@ def test_config_flush_interval_validation(self): actor_id="test-actor", flush_interval_seconds=-5.0, ) + + +class TestMetadataSupport: + """Tests for user-supplied event metadata on messages.""" + + @pytest.fixture + def config_with_metadata(self): + """Config with default metadata.""" + return AgentCoreMemoryConfig( + memory_id="test-memory-123", + session_id="test-session-456", + actor_id="test-actor-789", + default_metadata={"location": {"stringValue": "NYC"}, "team": {"stringValue": "support"}}, + ) + + @pytest.fixture + def session_manager_with_metadata(self, config_with_metadata, mock_memory_client): + """Session manager with default metadata configured.""" + return _create_session_manager(config_with_metadata, mock_memory_client) + + def test_create_message_with_default_metadata(self, session_manager_with_metadata, mock_memory_client): + """Config-level default_metadata flows to create_event.""" + mock_memory_client.create_event.return_value = {"eventId": "evt_1"} + session_message = SessionMessage.from_message({"role": "user", "content": [{"text": "hello"}]}, 0) + + session_manager_with_metadata.create_message("test-session-456", "agent-1", session_message) + + mock_memory_client.create_event.assert_called_once() + call_kwargs = mock_memory_client.create_event.call_args[1] + assert call_kwargs["metadata"] == {"location": {"stringValue": "NYC"}, "team": {"stringValue": "support"}} + + def test_create_message_with_per_call_metadata(self, session_manager, mock_memory_client): + """Per-call metadata passed via kwargs flows to create_event.""" + mock_memory_client.create_event.return_value = {"eventId": "evt_1"} + session_message = SessionMessage.from_message({"role": "user", "content": [{"text": "hello"}]}, 0) + per_call = {"project": {"stringValue": "alpha"}} + + session_manager.create_message("test-session-456", "agent-1", session_message, metadata=per_call) + + mock_memory_client.create_event.assert_called_once() + call_kwargs = mock_memory_client.create_event.call_args[1] + assert call_kwargs["metadata"] == {"project": {"stringValue": "alpha"}} + + def test_metadata_merging_precedence(self, session_manager_with_metadata, mock_memory_client): + """Per-call metadata overrides config default for the same key.""" + mock_memory_client.create_event.return_value = {"eventId": "evt_1"} + session_message = SessionMessage.from_message({"role": "user", "content": [{"text": "hello"}]}, 0) + per_call = {"location": {"stringValue": "SF"}, "project": {"stringValue": "beta"}} + + session_manager_with_metadata.create_message("test-session-456", "agent-1", session_message, metadata=per_call) + + call_kwargs = mock_memory_client.create_event.call_args[1] + assert call_kwargs["metadata"]["location"] == {"stringValue": "SF"} + assert call_kwargs["metadata"]["team"] == {"stringValue": "support"} + assert call_kwargs["metadata"]["project"] == {"stringValue": "beta"} + + def test_metadata_reserved_keys_rejected(self, session_manager): + """ValueError raised when user metadata contains reserved keys.""" + from bedrock_agentcore.memory.integrations.strands.session_manager import RESERVED_METADATA_KEYS + + session_message = SessionMessage.from_message({"role": "user", "content": [{"text": "hello"}]}, 0) + + for reserved_key in RESERVED_METADATA_KEYS: + with pytest.raises(ValueError, match="reserved"): + session_manager.create_message( + "test-session-456", + "agent-1", + session_message, + metadata={reserved_key: {"stringValue": "bad"}}, + ) + + def test_metadata_max_keys_exceeded(self, session_manager): + """ValueError raised when combined metadata exceeds MAX_METADATA_KEYS.""" + from bedrock_agentcore.memory.integrations.strands.session_manager import MAX_METADATA_KEYS + + session_message = SessionMessage.from_message({"role": "user", "content": [{"text": "hello"}]}, 0) + too_many = {f"key_{i}": {"stringValue": f"val_{i}"} for i in range(MAX_METADATA_KEYS + 1)} + + with pytest.raises(ValueError, match="exceeding the maximum"): + session_manager.create_message("test-session-456", "agent-1", session_message, metadata=too_many) + + def test_create_message_no_metadata_passes_none(self, session_manager, mock_memory_client): + """When no metadata configured and none passed, metadata kwarg is None.""" + mock_memory_client.create_event.return_value = {"eventId": "evt_1"} + session_message = SessionMessage.from_message({"role": "user", "content": [{"text": "hello"}]}, 0) + + session_manager.create_message("test-session-456", "agent-1", session_message) + + call_kwargs = mock_memory_client.create_event.call_args[1] + assert call_kwargs.get("metadata") is None + + def test_batched_messages_include_metadata(self, mock_memory_client): + """Metadata flows through the batching path and appears in the flushed event.""" + config = AgentCoreMemoryConfig( + memory_id="test-memory-123", + session_id="test-session-456", + actor_id="test-actor-789", + batch_size=5, + default_metadata={"env": {"stringValue": "staging"}}, + ) + manager = _create_session_manager(config, mock_memory_client) + + mock_memory_client.gmdp_client.create_event.return_value = {"event": {"eventId": "batch_evt_1"}} + + # Buffer messages with metadata + base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + manager._message_buffer = [ + BufferedMessage( + "test-session-456", + [("hello", "user")], + False, + base_time, + metadata={"env": {"stringValue": "staging"}}, + ), + BufferedMessage( + "test-session-456", + [("world", "assistant")], + False, + base_time, + metadata={"env": {"stringValue": "staging"}, "extra": {"stringValue": "val"}}, + ), + ] + + manager._flush_messages_only() + + call_kwargs = mock_memory_client.gmdp_client.create_event.call_args[1] + # Merged metadata: later message's metadata overrides earlier + assert call_kwargs["metadata"]["env"] == {"stringValue": "staging"} + assert call_kwargs["metadata"]["extra"] == {"stringValue": "val"} + + def test_blob_message_with_metadata(self, session_manager_with_metadata, mock_memory_client): + """Blob messages also receive metadata.""" + from bedrock_agentcore.memory.integrations.strands.bedrock_converter import CONVERSATIONAL_MAX_SIZE + + mock_memory_client.gmdp_client.create_event.return_value = {"event": {"eventId": "blob_1"}} + big_text = "x" * (CONVERSATIONAL_MAX_SIZE + 100) + session_message = SessionMessage.from_message({"role": "user", "content": [{"text": big_text}]}, 0) + + session_manager_with_metadata.create_message("test-session-456", "agent-1", session_message) + + call_kwargs = mock_memory_client.gmdp_client.create_event.call_args[1] + assert call_kwargs["metadata"] == {"location": {"stringValue": "NYC"}, "team": {"stringValue": "support"}} + + def test_metadata_provider_called_per_event(self, mock_memory_client): + """metadata_provider is called at each create_message and its values appear in the event.""" + call_count = 0 + + def provider(): + nonlocal call_count + call_count += 1 + return {"traceId": {"stringValue": f"trace-{call_count}"}} + + config = AgentCoreMemoryConfig( + memory_id="test-memory-123", + session_id="test-session-456", + actor_id="test-actor-789", + metadata_provider=provider, + ) + manager = _create_session_manager(config, mock_memory_client) + mock_memory_client.create_event.return_value = {"eventId": "evt_1"} + + msg1 = SessionMessage.from_message({"role": "user", "content": [{"text": "hello"}]}, 0) + manager.create_message("test-session-456", "agent-1", msg1) + + assert call_count == 1 + kwargs1 = mock_memory_client.create_event.call_args[1] + assert kwargs1["metadata"]["traceId"] == {"stringValue": "trace-1"} + + msg2 = SessionMessage.from_message({"role": "user", "content": [{"text": "world"}]}, 0) + manager.create_message("test-session-456", "agent-1", msg2) + + assert call_count == 2 + kwargs2 = mock_memory_client.create_event.call_args[1] + assert kwargs2["metadata"]["traceId"] == {"stringValue": "trace-2"} + + def test_metadata_provider_merged_with_defaults(self, mock_memory_client): + """metadata_provider values override default_metadata for same key, but both appear.""" + config = AgentCoreMemoryConfig( + memory_id="test-memory-123", + session_id="test-session-456", + actor_id="test-actor-789", + default_metadata={"env": {"stringValue": "prod"}, "team": {"stringValue": "support"}}, + metadata_provider=lambda: {"env": {"stringValue": "staging"}, "traceId": {"stringValue": "t-1"}}, + ) + manager = _create_session_manager(config, mock_memory_client) + mock_memory_client.create_event.return_value = {"eventId": "evt_1"} + + msg = SessionMessage.from_message({"role": "user", "content": [{"text": "hello"}]}, 0) + manager.create_message("test-session-456", "agent-1", msg) + + call_kwargs = mock_memory_client.create_event.call_args[1] + # provider overrides default for "env" + assert call_kwargs["metadata"]["env"] == {"stringValue": "staging"} + # default still present + assert call_kwargs["metadata"]["team"] == {"stringValue": "support"} + # provider adds new key + assert call_kwargs["metadata"]["traceId"] == {"stringValue": "t-1"} + + def test_metadata_provider_reserved_keys_rejected(self, mock_memory_client): + """metadata_provider returning reserved keys raises ValueError.""" + config = AgentCoreMemoryConfig( + memory_id="test-memory-123", + session_id="test-session-456", + actor_id="test-actor-789", + metadata_provider=lambda: {"stateType": {"stringValue": "bad"}}, + ) + manager = _create_session_manager(config, mock_memory_client) + + msg = SessionMessage.from_message({"role": "user", "content": [{"text": "hello"}]}, 0) + with pytest.raises(ValueError, match="reserved"): + manager.create_message("test-session-456", "agent-1", msg) + + def test_metadata_provider_plain_strings_normalized(self, mock_memory_client): + """metadata_provider returning plain strings gets auto-normalized.""" + config = AgentCoreMemoryConfig( + memory_id="test-memory-123", + session_id="test-session-456", + actor_id="test-actor-789", + metadata_provider=lambda: {"traceId": "trace-abc"}, + ) + manager = _create_session_manager(config, mock_memory_client) + mock_memory_client.create_event.return_value = {"eventId": "evt_1"} + + msg = SessionMessage.from_message({"role": "user", "content": [{"text": "hello"}]}, 0) + manager.create_message("test-session-456", "agent-1", msg) + + kwargs = mock_memory_client.create_event.call_args[1] + assert kwargs["metadata"]["traceId"] == {"stringValue": "trace-abc"} + + def test_default_metadata_plain_strings_normalized(self, mock_memory_client): + """default_metadata with plain strings gets auto-normalized at config time.""" + config = AgentCoreMemoryConfig( + memory_id="test-memory-123", + session_id="test-session-456", + actor_id="test-actor-789", + default_metadata={"project": "atlas"}, + ) + manager = _create_session_manager(config, mock_memory_client) + mock_memory_client.create_event.return_value = {"eventId": "evt_1"} + + msg = SessionMessage.from_message({"role": "user", "content": [{"text": "hello"}]}, 0) + manager.create_message("test-session-456", "agent-1", msg) + + kwargs = mock_memory_client.create_event.call_args[1] + assert kwargs["metadata"]["project"] == {"stringValue": "atlas"} diff --git a/tests_integ/memory/integrations/test_session_manager.py b/tests_integ/memory/integrations/test_session_manager.py index 5ab308b..b0ec576 100644 --- a/tests_integ/memory/integrations/test_session_manager.py +++ b/tests_integ/memory/integrations/test_session_manager.py @@ -19,7 +19,12 @@ from bedrock_agentcore.memory.integrations.strands.bedrock_converter import AgentCoreMemoryConverter from bedrock_agentcore.memory.integrations.strands.config import AgentCoreMemoryConfig, RetrievalConfig from bedrock_agentcore.memory.integrations.strands.session_manager import AgentCoreMemorySessionManager -from bedrock_agentcore.memory.models.filters import EventMetadataFilter, LeftExpression, OperatorType +from bedrock_agentcore.memory.models.filters import ( + EventMetadataFilter, + LeftExpression, + OperatorType, + RightExpression, +) logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) @@ -375,3 +380,197 @@ def test_agent_multi_turn_with_batching(self, test_memory_stm): assert len(messages) >= 6 # endregion End-to-end agent with batching tests + + # region Event metadata integration tests + + def test_create_message_with_metadata_and_filter(self, test_memory_stm, memory_client): + """Test that user-supplied metadata is persisted and that filters actually exclude non-matching events.""" + session_id = f"test-meta-{uuid.uuid4().hex[:8]}" + actor_id = f"test-actor-{uuid.uuid4().hex[:8]}" + + config = AgentCoreMemoryConfig( + memory_id=test_memory_stm["id"], + session_id=session_id, + actor_id=actor_id, + default_metadata={"project": {"stringValue": "atlas"}, "env": {"stringValue": "test"}}, + ) + session_manager = AgentCoreMemorySessionManager(agentcore_memory_config=config, region_name=REGION) + + agent = Agent(system_prompt="You are a helpful assistant.", session_manager=session_manager) + agent("Hello, remember my project is Atlas") + + # Get ALL events (unfiltered) to know the total count + all_events = memory_client.list_events( + memory_id=test_memory_stm["id"], + actor_id=actor_id, + session_id=session_id, + ) + assert len(all_events) >= 1 # Sanity: events exist + + # Positive filter: query events matching our custom metadata + project_filter = EventMetadataFilter.build_expression( + left_operand=LeftExpression.build("project"), + operator=OperatorType.EQUALS_TO, + right_operand=RightExpression.build("atlas"), + ) + matching_events = memory_client.list_events( + memory_id=test_memory_stm["id"], + actor_id=actor_id, + session_id=session_id, + event_metadata=[project_filter], + ) + assert len(matching_events) >= 1 + + # Verify metadata values round-trip correctly + for event in matching_events: + meta = event.get("metadata", {}) + assert meta.get("project", {}).get("stringValue") == "atlas" + assert meta.get("env", {}).get("stringValue") == "test" + + # Negative filter: query with a value that was never written + wrong_filter = EventMetadataFilter.build_expression( + left_operand=LeftExpression.build("project"), + operator=OperatorType.EQUALS_TO, + right_operand=RightExpression.build("nonexistent_project"), + ) + non_matching_events = memory_client.list_events( + memory_id=test_memory_stm["id"], + actor_id=actor_id, + session_id=session_id, + event_metadata=[wrong_filter], + ) + assert len(non_matching_events) == 0, ( + f"Expected 0 events for nonexistent metadata value, got {len(non_matching_events)}" + ) + + # The positive set should be a strict subset of all events + # (all_events includes state events that don't have user metadata) + assert len(matching_events) < len(all_events) + + def test_metadata_survives_session_resume(self, test_memory_stm, memory_client): + """Events with metadata written by one session manager are filterable from another.""" + session_id = f"test-meta-resume-{uuid.uuid4().hex[:8]}" + actor_id = f"test-actor-{uuid.uuid4().hex[:8]}" + + # First session: write messages with metadata + config1 = AgentCoreMemoryConfig( + memory_id=test_memory_stm["id"], + session_id=session_id, + actor_id=actor_id, + default_metadata={"source": {"stringValue": "session1"}}, + ) + sm1 = AgentCoreMemorySessionManager(agentcore_memory_config=config1, region_name=REGION) + agent1 = Agent(system_prompt="You are a helpful assistant.", session_manager=sm1) + agent1("My favourite colour is blue") + + # Second session: resume without metadata configured + config2 = AgentCoreMemoryConfig( + memory_id=test_memory_stm["id"], + session_id=session_id, + actor_id=actor_id, + ) + sm2 = AgentCoreMemorySessionManager(agentcore_memory_config=config2, region_name=REGION) + + # Positive: filter matches metadata written by session 1 + source_filter = EventMetadataFilter.build_expression( + left_operand=LeftExpression.build("source"), + operator=OperatorType.EQUALS_TO, + right_operand=RightExpression.build("session1"), + ) + matching = sm2.memory_client.list_events( + memory_id=test_memory_stm["id"], + actor_id=actor_id, + session_id=session_id, + event_metadata=[source_filter], + ) + assert len(matching) >= 1 + + # Negative: filter for a different source value returns nothing + wrong_source_filter = EventMetadataFilter.build_expression( + left_operand=LeftExpression.build("source"), + operator=OperatorType.EQUALS_TO, + right_operand=RightExpression.build("session999"), + ) + not_matching = sm2.memory_client.list_events( + memory_id=test_memory_stm["id"], + actor_id=actor_id, + session_id=session_id, + event_metadata=[wrong_source_filter], + ) + assert len(not_matching) == 0 + + def test_metadata_provider_attaches_dynamic_trace_id(self, test_memory_stm, memory_client): + """metadata_provider injects a different traceId per invocation and events are filterable by each.""" + session_id = f"test-meta-prov-{uuid.uuid4().hex[:8]}" + actor_id = f"test-actor-{uuid.uuid4().hex[:8]}" + + current_trace = {"traceId": {"stringValue": "trace-AAA"}} + + config = AgentCoreMemoryConfig( + memory_id=test_memory_stm["id"], + session_id=session_id, + actor_id=actor_id, + metadata_provider=lambda: dict(current_trace), + ) + sm = AgentCoreMemorySessionManager(agentcore_memory_config=config, region_name=REGION) + agent = Agent(system_prompt="You are a helpful assistant.", session_manager=sm) + + # First invocation with trace-AAA + agent("Hello from trace AAA") + + # Switch trace for second invocation + current_trace["traceId"] = {"stringValue": "trace-BBB"} + agent("Hello from trace BBB") + + # Filter for trace-AAA — should find events + filter_aaa = EventMetadataFilter.build_expression( + left_operand=LeftExpression.build("traceId"), + operator=OperatorType.EQUALS_TO, + right_operand=RightExpression.build("trace-AAA"), + ) + events_aaa = memory_client.list_events( + memory_id=test_memory_stm["id"], + actor_id=actor_id, + session_id=session_id, + event_metadata=[filter_aaa], + ) + assert len(events_aaa) >= 1 + for e in events_aaa: + assert e.get("metadata", {}).get("traceId", {}).get("stringValue") == "trace-AAA" + + # Filter for trace-BBB — should find events + filter_bbb = EventMetadataFilter.build_expression( + left_operand=LeftExpression.build("traceId"), + operator=OperatorType.EQUALS_TO, + right_operand=RightExpression.build("trace-BBB"), + ) + events_bbb = memory_client.list_events( + memory_id=test_memory_stm["id"], + actor_id=actor_id, + session_id=session_id, + event_metadata=[filter_bbb], + ) + assert len(events_bbb) >= 1 + for e in events_bbb: + assert e.get("metadata", {}).get("traceId", {}).get("stringValue") == "trace-BBB" + + # Negative: nonexistent trace returns nothing + filter_none = EventMetadataFilter.build_expression( + left_operand=LeftExpression.build("traceId"), + operator=OperatorType.EQUALS_TO, + right_operand=RightExpression.build("trace-DOES-NOT-EXIST"), + ) + events_none = memory_client.list_events( + memory_id=test_memory_stm["id"], + actor_id=actor_id, + session_id=session_id, + event_metadata=[filter_none], + ) + assert len(events_none) == 0 + + # The two traces should be disjoint sets + aaa_ids = {e["eventId"] for e in events_aaa} + bbb_ids = {e["eventId"] for e in events_bbb} + assert aaa_ids.isdisjoint(bbb_ids), "trace-AAA and trace-BBB events should not overlap" + + # endregion Event metadata integration tests