Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 19 additions & 22 deletions src/agents/extensions/memory/advanced_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,30 +487,22 @@ def _cleanup_sync():

def _cleanup_orphaned_messages_sync(self, conn: sqlite3.Connection) -> int:
with closing(conn.cursor()) as cursor:
# Find messages without structure metadata.
cursor.execute(
f"""
SELECT am.id
FROM {self.messages_table} am
LEFT JOIN message_structure ms ON am.id = ms.message_id
WHERE am.session_id = ? AND ms.message_id IS NULL
""",
(self.session_id,),
)

orphaned_ids = [row[0] for row in cursor.fetchall()]

if not orphaned_ids:
return 0

placeholders = ",".join("?" * len(orphaned_ids))
cursor.execute(
f"DELETE FROM {self.messages_table} WHERE id IN ({placeholders})",
orphaned_ids,
DELETE FROM {self.messages_table}
WHERE session_id = ?
AND id NOT IN (
SELECT message_id
FROM message_structure ms
WHERE ms.session_id = ?
)
""",
(self.session_id, self.session_id),
)

deleted_count = cursor.rowcount
self._logger.info(f"Cleaned up {deleted_count} orphaned messages")
if deleted_count:
self._logger.info(f"Cleaned up {deleted_count} orphaned messages")
return deleted_count

def _classify_message_type(self, item: TResponseInputItem) -> str:
Expand Down Expand Up @@ -786,14 +778,19 @@ def _delete_sync():

structure_deleted = cursor.rowcount

orphaned_messages_deleted = self._cleanup_orphaned_messages_sync(conn)

conn.commit()

return usage_deleted, structure_deleted
return usage_deleted, structure_deleted, orphaned_messages_deleted

usage_deleted, structure_deleted = await asyncio.to_thread(_delete_sync)
usage_deleted, structure_deleted, orphaned_messages_deleted = await asyncio.to_thread(
_delete_sync
)

self._logger.info(
f"Deleted branch '{branch_id}': {structure_deleted} message entries, {usage_deleted} usage entries" # noqa: E501
f"Deleted branch '{branch_id}': {structure_deleted} message entries, "
f"{usage_deleted} usage entries, {orphaned_messages_deleted} orphaned messages"
)

async def list_branches(self) -> list[dict[str, Any]]:
Expand Down
162 changes: 162 additions & 0 deletions tests/extensions/memory/test_advanced_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,168 @@ async def test_branching_functionality(agent: Agent):
session.close()


async def test_delete_branch_removes_branch_only_messages():
"""Deleting a branch should not leave unreferenced branch-only messages behind."""
session_id = "branch_delete_cleanup_test"
session = AdvancedSQLiteSession(session_id=session_id, create_tables=True)

main_items: list[TResponseInputItem] = [
{"role": "user", "content": "First question"},
{"role": "assistant", "content": "First answer"},
{"role": "user", "content": "Second question"},
{"role": "assistant", "content": "Second answer"},
]
await session.add_items(main_items)

await session.create_branch_from_turn(2, "cleanup_branch")
branch_items: list[TResponseInputItem] = [
{"role": "user", "content": "Branch-only question"},
{"role": "assistant", "content": "Branch-only answer"},
]
await session.add_items(branch_items)

await session.delete_branch("cleanup_branch", force=True)

with session._locked_connection() as conn:
rows = conn.execute(
f"""
SELECT message_data
FROM {session.messages_table}
WHERE session_id = ?
ORDER BY id
""",
(session.session_id,),
).fetchall()

contents = [json.loads(message_data)["content"] for (message_data,) in rows]
assert contents == [
"First question",
"First answer",
"Second question",
"Second answer",
]
assert await session.get_items(branch_id="main") == main_items

session.close()


async def test_delete_branch_keeps_messages_still_referenced_by_another_branch():
"""Deleting one branch should keep messages inherited by a surviving branch."""
session = AdvancedSQLiteSession(
session_id="branch_delete_shared_descendant_test",
create_tables=True,
)

main_items: list[TResponseInputItem] = [
{"role": "user", "content": "Main first question"},
{"role": "assistant", "content": "Main first answer"},
{"role": "user", "content": "Main second question"},
{"role": "assistant", "content": "Main second answer"},
]
branch_a_shared_items: list[TResponseInputItem] = [
{"role": "user", "content": "Branch A shared question"},
{"role": "assistant", "content": "Branch A shared answer"},
]
branch_a_only_items: list[TResponseInputItem] = [
{"role": "user", "content": "Branch A only question"},
{"role": "assistant", "content": "Branch A only answer"},
]

try:
await session.add_items(main_items)
await session.create_branch_from_turn(2, "branch_a")
await session.add_items(branch_a_shared_items + branch_a_only_items)

await session.create_branch_from_turn(3, "branch_b")
await session.delete_branch("branch_a")

with session._locked_connection() as conn:
rows = conn.execute(
f"""
SELECT message_data
FROM {session.messages_table}
WHERE session_id = ?
ORDER BY id
""",
(session.session_id,),
).fetchall()

contents = [json.loads(message_data)["content"] for (message_data,) in rows]
assert "Branch A shared question" in contents
assert "Branch A shared answer" in contents
assert "Branch A only question" not in contents
assert "Branch A only answer" not in contents
assert await session.get_items(branch_id="branch_b") == [
*main_items[:2],
*branch_a_shared_items,
]
finally:
session.close()


async def test_orphan_cleanup_uses_set_based_delete_for_many_messages():
"""Orphan cleanup should not build one DELETE parameter per orphaned row."""

class RecordingCursor:
def __init__(self, cursor: Any, connection: "RecordingConnection") -> None:
self._cursor = cursor
self._connection = connection

@property
def rowcount(self) -> int:
return cast(int, self._cursor.rowcount)

def execute(self, sql: str, parameters: Any = None) -> Any:
normalized_sql = " ".join(sql.split()).upper()
if normalized_sql.startswith("DELETE"):
self._connection.delete_parameter_counts.append(len(parameters or ()))
if parameters is None:
return self._cursor.execute(sql)
return self._cursor.execute(sql, parameters)

def fetchall(self) -> Any:
return self._cursor.fetchall()

def close(self) -> None:
self._cursor.close()

class RecordingConnection:
def __init__(self, conn: Any) -> None:
self._conn = conn
self.delete_parameter_counts: list[int] = []

def cursor(self) -> RecordingCursor:
return RecordingCursor(self._conn.cursor(), self)

session = AdvancedSQLiteSession(
session_id="branch_delete_many_orphans_cleanup",
create_tables=True,
)
orphan_items: list[TResponseInputItem] = [
{"role": "user", "content": f"orphan {i}"} for i in range(1200)
]

try:
with session._locked_connection() as conn:
session._insert_items(conn, orphan_items)
conn.commit()

recording_conn = RecordingConnection(conn)
deleted_count = session._cleanup_orphaned_messages_sync(cast(Any, recording_conn))
conn.commit()

remaining_count = conn.execute(
f"SELECT COUNT(*) FROM {session.messages_table} WHERE session_id = ?",
(session.session_id,),
).fetchone()[0]

assert deleted_count == len(orphan_items)
assert remaining_count == 0
assert recording_conn.delete_parameter_counts == [2]
finally:
session.close()


async def test_get_conversation_turns():
"""Test get_conversation_turns functionality."""
session_id = "conversation_turns_test"
Expand Down
Loading