diff --git a/src/agents/extensions/memory/advanced_sqlite_session.py b/src/agents/extensions/memory/advanced_sqlite_session.py index 83c289bdf8..c7a01dc3c5 100644 --- a/src/agents/extensions/memory/advanced_sqlite_session.py +++ b/src/agents/extensions/memory/advanced_sqlite_session.py @@ -487,30 +487,22 @@ def _cleanup_sync(): def _cleanup_orphaned_messages_sync(self, conn: sqlite3.Connection) -> int: with closing(conn.cursor()) as cursor: - # Find messages without structure metadata. cursor.execute( f""" - SELECT am.id - FROM {self.messages_table} am - LEFT JOIN message_structure ms ON am.id = ms.message_id - WHERE am.session_id = ? AND ms.message_id IS NULL - """, - (self.session_id,), - ) - - orphaned_ids = [row[0] for row in cursor.fetchall()] - - if not orphaned_ids: - return 0 - - placeholders = ",".join("?" * len(orphaned_ids)) - cursor.execute( - f"DELETE FROM {self.messages_table} WHERE id IN ({placeholders})", - orphaned_ids, + DELETE FROM {self.messages_table} + WHERE session_id = ? + AND id NOT IN ( + SELECT message_id + FROM message_structure ms + WHERE ms.session_id = ? + ) + """, + (self.session_id, self.session_id), ) deleted_count = cursor.rowcount - self._logger.info(f"Cleaned up {deleted_count} orphaned messages") + if deleted_count: + self._logger.info(f"Cleaned up {deleted_count} orphaned messages") return deleted_count def _classify_message_type(self, item: TResponseInputItem) -> str: @@ -786,14 +778,19 @@ def _delete_sync(): structure_deleted = cursor.rowcount + orphaned_messages_deleted = self._cleanup_orphaned_messages_sync(conn) + conn.commit() - return usage_deleted, structure_deleted + return usage_deleted, structure_deleted, orphaned_messages_deleted - usage_deleted, structure_deleted = await asyncio.to_thread(_delete_sync) + usage_deleted, structure_deleted, orphaned_messages_deleted = await asyncio.to_thread( + _delete_sync + ) self._logger.info( - f"Deleted branch '{branch_id}': {structure_deleted} message entries, {usage_deleted} usage entries" # noqa: E501 + f"Deleted branch '{branch_id}': {structure_deleted} message entries, " + f"{usage_deleted} usage entries, {orphaned_messages_deleted} orphaned messages" ) async def list_branches(self) -> list[dict[str, Any]]: diff --git a/tests/extensions/memory/test_advanced_sqlite_session.py b/tests/extensions/memory/test_advanced_sqlite_session.py index ad4b5c4d86..915e142b91 100644 --- a/tests/extensions/memory/test_advanced_sqlite_session.py +++ b/tests/extensions/memory/test_advanced_sqlite_session.py @@ -442,6 +442,168 @@ async def test_branching_functionality(agent: Agent): session.close() +async def test_delete_branch_removes_branch_only_messages(): + """Deleting a branch should not leave unreferenced branch-only messages behind.""" + session_id = "branch_delete_cleanup_test" + session = AdvancedSQLiteSession(session_id=session_id, create_tables=True) + + main_items: list[TResponseInputItem] = [ + {"role": "user", "content": "First question"}, + {"role": "assistant", "content": "First answer"}, + {"role": "user", "content": "Second question"}, + {"role": "assistant", "content": "Second answer"}, + ] + await session.add_items(main_items) + + await session.create_branch_from_turn(2, "cleanup_branch") + branch_items: list[TResponseInputItem] = [ + {"role": "user", "content": "Branch-only question"}, + {"role": "assistant", "content": "Branch-only answer"}, + ] + await session.add_items(branch_items) + + await session.delete_branch("cleanup_branch", force=True) + + with session._locked_connection() as conn: + rows = conn.execute( + f""" + SELECT message_data + FROM {session.messages_table} + WHERE session_id = ? + ORDER BY id + """, + (session.session_id,), + ).fetchall() + + contents = [json.loads(message_data)["content"] for (message_data,) in rows] + assert contents == [ + "First question", + "First answer", + "Second question", + "Second answer", + ] + assert await session.get_items(branch_id="main") == main_items + + session.close() + + +async def test_delete_branch_keeps_messages_still_referenced_by_another_branch(): + """Deleting one branch should keep messages inherited by a surviving branch.""" + session = AdvancedSQLiteSession( + session_id="branch_delete_shared_descendant_test", + create_tables=True, + ) + + main_items: list[TResponseInputItem] = [ + {"role": "user", "content": "Main first question"}, + {"role": "assistant", "content": "Main first answer"}, + {"role": "user", "content": "Main second question"}, + {"role": "assistant", "content": "Main second answer"}, + ] + branch_a_shared_items: list[TResponseInputItem] = [ + {"role": "user", "content": "Branch A shared question"}, + {"role": "assistant", "content": "Branch A shared answer"}, + ] + branch_a_only_items: list[TResponseInputItem] = [ + {"role": "user", "content": "Branch A only question"}, + {"role": "assistant", "content": "Branch A only answer"}, + ] + + try: + await session.add_items(main_items) + await session.create_branch_from_turn(2, "branch_a") + await session.add_items(branch_a_shared_items + branch_a_only_items) + + await session.create_branch_from_turn(3, "branch_b") + await session.delete_branch("branch_a") + + with session._locked_connection() as conn: + rows = conn.execute( + f""" + SELECT message_data + FROM {session.messages_table} + WHERE session_id = ? + ORDER BY id + """, + (session.session_id,), + ).fetchall() + + contents = [json.loads(message_data)["content"] for (message_data,) in rows] + assert "Branch A shared question" in contents + assert "Branch A shared answer" in contents + assert "Branch A only question" not in contents + assert "Branch A only answer" not in contents + assert await session.get_items(branch_id="branch_b") == [ + *main_items[:2], + *branch_a_shared_items, + ] + finally: + session.close() + + +async def test_orphan_cleanup_uses_set_based_delete_for_many_messages(): + """Orphan cleanup should not build one DELETE parameter per orphaned row.""" + + class RecordingCursor: + def __init__(self, cursor: Any, connection: "RecordingConnection") -> None: + self._cursor = cursor + self._connection = connection + + @property + def rowcount(self) -> int: + return cast(int, self._cursor.rowcount) + + def execute(self, sql: str, parameters: Any = None) -> Any: + normalized_sql = " ".join(sql.split()).upper() + if normalized_sql.startswith("DELETE"): + self._connection.delete_parameter_counts.append(len(parameters or ())) + if parameters is None: + return self._cursor.execute(sql) + return self._cursor.execute(sql, parameters) + + def fetchall(self) -> Any: + return self._cursor.fetchall() + + def close(self) -> None: + self._cursor.close() + + class RecordingConnection: + def __init__(self, conn: Any) -> None: + self._conn = conn + self.delete_parameter_counts: list[int] = [] + + def cursor(self) -> RecordingCursor: + return RecordingCursor(self._conn.cursor(), self) + + session = AdvancedSQLiteSession( + session_id="branch_delete_many_orphans_cleanup", + create_tables=True, + ) + orphan_items: list[TResponseInputItem] = [ + {"role": "user", "content": f"orphan {i}"} for i in range(1200) + ] + + try: + with session._locked_connection() as conn: + session._insert_items(conn, orphan_items) + conn.commit() + + recording_conn = RecordingConnection(conn) + deleted_count = session._cleanup_orphaned_messages_sync(cast(Any, recording_conn)) + conn.commit() + + remaining_count = conn.execute( + f"SELECT COUNT(*) FROM {session.messages_table} WHERE session_id = ?", + (session.session_id,), + ).fetchone()[0] + + assert deleted_count == len(orphan_items) + assert remaining_count == 0 + assert recording_conn.delete_parameter_counts == [2] + finally: + session.close() + + async def test_get_conversation_turns(): """Test get_conversation_turns functionality.""" session_id = "conversation_turns_test"