From c848c1dfd0ec04ad2b1e9a55f4fcd56ef379d45e Mon Sep 17 00:00:00 2001 From: Tejas Kashinath Date: Fri, 13 Mar 2026 23:25:19 -0400 Subject: [PATCH 1/2] feat(strands-memory): add multi-agent support to AgentCoreMemorySessionManager Tag every message event with agentId metadata so each agent in a multi-agent session only retrieves its own messages. The list_messages() path now filters by agent_id first, with a backward- compatible fallback for sessions created before this change. Changes: - create_message: always attach {agentId: agent_id} metadata to both conversational and blob events (immediate and batched paths) - _flush_messages_only: group buffered messages by (session_id, agent_id) so each agent gets its own batched event with correct metadata - list_messages: filter by agentId metadata; fall back to unfiltered query when the filtered result is empty (backward compat) - initialize: allow multiple agents per session (info log instead of warning gate) Closes #149 (multi-agent portion) --- .../integrations/strands/session_manager.py | 129 ++++--- .../test_agentcore_memory_session_manager.py | 327 ++++++++++++++++-- .../integrations/test_session_manager.py | 149 ++++++++ 3 files changed, 538 insertions(+), 67 deletions(-) diff --git a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py index 68bc05b..846175c 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py +++ b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py @@ -129,8 +129,8 @@ def __init__( session = boto_session or boto3.Session(region_name=region_name) self.has_existing_agent = False - # Batching support - stores pre-processed messages: (session_id, messages, is_blob, timestamp) - self._message_buffer: list[tuple[str, list[tuple[str, str]], bool, datetime]] = [] + # Batching support - stores pre-processed messages: (session_id, agent_id, messages, is_blob, timestamp) + self._message_buffer: list[tuple[str, Optional[str], list[tuple[str, str]], bool, datetime]] = [] self._message_lock = threading.Lock() # Agent state buffering - stores all agent state updates: (session_id, agent) @@ -482,6 +482,11 @@ def create_message( is_blob = self.converter.exceeds_conversational_limit(messages[0]) + # Build agent_id metadata for multi-agent message tagging + agent_metadata = None + if agent_id: + agent_metadata = {AGENT_ID_KEY: {"stringValue": agent_id}} + # Parse the original timestamp and use it as desired timestamp original_timestamp = datetime.fromisoformat(session_message.created_at.replace("Z", "+00:00")) monotonic_timestamp = self._get_monotonic_timestamp(original_timestamp) @@ -490,7 +495,7 @@ def create_message( # Buffer the pre-processed message should_flush = False with self._message_lock: - self._message_buffer.append((session_id, messages, is_blob, monotonic_timestamp)) + self._message_buffer.append((session_id, agent_id, messages, is_blob, monotonic_timestamp)) should_flush = len(self._message_buffer) >= self.config.batch_size # Flush only messages outside the lock to prevent deadlock @@ -508,17 +513,19 @@ def create_message( session_id=session_id, messages=messages, event_timestamp=monotonic_timestamp, + metadata=agent_metadata, ) else: - event = self.memory_client.gmdp_client.create_event( - memoryId=self.config.memory_id, - actorId=self.config.actor_id, - sessionId=session_id, - payload=[ - {"blob": json.dumps(messages[0])}, - ], - eventTimestamp=monotonic_timestamp, - ) + create_event_kwargs: dict[str, Any] = { + "memoryId": self.config.memory_id, + "actorId": self.config.actor_id, + "sessionId": session_id, + "payload": [{"blob": json.dumps(messages[0])}], + "eventTimestamp": monotonic_timestamp, + } + if agent_metadata: + create_event_kwargs["metadata"] = agent_metadata + event = self.memory_client.gmdp_client.create_event(**create_event_kwargs) logger.debug("Created event: %s for message: %s", event.get("eventId"), session_message.message_id) return event except Exception as e: @@ -598,12 +605,46 @@ def list_messages( try: max_results = (limit + offset) if limit else MAX_FETCH_ALL_RESULTS - events = self.memory_client.list_events( - memory_id=self.config.memory_id, - actor_id=self.config.actor_id, - session_id=session_id, - max_results=max_results, - ) + # Try filtering by agent_id first for multi-agent support + if agent_id: + agent_id_filter = [ + EventMetadataFilter.build_expression( + left_operand=LeftExpression.build(AGENT_ID_KEY), + operator=OperatorType.EQUALS_TO, + right_operand=RightExpression.build(agent_id), + ) + ] + events = self.memory_client.list_events( + memory_id=self.config.memory_id, + actor_id=self.config.actor_id, + session_id=session_id, + max_results=max_results, + event_metadata=agent_id_filter, + ) + + # Backward compatibility: if filtered query returns empty, retry without + # the agent_id filter. This handles sessions created before multi-agent + # metadata was added to message events. + if not events: + logger.debug( + "No events found with agent_id filter for agent %s, " + "falling back to unfiltered query for backward compatibility.", + agent_id, + ) + events = self.memory_client.list_events( + memory_id=self.config.memory_id, + actor_id=self.config.actor_id, + session_id=session_id, + max_results=max_results, + ) + else: + events = self.memory_client.list_events( + memory_id=self.config.memory_id, + actor_id=self.config.actor_id, + session_id=session_id, + max_results=max_results, + ) + messages = self.converter.events_to_messages(events) if self.config.filter_restored_tool_context: messages = self._filter_restored_tool_context(messages) @@ -756,11 +797,8 @@ def register_hooks(self, registry: HookRegistry, **kwargs) -> None: @override def initialize(self, agent: "Agent", **kwargs: Any) -> None: if self.has_existing_agent: - logger.warning( - "An Agent already exists in session %s. We currently support one agent per session.", self.session_id - ) - else: - self.has_existing_agent = True + logger.info("Multiple agents registered in session %s.", self.session_id) + self.has_existing_agent = True RepositorySessionManager.initialize(self, agent, **kwargs) # endregion RepositorySessionManager overrides @@ -789,44 +827,49 @@ def _flush_messages_only(self) -> list[dict[str, Any]]: if not messages_to_send: return [] - # Group all messages by session_id, combining conversational and blob messages - # Structure: {session_id: {"payload": [...], "timestamp": latest_timestamp}} - session_groups: dict[str, dict[str, Any]] = {} + # Group all messages by (session_id, agent_id), combining conversational and blob messages + # Structure: {(session_id, agent_id): {"payload": [...], "timestamp": latest_timestamp}} + session_groups: dict[tuple[str, Optional[str]], dict[str, Any]] = {} - for session_id, messages, is_blob, monotonic_timestamp in messages_to_send: - if session_id not in session_groups: - session_groups[session_id] = {"payload": [], "timestamp": monotonic_timestamp} + for session_id, agent_id, messages, is_blob, monotonic_timestamp in messages_to_send: + group_key = (session_id, agent_id) + if group_key not in session_groups: + session_groups[group_key] = {"payload": [], "timestamp": monotonic_timestamp} if is_blob: # Add blob messages to payload for msg in messages: - session_groups[session_id]["payload"].append({"blob": json.dumps(msg)}) + session_groups[group_key]["payload"].append({"blob": json.dumps(msg)}) else: # Add conversational messages to payload for text, role in messages: - session_groups[session_id]["payload"].append( + session_groups[group_key]["payload"].append( {"conversational": {"content": {"text": text}, "role": role.upper()}} ) # Use the latest timestamp for the combined event - if monotonic_timestamp > session_groups[session_id]["timestamp"]: - session_groups[session_id]["timestamp"] = monotonic_timestamp + if monotonic_timestamp > session_groups[group_key]["timestamp"]: + session_groups[group_key]["timestamp"] = monotonic_timestamp results = [] try: - # Send one create_event per session_id with all messages (conversational + blob) - for session_id, group in session_groups.items(): - event = self.memory_client.gmdp_client.create_event( - memoryId=self.config.memory_id, - actorId=self.config.actor_id, - sessionId=session_id, - payload=group["payload"], - eventTimestamp=group["timestamp"], - ) + # Send one create_event per (session_id, agent_id) with all messages + for (session_id, agent_id), group in session_groups.items(): + create_event_kwargs: dict[str, Any] = { + "memoryId": self.config.memory_id, + "actorId": self.config.actor_id, + "sessionId": session_id, + "payload": group["payload"], + "eventTimestamp": group["timestamp"], + } + if agent_id: + create_event_kwargs["metadata"] = {AGENT_ID_KEY: {"stringValue": agent_id}} + event = self.memory_client.gmdp_client.create_event(**create_event_kwargs) results.append(event) logger.debug( - "Flushed batched event for session %s with %d messages: %s", + "Flushed batched event for session %s agent %s with %d messages: %s", session_id, + agent_id, len(group["payload"]), event.get("eventId"), ) diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py index 3d872a6..7c2f3d6 100644 --- a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py @@ -1195,9 +1195,11 @@ def test_list_messages_default_max_results(self, session_manager, mock_memory_cl session_manager.list_messages("test-session-456", "test-agent-123") - mock_memory_client.list_events.assert_called_once() - call_kwargs = mock_memory_client.list_events.call_args[1] - assert call_kwargs["max_results"] == 10000 + # With multi-agent support, first call filters by agent_id; second is fallback + assert mock_memory_client.list_events.call_count == 2 + # Both calls should use the same max_results + for call in mock_memory_client.list_events.call_args_list: + assert call[1]["max_results"] == 10000 def test_list_messages_with_limit_calculates_max_results(self, session_manager, mock_memory_client): """Test listing messages with limit calculates max_results correctly.""" @@ -1205,9 +1207,10 @@ def test_list_messages_with_limit_calculates_max_results(self, session_manager, session_manager.list_messages("test-session-456", "test-agent-123", limit=500, offset=50) - mock_memory_client.list_events.assert_called_once() - call_kwargs = mock_memory_client.list_events.call_args[1] - assert call_kwargs["max_results"] == 550 # limit + offset + # With multi-agent support, first call filters by agent_id; second is fallback + assert mock_memory_client.list_events.call_count == 2 + for call in mock_memory_client.list_events.call_args_list: + assert call[1]["max_results"] == 550 # limit + offset def test_append_message_handles_none_from_create_message(self, session_manager, test_agent): """Test that append_message gracefully handles None return from create_message.""" @@ -1690,15 +1693,15 @@ def track_create_event(**kwargs): mock_memory_client.gmdp_client.create_event.side_effect = track_create_event # Directly populate buffer with messages for multiple sessions - # Buffer format: (session_id, messages, is_blob, monotonic_timestamp) + # Buffer format: (session_id, agent_id, messages, is_blob, monotonic_timestamp) base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) batching_session_manager._message_buffer = [ - ("session-A", [("SessionA_Message_0", "user")], False, base_time), - ("session-A", [("SessionA_Message_1", "user")], False, base_time), - ("session-B", [("SessionB_Message_0", "user")], False, base_time), - ("session-B", [("SessionB_Message_1", "user")], False, base_time), - ("session-B", [("SessionB_Message_2", "user")], False, base_time), - ("session-A", [("SessionA_Message_2", "user")], False, base_time), # Non-consecutive + ("session-A", "agent-1", [("SessionA_Message_0", "user")], False, base_time), + ("session-A", "agent-1", [("SessionA_Message_1", "user")], False, base_time), + ("session-B", "agent-1", [("SessionB_Message_0", "user")], False, base_time), + ("session-B", "agent-1", [("SessionB_Message_1", "user")], False, base_time), + ("session-B", "agent-1", [("SessionB_Message_2", "user")], False, base_time), + ("session-A", "agent-1", [("SessionA_Message_2", "user")], False, base_time), # Non-consecutive ] batching_session_manager._flush_messages() @@ -1773,10 +1776,10 @@ def fail_on_second_session(**kwargs): # Directly populate buffer with messages for multiple sessions base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) batching_session_manager._message_buffer = [ - ("session-A", [("SessionA_Message_0", "user")], False, base_time), - ("session-A", [("SessionA_Message_1", "user")], False, base_time), - ("session-B", [("SessionB_Message_0", "user")], False, base_time), - ("session-B", [("SessionB_Message_1", "user")], False, base_time), + ("session-A", "agent-1", [("SessionA_Message_0", "user")], False, base_time), + ("session-A", "agent-1", [("SessionA_Message_1", "user")], False, base_time), + ("session-B", "agent-1", [("SessionB_Message_0", "user")], False, base_time), + ("session-B", "agent-1", [("SessionB_Message_1", "user")], False, base_time), ] assert batching_session_manager.pending_message_count() == 4 @@ -1845,12 +1848,12 @@ def track_create_event(**kwargs): blob_content = {"role": "user", "content": [{"text": "blob_A_" + "x" * (CONVERSATIONAL_MAX_SIZE + 100)}]} batching_session_manager._message_buffer = [ # Session A: 2 conversational messages - ("session-A", [("SessionA_conv_0", "user")], False, base_time), - ("session-A", [("SessionA_conv_1", "user")], False, base_time), + ("session-A", "agent-1", [("SessionA_conv_0", "user")], False, base_time), + ("session-A", "agent-1", [("SessionA_conv_1", "user")], False, base_time), # Session A: 1 blob message - ("session-A", [blob_content], True, base_time), + ("session-A", "agent-1", [blob_content], True, base_time), # Session B: 1 conversational message - ("session-B", [("SessionB_conv_0", "user")], False, base_time), + ("session-B", "agent-1", [("SessionB_conv_0", "user")], False, base_time), ] batching_session_manager._flush_messages() @@ -2541,7 +2544,7 @@ def test_after_invocation_hook_flushes_buffer(self, batching_session_manager, mo # Add messages to buffer with batching_session_manager._message_lock: batching_session_manager._message_buffer.append( - ("test-session", [("user", "test message")], False, batching_session_manager._get_monotonic_timestamp()) + ("test-session", "agent-1", [("user", "test message")], False, batching_session_manager._get_monotonic_timestamp()) ) assert batching_session_manager.pending_message_count() == 1 @@ -2764,7 +2767,7 @@ def test_interval_flush_callback_flushes_when_buffer_has_messages(self): # Add messages to buffer with manager._message_lock: manager._message_buffer.append( - ("test-session", [("user", "test message")], False, manager._get_monotonic_timestamp()) + ("test-session", "agent-1", [("user", "test message")], False, manager._get_monotonic_timestamp()) ) assert manager.pending_message_count() == 1 @@ -2919,7 +2922,7 @@ def test_interval_flush_callback_flushes_when_both_buffers_have_data(self): # Add both messages and agent state to buffers with manager._message_lock: manager._message_buffer.append( - ("test-session", [("user", "test message")], False, manager._get_monotonic_timestamp()) + ("test-session", "agent-1", [("user", "test message")], False, manager._get_monotonic_timestamp()) ) from strands.types.session import SessionAgent @@ -2984,3 +2987,279 @@ def test_config_flush_interval_validation(self): actor_id="test-actor", flush_interval_seconds=-5.0, ) + + +class TestMultiAgentSupport: + """Test multi-agent support: agent_id metadata tagging, filtering, and initialize behavior.""" + + def test_create_message_tags_agent_id_metadata(self, session_manager, mock_memory_client): + """Verify AGENT_ID_KEY metadata is auto-added to message events.""" + mock_memory_client.create_event.return_value = {"eventId": "event-123"} + + message = SessionMessage( + message={"role": "user", "content": [{"text": "Hello"}]}, + message_id=1, + created_at="2024-01-01T12:00:00Z", + ) + + session_manager.create_message("test-session-456", "my-agent-id", message) + + mock_memory_client.create_event.assert_called_once() + call_kwargs = mock_memory_client.create_event.call_args[1] + assert call_kwargs["metadata"] == {"agentId": {"stringValue": "my-agent-id"}} + + def test_create_message_blob_tags_agent_id_metadata(self, session_manager, mock_memory_client): + """Verify AGENT_ID_KEY metadata is auto-added to blob message events.""" + mock_memory_client.gmdp_client.create_event.return_value = {"eventId": "event-123"} + + # Create a message that exceeds the conversational limit (becomes blob) + large_text = "x" * (CONVERSATIONAL_MAX_SIZE + 100) + message = SessionMessage( + message={"role": "user", "content": [{"text": large_text}]}, + message_id=1, + created_at="2024-01-01T12:00:00Z", + ) + + session_manager.create_message("test-session-456", "my-agent-id", message) + + mock_memory_client.gmdp_client.create_event.assert_called_once() + call_kwargs = mock_memory_client.gmdp_client.create_event.call_args[1] + assert call_kwargs["metadata"] == {"agentId": {"stringValue": "my-agent-id"}} + + def test_create_message_no_agent_id_omits_metadata(self, session_manager, mock_memory_client): + """Verify no metadata is added when agent_id is empty/None.""" + mock_memory_client.create_event.return_value = {"eventId": "event-123"} + + message = SessionMessage( + message={"role": "user", "content": [{"text": "Hello"}]}, + message_id=1, + created_at="2024-01-01T12:00:00Z", + ) + + session_manager.create_message("test-session-456", "", message) + + mock_memory_client.create_event.assert_called_once() + call_kwargs = mock_memory_client.create_event.call_args[1] + assert call_kwargs.get("metadata") is None + + def test_list_messages_filters_by_agent_id(self, session_manager, mock_memory_client): + """Verify agent_id filter is passed to list_events when agent_id is provided.""" + mock_memory_client.list_events.return_value = [ + { + "eventId": "event-1", + "eventTimestamp": "2024-01-01T12:00:00Z", + "payload": [ + { + "conversational": { + "content": { + "text": '{"message": {"role": "user", "content": [{"text": "Hello"}]}, "message_id": 1}' + }, + "role": "USER", + } + } + ], + } + ] + + session_manager.list_messages("test-session-456", "my-agent-id") + + # First call should include the agent_id filter + first_call_kwargs = mock_memory_client.list_events.call_args_list[0][1] + assert "event_metadata" in first_call_kwargs + filter_expr = first_call_kwargs["event_metadata"][0] + assert filter_expr["left"]["metadataKey"] == "agentId" + assert filter_expr["right"]["metadataValue"]["stringValue"] == "my-agent-id" + + # Since filtered call returned results, no fallback needed + assert mock_memory_client.list_events.call_count == 1 + + def test_list_messages_fallback_when_no_agent_metadata(self, session_manager, mock_memory_client): + """Verify backward-compat fallback when filtered query returns empty.""" + events_data = [ + { + "eventId": "event-1", + "eventTimestamp": "2024-01-01T12:00:00Z", + "payload": [ + { + "conversational": { + "content": { + "text": '{"message": {"role": "user", "content": [{"text": "Legacy msg"}]}, "message_id": 1}' # noqa E501 + }, + "role": "USER", + } + } + ], + } + ] + # First call (filtered) returns empty, second (unfiltered) returns events + mock_memory_client.list_events.side_effect = [[], events_data] + + messages = session_manager.list_messages("test-session-456", "my-agent-id") + + # Two calls: filtered then fallback + assert mock_memory_client.list_events.call_count == 2 + # Fallback call should NOT have event_metadata + fallback_kwargs = mock_memory_client.list_events.call_args_list[1][1] + assert "event_metadata" not in fallback_kwargs + # Should still return the legacy messages + assert len(messages) == 1 + assert messages[0].message["content"][0]["text"] == "Legacy msg" + + def test_list_messages_no_agent_id_skips_filter(self, session_manager, mock_memory_client): + """Verify no filter is applied when agent_id is empty.""" + mock_memory_client.list_events.return_value = [] + + session_manager.list_messages("test-session-456", "") + + mock_memory_client.list_events.assert_called_once() + call_kwargs = mock_memory_client.list_events.call_args[1] + assert "event_metadata" not in call_kwargs + + def test_initialize_allows_multiple_agents(self, session_manager, mock_memory_client): + """Verify initialize no longer blocks multiple agents with a warning.""" + session_manager._latest_agent_message = {} + session_manager.session_repository = Mock() + session_manager.session_repository.read_agent = Mock(return_value=None) + + agent1 = Agent( + agent_id="agent-1", + messages=[{"role": "user", "content": [{"text": "Hello"}]}], + ) + agent2 = Agent( + agent_id="agent-2", + messages=[{"role": "user", "content": [{"text": "Hi"}]}], + ) + + # First agent initializes normally + session_manager.initialize(agent1) + assert session_manager.has_existing_agent is True + + # Second agent should also initialize without raising + session_manager.initialize(agent2) + assert session_manager.has_existing_agent is True + + def test_initialize_logs_info_for_second_agent(self, session_manager, mock_memory_client, caplog): + """Verify second agent initialization logs info (not warning).""" + session_manager._latest_agent_message = {} + session_manager.session_repository = Mock() + session_manager.session_repository.read_agent = Mock(return_value=None) + + agent1 = Agent( + agent_id="agent-1", + messages=[{"role": "user", "content": [{"text": "Hello"}]}], + ) + agent2 = Agent( + agent_id="agent-2", + messages=[{"role": "user", "content": [{"text": "Hi"}]}], + ) + + session_manager.initialize(agent1) + + with caplog.at_level(logging.INFO): + session_manager.initialize(agent2) + + assert "Multiple agents registered" in caplog.text + + def test_batched_messages_grouped_by_agent_id(self, mock_memory_client): + """Verify batched flush groups messages by (session_id, agent_id).""" + from datetime import datetime, timedelta, timezone + + config = AgentCoreMemoryConfig( + memory_id="test-memory-123", + session_id="test-session-456", + actor_id="test-actor-789", + batch_size=10, + ) + manager = _create_session_manager(config, mock_memory_client) + + mock_memory_client.gmdp_client.create_event.return_value = {"eventId": "event-123"} + + base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + manager._message_buffer = [ + ("test-session-456", "agent-A", [("msg1", "user")], False, base_time), + ("test-session-456", "agent-B", [("msg2", "user")], False, base_time + timedelta(seconds=1)), + ("test-session-456", "agent-A", [("msg3", "assistant")], False, base_time + timedelta(seconds=2)), + ] + + manager._flush_messages_only() + + # Should be 2 API calls: one for agent-A, one for agent-B + assert mock_memory_client.gmdp_client.create_event.call_count == 2 + + # Verify each call has correct agent metadata + calls = mock_memory_client.gmdp_client.create_event.call_args_list + agent_ids_seen = set() + for call in calls: + kwargs = call[1] + assert "metadata" in kwargs + agent_id = kwargs["metadata"]["agentId"]["stringValue"] + agent_ids_seen.add(agent_id) + assert agent_ids_seen == {"agent-A", "agent-B"} + + def test_batched_messages_agent_a_has_two_messages(self, mock_memory_client): + """Verify agent-A gets both its messages combined in the batched flush.""" + from datetime import datetime, timedelta, timezone + + config = AgentCoreMemoryConfig( + memory_id="test-memory-123", + session_id="test-session-456", + actor_id="test-actor-789", + batch_size=10, + ) + manager = _create_session_manager(config, mock_memory_client) + + calls_by_agent: dict[str, list] = {} + + def track_create_event(**kwargs): + agent_id = kwargs.get("metadata", {}).get("agentId", {}).get("stringValue", "unknown") + calls_by_agent[agent_id] = kwargs["payload"] + return {"eventId": f"event_{agent_id}"} + + mock_memory_client.gmdp_client.create_event.side_effect = track_create_event + + base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + manager._message_buffer = [ + ("test-session-456", "agent-A", [("msg1", "user")], False, base_time), + ("test-session-456", "agent-B", [("msg2", "user")], False, base_time + timedelta(seconds=1)), + ("test-session-456", "agent-A", [("msg3", "assistant")], False, base_time + timedelta(seconds=2)), + ] + + manager._flush_messages_only() + + assert len(calls_by_agent["agent-A"]) == 2 + assert len(calls_by_agent["agent-B"]) == 1 + + def test_two_agents_independent_messages(self, mock_memory_client): + """Test two agents create and list messages, each sees only their own.""" + config = AgentCoreMemoryConfig( + memory_id="test-memory-123", + session_id="test-session-456", + actor_id="test-actor-789", + ) + manager = _create_session_manager(config, mock_memory_client) + + # Agent A creates a message + mock_memory_client.create_event.return_value = {"eventId": "event-a1"} + msg_a = SessionMessage( + message={"role": "user", "content": [{"text": "From Agent A"}]}, + message_id=1, + created_at="2024-01-01T12:00:00Z", + ) + manager.create_message("test-session-456", "agent-A", msg_a) + + # Verify agent_id metadata was set for agent A + call_kwargs_a = mock_memory_client.create_event.call_args[1] + assert call_kwargs_a["metadata"] == {"agentId": {"stringValue": "agent-A"}} + + # Agent B creates a message + mock_memory_client.create_event.return_value = {"eventId": "event-b1"} + msg_b = SessionMessage( + message={"role": "user", "content": [{"text": "From Agent B"}]}, + message_id=2, + created_at="2024-01-01T12:01:00Z", + ) + manager.create_message("test-session-456", "agent-B", msg_b) + + # Verify agent_id metadata was set for agent B + call_kwargs_b = mock_memory_client.create_event.call_args[1] + assert call_kwargs_b["metadata"] == {"agentId": {"stringValue": "agent-B"}} diff --git a/tests_integ/memory/integrations/test_session_manager.py b/tests_integ/memory/integrations/test_session_manager.py index 5ab308b..150dcff 100644 --- a/tests_integ/memory/integrations/test_session_manager.py +++ b/tests_integ/memory/integrations/test_session_manager.py @@ -375,3 +375,152 @@ def test_agent_multi_turn_with_batching(self, test_memory_stm): assert len(messages) >= 6 # endregion End-to-end agent with batching tests + + # region Multi-agent tests + + def test_multi_agent_conversation(self, test_memory_stm): + """Test two agents write to same session, each retrieves only their own messages.""" + session_id = f"multi-agent-session-{int(time.time())}" + actor_id = f"test-actor-{int(time.time())}" + + # Create session manager for agent A + config_a = AgentCoreMemoryConfig( + memory_id=test_memory_stm["id"], + session_id=session_id, + actor_id=actor_id, + ) + sm_a = AgentCoreMemorySessionManager(agentcore_memory_config=config_a, region_name=REGION) + + agent_a = Agent( + agent_id="agent-A", + system_prompt="You are Agent A. Always start your responses with 'Agent A here:'.", + session_manager=sm_a, + ) + + # Agent A has a conversation + response_a = agent_a("Hello from user to Agent A. Remember the word 'pineapple'.") + assert response_a is not None + + time.sleep(2) + + # Create session manager for agent B (same session, same actor) + config_b = AgentCoreMemoryConfig( + memory_id=test_memory_stm["id"], + session_id=session_id, + actor_id=actor_id, + ) + sm_b = AgentCoreMemorySessionManager(agentcore_memory_config=config_b, region_name=REGION) + + agent_b = Agent( + agent_id="agent-B", + system_prompt="You are Agent B. Always start your responses with 'Agent B here:'.", + session_manager=sm_b, + ) + + # Agent B has a conversation + response_b = agent_b("Hello from user to Agent B. Remember the word 'strawberry'.") + assert response_b is not None + + time.sleep(2) + + # Now verify each agent retrieves only its own messages via metadata filter + agent_a_filter = EventMetadataFilter.build_expression( + left_operand=LeftExpression.build("agentId"), + operator=OperatorType.EQUALS_TO, + right_operand={"metadataValue": {"stringValue": "agent-A"}}, + ) + agent_b_filter = EventMetadataFilter.build_expression( + left_operand=LeftExpression.build("agentId"), + operator=OperatorType.EQUALS_TO, + right_operand={"metadataValue": {"stringValue": "agent-B"}}, + ) + + events_a = sm_a.memory_client.list_events( + memory_id=test_memory_stm["id"], + actor_id=actor_id, + session_id=session_id, + event_metadata=[agent_a_filter], + ) + events_b = sm_b.memory_client.list_events( + memory_id=test_memory_stm["id"], + actor_id=actor_id, + session_id=session_id, + event_metadata=[agent_b_filter], + ) + + # Each agent should have its own events (at least user + assistant) + assert len(events_a) >= 2, f"Expected >=2 events for agent-A, got {len(events_a)}" + assert len(events_b) >= 2, f"Expected >=2 events for agent-B, got {len(events_b)}" + + # The two sets should be disjoint + event_ids_a = {e["eventId"] for e in events_a} + event_ids_b = {e["eventId"] for e in events_b} + assert event_ids_a.isdisjoint(event_ids_b), "Agent A and B events should be disjoint" + + def test_multi_agent_with_batching(self, test_memory_stm): + """Test two agents with batching enabled, verify flush groups correctly.""" + session_id = f"multi-agent-batch-{int(time.time())}" + actor_id = f"test-actor-{int(time.time())}" + + config = AgentCoreMemoryConfig( + memory_id=test_memory_stm["id"], + session_id=session_id, + actor_id=actor_id, + batch_size=10, + ) + + with AgentCoreMemorySessionManager(agentcore_memory_config=config, region_name=REGION) as sm: + agent_a = Agent( + agent_id="agent-A", + system_prompt="You are Agent A. Keep responses very short (one sentence).", + session_manager=sm, + ) + agent_b = Agent( + agent_id="agent-B", + system_prompt="You are Agent B. Keep responses very short (one sentence).", + session_manager=sm, + ) + + agent_a("Hello Agent A, remember the color blue.") + time.sleep(2) + agent_b("Hello Agent B, remember the color red.") + # Context manager flushes remaining buffered messages + + time.sleep(2) + + # Verify events are tagged correctly with agent metadata + agent_a_filter = EventMetadataFilter.build_expression( + left_operand=LeftExpression.build("agentId"), + operator=OperatorType.EQUALS_TO, + right_operand={"metadataValue": {"stringValue": "agent-A"}}, + ) + agent_b_filter = EventMetadataFilter.build_expression( + left_operand=LeftExpression.build("agentId"), + operator=OperatorType.EQUALS_TO, + right_operand={"metadataValue": {"stringValue": "agent-B"}}, + ) + + client = MemoryClient(region_name=REGION) + events_a = client.list_events( + memory_id=test_memory_stm["id"], + actor_id=actor_id, + session_id=session_id, + event_metadata=[agent_a_filter], + ) + events_b = client.list_events( + memory_id=test_memory_stm["id"], + actor_id=actor_id, + session_id=session_id, + event_metadata=[agent_b_filter], + ) + + # Both agents should have events + assert len(events_a) >= 1, f"Expected >=1 events for agent-A, got {len(events_a)}" + assert len(events_b) >= 1, f"Expected >=1 events for agent-B, got {len(events_b)}" + + # Events should be disjoint between agents + event_ids_a = {e["eventId"] for e in events_a} + event_ids_b = {e["eventId"] for e in events_b} + assert event_ids_a.isdisjoint(event_ids_b), "Agent A and B events should be disjoint" + + # endregion Multi-agent tests From 437ce7e6050370b657a46382edf9d3d207f283f5 Mon Sep 17 00:00:00 2001 From: Tejas Kashinath Date: Mon, 16 Mar 2026 13:25:11 -0400 Subject: [PATCH 2/2] test: add list_messages isolation integ test for multi-agent Verifies that sm.list_messages() with agent_id returns only that agent's messages by checking content unique to each agent's system prompt (pineapple vs strawberry). --- .../integrations/test_session_manager.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests_integ/memory/integrations/test_session_manager.py b/tests_integ/memory/integrations/test_session_manager.py index 150dcff..2abfe2b 100644 --- a/tests_integ/memory/integrations/test_session_manager.py +++ b/tests_integ/memory/integrations/test_session_manager.py @@ -523,4 +523,50 @@ def test_multi_agent_with_batching(self, test_memory_stm): event_ids_b = {e["eventId"] for e in events_b} assert event_ids_a.isdisjoint(event_ids_b), "Agent A and B events should be disjoint" + def test_multi_agent_list_messages_isolation(self, test_memory_stm): + """Test that list_messages returns only the calling agent's messages.""" + session_id = f"multi-agent-list-{int(time.time())}" + actor_id = f"test-actor-{int(time.time())}" + + config = AgentCoreMemoryConfig( + memory_id=test_memory_stm["id"], + session_id=session_id, + actor_id=actor_id, + ) + sm = AgentCoreMemorySessionManager(agentcore_memory_config=config, region_name=REGION) + + agent_a = Agent( + agent_id="agent-A", + system_prompt="You are Agent A. Always mention the word 'pineapple' in your response.", + session_manager=sm, + ) + agent_b = Agent( + agent_id="agent-B", + system_prompt="You are Agent B. Always mention the word 'strawberry' in your response.", + session_manager=sm, + ) + + agent_a("Hello Agent A") + time.sleep(2) + agent_b("Hello Agent B") + time.sleep(2) + + # Use the session manager's list_messages (the code path under test) + messages_a = sm.list_messages(session_id, "agent-A") + messages_b = sm.list_messages(session_id, "agent-B") + + # Each agent should have at least user + assistant + assert len(messages_a) >= 2, f"Expected >=2 messages for agent-A, got {len(messages_a)}" + assert len(messages_b) >= 2, f"Expected >=2 messages for agent-B, got {len(messages_b)}" + + # Agent A's messages should contain 'pineapple' (from its system prompt), not 'strawberry' + messages_a_text = json.dumps([m.message for m in messages_a]) + assert "pineapple" in messages_a_text.lower(), "Agent A messages should contain 'pineapple'" + assert "strawberry" not in messages_a_text.lower(), "Agent A messages should not contain 'strawberry'" + + # Agent B's messages should contain 'strawberry', not 'pineapple' + messages_b_text = json.dumps([m.message for m in messages_b]) + assert "strawberry" in messages_b_text.lower(), "Agent B messages should contain 'strawberry'" + assert "pineapple" not in messages_b_text.lower(), "Agent B messages should not contain 'pineapple'" + # endregion Multi-agent tests