From 1dc9f798664f2e2e763fb428ae621568207aa6ec Mon Sep 17 00:00:00 2001 From: r266-tech Date: Mon, 2 Mar 2026 21:44:00 +0800 Subject: [PATCH] fix(search): use all memory_types in search API instead of only first one Previously, all retrieval methods (keyword, vector, hybrid, RRF, agentic) only used memory_types[0], silently ignoring all other types in the list. If the first type was unsupported (e.g. profile), the entire search errored out. Changes: - get_keyword_search_results: iterate over ALL memory_types, search each supported type via ES_REPO_MAP, skip unsupported types with info log, merge results - get_vector_search_results: compute embedding ONCE, iterate over ALL memory_types, search each supported Milvus repo, skip unsupported types, merge results (foresight special params handled per-type) - Add module-level MILVUS_REPO_MAP dict for cleaner repo lookup - Add _memory_types_label() helper for metrics/logging across all methods - _to_response uses joined memory types string for metadata Also adds comprehensive unit tests covering multi-type search, unsupported type skipping, single type backward compatibility, and embedding-once verification. Closes #78 --- src/agentic_layer/memory_manager.py | 229 +++++++------ tests/test_memory_manager_search.py | 483 ++++++++++++++++++++++++++++ 2 files changed, 595 insertions(+), 117 deletions(-) create mode 100644 tests/test_memory_manager_search.py diff --git a/src/agentic_layer/memory_manager.py b/src/agentic_layer/memory_manager.py index df42b05a..448da792 100644 --- a/src/agentic_layer/memory_manager.py +++ b/src/agentic_layer/memory_manager.py @@ -94,6 +94,20 @@ MemoryType.EPISODIC_MEMORY: EpisodicMemoryEsRepository, } +# MemoryType -> Milvus Repository mapping +MILVUS_REPO_MAP = { + MemoryType.FORESIGHT: ForesightMilvusRepository, + MemoryType.EVENT_LOG: EventLogMilvusRepository, + MemoryType.EPISODIC_MEMORY: EpisodicMemoryMilvusRepository, +} + + +def _memory_types_label(memory_types: List) -> str: + """Return a comma-joined string of memory type values for metrics/logging.""" + if not memory_types: + return 'unknown' + return ','.join(mt.value for mt in memory_types) + @dataclass class EventLogCandidate: @@ -298,11 +312,7 @@ async def retrieve_mem_keyword( ) -> RetrieveMemResponse: """Keyword-based memory retrieval""" start_time = time.perf_counter() - memory_type = ( - retrieve_mem_request.memory_types[0].value - if retrieve_mem_request.memory_types - else 'unknown' - ) + memory_type = _memory_types_label(retrieve_mem_request.memory_types) try: hits = await self.get_keyword_search_results( @@ -339,11 +349,7 @@ async def get_keyword_search_results( ) -> List[Dict[str, Any]]: """Keyword search with stage-level metrics""" stage_start = time.perf_counter() - memory_type = ( - retrieve_mem_request.memory_types[0].value - if retrieve_mem_request.memory_types - else 'unknown' - ) + memory_type = _memory_types_label(retrieve_mem_request.memory_types) try: # Get parameters from Request @@ -375,32 +381,36 @@ async def get_keyword_search_results( if end_time is not None: date_range["lte"] = end_time - mem_type = memory_types[0] - - repo_class = ES_REPO_MAP.get(mem_type) - if not repo_class: - logger.warning(f"Unsupported memory_type: {mem_type}") - return [] + # Iterate over ALL memory types and merge results + all_results = [] + for mem_type in memory_types: + repo_class = ES_REPO_MAP.get(mem_type) + if not repo_class: + logger.info( + f"Memory type {mem_type} is not searchable via ES, skipping" + ) + continue - es_repo = get_bean_by_type(repo_class) - logger.debug(f"Using {repo_class.__name__} for {mem_type}") + es_repo = get_bean_by_type(repo_class) + logger.debug(f"Using {repo_class.__name__} for {mem_type}") - results = await es_repo.multi_search( - query=query_words, - user_id=user_id, - group_id=group_id, - size=top_k, - from_=0, - date_range=date_range, - ) + results = await es_repo.multi_search( + query=query_words, + user_id=user_id, + group_id=group_id, + size=top_k, + from_=0, + date_range=date_range, + ) - # Mark memory_type, search_source, and unified score - if results: - for r in results: - r['memory_type'] = mem_type.value - r['_search_source'] = RetrieveMethod.KEYWORD.value - r['id'] = r.get('_id', '') # Unify ES '_id' to 'id' - r['score'] = r.get('_score', 0.0) # Unified score field + # Mark memory_type, search_source, and unified score + if results: + for r in results: + r['memory_type'] = mem_type.value + r['_search_source'] = RetrieveMethod.KEYWORD.value + r['id'] = r.get('_id', '') # Unify ES '_id' to 'id' + r['score'] = r.get('_score', 0.0) # Unified score field + all_results.extend(results) # Record stage metrics record_retrieve_stage( @@ -410,7 +420,7 @@ async def get_keyword_search_results( duration_seconds=time.perf_counter() - stage_start, ) - return results or [] + return all_results except Exception as e: record_retrieve_stage( retrieve_method=retrieve_method, @@ -433,11 +443,7 @@ async def retrieve_mem_vector( ) -> RetrieveMemResponse: """Vector-based memory retrieval""" start_time = time.perf_counter() - memory_type = ( - retrieve_mem_request.memory_types[0].value - if retrieve_mem_request.memory_types - else 'unknown' - ) + memory_type = _memory_types_label(retrieve_mem_request.memory_types) try: hits = await self.get_vector_search_results( @@ -473,11 +479,8 @@ async def get_vector_search_results( retrieve_method: str = RetrieveMethod.VECTOR.value, ) -> List[Dict[str, Any]]: """Vector search with stage-level metrics (embedding + milvus_search)""" - memory_type = ( - retrieve_mem_request.memory_types[0].value - if retrieve_mem_request.memory_types - else 'unknown' - ) + memory_type = _memory_types_label(retrieve_mem_request.memory_types) + milvus_start = time.perf_counter() try: # Get parameters from Request @@ -497,7 +500,7 @@ async def get_vector_search_results( top_k = retrieve_mem_request.top_k start_time = retrieve_mem_request.start_time end_time = retrieve_mem_request.end_time - mem_type = retrieve_mem_request.memory_types[0] + memory_types = retrieve_mem_request.memory_types logger.debug( f"retrieve_mem_vector called with query: {query}, user_id: {user_id}, group_id: {group_id}, top_k: {top_k}" @@ -506,7 +509,7 @@ async def get_vector_search_results( # Get vectorization service vectorize_service = get_vectorize_service() - # Convert query text to vector (embedding stage) + # Convert query text to vector ONCE (embedding stage) logger.debug(f"Starting to vectorize query text: {query}") embedding_start = time.perf_counter() query_vector = await vectorize_service.get_embedding(query) @@ -521,21 +524,9 @@ async def get_vector_search_results( f"Query text vectorization completed, vector dimension: {len(query_vector_list)}" ) - # Select Milvus repository based on memory type - match mem_type: - case MemoryType.FORESIGHT: - milvus_repo = get_bean_by_type(ForesightMilvusRepository) - case MemoryType.EVENT_LOG: - milvus_repo = get_bean_by_type(EventLogMilvusRepository) - case MemoryType.EPISODIC_MEMORY: - milvus_repo = get_bean_by_type(EpisodicMemoryMilvusRepository) - case _: - raise ValueError(f"Unsupported memory type: {mem_type}") - # Handle time range filter conditions start_time_dt = None end_time_dt = None - current_time_dt = None if start_time is not None: start_time_dt = ( @@ -553,42 +544,63 @@ async def get_vector_search_results( else: end_time_dt = end_time - # Handle foresight time range (only valid for foresight) - if mem_type == MemoryType.FORESIGHT: - if retrieve_mem_request.start_time: - start_time_dt = from_iso_format(retrieve_mem_request.start_time) - if retrieve_mem_request.end_time: - end_time_dt = from_iso_format(retrieve_mem_request.end_time) - if retrieve_mem_request.current_time: - current_time_dt = from_iso_format(retrieve_mem_request.current_time) - - # Call Milvus vector search (pass different parameters based on memory type) + # Iterate over ALL memory types and merge results + all_results = [] milvus_start = time.perf_counter() - if mem_type == MemoryType.FORESIGHT: - # Foresight: supports time range and validity filtering, supports radius parameter - search_results = await milvus_repo.vector_search( - query_vector=query_vector_list, - user_id=user_id, - group_id=group_id, - start_time=start_time_dt, - end_time=end_time_dt, - current_time=current_time_dt, - limit=top_k, - score_threshold=0.0, - radius=retrieve_mem_request.radius, - ) - else: - # Episodic memory and event log: use timestamp filtering, supports radius parameter - search_results = await milvus_repo.vector_search( - query_vector=query_vector_list, - user_id=user_id, - group_id=group_id, - start_time=start_time_dt, - end_time=end_time_dt, - limit=top_k, - score_threshold=0.0, - radius=retrieve_mem_request.radius, - ) + for mem_type in memory_types: + milvus_repo_class = MILVUS_REPO_MAP.get(mem_type) + if not milvus_repo_class: + logger.info( + f"Memory type {mem_type} is not searchable via Milvus, skipping" + ) + continue + + milvus_repo = get_bean_by_type(milvus_repo_class) + + # Call Milvus vector search (pass different parameters based on memory type) + if mem_type == MemoryType.FORESIGHT: + # Handle foresight-specific time range + foresight_start_dt = start_time_dt + foresight_end_dt = end_time_dt + current_time_dt = None + if retrieve_mem_request.start_time: + foresight_start_dt = from_iso_format(retrieve_mem_request.start_time) + if retrieve_mem_request.end_time: + foresight_end_dt = from_iso_format(retrieve_mem_request.end_time) + if retrieve_mem_request.current_time: + current_time_dt = from_iso_format(retrieve_mem_request.current_time) + + # Foresight: supports time range and validity filtering, supports radius parameter + search_results = await milvus_repo.vector_search( + query_vector=query_vector_list, + user_id=user_id, + group_id=group_id, + start_time=foresight_start_dt, + end_time=foresight_end_dt, + current_time=current_time_dt, + limit=top_k, + score_threshold=0.0, + radius=retrieve_mem_request.radius, + ) + else: + # Episodic memory and event log: use timestamp filtering, supports radius parameter + search_results = await milvus_repo.vector_search( + query_vector=query_vector_list, + user_id=user_id, + group_id=group_id, + start_time=start_time_dt, + end_time=end_time_dt, + limit=top_k, + score_threshold=0.0, + radius=retrieve_mem_request.radius, + ) + + for r in search_results: + r['memory_type'] = mem_type.value + r['_search_source'] = RetrieveMethod.VECTOR.value + # Milvus already uses 'score', no need to rename + all_results.extend(search_results) + record_retrieve_stage( retrieve_method=retrieve_method, stage='milvus_search', @@ -596,12 +608,7 @@ async def get_vector_search_results( duration_seconds=time.perf_counter() - milvus_start, ) - for r in search_results: - r['memory_type'] = mem_type.value - r['_search_source'] = RetrieveMethod.VECTOR.value - # Milvus already uses 'score', no need to rename - - return search_results + return all_results except Exception as e: record_retrieve_stage( retrieve_method=retrieve_method, @@ -624,11 +631,7 @@ async def retrieve_mem_hybrid( ) -> RetrieveMemResponse: """Hybrid memory retrieval: keyword + vector + rerank""" start_time = time.perf_counter() - memory_type = ( - retrieve_mem_request.memory_types[0].value - if retrieve_mem_request.memory_types - else 'unknown' - ) + memory_type = _memory_types_label(retrieve_mem_request.memory_types) try: hits = await self._search_hybrid( @@ -699,9 +702,7 @@ async def _search_hybrid( retrieve_method: str = RetrieveMethod.HYBRID.value, ) -> List[Dict]: """Core hybrid search: keyword + vector + rerank, returns flat list""" - memory_type = ( - request.memory_types[0].value if request.memory_types else 'unknown' - ) + memory_type = _memory_types_label(request.memory_types) # Run keyword and vector search concurrently kw_results, vec_results = await asyncio.gather( self.get_keyword_search_results(request, retrieve_method=retrieve_method), @@ -722,9 +723,7 @@ async def _search_rrf( retrieve_method: str = RetrieveMethod.RRF.value, ) -> List[Dict]: """Core RRF search: keyword + vector + RRF fusion, returns flat list""" - memory_type = ( - request.memory_types[0].value if request.memory_types else 'unknown' - ) + memory_type = _memory_types_label(request.memory_types) # Run keyword and vector search concurrently kw, vec = await asyncio.gather( @@ -766,7 +765,7 @@ async def _to_response( """Convert flat hits list to grouped RetrieveMemResponse""" user_id = req.user_id if req else "" source_type = req.retrieve_method.value - memory_type = req.memory_types[0].value + memory_type = _memory_types_label(req.memory_types) if not hits: return RetrieveMemResponse( @@ -808,11 +807,7 @@ async def retrieve_mem_rrf( ) -> RetrieveMemResponse: """RRF-based memory retrieval: keyword + vector + RRF fusion""" start_time = time.perf_counter() - memory_type = ( - retrieve_mem_request.memory_types[0].value - if retrieve_mem_request.memory_types - else 'unknown' - ) + memory_type = _memory_types_label(retrieve_mem_request.memory_types) try: hits = await self._search_rrf( @@ -855,7 +850,7 @@ async def retrieve_mem_agentic( req = retrieve_mem_request # alias top_k = req.top_k config = AgenticConfig() - memory_type = req.memory_types[0].value if req.memory_types else 'unknown' + memory_type = _memory_types_label(req.memory_types) try: llm_provider = LLMProvider( diff --git a/tests/test_memory_manager_search.py b/tests/test_memory_manager_search.py new file mode 100644 index 00000000..74e6494a --- /dev/null +++ b/tests/test_memory_manager_search.py @@ -0,0 +1,483 @@ +"""Unit tests for multi-memory-type search logic in MemoryManager. + +Tests cover: +- get_keyword_search_results iterating over multiple memory types +- get_vector_search_results iterating over multiple memory types +- Unsupported types (e.g. PROFILE) silently skipped +- Single memory type still works +- Empty memory_types list returns empty +- _memory_types_label helper +""" + +import sys +import os +from unittest.mock import AsyncMock, MagicMock, patch +import pytest + +# Ensure src/ is on the path so that imports like `api_specs.*` resolve. +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) + + +class _FakeEmbedding: + """Mimics an np.ndarray with .tolist() — avoids numpy import at test time.""" + + def __init__(self, values): + self._values = values + + def tolist(self): + return list(self._values) + +from api_specs.memory_models import MemoryType, RetrieveMethod +from api_specs.dtos import RetrieveMemRequest + + +# --------------------------------------------------------------------------- +# Helper: build a RetrieveMemRequest without triggering validation side-effects +# --------------------------------------------------------------------------- +def _make_request( + memory_types, + query="test query", + user_id="u1", + group_id="g1", + top_k=10, +): + return RetrieveMemRequest( + user_id=user_id, + group_id=group_id, + memory_types=memory_types, + query=query, + top_k=top_k, + retrieve_method=RetrieveMethod.KEYWORD, + ) + + +# --------------------------------------------------------------------------- +# 7. _memory_types_label +# --------------------------------------------------------------------------- +class TestMemoryTypesLabel: + def test_single_type(self): + from agentic_layer.memory_manager import _memory_types_label + + assert _memory_types_label([MemoryType.EPISODIC_MEMORY]) == "episodic_memory" + + def test_multiple_types(self): + from agentic_layer.memory_manager import _memory_types_label + + result = _memory_types_label( + [MemoryType.EPISODIC_MEMORY, MemoryType.FORESIGHT] + ) + assert result == "episodic_memory,foresight" + + def test_empty_list(self): + from agentic_layer.memory_manager import _memory_types_label + + assert _memory_types_label([]) == "unknown" + + def test_three_types(self): + from agentic_layer.memory_manager import _memory_types_label + + result = _memory_types_label( + [MemoryType.EPISODIC_MEMORY, MemoryType.FORESIGHT, MemoryType.EVENT_LOG] + ) + assert result == "episodic_memory,foresight,event_log" + + +# --------------------------------------------------------------------------- +# Fixtures shared by keyword / vector tests +# --------------------------------------------------------------------------- + +# Patch targets – all within the memory_manager module +_MM = "agentic_layer.memory_manager" + + +def _patch_constructor(): + """Patch MemoryManager.__init__ so it doesn't need real DI beans.""" + return patch(f"{_MM}.MemoryManager.__init__", lambda self: None) + + +@pytest.fixture +def manager(): + """Return a MemoryManager instance with __init__ patched out.""" + with _patch_constructor(): + from agentic_layer.memory_manager import MemoryManager + + return MemoryManager() + + +# --------------------------------------------------------------------------- +# 1 & 2 & 5 & 6. get_keyword_search_results +# --------------------------------------------------------------------------- +class TestGetKeywordSearchResults: + + @pytest.mark.asyncio + async def test_multiple_memory_types_merged(self, manager): + """Multiple memory types → results from all repos are merged.""" + episodic_hits = [{"_id": "e1", "_score": 1.0, "summary": "ep hit"}] + foresight_hits = [{"_id": "f1", "_score": 0.8, "content": "fore hit"}] + + mock_ep_repo = MagicMock() + mock_ep_repo.multi_search = AsyncMock(return_value=episodic_hits) + mock_fore_repo = MagicMock() + mock_fore_repo.multi_search = AsyncMock(return_value=foresight_hits) + + def _bean_router(cls): + from infra_layer.adapters.out.search.repository.episodic_memory_es_repository import ( + EpisodicMemoryEsRepository, + ) + from infra_layer.adapters.out.search.repository.foresight_es_repository import ( + ForesightEsRepository, + ) + + if cls is EpisodicMemoryEsRepository: + return mock_ep_repo + if cls is ForesightEsRepository: + return mock_fore_repo + raise ValueError(f"Unexpected class: {cls}") + + req = _make_request( + memory_types=[MemoryType.EPISODIC_MEMORY, MemoryType.FORESIGHT], + ) + + with ( + patch(f"{_MM}.get_bean_by_type", side_effect=_bean_router), + patch(f"{_MM}.jieba") as mock_jieba, + patch(f"{_MM}.filter_stopwords", return_value=["test", "query"]), + patch(f"{_MM}.record_retrieve_stage"), + patch(f"{_MM}.record_retrieve_request"), + patch(f"{_MM}.record_retrieve_error"), + ): + mock_jieba.cut_for_search.return_value = ["test", "query"] + results = await manager.get_keyword_search_results(req) + + # Both repos were called + mock_ep_repo.multi_search.assert_awaited_once() + mock_fore_repo.multi_search.assert_awaited_once() + + # Results merged + assert len(results) == 2 + # Each hit annotated with its memory_type + types_in_results = {r["memory_type"] for r in results} + assert types_in_results == {"episodic_memory", "foresight"} + + # Verify unified fields + for r in results: + assert "id" in r + assert "score" in r + assert r["_search_source"] == RetrieveMethod.KEYWORD.value + + @pytest.mark.asyncio + async def test_unsupported_type_skipped(self, manager): + """PROFILE is not in ES_REPO_MAP → silently skipped, no error.""" + episodic_hits = [{"_id": "e1", "_score": 1.0}] + mock_ep_repo = MagicMock() + mock_ep_repo.multi_search = AsyncMock(return_value=episodic_hits) + + def _bean_router(cls): + from infra_layer.adapters.out.search.repository.episodic_memory_es_repository import ( + EpisodicMemoryEsRepository, + ) + + if cls is EpisodicMemoryEsRepository: + return mock_ep_repo + raise ValueError(f"Unexpected class: {cls}") + + req = _make_request( + memory_types=[MemoryType.PROFILE, MemoryType.EPISODIC_MEMORY], + ) + + with ( + patch(f"{_MM}.get_bean_by_type", side_effect=_bean_router), + patch(f"{_MM}.jieba") as mock_jieba, + patch(f"{_MM}.filter_stopwords", return_value=["test"]), + patch(f"{_MM}.record_retrieve_stage"), + patch(f"{_MM}.record_retrieve_request"), + patch(f"{_MM}.record_retrieve_error"), + ): + mock_jieba.cut_for_search.return_value = ["test"] + results = await manager.get_keyword_search_results(req) + + # Only episodic results, PROFILE silently skipped + assert len(results) == 1 + assert results[0]["memory_type"] == "episodic_memory" + + @pytest.mark.asyncio + async def test_single_memory_type(self, manager): + """Single memory type still works exactly like before.""" + hits = [{"_id": "e1", "_score": 2.5}] + mock_repo = MagicMock() + mock_repo.multi_search = AsyncMock(return_value=hits) + + req = _make_request(memory_types=[MemoryType.EVENT_LOG]) + + with ( + patch(f"{_MM}.get_bean_by_type", return_value=mock_repo), + patch(f"{_MM}.jieba") as mock_jieba, + patch(f"{_MM}.filter_stopwords", return_value=["test"]), + patch(f"{_MM}.record_retrieve_stage"), + patch(f"{_MM}.record_retrieve_request"), + patch(f"{_MM}.record_retrieve_error"), + ): + mock_jieba.cut_for_search.return_value = ["test"] + results = await manager.get_keyword_search_results(req) + + assert len(results) == 1 + assert results[0]["memory_type"] == "event_log" + mock_repo.multi_search.assert_awaited_once() + + @pytest.mark.asyncio + async def test_empty_memory_types(self, manager): + """Empty memory_types list → no iteration, empty results.""" + req = _make_request(memory_types=[]) + + with ( + patch(f"{_MM}.jieba") as mock_jieba, + patch(f"{_MM}.filter_stopwords", return_value=["test"]), + patch(f"{_MM}.record_retrieve_stage"), + patch(f"{_MM}.record_retrieve_request"), + patch(f"{_MM}.record_retrieve_error"), + ): + mock_jieba.cut_for_search.return_value = ["test"] + results = await manager.get_keyword_search_results(req) + + assert results == [] + + @pytest.mark.asyncio + async def test_repo_returns_empty(self, manager): + """Repo returning [] for a type → no items added, no crash.""" + mock_repo = MagicMock() + mock_repo.multi_search = AsyncMock(return_value=[]) + + req = _make_request(memory_types=[MemoryType.EPISODIC_MEMORY]) + + with ( + patch(f"{_MM}.get_bean_by_type", return_value=mock_repo), + patch(f"{_MM}.jieba") as mock_jieba, + patch(f"{_MM}.filter_stopwords", return_value=[]), + patch(f"{_MM}.record_retrieve_stage"), + patch(f"{_MM}.record_retrieve_request"), + patch(f"{_MM}.record_retrieve_error"), + ): + mock_jieba.cut_for_search.return_value = [] + results = await manager.get_keyword_search_results(req) + + assert results == [] + + +# --------------------------------------------------------------------------- +# 3 & 4. get_vector_search_results +# --------------------------------------------------------------------------- +class TestGetVectorSearchResults: + + @pytest.mark.asyncio + async def test_multiple_memory_types_merged(self, manager): + """Multiple memory types → vector results from all repos merged.""" + ep_hits = [{"id": "e1", "score": 0.95}] + fore_hits = [{"id": "f1", "score": 0.88}] + + mock_ep_repo = MagicMock() + mock_ep_repo.vector_search = AsyncMock(return_value=ep_hits) + mock_fore_repo = MagicMock() + mock_fore_repo.vector_search = AsyncMock(return_value=fore_hits) + + def _bean_router(cls): + from infra_layer.adapters.out.search.repository.episodic_memory_milvus_repository import ( + EpisodicMemoryMilvusRepository, + ) + from infra_layer.adapters.out.search.repository.foresight_milvus_repository import ( + ForesightMilvusRepository, + ) + + if cls is EpisodicMemoryMilvusRepository: + return mock_ep_repo + if cls is ForesightMilvusRepository: + return mock_fore_repo + raise ValueError(f"Unexpected class: {cls}") + + mock_vectorize = MagicMock() + mock_vectorize.get_embedding = AsyncMock( + return_value=_FakeEmbedding([0.1, 0.2, 0.3]) + ) + + req = _make_request( + memory_types=[MemoryType.EPISODIC_MEMORY, MemoryType.FORESIGHT], + ) + + with ( + patch(f"{_MM}.get_bean_by_type", side_effect=_bean_router), + patch(f"{_MM}.get_vectorize_service", return_value=mock_vectorize), + patch(f"{_MM}.from_iso_format"), + patch(f"{_MM}.record_retrieve_stage"), + patch(f"{_MM}.record_retrieve_request"), + patch(f"{_MM}.record_retrieve_error"), + ): + results = await manager.get_vector_search_results(req) + + mock_ep_repo.vector_search.assert_awaited_once() + mock_fore_repo.vector_search.assert_awaited_once() + + assert len(results) == 2 + types_in_results = {r["memory_type"] for r in results} + assert types_in_results == {"episodic_memory", "foresight"} + for r in results: + assert r["_search_source"] == RetrieveMethod.VECTOR.value + + @pytest.mark.asyncio + async def test_unsupported_type_skipped(self, manager): + """PROFILE is not in MILVUS_REPO_MAP → silently skipped.""" + ep_hits = [{"id": "e1", "score": 0.9}] + mock_ep_repo = MagicMock() + mock_ep_repo.vector_search = AsyncMock(return_value=ep_hits) + + def _bean_router(cls): + from infra_layer.adapters.out.search.repository.episodic_memory_milvus_repository import ( + EpisodicMemoryMilvusRepository, + ) + + if cls is EpisodicMemoryMilvusRepository: + return mock_ep_repo + raise ValueError(f"Unexpected class: {cls}") + + mock_vectorize = MagicMock() + mock_vectorize.get_embedding = AsyncMock( + return_value=_FakeEmbedding([0.1, 0.2, 0.3]) + ) + + req = _make_request( + memory_types=[MemoryType.PROFILE, MemoryType.EPISODIC_MEMORY], + ) + + with ( + patch(f"{_MM}.get_bean_by_type", side_effect=_bean_router), + patch(f"{_MM}.get_vectorize_service", return_value=mock_vectorize), + patch(f"{_MM}.from_iso_format"), + patch(f"{_MM}.record_retrieve_stage"), + patch(f"{_MM}.record_retrieve_request"), + patch(f"{_MM}.record_retrieve_error"), + ): + results = await manager.get_vector_search_results(req) + + assert len(results) == 1 + assert results[0]["memory_type"] == "episodic_memory" + + @pytest.mark.asyncio + async def test_single_memory_type(self, manager): + """Single memory type in vector search still works.""" + hits = [{"id": "el1", "score": 0.7}] + mock_repo = MagicMock() + mock_repo.vector_search = AsyncMock(return_value=hits) + + mock_vectorize = MagicMock() + mock_vectorize.get_embedding = AsyncMock( + return_value=_FakeEmbedding([0.5, 0.6]) + ) + + req = _make_request(memory_types=[MemoryType.EVENT_LOG]) + + with ( + patch(f"{_MM}.get_bean_by_type", return_value=mock_repo), + patch(f"{_MM}.get_vectorize_service", return_value=mock_vectorize), + patch(f"{_MM}.from_iso_format"), + patch(f"{_MM}.record_retrieve_stage"), + patch(f"{_MM}.record_retrieve_request"), + patch(f"{_MM}.record_retrieve_error"), + ): + results = await manager.get_vector_search_results(req) + + assert len(results) == 1 + assert results[0]["memory_type"] == "event_log" + mock_repo.vector_search.assert_awaited_once() + + @pytest.mark.asyncio + async def test_empty_memory_types(self, manager): + """Empty memory_types list → no Milvus calls, empty results.""" + mock_vectorize = MagicMock() + mock_vectorize.get_embedding = AsyncMock( + return_value=_FakeEmbedding([0.1]) + ) + + req = _make_request(memory_types=[]) + + with ( + patch(f"{_MM}.get_vectorize_service", return_value=mock_vectorize), + patch(f"{_MM}.from_iso_format"), + patch(f"{_MM}.record_retrieve_stage"), + patch(f"{_MM}.record_retrieve_request"), + patch(f"{_MM}.record_retrieve_error"), + ): + results = await manager.get_vector_search_results(req) + + assert results == [] + + @pytest.mark.asyncio + async def test_embedding_called_once(self, manager): + """Even with multiple memory types, embedding is computed only once.""" + mock_repo = MagicMock() + mock_repo.vector_search = AsyncMock(return_value=[]) + + mock_vectorize = MagicMock() + mock_vectorize.get_embedding = AsyncMock( + return_value=_FakeEmbedding([0.1, 0.2]) + ) + + req = _make_request( + memory_types=[ + MemoryType.EPISODIC_MEMORY, + MemoryType.FORESIGHT, + MemoryType.EVENT_LOG, + ], + ) + + with ( + patch(f"{_MM}.get_bean_by_type", return_value=mock_repo), + patch(f"{_MM}.get_vectorize_service", return_value=mock_vectorize), + patch(f"{_MM}.from_iso_format"), + patch(f"{_MM}.record_retrieve_stage"), + patch(f"{_MM}.record_retrieve_request"), + patch(f"{_MM}.record_retrieve_error"), + ): + await manager.get_vector_search_results(req) + + # Embedding should only be computed once regardless of memory type count + mock_vectorize.get_embedding.assert_awaited_once_with("test query") + + @pytest.mark.asyncio + async def test_foresight_uses_special_params(self, manager): + """Foresight vector search passes current_time and time range params.""" + fore_hits = [{"id": "f1", "score": 0.9}] + mock_fore_repo = MagicMock() + mock_fore_repo.vector_search = AsyncMock(return_value=fore_hits) + + mock_vectorize = MagicMock() + mock_vectorize.get_embedding = AsyncMock( + return_value=_FakeEmbedding([0.1]) + ) + + req = RetrieveMemRequest( + user_id="u1", + group_id="g1", + memory_types=[MemoryType.FORESIGHT], + query="meeting tomorrow", + top_k=5, + retrieve_method=RetrieveMethod.VECTOR, + current_time="2025-06-01T10:00:00", + start_time="2025-06-01T00:00:00", + end_time="2025-06-30T23:59:59", + ) + + mock_dt = MagicMock() + + with ( + patch(f"{_MM}.get_bean_by_type", return_value=mock_fore_repo), + patch(f"{_MM}.get_vectorize_service", return_value=mock_vectorize), + patch(f"{_MM}.from_iso_format", return_value=mock_dt), + patch(f"{_MM}.record_retrieve_stage"), + patch(f"{_MM}.record_retrieve_request"), + patch(f"{_MM}.record_retrieve_error"), + ): + results = await manager.get_vector_search_results(req) + + # Foresight path passes current_time kwarg + call_kwargs = mock_fore_repo.vector_search.call_args.kwargs + assert "current_time" in call_kwargs + assert len(results) == 1 + assert results[0]["memory_type"] == "foresight"