diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index eb15b48ed..a021f56e1 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -320,10 +320,13 @@ def _parse_task( ) query = parsed_goal.rephrased_query or query - # if goal has extra memories, embed them too - if parsed_goal.memories: - embed_texts = list(dict.fromkeys([query, *parsed_goal.memories])) - query_embedding = self.embedder.embed(embed_texts) + embed_texts = [ + text.strip() + for text in [query, *parsed_goal.memories] + if isinstance(text, str) and text.strip() + ] + if embed_texts: + query_embedding = self.embedder.embed(list(dict.fromkeys(embed_texts))) return parsed_goal, query_embedding, context, query @timed diff --git a/tests/memories/textual/test_tree_searcher.py b/tests/memories/textual/test_tree_searcher.py index 3d1469d00..82b733233 100644 --- a/tests/memories/textual/test_tree_searcher.py +++ b/tests/memories/textual/test_tree_searcher.py @@ -3,6 +3,7 @@ import pytest from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata +from memos.memories.textual.tree_text_memory.retrieve.retrieval_mid_structs import ParsedTaskGoal from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.reranker.base import BaseReranker @@ -24,13 +25,14 @@ def mock_searcher(): return s -def make_item(content: str, score: float): +def make_item(content: str, score: float, memory_type: str = "WorkingMemory"): # Simulate a TextualMemoryItem with usage list for update test return ( TextualMemoryItem( memory=content, metadata=TreeNodeTextualMemoryMetadata( embedding=[0.1] * 5, + memory_type=memory_type, usage=[], ), ), @@ -104,6 +106,34 @@ def test_searcher_fine_mode_triggers_reasoner(mock_searcher): assert len(result) == 1 +def test_fine_search_embeds_query_when_parser_returns_no_memory_expansions(mock_searcher): + query = "我喜欢什么" + user_memory = make_item("我喜欢草莓", 0.9, memory_type="UserMemory")[0] + + mock_searcher.task_goal_parser.parse.return_value = ParsedTaskGoal( + keys=["喜欢", "偏好", "兴趣"], + tags=["personal preference", "user interest", "taste"], + memories=[], + rephrased_query="", + ) + mock_searcher.embedder.embed.return_value = [[0.1] * 5] + + def retrieve_side_effect(*args, **kwargs): + if kwargs.get("memory_scope") == "UserMemory": + assert kwargs["query_embedding"] == [[0.1] * 5] + return [user_memory] + return [] + + mock_searcher.graph_retriever.retrieve.side_effect = retrieve_side_effect + mock_searcher.reranker.rerank.return_value = [(user_memory, 0.9)] + + result = mock_searcher.search(query=query, top_k=1, mode="fine", memory_type="UserMemory") + + mock_searcher.embedder.embed.assert_called_once_with([query]) + assert len(result) == 1 + assert result[0].memory == "我喜欢草莓" + + def test_searcher_respects_memory_type(mock_searcher): parsed_goal = MagicMock() parsed_goal.memories = ["Something"]