Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,13 @@ def _parse_task(
)

query = parsed_goal.rephrased_query or query
# if goal has extra memories, embed them too
if parsed_goal.memories:
embed_texts = list(dict.fromkeys([query, *parsed_goal.memories]))
query_embedding = self.embedder.embed(embed_texts)
embed_texts = [
text.strip()
for text in [query, *parsed_goal.memories]
if isinstance(text, str) and text.strip()
]
if embed_texts:
query_embedding = self.embedder.embed(list(dict.fromkeys(embed_texts)))
return parsed_goal, query_embedding, context, query

@timed
Expand Down
32 changes: 31 additions & 1 deletion tests/memories/textual/test_tree_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata
from memos.memories.textual.tree_text_memory.retrieve.retrieval_mid_structs import ParsedTaskGoal
from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
from memos.reranker.base import BaseReranker

Expand All @@ -24,13 +25,14 @@ def mock_searcher():
return s


def make_item(content: str, score: float):
def make_item(content: str, score: float, memory_type: str = "WorkingMemory"):
# Simulate a TextualMemoryItem with usage list for update test
return (
TextualMemoryItem(
memory=content,
metadata=TreeNodeTextualMemoryMetadata(
embedding=[0.1] * 5,
memory_type=memory_type,
usage=[],
),
),
Expand Down Expand Up @@ -104,6 +106,34 @@ def test_searcher_fine_mode_triggers_reasoner(mock_searcher):
assert len(result) == 1


def test_fine_search_embeds_query_when_parser_returns_no_memory_expansions(mock_searcher):
query = "我喜欢什么"
user_memory = make_item("我喜欢草莓", 0.9, memory_type="UserMemory")[0]

mock_searcher.task_goal_parser.parse.return_value = ParsedTaskGoal(
keys=["喜欢", "偏好", "兴趣"],
tags=["personal preference", "user interest", "taste"],
memories=[],
rephrased_query="",
)
mock_searcher.embedder.embed.return_value = [[0.1] * 5]

def retrieve_side_effect(*args, **kwargs):
if kwargs.get("memory_scope") == "UserMemory":
assert kwargs["query_embedding"] == [[0.1] * 5]
return [user_memory]
return []

mock_searcher.graph_retriever.retrieve.side_effect = retrieve_side_effect
mock_searcher.reranker.rerank.return_value = [(user_memory, 0.9)]

result = mock_searcher.search(query=query, top_k=1, mode="fine", memory_type="UserMemory")

mock_searcher.embedder.embed.assert_called_once_with([query])
assert len(result) == 1
assert result[0].memory == "我喜欢草莓"


def test_searcher_respects_memory_type(mock_searcher):
parsed_goal = MagicMock()
parsed_goal.memories = ["Something"]
Expand Down
Loading