Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions demo/utils/simple_memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,10 @@ async def search(
query: Query text
top_k: Number of results to return (default: 3)
mode: Retrieval mode (default: "rrf")
- "rrf": RRF fusion (recommended)
- "rrf": Keyword + Vector + RRF fusion (recommended)
- "keyword": Keyword retrieval (BM25)
- "vector": Vector retrieval
- "hybrid": Keyword + Vector + Rerank
- "rrf": Keyword + Vector + RRF fusion
- "agentic": LLM-guided multi-round retrieval
show_details: Whether to show detailed information (default: True)

Expand Down
288 changes: 151 additions & 137 deletions src/agentic_layer/memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,12 +337,11 @@ async def get_keyword_search_results(
retrieve_mem_request: 'RetrieveMemRequest',
retrieve_method: str = RetrieveMethod.KEYWORD.value,
) -> List[Dict[str, Any]]:
"""Keyword search with stage-level metrics"""
"""Keyword search with stage-level metrics - supports multiple memory_types"""
stage_start = time.perf_counter()
memory_types = retrieve_mem_request.memory_types
memory_type = (
retrieve_mem_request.memory_types[0].value
if retrieve_mem_request.memory_types
else 'unknown'
memory_types[0].value if memory_types else 'unknown'
)

try:
Expand All @@ -356,7 +355,6 @@ async def get_keyword_search_results(
group_id = retrieve_mem_request.group_id
start_time = retrieve_mem_request.start_time
end_time = retrieve_mem_request.end_time
memory_types = retrieve_mem_request.memory_types

# Convert query string to search word list
# Use jieba for search mode word segmentation, then filter stopwords
Expand All @@ -375,32 +373,41 @@ async def get_keyword_search_results(
if end_time is not None:
date_range["lte"] = end_time

mem_type = memory_types[0]

repo_class = ES_REPO_MAP.get(mem_type)
if not repo_class:
logger.warning(f"Unsupported memory_type: {mem_type}")
return []

es_repo = get_bean_by_type(repo_class)
logger.debug(f"Using {repo_class.__name__} for {mem_type}")

results = await es_repo.multi_search(
query=query_words,
user_id=user_id,
group_id=group_id,
size=top_k,
from_=0,
date_range=date_range,
)
# Iterate over all requested memory_types and collect results
all_results = []
seen_ids = set()

for mem_type in memory_types:
# Skip unsupported memory types (e.g., profile which is stored in MongoDB)
repo_class = ES_REPO_MAP.get(mem_type)
if not repo_class:
logger.info(f"Skipping unsupported memory_type for keyword search: {mem_type}")
continue

es_repo = get_bean_by_type(repo_class)
logger.debug(f"Using {repo_class.__name__} for {mem_type}")

results = await es_repo.multi_search(
query=query_words,
user_id=user_id,
group_id=group_id,
size=top_k,
from_=0,
date_range=date_range,
)

# Mark memory_type, search_source, and unified score
if results:
for r in results:
r['memory_type'] = mem_type.value
r['_search_source'] = RetrieveMethod.KEYWORD.value
r['id'] = r.get('_id', '') # Unify ES '_id' to 'id'
r['score'] = r.get('_score', 0.0) # Unified score field
# Mark memory_type, search_source, and unified score
# Deduplicate by id
if results:
for r in results:
result_id = r.get('_id', '')
if result_id not in seen_ids:
seen_ids.add(result_id)
r['memory_type'] = mem_type.value
r['_search_source'] = RetrieveMethod.KEYWORD.value
r['id'] = result_id # Unify ES '_id' to 'id'
r['score'] = r.get('_score', 0.0) # Unified score field
all_results.append(r)

# Record stage metrics
record_retrieve_stage(
Expand All @@ -410,7 +417,7 @@ async def get_keyword_search_results(
duration_seconds=time.perf_counter() - stage_start,
)

return results or []
return all_results
except Exception as e:
record_retrieve_stage(
retrieve_method=retrieve_method,
Expand Down Expand Up @@ -472,142 +479,149 @@ async def get_vector_search_results(
retrieve_mem_request: 'RetrieveMemRequest',
retrieve_method: str = RetrieveMethod.VECTOR.value,
) -> List[Dict[str, Any]]:
"""Vector search with stage-level metrics (embedding + milvus_search)"""
"""Vector search with stage-level metrics - supports multiple memory_types"""
memory_types = retrieve_mem_request.memory_types
memory_type = (
retrieve_mem_request.memory_types[0].value
if retrieve_mem_request.memory_types
else 'unknown'
memory_types[0].value if memory_types else 'unknown'
)

# Get vectorization service (shared across all memory types)
vectorize_service = get_vectorize_service()

# Convert query text to vector (embedding stage) - shared across all memory types
logger.debug(f"Starting to vectorize query text: {retrieve_mem_request.query}")
embedding_start = time.perf_counter()
query_vector = await vectorize_service.get_embedding(retrieve_mem_request.query)
query_vector_list = query_vector.tolist() # Convert to list format
record_retrieve_stage(
retrieve_method=retrieve_method,
stage='embedding',
memory_type=memory_type,
duration_seconds=time.perf_counter() - embedding_start,
)
logger.debug(
f"Query text vectorization completed, vector dimension: {len(query_vector_list)}"
)

# Iterate over all requested memory_types and collect results
all_results = []
seen_ids = set()

try:
# Get parameters from Request
logger.debug(
f"get_vector_search_results called with retrieve_mem_request: {retrieve_mem_request}"
)
# Get common parameters from Request
if not retrieve_mem_request:
raise ValueError(
"retrieve_mem_request is required for get_vector_search_results"
)
query = retrieve_mem_request.query
if not query:
if not retrieve_mem_request.query:
raise ValueError("query is required for retrieve_mem_vector")

user_id = retrieve_mem_request.user_id
group_id = retrieve_mem_request.group_id
top_k = retrieve_mem_request.top_k
start_time = retrieve_mem_request.start_time
end_time = retrieve_mem_request.end_time
mem_type = retrieve_mem_request.memory_types[0]

logger.debug(
f"retrieve_mem_vector called with query: {query}, user_id: {user_id}, group_id: {group_id}, top_k: {top_k}"
)

# Get vectorization service
vectorize_service = get_vectorize_service()

# Convert query text to vector (embedding stage)
logger.debug(f"Starting to vectorize query text: {query}")
embedding_start = time.perf_counter()
query_vector = await vectorize_service.get_embedding(query)
query_vector_list = query_vector.tolist() # Convert to list format
record_retrieve_stage(
retrieve_method=retrieve_method,
stage='embedding',
memory_type=memory_type,
duration_seconds=time.perf_counter() - embedding_start,
)
logger.debug(
f"Query text vectorization completed, vector dimension: {len(query_vector_list)}"
f"retrieve_mem_vector called with query: {retrieve_mem_request.query}, user_id: {user_id}, group_id: {group_id}, top_k: {top_k}"
)

# Select Milvus repository based on memory type
match mem_type:
case MemoryType.FORESIGHT:
milvus_repo = get_bean_by_type(ForesightMilvusRepository)
case MemoryType.EVENT_LOG:
milvus_repo = get_bean_by_type(EventLogMilvusRepository)
case MemoryType.EPISODIC_MEMORY:
milvus_repo = get_bean_by_type(EpisodicMemoryMilvusRepository)
case _:
raise ValueError(f"Unsupported memory type: {mem_type}")
for mem_type in memory_types:
# Skip unsupported memory types (e.g., profile which is stored in MongoDB)
# Select Milvus repository based on memory type
match mem_type:
case MemoryType.FORESIGHT:
milvus_repo = get_bean_by_type(ForesightMilvusRepository)
case MemoryType.EVENT_LOG:
milvus_repo = get_bean_by_type(EventLogMilvusRepository)
case MemoryType.EPISODIC_MEMORY:
milvus_repo = get_bean_by_type(EpisodicMemoryMilvusRepository)
case _:
logger.info(f"Skipping unsupported memory_type for vector search: {mem_type}")
continue

# Handle time range filter conditions
start_time_dt = None
end_time_dt = None
current_time_dt = None
# Handle time range filter conditions
start_time_dt = None
end_time_dt = None
current_time_dt = None

if start_time is not None:
start_time_dt = (
from_iso_format(start_time)
if isinstance(start_time, str)
else start_time
)
if start_time is not None:
start_time_dt = (
from_iso_format(start_time)
if isinstance(start_time, str)
else start_time
)

if end_time is not None:
if isinstance(end_time, str):
end_time_dt = from_iso_format(end_time)
# If date only format, set to end of day
if len(end_time) == 10:
end_time_dt = end_time_dt.replace(hour=23, minute=59, second=59)
if end_time is not None:
if isinstance(end_time, str):
end_time_dt = from_iso_format(end_time)
# If date only format, set to end of day
if len(end_time) == 10:
end_time_dt = end_time_dt.replace(hour=23, minute=59, second=59)
else:
end_time_dt = end_time

# Handle foresight time range (only valid for foresight)
if mem_type == MemoryType.FORESIGHT:
if retrieve_mem_request.start_time:
start_time_dt = from_iso_format(retrieve_mem_request.start_time)
if retrieve_mem_request.end_time:
end_time_dt = from_iso_format(retrieve_mem_request.end_time)
if retrieve_mem_request.current_time:
current_time_dt = from_iso_format(retrieve_mem_request.current_time)

# Call Milvus vector search (pass different parameters based on memory type)
milvus_start = time.perf_counter()
if mem_type == MemoryType.FORESIGHT:
# Foresight: supports time range and validity filtering, supports radius parameter
search_results = await milvus_repo.vector_search(
query_vector=query_vector_list,
user_id=user_id,
group_id=group_id,
start_time=start_time_dt,
end_time=end_time_dt,
current_time=current_time_dt,
limit=top_k,
score_threshold=0.0,
radius=retrieve_mem_request.radius,
)
else:
end_time_dt = end_time

# Handle foresight time range (only valid for foresight)
if mem_type == MemoryType.FORESIGHT:
if retrieve_mem_request.start_time:
start_time_dt = from_iso_format(retrieve_mem_request.start_time)
if retrieve_mem_request.end_time:
end_time_dt = from_iso_format(retrieve_mem_request.end_time)
if retrieve_mem_request.current_time:
current_time_dt = from_iso_format(retrieve_mem_request.current_time)

# Call Milvus vector search (pass different parameters based on memory type)
milvus_start = time.perf_counter()
if mem_type == MemoryType.FORESIGHT:
# Foresight: supports time range and validity filtering, supports radius parameter
search_results = await milvus_repo.vector_search(
query_vector=query_vector_list,
user_id=user_id,
group_id=group_id,
start_time=start_time_dt,
end_time=end_time_dt,
current_time=current_time_dt,
limit=top_k,
score_threshold=0.0,
radius=retrieve_mem_request.radius,
)
else:
# Episodic memory and event log: use timestamp filtering, supports radius parameter
search_results = await milvus_repo.vector_search(
query_vector=query_vector_list,
user_id=user_id,
group_id=group_id,
start_time=start_time_dt,
end_time=end_time_dt,
limit=top_k,
score_threshold=0.0,
radius=retrieve_mem_request.radius,
# Episodic memory and event log: use timestamp filtering, supports radius parameter
search_results = await milvus_repo.vector_search(
query_vector=query_vector_list,
user_id=user_id,
group_id=group_id,
start_time=start_time_dt,
end_time=end_time_dt,
limit=top_k,
score_threshold=0.0,
radius=retrieve_mem_request.radius,
)
record_retrieve_stage(
retrieve_method=retrieve_method,
stage='milvus_search',
memory_type=mem_type.value,
duration_seconds=time.perf_counter() - milvus_start,
)
record_retrieve_stage(
retrieve_method=retrieve_method,
stage='milvus_search',
memory_type=memory_type,
duration_seconds=time.perf_counter() - milvus_start,
)

for r in search_results:
r['memory_type'] = mem_type.value
r['_search_source'] = RetrieveMethod.VECTOR.value
# Milvus already uses 'score', no need to rename

return search_results
# Deduplicate by id
if search_results:
for r in search_results:
result_id = r.get('id', '')
if result_id not in seen_ids:
seen_ids.add(result_id)
r['memory_type'] = mem_type.value
r['_search_source'] = RetrieveMethod.VECTOR.value
# Milvus already uses 'score', no need to rename
all_results.append(r)

return all_results
except Exception as e:
record_retrieve_stage(
retrieve_method=retrieve_method,
stage=RetrieveMethod.VECTOR.value,
memory_type=memory_type,
duration_seconds=time.perf_counter() - milvus_start,
duration_seconds=time.perf_counter() - embedding_start,
)
record_retrieve_error(
retrieve_method=retrieve_method,
Expand Down
2 changes: 1 addition & 1 deletion src/biz_layer/mem_db_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _convert_timestamp_to_time(
try:
dt = from_iso_format(timestamp)
return to_iso_format(dt)
except:
except Exception:
# If parsing fails, return string directly
return timestamp
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ async def create_and_save_episodic_memory(
metadata_json = metadata
try:
metadata_dict = json.loads(metadata)
except:
except Exception:
metadata_dict = {}

# Prepare entity data
Expand Down
Loading