Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 14 additions & 25 deletions src/agents/extensions/memory/advanced_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,26 +133,15 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
def _add_items_sync():
"""Synchronous helper to add items and structure metadata together."""
with self._locked_connection() as conn:
# Keep both writes in one critical section so message IDs and metadata stay aligned.
self._insert_items(conn, items)
conn.commit()
try:
# Keep both writes in one transaction so metadata failures do not leave orphans.
self._insert_items(conn, items)
self._insert_structure_metadata(conn, items)
conn.commit()
except Exception as e:
except Exception:
conn.rollback()
self._logger.error(
f"Failed to add structure metadata for session {self.session_id}: {e}"
)
try:
deleted_count = self._cleanup_orphaned_messages_sync(conn)
if deleted_count:
conn.commit()
else:
conn.rollback()
except Exception as cleanup_error:
conn.rollback()
self._logger.error(f"Failed to cleanup orphaned messages: {cleanup_error}")
self._logger.exception("Failed to add items for session %s", self.session_id)
raise

await asyncio.to_thread(_add_items_sync)

Expand Down Expand Up @@ -367,16 +356,16 @@ def _add_structure_sync():

try:
await asyncio.to_thread(_add_structure_sync)
except Exception as e:
self._logger.error(
f"Failed to add structure metadata for session {self.session_id}: {e}"
except Exception:
self._logger.exception(
"Failed to add structure metadata for session %s", self.session_id
)
# Try to clean up any orphaned messages to maintain consistency
# Try to clean up any orphaned messages to maintain consistency.
try:
await self._cleanup_orphaned_messages()
except Exception as cleanup_error:
self._logger.error(f"Failed to cleanup orphaned messages: {cleanup_error}")
# Don't re-raise - structure metadata is supplementary
except Exception:
self._logger.exception("Failed to cleanup orphaned messages")
raise

def _insert_structure_metadata(
self,
Expand Down Expand Up @@ -469,8 +458,8 @@ def _insert_structure_metadata(
async def _cleanup_orphaned_messages(self) -> int:
"""Remove messages that exist in the configured message table but not in message_structure.

This can happen if _add_structure_metadata fails after super().add_items() succeeds.
Used for maintaining data consistency.
This can happen for rows written by older or non-atomic structure metadata paths.
`add_items()` writes message rows and structure metadata in a single transaction.
"""

def _cleanup_sync():
Expand Down
173 changes: 173 additions & 0 deletions tests/extensions/memory/test_advanced_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,52 @@ def create_mock_run_result(usage: Usage | None = None, agent: Agent | None = Non
)


class FailingOnceStructureMetadataSession(AdvancedSQLiteSession):
"""Advanced session test double that fails the next structure metadata write."""

def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self.fail_structure_metadata_once = True

def _insert_structure_metadata(
self,
conn: Any,
items: list[TResponseInputItem],
) -> None:
if self.fail_structure_metadata_once:
self.fail_structure_metadata_once = False
raise RuntimeError("structure metadata failed")
super()._insert_structure_metadata(conn, items)


class PartiallyFailingStructureMetadataSession(AdvancedSQLiteSession):
"""Advanced session test double that fails after writing one structure row."""

def _insert_structure_metadata(
self,
conn: Any,
items: list[TResponseInputItem],
) -> None:
cursor = conn.execute(
f"SELECT id FROM {self.messages_table} WHERE session_id = ? ORDER BY id ASC LIMIT 1",
(self.session_id,),
)
row = cursor.fetchone()
if row is None:
raise RuntimeError("no inserted message id found")

conn.execute(
"""
INSERT INTO message_structure
(session_id, message_id, branch_id, message_type, sequence_number,
user_turn_number, branch_turn_number, tool_name)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(self.session_id, row[0], self._current_branch_id, "user", 1, 1, 1, None),
)
raise RuntimeError("structure metadata failed after partial write")


async def test_advanced_session_basic_functionality(agent: Agent):
"""Test basic AdvancedSQLiteSession functionality."""
session_id = "advanced_test"
Expand Down Expand Up @@ -147,6 +193,133 @@ async def test_advanced_session_respects_custom_table_names():
session.close()


async def test_add_items_rolls_back_messages_when_structure_metadata_fails():
"""Failed structure metadata writes should not leave invisible message rows."""
session = FailingOnceStructureMetadataSession(
session_id="advanced_add_items_rollback",
create_tables=True,
)
items: list[TResponseInputItem] = [{"role": "user", "content": "not saved"}]

try:
with pytest.raises(RuntimeError, match="structure metadata failed"):
await session.add_items(items)

assert await session.get_items() == []

with session._locked_connection() as conn:
message_count = conn.execute(
f"SELECT COUNT(*) FROM {session.messages_table} WHERE session_id = ?",
(session.session_id,),
).fetchone()[0]
structure_count = conn.execute(
"SELECT COUNT(*) FROM message_structure WHERE session_id = ?",
(session.session_id,),
).fetchone()[0]

assert message_count == 0
assert structure_count == 0
finally:
session.close()


async def test_add_items_can_retry_after_structure_metadata_failure():
"""Retrying after a metadata failure should persist the batch exactly once."""
session = FailingOnceStructureMetadataSession(
session_id="advanced_add_items_retry",
create_tables=True,
)
items: list[TResponseInputItem] = [{"role": "user", "content": "saved once"}]

try:
with pytest.raises(RuntimeError, match="structure metadata failed"):
await session.add_items(items)

await session.add_items(items)

assert await session.get_items() == items

with session._locked_connection() as conn:
message_count = conn.execute(
f"SELECT COUNT(*) FROM {session.messages_table} WHERE session_id = ?",
(session.session_id,),
).fetchone()[0]
structure_count = conn.execute(
"SELECT COUNT(*) FROM message_structure WHERE session_id = ?",
(session.session_id,),
).fetchone()[0]

assert message_count == 1
assert structure_count == 1
finally:
session.close()


async def test_add_items_failure_preserves_existing_history():
"""A failed batch should not roll back or hide previously committed messages."""
session = FailingOnceStructureMetadataSession(
session_id="advanced_add_items_existing_history",
create_tables=True,
)
existing_items: list[TResponseInputItem] = [{"role": "user", "content": "already saved"}]
failed_items: list[TResponseInputItem] = [{"role": "assistant", "content": "not saved"}]

try:
session.fail_structure_metadata_once = False
await session.add_items(existing_items)

session.fail_structure_metadata_once = True
with pytest.raises(RuntimeError, match="structure metadata failed"):
await session.add_items(failed_items)

assert await session.get_items() == existing_items

with session._locked_connection() as conn:
message_count = conn.execute(
f"SELECT COUNT(*) FROM {session.messages_table} WHERE session_id = ?",
(session.session_id,),
).fetchone()[0]
structure_count = conn.execute(
"SELECT COUNT(*) FROM message_structure WHERE session_id = ?",
(session.session_id,),
).fetchone()[0]

assert message_count == 1
assert structure_count == 1
finally:
session.close()


async def test_add_items_rolls_back_partial_structure_metadata_write():
"""Partial metadata writes should roll back with the message rows in the same batch."""
session = PartiallyFailingStructureMetadataSession(
session_id="advanced_add_items_partial_metadata",
create_tables=True,
)
items: list[TResponseInputItem] = [{"role": "user", "content": "not saved"}]

try:
with pytest.raises(RuntimeError, match="structure metadata failed after partial write"):
await session.add_items(items)

assert await session.get_items() == []

with session._locked_connection() as conn:
message_count = conn.execute(
f"SELECT COUNT(*) FROM {session.messages_table} WHERE session_id = ?",
(session.session_id,),
).fetchone()[0]
structure_count = conn.execute(
"SELECT COUNT(*) FROM message_structure WHERE session_id = ?",
(session.session_id,),
).fetchone()[0]

assert message_count == 0
assert structure_count == 0
finally:
session.close()


async def test_message_structure_tracking(agent: Agent):
"""Test that message structure is properly tracked."""
session_id = "structure_test"
Expand Down
Loading