diff --git a/haystack/components/generators/hugging_face_local.py b/haystack/components/generators/hugging_face_local.py index 290bc1a11b..f12b7838bc 100644 --- a/haystack/components/generators/hugging_face_local.py +++ b/haystack/components/generators/hugging_face_local.py @@ -260,6 +260,7 @@ def run( if self.stop_words: # the output of the pipeline includes the stop word - replies = [reply.replace(stop_word, "").rstrip() for reply in replies for stop_word in self.stop_words] + for stop_word in self.stop_words: + replies = [reply.replace(stop_word, "").rstrip() for reply in replies] return {"replies": replies} diff --git a/releasenotes/notes/fix-hf-local-stop-words-cross-product-31d12d96c57c442b.yaml b/releasenotes/notes/fix-hf-local-stop-words-cross-product-31d12d96c57c442b.yaml new file mode 100644 index 0000000000..312f3e1240 --- /dev/null +++ b/releasenotes/notes/fix-hf-local-stop-words-cross-product-31d12d96c57c442b.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Fixes ``HuggingFaceLocalGenerator`` so that when multiple ``stop_words`` are configured the generator returns exactly N replies instead of N×M replies. Previously the list comprehension created a cross-product, duplicating replies and leaving some stop words unstripped. diff --git a/test/components/generators/test_hugging_face_local_generator.py b/test/components/generators/test_hugging_face_local_generator.py index 62779e855b..836dd94684 100644 --- a/test/components/generators/test_hugging_face_local_generator.py +++ b/test/components/generators/test_hugging_face_local_generator.py @@ -420,6 +420,22 @@ def test_run_stop_words_removal(self): results = generator.run(prompt="irrelevant") assert results == {"replies": ["Hello"]} + def test_run_multiple_stop_words_removal(self): + """Test that multiple stop words are all removed without producing N*M replies.""" + generator = HuggingFaceLocalGenerator( + model="Qwen/Qwen3-0.6B", task="text-generation", stop_words=["STOP", "END"] + ) + generator.pipeline = Mock( + return_value=[ + {"generated_text": "Paris is the capital. STOP"}, + {"generated_text": "France is in Europe. STOP"}, + ] + ) + generator.stopping_criteria_list = Mock() + results = generator.run(prompt="irrelevant") + # Should return exactly 2 replies, both stop words stripped + assert results == {"replies": ["Paris is the capital.", "France is in Europe."]} + @pytest.mark.integration def test_stop_words_criteria_using_hf_tokenizer(self): """