From 771ff86bab9aab00f59de35d3c64b9ddc94bb0cf Mon Sep 17 00:00:00 2001 From: 0xVox Date: Tue, 2 Jun 2026 17:16:16 -0700 Subject: [PATCH] fix(agentic): search API queries all requested memory_types, not just the first The keyword and vector search paths read only `memory_types[0]`, so any request asking for multiple memory types silently returned hits from the first type and dropped the rest. retrieve_mem now iterates every requested memory type for both keyword (ES_REPO_MAP) and vector (MILVUS_REPO_MAP) search, merges the results, and sorts by score. Hybrid dedup is also fixed: previously hits were deduped by `id` alone, so two distinct collections that happen to share a backend id erased each other. Dedup now keys on `(memory_type, id)` via `_hit_identity`. Adds an offline regression test (monkeypatched repos, no live stack) that asserts both EPISODIC_MEMORY and AGENT_CASE return hits and that hybrid dedup keeps same-id hits from distinct memory types. Closes #78 Co-authored-by: Exploreunive <117084012+Exploreunive@users.noreply.github.com> Co-authored-by: Jah-yee <166608075+Jah-yee@users.noreply.github.com> Co-authored-by: wucm667 <109257021+wucm667@users.noreply.github.com> Co-Authored-By: Claude Opus 4.8 (1M context) --- .../src/agentic_layer/memory_manager.py | 330 +++++++++--------- .../test_memory_manager_multi_type_search.py | 80 +++++ 2 files changed, 250 insertions(+), 160 deletions(-) create mode 100644 methods/EverCore/tests/test_memory_manager_multi_type_search.py diff --git a/methods/EverCore/src/agentic_layer/memory_manager.py b/methods/EverCore/src/agentic_layer/memory_manager.py index e6e93005..08c6b59b 100644 --- a/methods/EverCore/src/agentic_layer/memory_manager.py +++ b/methods/EverCore/src/agentic_layer/memory_manager.py @@ -117,6 +117,39 @@ MemoryType.AGENT_SKILL: AgentSkillEsRepository, } +MILVUS_REPO_MAP = { + MemoryType.FORESIGHT: ForesightMilvusRepository, + MemoryType.ATOMIC_FACT: AtomicFactMilvusRepository, + MemoryType.EPISODIC_MEMORY: EpisodicMemoryMilvusRepository, + MemoryType.AGENT_CASE: AgentCaseMilvusRepository, + MemoryType.AGENT_SKILL: AgentSkillMilvusRepository, +} + + +def _memory_type_label(memory_types: List[MemoryType]) -> str: + if not memory_types: + return 'unknown' + return ','.join(memory_type.value for memory_type in memory_types) + + +def _hit_score(hit: Dict[str, Any]) -> float: + raw_score = hit.get('score', hit.get('_score', 0.0)) + try: + return float(raw_score) + except (TypeError, ValueError): + return 0.0 + + +def _sort_hits_by_score(hits: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + return sorted(hits, key=_hit_score, reverse=True) + + +def _hit_identity(hit: Dict[str, Any]) -> Optional[tuple[str, str]]: + hit_id = hit.get('id') or hit.get('_id') + if not hit_id: + return None + return (str(hit.get('memory_type', 'unknown')), str(hit_id)) + @dataclass class AtomicFactCandidate: @@ -457,11 +490,6 @@ async def retrieve_mem_keyword( """Keyword-based memory retrieval""" top_k = retrieve_mem_request.top_k is_unlimited_mode = top_k == -1 - memory_type = ( - retrieve_mem_request.memory_types[0].value - if retrieve_mem_request.memory_types - else 'unknown' - ) try: hits = await self.get_keyword_search_results( @@ -485,11 +513,7 @@ async def get_keyword_search_results( ) -> List[Dict[str, Any]]: """Keyword search with stage-level metrics""" stage_start = time.perf_counter() - memory_type = ( - retrieve_mem_request.memory_types[0].value - if retrieve_mem_request.memory_types - else 'unknown' - ) + memory_type = _memory_type_label(retrieve_mem_request.memory_types) try: # Get parameters from Request @@ -528,32 +552,34 @@ async def get_keyword_search_results( if end_time is not None: date_range["lte"] = end_time - mem_type = memory_types[0] - - repo_class = ES_REPO_MAP.get(mem_type) - if not repo_class: - logger.warning(f"Unsupported memory_type: {mem_type}") - return [] + all_results = [] + for mem_type in memory_types: + repo_class = ES_REPO_MAP.get(mem_type) + if not repo_class: + logger.warning(f"Unsupported memory_type: {mem_type}") + continue - es_repo = get_bean_by_type(repo_class) - logger.debug(f"Using {repo_class.__name__} for {mem_type}") + es_repo = get_bean_by_type(repo_class) + logger.debug(f"Using {repo_class.__name__} for {mem_type}") - results = await es_repo.multi_search( - query=query_words, - user_id=user_id, - group_ids=group_ids, # Pass normalized list - size=effective_limit, - from_=0, - date_range=date_range, - ) + results = await es_repo.multi_search( + query=query_words, + user_id=user_id, + group_ids=group_ids, # Pass normalized list + size=effective_limit, + from_=0, + date_range=date_range, + ) - # Mark memory_type, search_source, and unified score - if results: - for r in results: - r['memory_type'] = mem_type.value - r['_search_source'] = RetrieveMethod.KEYWORD.value - r['id'] = r.get('_id', '') # Unify ES '_id' to 'id' - r['score'] = r.get('_score', 0.0) # Unified score field + # Mark memory_type, search_source, and unified score + if results: + for r in results: + r['memory_type'] = mem_type.value + r['_search_source'] = RetrieveMethod.KEYWORD.value + r['id'] = r.get('_id', '') # Unify ES '_id' to 'id' + r['score'] = r.get('_score', 0.0) # Unified score field + all_results.extend(results) + results = _sort_hits_by_score(all_results) # Record stage metrics record_retrieve_stage( @@ -587,11 +613,6 @@ async def retrieve_mem_vector( """Vector-based memory retrieval""" top_k = retrieve_mem_request.top_k is_unlimited_mode = top_k == -1 - memory_type = ( - retrieve_mem_request.memory_types[0].value - if retrieve_mem_request.memory_types - else 'unknown' - ) try: hits = await self.get_vector_search_results( @@ -614,11 +635,8 @@ async def get_vector_search_results( retrieve_method: str = RetrieveMethod.VECTOR.value, ) -> List[Dict[str, Any]]: """Vector search with stage-level metrics (embedding + milvus_search)""" - memory_type = ( - retrieve_mem_request.memory_types[0].value - if retrieve_mem_request.memory_types - else 'unknown' - ) + stage_start = time.perf_counter() + memory_type = _memory_type_label(retrieve_mem_request.memory_types) try: # Get parameters from Request @@ -643,12 +661,8 @@ async def get_vector_search_results( effective_limit = DEFAULT_TOPK_LIMIT else: effective_limit = top_k * DEFAULT_RECALL_MULTIPLIER - # Milvus similarity threshold (only applied in unlimited mode or when user specifies radius) - effective_radius = None start_time = retrieve_mem_request.start_time end_time = retrieve_mem_request.end_time - mem_type = retrieve_mem_request.memory_types[0] - logger.debug( f"retrieve_mem_vector called with query: {query}, user_id: {user_id}, group_ids: {group_ids}, top_k: {top_k}" ) @@ -671,115 +685,113 @@ async def get_vector_search_results( f"Query text vectorization completed, vector dimension: {len(query_vector_list)}" ) - # Select Milvus repository based on memory type - match mem_type: - case MemoryType.FORESIGHT: - milvus_repo = get_bean_by_type(ForesightMilvusRepository) - case MemoryType.ATOMIC_FACT: - milvus_repo = get_bean_by_type(AtomicFactMilvusRepository) - case MemoryType.EPISODIC_MEMORY: - milvus_repo = get_bean_by_type(EpisodicMemoryMilvusRepository) - case MemoryType.AGENT_CASE: - milvus_repo = get_bean_by_type(AgentCaseMilvusRepository) - case MemoryType.AGENT_SKILL: - milvus_repo = get_bean_by_type(AgentSkillMilvusRepository) - case _: - raise ValueError(f"Unsupported memory type: {mem_type}") + all_search_results = [] + for mem_type in retrieve_mem_request.memory_types: + repo_class = MILVUS_REPO_MAP.get(mem_type) + if not repo_class: + logger.warning(f"Unsupported memory type: {mem_type}") + continue - # Handle time range filter conditions - start_time_dt = None - end_time_dt = None + milvus_repo = get_bean_by_type(repo_class) - if start_time is not None: - start_time_dt = ( - from_iso_format(start_time) - if isinstance(start_time, str) - else start_time - ) + # Handle time range filter conditions + start_time_dt = None + end_time_dt = None - if end_time is not None: - if isinstance(end_time, str): - end_time_dt = from_iso_format(end_time) - # If date only format, set to end of day - if len(end_time) == 10: - end_time_dt = end_time_dt.replace(hour=23, minute=59, second=59) + if start_time is not None: + start_time_dt = ( + from_iso_format(start_time) + if isinstance(start_time, str) + else start_time + ) + + if end_time is not None: + if isinstance(end_time, str): + end_time_dt = from_iso_format(end_time) + # If date only format, set to end of day + if len(end_time) == 10: + end_time_dt = end_time_dt.replace( + hour=23, minute=59, second=59 + ) + else: + end_time_dt = end_time + + # Handle foresight time range (only valid for foresight) + if mem_type == MemoryType.FORESIGHT: + if retrieve_mem_request.start_time: + start_time_dt = from_iso_format(retrieve_mem_request.start_time) + if retrieve_mem_request.end_time: + end_time_dt = from_iso_format(retrieve_mem_request.end_time) + + # Call Milvus vector search (pass different parameters based on memory type) + # Threshold logic: + # - User specified radius: always use it + # - Unlimited mode (top_k=-1): apply DEFAULT_MILVUS_SIMILARITY_THRESHOLD (0.6) + # - Normal mode (top_k>0): no threshold filtering (rely on top_k limit) + if retrieve_mem_request.radius is not None: + # User specified radius, use it + effective_radius = retrieve_mem_request.radius + elif top_k == -1: + # Unlimited mode: apply default Milvus threshold for quality filtering + effective_radius = DEFAULT_MILVUS_SIMILARITY_THRESHOLD else: - end_time_dt = end_time - - # Handle foresight time range (only valid for foresight) - if mem_type == MemoryType.FORESIGHT: - if retrieve_mem_request.start_time: - start_time_dt = from_iso_format(retrieve_mem_request.start_time) - if retrieve_mem_request.end_time: - end_time_dt = from_iso_format(retrieve_mem_request.end_time) - - # Call Milvus vector search (pass different parameters based on memory type) - # Threshold logic: - # - User specified radius: always use it - # - Unlimited mode (top_k=-1): apply DEFAULT_MILVUS_SIMILARITY_THRESHOLD (0.6) - # - Normal mode (top_k>0): no threshold filtering (rely on top_k limit) - if retrieve_mem_request.radius is not None: - # User specified radius, use it - effective_radius = retrieve_mem_request.radius - elif top_k == -1: - # Unlimited mode: apply default Milvus threshold for quality filtering - effective_radius = DEFAULT_MILVUS_SIMILARITY_THRESHOLD - # else: keep None (no threshold filtering for normal top_k mode) - - milvus_start = time.perf_counter() - if mem_type == MemoryType.FORESIGHT: - # Foresight: supports time range and validity filtering, supports radius parameter - search_results = await milvus_repo.vector_search( - query_vector=query_vector_list, - user_id=user_id, - group_ids=group_ids, # Pass normalized list - start_time=start_time_dt, - end_time=end_time_dt, - limit=effective_limit, - score_threshold=0.0, - radius=effective_radius, - ) - elif mem_type == MemoryType.AGENT_SKILL: - # Agent skill: no timestamp filtering - search_results = await milvus_repo.vector_search( - query_vector=query_vector_list, - user_id=user_id, - group_ids=group_ids, - limit=effective_limit, - score_threshold=0.0, - radius=effective_radius, - ) - else: - # Episodic memory, atomic fact, agent case: use timestamp filtering - search_results = await milvus_repo.vector_search( - query_vector=query_vector_list, - user_id=user_id, - group_ids=group_ids, # Pass normalized list - start_time=start_time_dt, - end_time=end_time_dt, - limit=effective_limit, - score_threshold=0.0, - radius=effective_radius, + effective_radius = None + + milvus_start = time.perf_counter() + if mem_type == MemoryType.FORESIGHT: + # Foresight: supports time range and validity filtering, supports radius parameter + search_results = await milvus_repo.vector_search( + query_vector=query_vector_list, + user_id=user_id, + group_ids=group_ids, # Pass normalized list + start_time=start_time_dt, + end_time=end_time_dt, + limit=effective_limit, + score_threshold=0.0, + radius=effective_radius, + ) + elif mem_type == MemoryType.AGENT_SKILL: + # Agent skill: no timestamp filtering + search_results = await milvus_repo.vector_search( + query_vector=query_vector_list, + user_id=user_id, + group_ids=group_ids, + limit=effective_limit, + score_threshold=0.0, + radius=effective_radius, + ) + else: + # Episodic memory, atomic fact, agent case: use timestamp filtering + search_results = await milvus_repo.vector_search( + query_vector=query_vector_list, + user_id=user_id, + group_ids=group_ids, # Pass normalized list + start_time=start_time_dt, + end_time=end_time_dt, + limit=effective_limit, + score_threshold=0.0, + radius=effective_radius, + ) + record_retrieve_stage( + retrieve_method=retrieve_method, + stage='milvus_search', + memory_type=mem_type.value, + duration_seconds=time.perf_counter() - milvus_start, ) - record_retrieve_stage( - retrieve_method=retrieve_method, - stage='milvus_search', - memory_type=memory_type, - duration_seconds=time.perf_counter() - milvus_start, - ) - for r in search_results: - r['memory_type'] = mem_type.value - r['_search_source'] = RetrieveMethod.VECTOR.value - # Milvus already uses 'score', no need to rename + for r in search_results or []: + r['memory_type'] = mem_type.value + r['_search_source'] = RetrieveMethod.VECTOR.value + # Milvus already uses 'score', no need to rename + all_search_results.extend(search_results or []) - return search_results + return _sort_hits_by_score(all_search_results) except Exception as e: record_retrieve_stage( retrieve_method=retrieve_method, stage=RetrieveMethod.VECTOR.value, memory_type=memory_type, - duration_seconds=time.perf_counter() - milvus_start, + duration_seconds=time.perf_counter() - stage_start, ) record_retrieve_error( retrieve_method=retrieve_method, @@ -795,12 +807,6 @@ async def retrieve_mem_hybrid( self, retrieve_mem_request: 'RetrieveMemRequest' ) -> RetrieveMemResponse: """Hybrid memory retrieval: keyword + vector + rerank""" - memory_type = ( - retrieve_mem_request.memory_types[0].value - if retrieve_mem_request.memory_types - else 'unknown' - ) - try: hits = await self._search_hybrid( retrieve_mem_request, retrieve_method=RetrieveMethod.HYBRID.value @@ -882,9 +888,7 @@ async def _search_hybrid( retrieve_method: str = RetrieveMethod.HYBRID.value, ) -> List[Dict]: """Core hybrid search: keyword + vector + rerank, returns flat list""" - memory_type = ( - request.memory_types[0].value if request.memory_types else 'unknown' - ) + memory_type = _memory_type_label(request.memory_types) top_k = request.top_k is_unlimited_mode = top_k == -1 @@ -893,11 +897,17 @@ async def _search_hybrid( self.get_keyword_search_results(request, retrieve_method=retrieve_method), self.get_vector_search_results(request, retrieve_method=retrieve_method), ) - # Deduplicate by id - seen_ids = {h.get('id') for h in kw_results} - merged_results = kw_results + [ - h for h in vec_results if h.get('id') not in seen_ids - ] + # Deduplicate by memory collection and id so unrelated collections with + # the same backend id do not erase each other. + seen_ids = set() + merged_results = [] + for hit in [*kw_results, *vec_results]: + identity = _hit_identity(hit) + if identity is not None: + if identity in seen_ids: + continue + seen_ids.add(identity) + merged_results.append(hit) # When top_k is -1, use DEFAULT_TOPK_LIMIT for rerank rerank_limit = DEFAULT_TOPK_LIMIT if is_unlimited_mode else top_k @@ -991,7 +1001,7 @@ async def retrieve_mem_agentic( top_k = req.top_k is_unlimited_mode = top_k == -1 config = AgenticConfig() - memory_type = req.memory_types[0].value if req.memory_types else 'unknown' + memory_type = _memory_type_label(req.memory_types) try: llm_provider = build_default_provider() @@ -1359,7 +1369,7 @@ async def group_by_groupid_stratagy( task_intent=fields.get('task_intent', ''), approach=fields.get('approach', ''), quality_score=fields.get('quality_score'), - key_insight=fields.get('key_insight', '') + key_insight=fields.get('key_insight', ''), ) case MemoryType.AGENT_SKILL.value: # AgentSkill doesn't have parent_type/parent_id fields diff --git a/methods/EverCore/tests/test_memory_manager_multi_type_search.py b/methods/EverCore/tests/test_memory_manager_multi_type_search.py new file mode 100644 index 00000000..380c698e --- /dev/null +++ b/methods/EverCore/tests/test_memory_manager_multi_type_search.py @@ -0,0 +1,80 @@ +import pytest + +from agentic_layer import memory_manager as memory_manager_module +from agentic_layer.memory_manager import MemoryManager +from api_specs.dtos import RetrieveMemRequest +from api_specs.memory_models import MemoryType + + +class _RepoA: + async def multi_search(self, **kwargs): + return [{'_id': 'a', '_score': 0.2}] + + +class _RepoB: + async def multi_search(self, **kwargs): + return [{'_id': 'b', '_score': 0.9}] + + +@pytest.mark.asyncio +async def test_keyword_search_queries_all_requested_memory_types(monkeypatch): + repos = {_RepoA: _RepoA(), _RepoB: _RepoB()} + monkeypatch.setattr( + memory_manager_module, + 'ES_REPO_MAP', + {MemoryType.EPISODIC_MEMORY: _RepoA, MemoryType.AGENT_CASE: _RepoB}, + ) + monkeypatch.setattr( + memory_manager_module, 'get_bean_by_type', lambda repo_class: repos[repo_class] + ) + + manager = object.__new__(MemoryManager) + request = RetrieveMemRequest( + query='soccer', + group_ids=['group-1'], + top_k=10, + memory_types=[MemoryType.EPISODIC_MEMORY, MemoryType.AGENT_CASE], + ) + + hits = await manager.get_keyword_search_results(request) + + assert [hit['memory_type'] for hit in hits] == [ + MemoryType.AGENT_CASE.value, + MemoryType.EPISODIC_MEMORY.value, + ] + assert [hit['id'] for hit in hits] == ['b', 'a'] + + +@pytest.mark.asyncio +async def test_hybrid_dedupe_keeps_same_id_from_distinct_memory_types(): + manager = object.__new__(MemoryManager) + + async def keyword_results(*args, **kwargs): + return [{'id': 'same', 'memory_type': 'episodic_memory', 'score': 0.8}] + + async def vector_results(*args, **kwargs): + return [ + {'id': 'same', 'memory_type': 'agent_case', 'score': 0.9}, + {'id': 'same', 'memory_type': 'episodic_memory', 'score': 0.7}, + ] + + async def rerank(query, hits, top_k, *args, **kwargs): + return hits + + manager.get_keyword_search_results = keyword_results + manager.get_vector_search_results = vector_results + manager._rerank = rerank + + request = RetrieveMemRequest( + query='soccer', + group_ids=['group-1'], + top_k=10, + memory_types=[MemoryType.EPISODIC_MEMORY, MemoryType.AGENT_CASE], + ) + + hits = await manager._search_hybrid(request) + + assert hits == [ + {'id': 'same', 'memory_type': 'episodic_memory', 'score': 0.8}, + {'id': 'same', 'memory_type': 'agent_case', 'score': 0.9}, + ]