Skip to content

Commit 328acba

Browse files
authored
fix: Session manager batching improvements (aws#298)
1. Unify blob and conversational messages into single payload structure 2. Batch create agent state event 3. Separate the buffers and flushing for messages and agent states.
1 parent 2371461 commit 328acba

2 files changed

Lines changed: 511 additions & 203 deletions

File tree

src/bedrock_agentcore/memory/integrations/strands/session_manager.py

Lines changed: 147 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -283,30 +283,48 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A
283283
if session_id != self.config.session_id:
284284
raise SessionException(f"Session ID mismatch: expected {self.config.session_id}, got {session_id}")
285285

286-
event = self.memory_client.gmdp_client.create_event(
287-
memoryId=self.config.memory_id,
288-
actorId=self.config.actor_id,
289-
sessionId=self.session_id,
290-
payload=[
291-
{"blob": json.dumps(session_agent.to_dict())},
292-
],
293-
eventTimestamp=self._get_monotonic_timestamp(),
294-
metadata={
295-
STATE_TYPE_KEY: {"stringValue": StateType.AGENT.value},
296-
AGENT_ID_KEY: {"stringValue": session_agent.agent_id},
297-
},
298-
)
299-
300286
# Cache the created_at timestamp to avoid re-fetching on updates
301287
if session_agent.created_at:
302288
self._agent_created_at_cache[session_agent.agent_id] = session_agent.created_at
303289

304-
logger.info(
305-
"Created agent: %s in session: %s with event %s",
306-
session_agent.agent_id,
307-
session_id,
308-
event.get("event", {}).get("eventId"),
309-
)
290+
if self.config.batch_size > 1:
291+
# Buffer the agent state events
292+
should_flush = False
293+
with self._agent_state_lock:
294+
self._agent_state_buffer.append((session_id, session_agent))
295+
should_flush = len(self._agent_state_buffer) >= self.config.batch_size
296+
297+
# Flush only agent states outside the lock to prevent deadlock
298+
if should_flush:
299+
self._flush_agent_states_only()
300+
301+
logger.info(
302+
"Buffered agent creation: %s in session: %s",
303+
session_agent.agent_id,
304+
session_id,
305+
)
306+
else:
307+
# Immediate send when batching is disabled
308+
event = self.memory_client.gmdp_client.create_event(
309+
memoryId=self.config.memory_id,
310+
actorId=self.config.actor_id,
311+
sessionId=self.session_id,
312+
payload=[
313+
{"blob": json.dumps(session_agent.to_dict())},
314+
],
315+
eventTimestamp=self._get_monotonic_timestamp(),
316+
metadata={
317+
STATE_TYPE_KEY: {"stringValue": StateType.AGENT.value},
318+
AGENT_ID_KEY: {"stringValue": session_agent.agent_id},
319+
},
320+
)
321+
322+
logger.info(
323+
"Created agent: %s in session: %s with event %s",
324+
session_agent.agent_id,
325+
session_id,
326+
event.get("event", {}).get("eventId"),
327+
)
310328

311329
def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]:
312330
"""Read agent data from AgentCore Memory events.
@@ -395,20 +413,18 @@ def update_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A
395413
"""
396414
agent_id = session_agent.agent_id
397415

416+
# Verify agent exists and get created_at timestamp if not cached
398417
if agent_id not in self._agent_created_at_cache:
399418
previous_agent = self.read_agent(session_id=session_id, agent_id=agent_id)
400419
if previous_agent is None:
401420
raise SessionException(f"Agent {agent_id} in session {session_id} does not exist")
421+
422+
# Set created_at from cache before creating the update event
402423
session_agent.created_at = self._agent_created_at_cache[agent_id]
403424

404-
if self.config.batch_size > 1:
405-
# Buffer the agent state update
406-
with self._agent_state_lock:
407-
self._agent_state_buffer.append((session_id, session_agent))
408-
else:
409-
# Immediate send create_event without buffering
410-
# Create a new agent as AgentCore Memory is immutable. We always get the latest one in `read_agent`
411-
self.create_agent(session_id, session_agent)
425+
# Create a new agent event (AgentCore Memory is immutable)
426+
# create_agent will handle batching and caching appropriately
427+
self.create_agent(session_id, session_agent)
412428

413429
def create_message(
414430
self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any
@@ -466,9 +482,9 @@ def create_message(
466482
self._message_buffer.append((session_id, messages, is_blob, monotonic_timestamp))
467483
should_flush = len(self._message_buffer) >= self.config.batch_size
468484

469-
# Flush outside the lock to prevent deadlock
485+
# Flush only messages outside the lock to prevent deadlock
470486
if should_flush:
471-
self._flush_messages()
487+
self._flush_messages_only()
472488

473489
return {} # No eventId yet
474490

@@ -711,116 +727,148 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None:
711727

712728
# region Batching support
713729

714-
def _flush_messages(self) -> list[dict[str, Any]]:
715-
"""Flush all buffered messages and agent state to AgentCore Memory.
716-
717-
Call this method to send any remaining buffered messages and agent state when batch_size > 1.
718-
This is automatically called when the buffer reaches batch_size, but should
719-
also be called explicitly when the session is complete (via close() or context manager).
730+
def _flush_messages_only(self) -> list[dict[str, Any]]:
731+
"""Flush only buffered messages to AgentCore Memory.
720732
733+
Call this method to send any remaining buffered messages when batch_size > 1.
734+
This is called when the message buffer reaches batch_size.
721735
Messages are batched by session_id - all conversational messages for the same
722736
session are combined into a single create_event() call to reduce API calls.
723737
Blob messages (>9KB) are sent individually as they require a different API path.
724-
Agent state updates are sent after messages.
725738
726739
Returns:
727740
list[dict[str, Any]]: List of created event responses from AgentCore Memory.
728741
729742
Raises:
730-
SessionException: If any message or agent state creation fails. On failure, all messages
731-
and agent state remain in the buffer to prevent data loss.
743+
SessionException: If message creation fails. On failure, messages remain in the buffer.
732744
"""
733745
with self._message_lock:
734746
messages_to_send = list(self._message_buffer)
735747

736-
with self._agent_state_lock:
737-
agent_states_to_send = list(self._agent_state_buffer)
738-
739-
if not messages_to_send and not agent_states_to_send:
748+
if not messages_to_send:
740749
return []
741750

742-
# Group conversational messages by session_id, preserve order
743-
# Structure: {session_id: {"messages": [...], "timestamp": latest_timestamp}}
751+
# Group all messages by session_id, combining conversational and blob messages
752+
# Structure: {session_id: {"payload": [...], "timestamp": latest_timestamp}}
744753
session_groups: dict[str, dict[str, Any]] = {}
745-
blob_messages: list[tuple[str, list[tuple[str, str]], datetime]] = []
746754

747755
for session_id, messages, is_blob, monotonic_timestamp in messages_to_send:
756+
if session_id not in session_groups:
757+
session_groups[session_id] = {"payload": [], "timestamp": monotonic_timestamp}
758+
748759
if is_blob:
749-
# Blobs cannot be combined - collect them separately
750-
blob_messages.append((session_id, messages, monotonic_timestamp))
760+
# Add blob messages to payload
761+
for msg in messages:
762+
session_groups[session_id]["payload"].append({"blob": json.dumps(msg)})
751763
else:
752-
# Group conversational messages by session_id
753-
if session_id not in session_groups:
754-
session_groups[session_id] = {"messages": [], "timestamp": monotonic_timestamp}
755-
# Extend messages list to preserve order (earlier messages first)
756-
session_groups[session_id]["messages"].extend(messages)
757-
# Use the latest timestamp for the combined event
758-
if monotonic_timestamp > session_groups[session_id]["timestamp"]:
759-
session_groups[session_id]["timestamp"] = monotonic_timestamp
764+
# Add conversational messages to payload
765+
for text, role in messages:
766+
session_groups[session_id]["payload"].append(
767+
{"conversational": {"content": {"text": text}, "role": role.upper()}}
768+
)
769+
770+
# Use the latest timestamp for the combined event
771+
if monotonic_timestamp > session_groups[session_id]["timestamp"]:
772+
session_groups[session_id]["timestamp"] = monotonic_timestamp
760773

761774
results = []
762775
try:
763-
# Send one create_event per session_id with combined messages
776+
# Send one create_event per session_id with all messages (conversational + blob)
764777
for session_id, group in session_groups.items():
765-
event = self.memory_client.create_event(
766-
memory_id=self.config.memory_id,
767-
actor_id=self.config.actor_id,
768-
session_id=session_id,
769-
messages=group["messages"],
770-
event_timestamp=group["timestamp"],
771-
)
772-
results.append(event)
773-
logger.debug("Flushed batched event for session %s: %s", session_id, event.get("eventId"))
774-
775-
# Send blob messages individually (they use a different API path)
776-
for session_id, messages, monotonic_timestamp in blob_messages:
777778
event = self.memory_client.gmdp_client.create_event(
778779
memoryId=self.config.memory_id,
779780
actorId=self.config.actor_id,
780781
sessionId=session_id,
781-
payload=[
782-
{"blob": json.dumps(messages[0])},
783-
],
784-
eventTimestamp=monotonic_timestamp,
785-
)
786-
results.append(event)
787-
logger.debug("Flushed blob event for session %s: %s", session_id, event.get("eventId"))
788-
789-
# Flush agent state updates after messages - batch all agent states into a single API call
790-
if agent_states_to_send:
791-
# Convert all agent states to payload format
792-
agent_state_payloads = []
793-
for _session_id, session_agent in agent_states_to_send:
794-
agent_state_payloads.append({"blob": json.dumps(session_agent.to_dict())})
795-
796-
# Send all agent states in a single batched create_event call
797-
event = self.memory_client.gmdp_client.create_event(
798-
memoryId=self.config.memory_id,
799-
actorId=self.config.actor_id,
800-
sessionId=self.config.session_id,
801-
payload=agent_state_payloads,
802-
eventTimestamp=self._get_monotonic_timestamp(),
803-
metadata={
804-
STATE_TYPE_KEY: {"stringValue": StateType.AGENT.value},
805-
},
782+
payload=group["payload"],
783+
eventTimestamp=group["timestamp"],
806784
)
807785
results.append(event)
808786
logger.debug(
809-
"Flushed %d agent states in batched event: %s", len(agent_states_to_send), event.get("eventId")
787+
"Flushed batched event for session %s with %d messages: %s",
788+
session_id,
789+
len(group["payload"]),
790+
event.get("eventId"),
810791
)
811792

812-
# Clear buffers only after ALL messages and agent state succeed
793+
# Clear message buffer only after ALL messages succeed
813794
with self._message_lock:
814795
self._message_buffer.clear()
815796

797+
except Exception as e:
798+
logger.error("Failed to flush messages to AgentCore Memory: %s", e)
799+
raise SessionException(f"Failed to flush messages: {e}") from e
800+
801+
logger.info("Flushed %d message events to AgentCore Memory", len(results))
802+
return results
803+
804+
def _flush_agent_states_only(self) -> list[dict[str, Any]]:
805+
"""Flush only buffered agent states to AgentCore Memory.
806+
807+
Call this method to send any remaining agent state when batch_size > 1.
808+
This is called when the agent state buffer reaches batch_size.
809+
All agent states are batched into a single create_event() call.
810+
811+
Returns:
812+
list[dict[str, Any]]: List of created event responses from AgentCore Memory.
813+
814+
Raises:
815+
SessionException: If agent state creation fails. On failure, agent states remain in the buffer.
816+
"""
817+
with self._agent_state_lock:
818+
agent_states_to_send = list(self._agent_state_buffer)
819+
820+
if not agent_states_to_send:
821+
return []
822+
823+
results = []
824+
try:
825+
# Convert all agent states to payload format
826+
agent_state_payloads = []
827+
for _session_id, session_agent in agent_states_to_send:
828+
agent_state_payloads.append({"blob": json.dumps(session_agent.to_dict())})
829+
830+
# Send all agent states in a single batched create_event call
831+
event = self.memory_client.gmdp_client.create_event(
832+
memoryId=self.config.memory_id,
833+
actorId=self.config.actor_id,
834+
sessionId=self.config.session_id,
835+
payload=agent_state_payloads,
836+
eventTimestamp=self._get_monotonic_timestamp(),
837+
metadata={
838+
STATE_TYPE_KEY: {"stringValue": StateType.AGENT.value},
839+
},
840+
)
841+
results.append(event)
842+
logger.debug(
843+
"Flushed %d agent states in batched event: %s", len(agent_states_to_send), event.get("eventId")
844+
)
845+
846+
# Clear agent state buffer only after success
816847
with self._agent_state_lock:
817848
self._agent_state_buffer.clear()
818849

819850
except Exception as e:
820-
logger.error("Failed to flush messages and agent state to AgentCore Memory: %s", e)
821-
raise SessionException(f"Failed to flush messages and agent state: {e}") from e
851+
logger.error("Failed to flush agent states to AgentCore Memory: %s", e)
852+
raise SessionException(f"Failed to flush agent states: {e}") from e
822853

823-
logger.info("Flushed %d events to AgentCore Memory", len(results))
854+
logger.info("Flushed %d agent state events to AgentCore Memory", len(results))
855+
return results
856+
857+
def _flush_messages(self) -> list[dict[str, Any]]:
858+
"""Flush all buffered messages and agent state to AgentCore Memory.
859+
860+
Call this method to send any remaining buffered messages and agent state messages.
861+
This is automatically called when the session is complete (via close() or context manager).
862+
863+
Returns:
864+
list[dict[str, Any]]: List of created event responses from AgentCore Memory.
865+
866+
Raises:
867+
SessionException: If any message or agent state creation fails.
868+
"""
869+
results = []
870+
results.extend(self._flush_messages_only())
871+
results.extend(self._flush_agent_states_only())
824872
return results
825873

826874
def pending_message_count(self) -> int:

0 commit comments

Comments
 (0)