diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 196de087..bda11008 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -6,6 +6,17 @@ All notable changes to EverMemOS will be documented in this file. --- +## [Unreleased] - 2026-03-02 + +### Added +- ⚡ **两阶段记忆提取(Two-Phase Memory Extraction)**: 新增 `REALTIME_EVENT_LOG_ONLY` 模式,实时阶段只提取 event_log(原子事实),episodic memory 由后台 `BatchEpisodeWorker` 定时合并多轮上下文后生成,在不丢失数据的前提下保留叙事记忆质量。详见 [使用指南](advanced/TWO_PHASE_EXTRACTION.md)。 +- 🔧 **BatchEpisodeWorker**: 后台定时任务,按 group 聚合 MemCell,批量生成高质量 episodic memory,支持环境变量调优(`BATCH_EPISODE_INTERVAL` / `BATCH_EPISODE_LOOKBACK` / `BATCH_EPISODE_MIN_CELLS`)。 + +### Fixed +- 🐛 **客户端初始化解耦**: 修复 `inject_memories=false` 时 `auto_memorize` 静默失效的问题,现在只要 `enabled=true` 就初始化 EverMemOS 客户端,`inject_memories` 仅控制检索注入步骤。 + +--- + ## [1.2.0] - 2025-01-20 ### Changed diff --git a/docs/advanced/TWO_PHASE_EXTRACTION.md b/docs/advanced/TWO_PHASE_EXTRACTION.md new file mode 100644 index 00000000..9d932fb7 --- /dev/null +++ b/docs/advanced/TWO_PHASE_EXTRACTION.md @@ -0,0 +1,124 @@ +# 两阶段记忆提取(Two-Phase Memory Extraction) + +[Home](../../README.md) > [Docs](../README.md) > [Advanced](README.md) > Two-Phase Extraction + +--- + +## 背景:问题与动机 + +EverMemOS 默认通过 LLM **边界检测**决定何时将对话片段切分为 MemCell,再触发 episodic memory 提取。这种方式能保证 episode LLM 看到完整的多轮上下文,生成的叙事质量高。 + +但存在一个痛点:**如果对话中途结束,边界从未触发,积压在 Redis 队列里的消息会在 60 分钟 TTL 过期后永久丢失**,一条记忆都不会生成。 + +常见的应对方式是设置 `DISABLE_BOUNDARY_DETECTION=true`,让每条消息立即生成 MemCell——但这样 episode LLM 每次只看到 1–2 条孤立消息,叙事质量明显下降。 + +**两阶段提取**解决了这一张力: + +``` +实时阶段(每条消息后,毫秒级) + MemCell 立即生成 ──→ 只提取 event_log(原子事实) + 原子事实不依赖多轮上下文,单条消息已足够 + +批处理阶段(后台 Worker,默认 15 分钟) + BatchEpisodeWorker 扫描同一 group 的多个 MemCell + └─ 合并为完整多轮对话上下文 + └─ 调用 episode LLM ──→ 生成高质量叙事记忆 +``` + +--- + +## 快速上手 + +在 EverMemOS 的 `.env` 中添加以下两行(**缺一不可**): + +```bash +# 1. 每条消息立即生成 MemCell,不等 LLM 判断边界 +DISABLE_BOUNDARY_DETECTION=true + +# 2. 实时阶段只提取 event_log,episode 交给批处理 Worker +REALTIME_EVENT_LOG_ONLY=true +``` + +重启 EverMemOS,启动日志出现以下内容即表示生效: + +``` +✅ BatchEpisodeWorker started +``` + +--- + +## 参数配置 + +| 环境变量 | 默认值 | 说明 | +|---|---|---| +| `DISABLE_BOUNDARY_DETECTION` | `false` | 设为 `true` 跳过 LLM 边界检测,每条消息立即生成 MemCell | +| `REALTIME_EVENT_LOG_ONLY` | `false` | 设为 `true` 开启两阶段模式 | +| `BATCH_EPISODE_INTERVAL` | `15` | Worker 运行间隔(分钟) | +| `BATCH_EPISODE_LOOKBACK` | `120` | 每次向前扫描多少分钟内的 MemCell | +| `BATCH_EPISODE_MIN_CELLS` | `3` | 一个 group 至少积累几个 MemCell 才触发 episode 生成 | + +### 调优建议 + +| 场景 | 建议 | +|---|---| +| 对话节奏慢(每天只聊几条) | `BATCH_EPISODE_MIN_CELLS=2` | +| 希望 episode 尽快出现 | `BATCH_EPISODE_INTERVAL=5` | +| 限制 Worker 扫描量 | `BATCH_EPISODE_LOOKBACK=60` | +| 使用上下文窗口较小的 LLM | 同时降低 `BATCH_EPISODE_MIN_CELLS` 和 `BATCH_EPISODE_LOOKBACK`,减少合并消息数 | + +--- + +## 工作流程详解 + +### 实时阶段 + +``` +用户发送消息 + └─ CountBot (或其他客户端) POST /api/v1/memories + └─ MemorizeRequest → preprocess_conv_request + └─ ConvMemCellExtractor + DISABLE_BOUNDARY_DETECTION=true + → should_end=True(立即) + → 生成 MemCell(含本条消息 original_data) + └─ process_memory_extraction + REALTIME_EVENT_LOG_ONLY=true + → 仅调用 EventLogExtractor + → event_log 以 parent_type="memcell" 写入数据库 + → 跳过 EpisodeMemoryExtractor / ForesightExtractor + → 更新会话状态 +``` + +### 批处理阶段 + +``` +BatchEpisodeWorker.run_once()(每 BATCH_EPISODE_INTERVAL 分钟) + └─ 查询最近 BATCH_EPISODE_LOOKBACK 分钟内所有 MemCell + └─ 过滤 extend.batch_ep_done=true(已处理) + └─ 按 group_id 分组 + 每组 ≥ BATCH_EPISODE_MIN_CELLS + └─ 合并所有 original_data(按时间升序) + └─ 构建合并 MemCell(完整上下文) + └─ EpisodeMemoryExtractor → group episode + 每个用户 episode + └─ 保存 episode 到数据库(MongoDB + ES + Milvus) + └─ 标记每个 MemCell: extend.batch_ep_done=true +``` + +--- + +## 与默认模式的对比 + +| | 默认模式 | `DISABLE_BOUNDARY_DETECTION` | 两阶段模式(本功能) | +|---|---|---|---| +| 数据丢失风险 | ⚠️ 对话中断后 60 分钟丢失 | ✅ 无 | ✅ 无 | +| event_log 质量 | ✅ 高 | ✅ 高 | ✅ 高 | +| episode 质量 | ✅ 高 | ❌ 低(上下文不足) | ✅ 高 | +| episode 生成延迟 | 实时 | 实时 | ≤ BATCH_EPISODE_INTERVAL 分钟 | +| LLM 调用次数 | 边界检测 + 提取 | 仅提取 | 仅提取(分离时机) | + +--- + +## 已知限制 + +- **Clustering 未在批处理阶段触发**:BatchEpisodeWorker 保存 episode 后不会调用 `_trigger_clustering`。Clustering 主要影响**用户画像(Profile)提取**,若你的场景高度依赖 Profile,建议暂时仍使用默认模式或手动触发 clustering。 + +- **`BATCH_EPISODE_MIN_CELLS` 的影响**:若用户在 lookback 窗口内的 MemCell 数量始终低于该阈值(如新用户只发了 1 条消息),该周期内不会生成 episode。可通过降低 `BATCH_EPISODE_MIN_CELLS=1` 解决(但会降低 episode 上下文丰富度)。 diff --git a/env.template b/env.template index 72ca8855..1a3e10ed 100755 --- a/env.template +++ b/env.template @@ -165,4 +165,4 @@ API_BASE_URL=http://localhost:1995 LOG_LEVEL=INFO ENV=dev PYTHONASYNCIODEBUG=1 -MEMORY_LANGUAGE=en +MEMORY_LANGUAGE=zh diff --git a/src/biz_layer/mem_batch_episode_worker.py b/src/biz_layer/mem_batch_episode_worker.py new file mode 100644 index 00000000..6f1a1f99 --- /dev/null +++ b/src/biz_layer/mem_batch_episode_worker.py @@ -0,0 +1,262 @@ +""" +Batch Episode Worker +==================== +Periodically generates high-quality episodic memories from accumulated MemCells. + +Designed to work alongside ``REALTIME_EVENT_LOG_ONLY=true``. In that mode, +real-time message processing extracts only event_logs immediately (no LLM +context needed for atomic facts). This worker wakes up on a schedule, merges +several consecutive MemCells from the same group into one rich context window, +and then asks the episode LLM to produce a narrative – preserving the full +multi-turn quality that a single isolated message cannot provide. + +Env vars (all optional): + REALTIME_EVENT_LOG_ONLY=true – turns on realtime mode (worker is needed) + BATCH_EPISODE_INTERVAL=15 – worker run interval in minutes (default 15) + BATCH_EPISODE_LOOKBACK=120 – lookback window in minutes (default 120) + BATCH_EPISODE_MIN_CELLS=3 – min pending MemCells per group to trigger + episode generation (default 3) +""" + +from __future__ import annotations + +import asyncio +import os +from collections import defaultdict +from datetime import datetime, timedelta, timezone +from typing import Dict, List, Optional + +from api_specs.memory_types import EpisodeMemory, MemCell, MemoryType, RawDataType +from common_utils.datetime_utils import get_now_with_timezone +from core.di import get_bean_by_type +from core.observation.logger import get_logger +from infra_layer.adapters.out.persistence.repository.memcell_raw_repository import ( + MemCellRawRepository, +) +from memory_layer.memory_manager import MemoryManager + +logger = get_logger(__name__) + +_ROBOT_KEYWORDS = ("robot", "assistant") + + +class BatchEpisodeWorker: + """Generates episodic memories from buffered MemCells in a background loop.""" + + def __init__( + self, + interval_minutes: int = 15, + lookback_minutes: int = 120, + min_cells: int = 3, + ) -> None: + self.interval_minutes = int( + os.getenv("BATCH_EPISODE_INTERVAL", str(interval_minutes)) + ) + self.lookback_minutes = int( + os.getenv("BATCH_EPISODE_LOOKBACK", str(lookback_minutes)) + ) + self.min_cells = int( + os.getenv("BATCH_EPISODE_MIN_CELLS", str(min_cells)) + ) + self._running = False + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + async def run_once(self) -> int: + """Scan recent MemCells and generate episodes for every pending group. + + Returns: + Total number of episode documents saved across all groups. + """ + now = get_now_with_timezone() + start_time = now - timedelta(minutes=self.lookback_minutes) + + try: + repo = get_bean_by_type(MemCellRawRepository) + recent_cells: List[MemCell] = await repo.find_by_time_range( + start_time=start_time, end_time=now + ) + except Exception as exc: + logger.error("[BatchEpisode] Failed to query MemCells: %s", exc) + return 0 + + # Group unprocessed cells by group_id + by_group: Dict[str, List[MemCell]] = defaultdict(list) + for mc in recent_cells: + if (mc.extend or {}).get("batch_ep_done"): + continue + gid = mc.group_id or "__ungrouped__" + by_group[gid].append(mc) + + if not by_group: + logger.debug("[BatchEpisode] No pending MemCells found in lookback window") + return 0 + + total = 0 + for group_id, cells in by_group.items(): + if len(cells) < self.min_cells: + logger.debug( + "[BatchEpisode] Group %s: %d cells < min=%d, skipping", + group_id, + len(cells), + self.min_cells, + ) + continue + try: + count = await self._generate_episodes_for_group(group_id, cells) + total += count + except Exception as exc: + logger.error( + "[BatchEpisode] Group %s episode generation failed: %s", + group_id, + exc, + ) + + logger.info( + "[BatchEpisode] Run complete – %d episodes saved across %d groups", + total, + len(by_group), + ) + return total + + async def start_periodic(self) -> None: + """Blocking periodic loop – run as an asyncio background Task.""" + self._running = True + logger.info( + "[BatchEpisode] Worker started (interval=%dm, lookback=%dm, min_cells=%d)", + self.interval_minutes, + self.lookback_minutes, + self.min_cells, + ) + while self._running: + try: + await self.run_once() + except Exception as exc: + logger.error("[BatchEpisode] Unexpected error in periodic loop: %s", exc) + await asyncio.sleep(self.interval_minutes * 60) + + def stop(self) -> None: + self._running = False + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + async def _generate_episodes_for_group( + self, group_id: str, memcells: List[MemCell] + ) -> int: + """Merge *memcells* for *group_id* and generate episodic memories. + + Returns number of episode documents saved. + """ + # Sort ascending so the narrative is chronological + # Use a timezone-aware sentinel so comparison with aware timestamps works + _EPOCH = datetime.min.replace(tzinfo=timezone.utc) + cells = sorted( + memcells, key=lambda mc: mc.timestamp or _EPOCH + ) + + # Merge raw conversation messages + merged_messages: list = [] + all_user_ids: list = [] + all_participants: list = [] + for mc in cells: + merged_messages.extend(mc.original_data or []) + all_user_ids.extend(mc.user_id_list or []) + all_participants.extend(mc.participants or []) + + # Deduplicate while preserving order + all_user_ids = list(dict.fromkeys(all_user_ids)) + all_participants = list(dict.fromkeys(all_participants)) + + if not merged_messages: + return 0 + + real_group_id = group_id if group_id != "__ungrouped__" else None + + # Build a synthetic MemCell representing the merged context window + merged_cell = MemCell( + original_data=merged_messages, + user_id_list=all_user_ids, + timestamp=cells[-1].timestamp or get_now_with_timezone(), + group_id=real_group_id, + participants=all_participants, + type=cells[0].type or RawDataType.CONVERSATION, + summary="", + ) + + # --- Extract group episode ------------------------------------------ + memory_manager = MemoryManager() + try: + group_episode = await asyncio.wait_for( + memory_manager.extract_memory( + memcell=merged_cell, + memory_type=MemoryType.EPISODIC_MEMORY, + user_id=None, # None → group episode + group_id=real_group_id, + ), + timeout=120, + ) + except asyncio.TimeoutError: + logger.warning("[BatchEpisode] Group %s: episode extraction timed out", group_id) + return 0 + except Exception as exc: + logger.error("[BatchEpisode] Group %s: extraction error: %s", group_id, exc) + return 0 + + if group_episode is None or isinstance(group_episode, Exception): + logger.debug("[BatchEpisode] Group %s: no episode extracted", group_id) + return 0 + + # Ensure metadata fields are populated + if not getattr(group_episode, "group_id", None): + group_episode.group_id = real_group_id + if not getattr(group_episode, "user_name", None): + group_episode.user_name = getattr(group_episode, "user_id", None) + + # --- Save episodes -------------------------------------------------- + # Defer imports to avoid circular dependencies + from biz_layer.mem_db_operations import _convert_episode_memory_to_doc + from biz_layer.mem_memorize import MemoryDocPayload, save_memory_docs + from dataclasses import replace as dc_replace + + current_time = get_now_with_timezone() + + episodes_to_save: List[EpisodeMemory] = [group_episode] + + # Clone group episode for each real user (mirrors _clone_episodes_for_users) + for uid in all_participants: + if any(kw in uid.lower() for kw in _ROBOT_KEYWORDS): + continue + episodes_to_save.append( + dc_replace(group_episode, user_id=uid, user_name=uid) + ) + + docs = [_convert_episode_memory_to_doc(ep, current_time) for ep in episodes_to_save] + payloads = [MemoryDocPayload(MemoryType.EPISODIC_MEMORY, doc) for doc in docs] + saved_map = await save_memory_docs(payloads) + count = len(saved_map.get(MemoryType.EPISODIC_MEMORY, [])) + + # --- Mark MemCells as batch-processed -------------------------------- + repo = get_bean_by_type(MemCellRawRepository) + for mc in cells: + try: + extend = dict(mc.extend or {}) + extend["batch_ep_done"] = True + await repo.update_by_event_id(str(mc.event_id), {"extend": extend}) + except Exception as exc: + logger.warning( + "[BatchEpisode] Failed to mark MemCell %s as done: %s", + mc.event_id, + exc, + ) + + logger.info( + "[BatchEpisode] Group %s: merged %d MemCells → %d episode docs saved", + group_id, + len(cells), + count, + ) + return count diff --git a/src/biz_layer/mem_memorize.py b/src/biz_layer/mem_memorize.py index e951a13c..881c73fc 100644 --- a/src/biz_layer/mem_memorize.py +++ b/src/biz_layer/mem_memorize.py @@ -523,11 +523,80 @@ async def _timed_extract_event_logs(): return result if state.is_assistant_scene: + # Episode extraction modifies state in-place. Run it concurrently with + # Foresight/EventLog, but cap each with a per-task timeout so a slow or + # hanging LLM call never blocks episode persistence. + _EXTRACTION_TIMEOUT = 120 # seconds per extraction task + + # ★ REALTIME_EVENT_LOG_ONLY mode: only extract event_log right now. + # Episodic memories are generated later by BatchEpisodeWorker which + # merges multiple MemCells into one rich context before calling the + # episode LLM, preserving full multi-turn quality. + if os.getenv("REALTIME_EVENT_LOG_ONLY", "false").lower() == "true": + try: + event_logs = await asyncio.wait_for( + _timed_extract_event_logs(), timeout=_EXTRACTION_TIMEOUT + ) + event_logs = event_logs or [] + except (asyncio.TimeoutError, Exception) as e: + logger.warning("[Realtime] EventLog extraction failed: %s", e) + event_logs = [] + record_extraction_stage( + space_id=space_id, + raw_data_type=raw_data_type, + stage='extract_parallel', + duration_seconds=time.perf_counter() - extract_start, + ) + if event_logs: + record_memory_extracted( + space_id=space_id, + raw_data_type=raw_data_type, + memory_type='event_log', + count=len(event_logs), + ) + await _update_memcell_and_cluster(state) + if if_memorize(memcell): + rt_count = await _save_event_logs_without_episode(state, event_logs) + await update_status_after_memcell( + state.request, state.memcell, state.current_time, + state.request.raw_data_type, + ) + return rt_count + return 0 + + async def _safe_extract_foresights(): + try: + return await asyncio.wait_for(_timed_extract_foresights(), timeout=_EXTRACTION_TIMEOUT) + except asyncio.TimeoutError: + logger.warning( + "[Extraction] Foresight extraction timed out after %ss, episode will still be saved", + _EXTRACTION_TIMEOUT, + ) + return [] + except Exception as e: + logger.error("[Extraction] Foresight extraction failed: %s", e) + return [] + + async def _safe_extract_event_logs(): + try: + return await asyncio.wait_for(_timed_extract_event_logs(), timeout=_EXTRACTION_TIMEOUT) + except asyncio.TimeoutError: + logger.warning( + "[Extraction] EventLog extraction timed out after %ss, episode will still be saved", + _EXTRACTION_TIMEOUT, + ) + return [] + except Exception as e: + logger.error("[Extraction] EventLog extraction failed: %s", e) + return [] + _, foresight_memories, event_logs = await asyncio.gather( _timed_extract_episodes(), - _timed_extract_foresights(), - _timed_extract_event_logs(), + _safe_extract_foresights(), + _safe_extract_event_logs(), ) + foresight_memories = foresight_memories or [] + event_logs = event_logs or [] else: await _timed_extract_episodes() record_extraction_stage( @@ -797,6 +866,68 @@ async def _extract_event_logs( return [result] +async def _save_event_logs_without_episode( + state: ExtractionState, + event_logs: List[EventLog], +) -> int: + """Save event logs against the MemCell directly (no episode parent). + + Used in REALTIME_EVENT_LOG_ONLY mode where episode extraction is deferred to + BatchEpisodeWorker. Each EventLogRecord will have: + parent_type = "memcell" + parent_id = + """ + if not event_logs: + return 0 + + class _MemCellParent: + """Minimal duck-type shim to satisfy _convert_event_log_to_docs.""" + timestamp = state.memcell.timestamp + participants = state.memcell.participants or [] + type = ( + state.memcell.type.value + if state.memcell.type + else RawDataType.CONVERSATION.value + ) + + pseudo_parent = _MemCellParent() + + # Override parent pointers to MemCell + for el in event_logs: + el.parent_type = "memcell" + el.parent_id = str(state.memcell.event_id) + + base_docs = [] + for el in event_logs: + base_docs.extend(_convert_event_log_to_docs(el, pseudo_parent, state.current_time)) + + if not base_docs: + return 0 + + all_docs = list(base_docs) + # assistant scene: copy to each non-robot user + if state.is_assistant_scene: + user_ids = [ + u for u in state.participants + if "robot" not in u.lower() and "assistant" not in u.lower() + ] + all_docs.extend([ + doc.model_copy(update={"user_id": uid, "user_name": uid}) + for doc in base_docs + for uid in user_ids + ]) + + payloads = [MemoryDocPayload(MemoryType.EVENT_LOG, doc) for doc in all_docs] + saved_map = await save_memory_docs(payloads) + count = len(saved_map.get(MemoryType.EVENT_LOG, [])) + logger.info( + "[Realtime] Saved %d event_log facts (parent=memcell %s)", + count, + state.memcell.event_id, + ) + return count + + def _clone_episodes_for_users(state: ExtractionState) -> List[EpisodeMemory]: """Copy group Episode to each user""" from dataclasses import replace diff --git a/src/core/lifespan/batch_episode_lifespan.py b/src/core/lifespan/batch_episode_lifespan.py new file mode 100644 index 00000000..1e8d52d3 --- /dev/null +++ b/src/core/lifespan/batch_episode_lifespan.py @@ -0,0 +1,70 @@ +""" +Lifespan provider for BatchEpisodeWorker +========================================= +When ``REALTIME_EVENT_LOG_ONLY=true`` is set, this provider starts the +:class:`~biz_layer.mem_batch_episode_worker.BatchEpisodeWorker` as a background +asyncio Task during application startup and cancels it cleanly on shutdown. + +The provider is registered at order 200 so it starts *after* all database / +search-engine lifespans (typically order < 100) have completed their startup. +""" + +from __future__ import annotations + +import asyncio +import os +from typing import Any, Optional + +from fastapi import FastAPI + +from core.di.decorators import component +from core.lifespan.lifespan_interface import LifespanProvider +from core.observation.logger import get_logger + +logger = get_logger(__name__) + + +@component(name="batch_episode_lifespan_provider") +class BatchEpisodeLifespanProvider(LifespanProvider): + """Manages the lifecycle of the BatchEpisodeWorker background task.""" + + def __init__(self, order: int = 200) -> None: + super().__init__("batch_episode_worker", order) + self._worker = None + self._task: Optional[asyncio.Task] = None + + async def startup(self, app: FastAPI) -> Any: + """Start BatchEpisodeWorker if REALTIME_EVENT_LOG_ONLY=true.""" + if os.getenv("REALTIME_EVENT_LOG_ONLY", "false").lower() != "true": + logger.info( + "[BatchEpisode] REALTIME_EVENT_LOG_ONLY is not set – worker skipped" + ) + return None + + from biz_layer.mem_batch_episode_worker import BatchEpisodeWorker + + self._worker = BatchEpisodeWorker() + self._task = asyncio.create_task(self._worker.start_periodic()) + + # Expose the task on app.state for external inspection / testing + app.state.batch_episode_task = self._task + + logger.info("✅ BatchEpisodeWorker started") + return self._task + + async def shutdown(self, app: FastAPI) -> None: + """Stop the background task gracefully.""" + if self._worker is not None: + self._worker.stop() + + if self._task is not None and not self._task.done(): + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + + if hasattr(app.state, "batch_episode_task"): + delattr(app.state, "batch_episode_task") + + logger.info("[BatchEpisode] Worker stopped") diff --git a/src/service/memcell_delete_service.py b/src/service/memcell_delete_service.py index d60ba548..3bd923ec 100644 --- a/src/service/memcell_delete_service.py +++ b/src/service/memcell_delete_service.py @@ -182,6 +182,93 @@ async def delete_by_group_id( ) raise + async def _delete_milvus_vectors( + self, + user_id: Optional[str], + group_id: Optional[str], + event_id: Optional[str], + filters_used: list, + ) -> int: + """同步清理 Milvus 向量索引,防止已删除记忆仍被召回。 + + Args: + user_id: 用户 ID(MAGIC_ALL 或 None 表示不用此维度过滤) + group_id: 群组 ID(MAGIC_ALL 或 None 表示不用此维度过滤) + event_id: 单条记忆 ID(MAGIC_ALL 或 None 表示不按此过滤) + filters_used: 已使用的过滤条件列表(用于日志) + + Returns: + 删除的向量条数之和 + """ + import asyncio + from core.oxm.constants import MAGIC_ALL + from infra_layer.adapters.out.search.repository.episodic_memory_milvus_repository import ( + EpisodicMemoryMilvusRepository, + ) + from infra_layer.adapters.out.search.repository.foresight_milvus_repository import ( + ForesightMilvusRepository, + ) + from infra_layer.adapters.out.search.repository.event_log_milvus_repository import ( + EventLogMilvusRepository, + ) + from core.di.context import get_bean_by_type + + episodic_repo = get_bean_by_type(EpisodicMemoryMilvusRepository) + foresight_repo = get_bean_by_type(ForesightMilvusRepository) + event_log_repo = get_bean_by_type(EventLogMilvusRepository) + + effective_user_id = user_id if (user_id and user_id != MAGIC_ALL) else None + effective_group_id = group_id if (group_id and group_id != MAGIC_ALL) else None + effective_event_id = event_id if (event_id and event_id != MAGIC_ALL) else None + + tasks = [] + + # 按 user/group 批量删除向量 + if effective_user_id or effective_group_id: + tasks.append( + episodic_repo.delete_by_filters( + user_id=effective_user_id, group_id=effective_group_id + ) + ) + tasks.append( + foresight_repo.delete_by_filters( + user_id=effective_user_id, group_id=effective_group_id + ) + ) + tasks.append( + event_log_repo.delete_by_filters( + user_id=effective_user_id, group_id=effective_group_id + ) + ) + + # 按 event_id 精确删除单条向量 + if effective_event_id: + tasks.append(episodic_repo.delete_by_event_id(effective_event_id)) + tasks.append(foresight_repo.delete_by_parent_id(effective_event_id)) + tasks.append(event_log_repo.delete_by_parent_id(effective_event_id)) + + if not tasks: + return 0 + + results = await asyncio.gather(*tasks, return_exceptions=True) + total = 0 + for r in results: + if isinstance(r, Exception): + logger.warning( + "Milvus vector deletion failed (non-fatal): filters=%s, error=%s", + filters_used, + r, + ) + elif isinstance(r, int): + total += r + + logger.info( + "Milvus vector cleanup completed: filters=%s, deleted_vectors=%d", + filters_used, + total, + ) + return total + async def delete_by_combined_criteria( self, event_id: Optional[str] = None, @@ -214,6 +301,21 @@ async def delete_by_combined_criteria( """ from core.oxm.constants import MAGIC_ALL from infra_layer.adapters.out.persistence.document.memory.memcell import MemCell + from infra_layer.adapters.out.persistence.document.memory.episodic_memory import ( + EpisodicMemory, + ) + from infra_layer.adapters.out.persistence.document.memory.foresight_record import ( + ForesightRecord, + ) + from infra_layer.adapters.out.persistence.document.memory.event_log_record import ( + EventLogRecord, + ) + from infra_layer.adapters.out.persistence.document.memory.user_profile import ( + UserProfile, + ) + from infra_layer.adapters.out.persistence.document.memory.global_user_profile import ( + GlobalUserProfile, + ) # Build filter conditions filter_dict = {} @@ -253,21 +355,103 @@ async def delete_by_combined_criteria( } logger.info( - "Deleting MemCells with combined criteria: filters=%s", filters_used + "Deleting memories with combined criteria: filters=%s", filters_used ) + def _count_deleted(result) -> int: + if not result: + return 0 + if hasattr(result, "deleted_count"): + return int(result.deleted_count or 0) + if hasattr(result, "modified_count"): + return int(result.modified_count or 0) + return 0 + try: - # Use delete_many to batch soft delete - result = await MemCell.delete_many(filter_dict) - count = result.modified_count if result else 0 + total_count = 0 + + # 1) 删除 MemCell + memcell_result = await MemCell.delete_many(filter_dict) + memcell_count = _count_deleted(memcell_result) + total_count += memcell_count + + # 2) 删除 EpisodicMemory(按 user/group,或 event_id 关联) + ep_filter = {} + if user_id and user_id != MAGIC_ALL: + ep_filter["user_id"] = user_id + if group_id and group_id != MAGIC_ALL: + ep_filter["group_id"] = group_id + if event_id and event_id != MAGIC_ALL: + ep_filter["event_id"] = event_id + if ep_filter: + ep_result = await EpisodicMemory.find(ep_filter).delete() + total_count += _count_deleted(ep_result) + + # 3) 删除 ForesightRecord(按 user/group,或 parent_id 关联) + foresight_filter = {} + if user_id and user_id != MAGIC_ALL: + foresight_filter["user_id"] = user_id + if group_id and group_id != MAGIC_ALL: + foresight_filter["group_id"] = group_id + if event_id and event_id != MAGIC_ALL: + foresight_filter["parent_id"] = event_id + if foresight_filter: + foresight_result = await ForesightRecord.find(foresight_filter).delete() + total_count += _count_deleted(foresight_result) + + # 4) 删除 EventLogRecord(按 user/group,或 parent_id 关联) + event_log_filter = {} + if user_id and user_id != MAGIC_ALL: + event_log_filter["user_id"] = user_id + if group_id and group_id != MAGIC_ALL: + event_log_filter["group_id"] = group_id + if event_id and event_id != MAGIC_ALL: + event_log_filter["parent_id"] = event_id + if event_log_filter: + event_log_result = await EventLogRecord.find(event_log_filter).delete() + total_count += _count_deleted(event_log_result) + + # 5) 删除 UserProfile(按 user/group) + profile_filter = {} + if user_id and user_id != MAGIC_ALL: + profile_filter["user_id"] = user_id + if group_id and group_id != MAGIC_ALL: + profile_filter["group_id"] = group_id + if profile_filter: + profile_result = await UserProfile.find(profile_filter).delete() + total_count += _count_deleted(profile_result) + + # 6) 删除 GlobalUserProfile(仅按 user) + if user_id and user_id != MAGIC_ALL: + global_profile_result = await GlobalUserProfile.find( + {"user_id": user_id} + ).delete() + total_count += _count_deleted(global_profile_result) logger.info( - "Successfully deleted MemCells: filters=%s, count=%d", + "Successfully deleted memories: filters=%s, total_count=%d (memcell=%d)", filters_used, - count, + total_count, + memcell_count, ) - return {"filters": filters_used, "count": count, "success": count > 0} + # 7) 同步清理 Milvus 向量索引,防止已删除记忆仍被向量召回 + try: + await self._delete_milvus_vectors( + user_id=user_id, + group_id=group_id, + event_id=event_id, + filters_used=filters_used, + ) + except Exception as milvus_err: + # Milvus 清理失败不影响主流程返回成功 + logger.warning( + "Milvus vector cleanup encountered an error (non-fatal): filters=%s, error=%s", + filters_used, + milvus_err, + ) + + return {"filters": filters_used, "count": total_count, "success": True} except Exception as e: logger.error(