Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 112 additions & 117 deletions src/agentic_layer/memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,20 @@
MemoryType.EPISODIC_MEMORY: EpisodicMemoryEsRepository,
}

# MemoryType -> Milvus Repository mapping
MILVUS_REPO_MAP = {
MemoryType.FORESIGHT: ForesightMilvusRepository,
MemoryType.EVENT_LOG: EventLogMilvusRepository,
MemoryType.EPISODIC_MEMORY: EpisodicMemoryMilvusRepository,
}


def _memory_types_label(memory_types: List) -> str:
"""Return a comma-joined string of memory type values for metrics/logging."""
if not memory_types:
return 'unknown'
return ','.join(mt.value for mt in memory_types)


@dataclass
class EventLogCandidate:
Expand Down Expand Up @@ -298,11 +312,7 @@ async def retrieve_mem_keyword(
) -> RetrieveMemResponse:
"""Keyword-based memory retrieval"""
start_time = time.perf_counter()
memory_type = (
retrieve_mem_request.memory_types[0].value
if retrieve_mem_request.memory_types
else 'unknown'
)
memory_type = _memory_types_label(retrieve_mem_request.memory_types)

try:
hits = await self.get_keyword_search_results(
Expand Down Expand Up @@ -339,11 +349,7 @@ async def get_keyword_search_results(
) -> List[Dict[str, Any]]:
"""Keyword search with stage-level metrics"""
stage_start = time.perf_counter()
memory_type = (
retrieve_mem_request.memory_types[0].value
if retrieve_mem_request.memory_types
else 'unknown'
)
memory_type = _memory_types_label(retrieve_mem_request.memory_types)

try:
# Get parameters from Request
Expand Down Expand Up @@ -375,32 +381,36 @@ async def get_keyword_search_results(
if end_time is not None:
date_range["lte"] = end_time

mem_type = memory_types[0]

repo_class = ES_REPO_MAP.get(mem_type)
if not repo_class:
logger.warning(f"Unsupported memory_type: {mem_type}")
return []
# Iterate over ALL memory types and merge results
all_results = []
for mem_type in memory_types:
repo_class = ES_REPO_MAP.get(mem_type)
if not repo_class:
logger.info(
f"Memory type {mem_type} is not searchable via ES, skipping"
)
continue

es_repo = get_bean_by_type(repo_class)
logger.debug(f"Using {repo_class.__name__} for {mem_type}")
es_repo = get_bean_by_type(repo_class)
logger.debug(f"Using {repo_class.__name__} for {mem_type}")

results = await es_repo.multi_search(
query=query_words,
user_id=user_id,
group_id=group_id,
size=top_k,
from_=0,
date_range=date_range,
)
results = await es_repo.multi_search(
query=query_words,
user_id=user_id,
group_id=group_id,
size=top_k,
from_=0,
date_range=date_range,
)

# Mark memory_type, search_source, and unified score
if results:
for r in results:
r['memory_type'] = mem_type.value
r['_search_source'] = RetrieveMethod.KEYWORD.value
r['id'] = r.get('_id', '') # Unify ES '_id' to 'id'
r['score'] = r.get('_score', 0.0) # Unified score field
# Mark memory_type, search_source, and unified score
if results:
for r in results:
r['memory_type'] = mem_type.value
r['_search_source'] = RetrieveMethod.KEYWORD.value
r['id'] = r.get('_id', '') # Unify ES '_id' to 'id'
r['score'] = r.get('_score', 0.0) # Unified score field
all_results.extend(results)

# Record stage metrics
record_retrieve_stage(
Expand All @@ -410,7 +420,7 @@ async def get_keyword_search_results(
duration_seconds=time.perf_counter() - stage_start,
)

return results or []
return all_results
except Exception as e:
record_retrieve_stage(
retrieve_method=retrieve_method,
Expand All @@ -433,11 +443,7 @@ async def retrieve_mem_vector(
) -> RetrieveMemResponse:
"""Vector-based memory retrieval"""
start_time = time.perf_counter()
memory_type = (
retrieve_mem_request.memory_types[0].value
if retrieve_mem_request.memory_types
else 'unknown'
)
memory_type = _memory_types_label(retrieve_mem_request.memory_types)

try:
hits = await self.get_vector_search_results(
Expand Down Expand Up @@ -473,11 +479,8 @@ async def get_vector_search_results(
retrieve_method: str = RetrieveMethod.VECTOR.value,
) -> List[Dict[str, Any]]:
"""Vector search with stage-level metrics (embedding + milvus_search)"""
memory_type = (
retrieve_mem_request.memory_types[0].value
if retrieve_mem_request.memory_types
else 'unknown'
)
memory_type = _memory_types_label(retrieve_mem_request.memory_types)
milvus_start = time.perf_counter()

try:
# Get parameters from Request
Expand All @@ -497,7 +500,7 @@ async def get_vector_search_results(
top_k = retrieve_mem_request.top_k
start_time = retrieve_mem_request.start_time
end_time = retrieve_mem_request.end_time
mem_type = retrieve_mem_request.memory_types[0]
memory_types = retrieve_mem_request.memory_types

logger.debug(
f"retrieve_mem_vector called with query: {query}, user_id: {user_id}, group_id: {group_id}, top_k: {top_k}"
Expand All @@ -506,7 +509,7 @@ async def get_vector_search_results(
# Get vectorization service
vectorize_service = get_vectorize_service()

# Convert query text to vector (embedding stage)
# Convert query text to vector ONCE (embedding stage)
logger.debug(f"Starting to vectorize query text: {query}")
embedding_start = time.perf_counter()
query_vector = await vectorize_service.get_embedding(query)
Expand All @@ -521,21 +524,9 @@ async def get_vector_search_results(
f"Query text vectorization completed, vector dimension: {len(query_vector_list)}"
)

# Select Milvus repository based on memory type
match mem_type:
case MemoryType.FORESIGHT:
milvus_repo = get_bean_by_type(ForesightMilvusRepository)
case MemoryType.EVENT_LOG:
milvus_repo = get_bean_by_type(EventLogMilvusRepository)
case MemoryType.EPISODIC_MEMORY:
milvus_repo = get_bean_by_type(EpisodicMemoryMilvusRepository)
case _:
raise ValueError(f"Unsupported memory type: {mem_type}")

# Handle time range filter conditions
start_time_dt = None
end_time_dt = None
current_time_dt = None

if start_time is not None:
start_time_dt = (
Expand All @@ -553,55 +544,71 @@ async def get_vector_search_results(
else:
end_time_dt = end_time

# Handle foresight time range (only valid for foresight)
if mem_type == MemoryType.FORESIGHT:
if retrieve_mem_request.start_time:
start_time_dt = from_iso_format(retrieve_mem_request.start_time)
if retrieve_mem_request.end_time:
end_time_dt = from_iso_format(retrieve_mem_request.end_time)
if retrieve_mem_request.current_time:
current_time_dt = from_iso_format(retrieve_mem_request.current_time)

# Call Milvus vector search (pass different parameters based on memory type)
# Iterate over ALL memory types and merge results
all_results = []
milvus_start = time.perf_counter()
if mem_type == MemoryType.FORESIGHT:
# Foresight: supports time range and validity filtering, supports radius parameter
search_results = await milvus_repo.vector_search(
query_vector=query_vector_list,
user_id=user_id,
group_id=group_id,
start_time=start_time_dt,
end_time=end_time_dt,
current_time=current_time_dt,
limit=top_k,
score_threshold=0.0,
radius=retrieve_mem_request.radius,
)
else:
# Episodic memory and event log: use timestamp filtering, supports radius parameter
search_results = await milvus_repo.vector_search(
query_vector=query_vector_list,
user_id=user_id,
group_id=group_id,
start_time=start_time_dt,
end_time=end_time_dt,
limit=top_k,
score_threshold=0.0,
radius=retrieve_mem_request.radius,
)
for mem_type in memory_types:
milvus_repo_class = MILVUS_REPO_MAP.get(mem_type)
if not milvus_repo_class:
logger.info(
f"Memory type {mem_type} is not searchable via Milvus, skipping"
)
continue

milvus_repo = get_bean_by_type(milvus_repo_class)

# Call Milvus vector search (pass different parameters based on memory type)
if mem_type == MemoryType.FORESIGHT:
# Handle foresight-specific time range
foresight_start_dt = start_time_dt
foresight_end_dt = end_time_dt
current_time_dt = None
if retrieve_mem_request.start_time:
foresight_start_dt = from_iso_format(retrieve_mem_request.start_time)
if retrieve_mem_request.end_time:
foresight_end_dt = from_iso_format(retrieve_mem_request.end_time)
if retrieve_mem_request.current_time:
current_time_dt = from_iso_format(retrieve_mem_request.current_time)

# Foresight: supports time range and validity filtering, supports radius parameter
search_results = await milvus_repo.vector_search(
query_vector=query_vector_list,
user_id=user_id,
group_id=group_id,
start_time=foresight_start_dt,
end_time=foresight_end_dt,
current_time=current_time_dt,
limit=top_k,
score_threshold=0.0,
radius=retrieve_mem_request.radius,
)
else:
# Episodic memory and event log: use timestamp filtering, supports radius parameter
search_results = await milvus_repo.vector_search(
query_vector=query_vector_list,
user_id=user_id,
group_id=group_id,
start_time=start_time_dt,
end_time=end_time_dt,
limit=top_k,
score_threshold=0.0,
radius=retrieve_mem_request.radius,
)

for r in search_results:
r['memory_type'] = mem_type.value
r['_search_source'] = RetrieveMethod.VECTOR.value
# Milvus already uses 'score', no need to rename
all_results.extend(search_results)

record_retrieve_stage(
retrieve_method=retrieve_method,
stage='milvus_search',
memory_type=memory_type,
duration_seconds=time.perf_counter() - milvus_start,
)

for r in search_results:
r['memory_type'] = mem_type.value
r['_search_source'] = RetrieveMethod.VECTOR.value
# Milvus already uses 'score', no need to rename

return search_results
return all_results
except Exception as e:
record_retrieve_stage(
retrieve_method=retrieve_method,
Expand All @@ -624,11 +631,7 @@ async def retrieve_mem_hybrid(
) -> RetrieveMemResponse:
"""Hybrid memory retrieval: keyword + vector + rerank"""
start_time = time.perf_counter()
memory_type = (
retrieve_mem_request.memory_types[0].value
if retrieve_mem_request.memory_types
else 'unknown'
)
memory_type = _memory_types_label(retrieve_mem_request.memory_types)

try:
hits = await self._search_hybrid(
Expand Down Expand Up @@ -699,9 +702,7 @@ async def _search_hybrid(
retrieve_method: str = RetrieveMethod.HYBRID.value,
) -> List[Dict]:
"""Core hybrid search: keyword + vector + rerank, returns flat list"""
memory_type = (
request.memory_types[0].value if request.memory_types else 'unknown'
)
memory_type = _memory_types_label(request.memory_types)
# Run keyword and vector search concurrently
kw_results, vec_results = await asyncio.gather(
self.get_keyword_search_results(request, retrieve_method=retrieve_method),
Expand All @@ -722,9 +723,7 @@ async def _search_rrf(
retrieve_method: str = RetrieveMethod.RRF.value,
) -> List[Dict]:
"""Core RRF search: keyword + vector + RRF fusion, returns flat list"""
memory_type = (
request.memory_types[0].value if request.memory_types else 'unknown'
)
memory_type = _memory_types_label(request.memory_types)

# Run keyword and vector search concurrently
kw, vec = await asyncio.gather(
Expand Down Expand Up @@ -766,7 +765,7 @@ async def _to_response(
"""Convert flat hits list to grouped RetrieveMemResponse"""
user_id = req.user_id if req else ""
source_type = req.retrieve_method.value
memory_type = req.memory_types[0].value
memory_type = _memory_types_label(req.memory_types)

if not hits:
return RetrieveMemResponse(
Expand Down Expand Up @@ -808,11 +807,7 @@ async def retrieve_mem_rrf(
) -> RetrieveMemResponse:
"""RRF-based memory retrieval: keyword + vector + RRF fusion"""
start_time = time.perf_counter()
memory_type = (
retrieve_mem_request.memory_types[0].value
if retrieve_mem_request.memory_types
else 'unknown'
)
memory_type = _memory_types_label(retrieve_mem_request.memory_types)

try:
hits = await self._search_rrf(
Expand Down Expand Up @@ -855,7 +850,7 @@ async def retrieve_mem_agentic(
req = retrieve_mem_request # alias
top_k = req.top_k
config = AgenticConfig()
memory_type = req.memory_types[0].value if req.memory_types else 'unknown'
memory_type = _memory_types_label(req.memory_types)

try:
llm_provider = LLMProvider(
Expand Down
Loading