From 387752c56e587124f22e5f1aa0decfed8aadd7e3 Mon Sep 17 00:00:00 2001 From: biefan <70761325+biefan@users.noreply.github.com> Date: Tue, 17 Mar 2026 06:52:34 +0000 Subject: [PATCH] Return no memory results for empty prompt ID filters --- pyrit/memory/memory_interface.py | 3 +++ .../test_interface_prompts.py | 12 +++++++++ .../memory_interface/test_interface_scores.py | 25 +++++++++++++++++++ 3 files changed, 40 insertions(+) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 90322ebec4..2faad11d36 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -610,6 +610,9 @@ def get_message_pieces( Exception: If there is an error retrieving the prompts, an exception is logged and an empty list is returned. """ + if prompt_ids is not None and len(prompt_ids) == 0: + return [] + conditions = [] if attack_id: conditions.append(self._get_message_pieces_attack_conditions(attack_id=str(attack_id))) diff --git a/tests/unit/memory/memory_interface/test_interface_prompts.py b/tests/unit/memory/memory_interface/test_interface_prompts.py index 67a4292f87..53e2d0f180 100644 --- a/tests/unit/memory/memory_interface/test_interface_prompts.py +++ b/tests/unit/memory/memory_interface/test_interface_prompts.py @@ -108,6 +108,18 @@ def test_get_message_pieces_uuid_and_string_ids(sqlite_instance: MemoryInterface assert str(single_str_result[0].id) == str(uuid3) +def test_get_message_pieces_empty_prompt_ids_returns_empty(sqlite_instance: MemoryInterface): + piece = MessagePiece( + id=uuid.uuid4(), + role="user", + original_value="Test prompt", + converted_value="Test prompt", + ) + sqlite_instance.add_message_pieces_to_memory(message_pieces=[piece]) + + assert sqlite_instance.get_message_pieces(prompt_ids=[]) == [] + + def test_duplicate_memory(sqlite_instance: MemoryInterface): attack1 = PromptSendingAttack(objective_target=get_mock_target()) attack2 = PromptSendingAttack(objective_target=get_mock_target("Target2")) diff --git a/tests/unit/memory/memory_interface/test_interface_scores.py b/tests/unit/memory/memory_interface/test_interface_scores.py index 6087af1418..34358d5e93 100644 --- a/tests/unit/memory/memory_interface/test_interface_scores.py +++ b/tests/unit/memory/memory_interface/test_interface_scores.py @@ -131,6 +131,31 @@ def test_add_score_get_score( assert db_score[0].message_piece_id == prompt_id +def test_get_prompt_scores_empty_prompt_ids_returns_empty(sqlite_instance: MemoryInterface): + prompt_id = uuid4() + piece = MessagePiece( + id=prompt_id, + role="user", + original_value="original prompt text", + converted_value="Hello, how are you?", + ) + sqlite_instance.add_message_pieces_to_memory(message_pieces=[piece]) + + score = Score( + score_value=str(0.8), + score_value_description="High score", + score_type="float_scale", + score_category=["test"], + score_rationale="Test score", + score_metadata={"test": "metadata"}, + scorer_class_identifier=_test_scorer_id("TestScorer"), + message_piece_id=prompt_id, + ) + sqlite_instance.add_scores_to_memory(scores=[score]) + + assert sqlite_instance.get_prompt_scores(prompt_ids=[]) == [] + + def test_add_score_duplicate_prompt(sqlite_instance: MemoryInterface): # Ensure that scores of duplicate prompts are linked back to the original original_id = uuid4()