From 5f2fe03ed1252a3e12f3da513afb273c6585110a Mon Sep 17 00:00:00 2001 From: biefan Date: Tue, 17 Mar 2026 15:21:57 +0800 Subject: [PATCH] Deduplicate message pieces before batch scoring --- pyrit/score/batch_scorer.py | 5 ++++ tests/unit/score/test_batch_scorer.py | 40 +++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/pyrit/score/batch_scorer.py b/pyrit/score/batch_scorer.py index de571db9d0..f0e566f5b9 100644 --- a/pyrit/score/batch_scorer.py +++ b/pyrit/score/batch_scorer.py @@ -102,6 +102,11 @@ async def score_responses_by_filters_async( converted_value_sha256=converted_value_sha256, ) + if not message_pieces: + raise ValueError("No entries match the provided filters. Please check your filters.") + + message_pieces = self._remove_duplicates(message_pieces) + if not message_pieces: raise ValueError("No entries match the provided filters. Please check your filters.") diff --git a/tests/unit/score/test_batch_scorer.py b/tests/unit/score/test_batch_scorer.py index f91e8e3d15..136d3ed336 100644 --- a/tests/unit/score/test_batch_scorer.py +++ b/tests/unit/score/test_batch_scorer.py @@ -325,3 +325,43 @@ async def test_score_responses_by_filters_groups_by_sequence_within_conversation assert len(messages[0].message_pieces) == 1 assert len(messages[1].message_pieces) == 2 assert len(messages[2].message_pieces) == 1 + + @pytest.mark.asyncio + async def test_score_responses_by_filters_removes_duplicate_message_pieces(self) -> None: + """Test that duplicate message pieces are filtered out before batch scoring.""" + memory = MagicMock() + original_piece_id = uuid.uuid4() + + pieces = [ + MessagePiece( + id=original_piece_id, + role="assistant", + conversation_id="conv1", + sequence=1, + original_value="Original response", + ), + MessagePiece( + role="assistant", + conversation_id="conv1", + sequence=1, + original_value="Duplicate response copy", + original_prompt_id=original_piece_id, + ), + ] + + memory.get_message_pieces.return_value = pieces + + with patch.object(CentralMemory, "get_memory_instance", return_value=memory): + scorer = MagicMock() + scorer.score_prompts_batch_async = AsyncMock(return_value=[]) + + batch_scorer = BatchScorer() + + await batch_scorer.score_responses_by_filters_async(scorer=scorer, conversation_id="conv1") + + call_args = scorer.score_prompts_batch_async.call_args + messages = call_args.kwargs["messages"] + + assert len(messages) == 1 + assert len(messages[0].message_pieces) == 1 + assert messages[0].message_pieces[0].id == original_piece_id