From ccfd55d75caae0c06aac1ceb7f07f524ad9cc24e Mon Sep 17 00:00:00 2001 From: r266-tech Date: Mon, 2 Mar 2026 15:51:49 +0800 Subject: [PATCH] fix: add deduplication for episodic/event_log write and foresight expiry cleanup Closes #95 ## Changes ### 1. Delete-before-insert dedup in save_memory_docs() - For episodic_memory: before inserting, delete existing records with the same parent_id from MongoDB, Elasticsearch, and Milvus - For event_log: same delete-before-insert by parent_id across all stores - Dedup is best-effort: failures are logged as warnings but do not block insert ### 2. Foresight expiry cleanup - New cleanup_expired_foresights() function that removes ForesightRecords where end_time < today from all three stores (MongoDB, ES, Milvus) - Can be called periodically (e.g., via cron/scheduler) to keep storage lean ### 3. New delete_by_parent_id on EpisodicMemoryRawRepository - Added missing method to delete episodic memories by parent_id (EventLogRecordRawRepository already had this method) ### 4. Tests - tests/test_write_pipeline_dedup.py covers dedup and cleanup with mocked repos --- src/biz_layer/mem_memorize.py | 139 ++++++++ .../episodic_memory_raw_repository.py | 36 ++ tests/test_write_pipeline_dedup.py | 322 ++++++++++++++++++ 3 files changed, 497 insertions(+) create mode 100644 tests/test_write_pipeline_dedup.py diff --git a/src/biz_layer/mem_memorize.py b/src/biz_layer/mem_memorize.py index e951a13c..a77711fd 100644 --- a/src/biz_layer/mem_memorize.py +++ b/src/biz_layer/mem_memorize.py @@ -79,6 +79,15 @@ from infra_layer.adapters.out.search.repository.episodic_memory_es_repository import ( EpisodicMemoryEsRepository, ) +from infra_layer.adapters.out.search.repository.event_log_milvus_repository import ( + EventLogMilvusRepository, +) +from infra_layer.adapters.out.search.repository.foresight_es_repository import ( + ForesightEsRepository, +) +from infra_layer.adapters.out.search.repository.foresight_milvus_repository import ( + ForesightMilvusRepository, +) from biz_layer.mem_sync import MemorySyncService from core.context.context import get_current_app_info @@ -1119,6 +1128,28 @@ async def save_memory_docs( episodic_repo = get_bean_by_type(EpisodicMemoryRawRepository) episodic_es_repo = get_bean_by_type(EpisodicMemoryEsRepository) episodic_milvus_repo = get_bean_by_type(EpisodicMemoryMilvusRepository) + + # Dedup: delete existing records with the same parent_id before insert + parent_ids_seen: set = set() + for doc in episodic_docs: + pid = getattr(doc, "parent_id", None) + if pid and pid not in parent_ids_seen: + parent_ids_seen.add(pid) + try: + await episodic_repo.delete_by_parent_id(pid) + await episodic_es_repo.delete_by_filters( + filters={"parent_id": pid} + ) + await episodic_milvus_repo.delete_by_filters( + filters={"parent_id": pid} + ) + except Exception as e: + logger.warning( + "[Dedup] Failed to delete old episodic records for parent_id=%s: %s", + pid, + e, + ) + saved_episodic: List[Any] = [] for doc in episodic_docs: @@ -1158,6 +1189,24 @@ async def save_memory_docs( event_log_docs = grouped_docs.get(MemoryType.EVENT_LOG, []) if event_log_docs: event_log_repo = get_bean_by_type(EventLogRecordRawRepository) + event_log_milvus_repo = get_bean_by_type(EventLogMilvusRepository) + + # Dedup: delete existing event_log records with the same parent_id + el_parent_ids_seen: set = set() + for doc in event_log_docs: + pid = getattr(doc, "parent_id", None) + if pid and pid not in el_parent_ids_seen: + el_parent_ids_seen.add(pid) + try: + await event_log_repo.delete_by_parent_id(pid) + await event_log_milvus_repo.delete_by_parent_id(pid) + except Exception as e: + logger.warning( + "[Dedup] Failed to delete old event_log records for parent_id=%s: %s", + pid, + e, + ) + saved_event_logs = await event_log_repo.create_batch(event_log_docs) saved_result[MemoryType.EVENT_LOG] = saved_event_logs @@ -1410,3 +1459,93 @@ async def memorize(request: MemorizeRequest) -> int: logger.error(f"[mem_memorize] ❌ Memory extraction failed: {e}") traceback.print_exc() return 0 + + +async def cleanup_expired_foresights() -> int: + """ + Remove expired foresight records from all stores (MongoDB, Elasticsearch, Milvus). + + ForesightRecord has a validity window defined by ``start_time`` / ``end_time`` + (date strings in YYYY-MM-DD format). Once ``end_time`` is in the past the + record is no longer useful, but the current pipeline never deletes it. This + helper performs the housekeeping. + + Returns: + Total number of expired foresight records deleted from MongoDB. + """ + from infra_layer.adapters.out.persistence.document.memory.foresight_record import ( + ForesightRecord, + ) + + today_str = to_date_str(get_now_with_timezone()) + logger.info( + "[ForesightCleanup] Starting expired foresight cleanup (today=%s)", today_str + ) + + # 1. Query expired records from MongoDB + try: + expired_records: List[Any] = await ForesightRecord.find( + { + "end_time": {"$lt": today_str, "$ne": None}, + } + ).to_list() + except Exception as e: + logger.error("[ForesightCleanup] Failed to query expired foresights: %s", e) + return 0 + + if not expired_records: + logger.info("[ForesightCleanup] No expired foresight records found") + return 0 + + logger.info( + "[ForesightCleanup] Found %d expired foresight records to delete", + len(expired_records), + ) + + # 2. Delete from search stores (best-effort) + foresight_es_repo = get_bean_by_type(ForesightEsRepository) + foresight_milvus_repo = get_bean_by_type(ForesightMilvusRepository) + + for record in expired_records: + record_id = str(record.id) if record.id else None + if not record_id: + continue + try: + await foresight_milvus_repo.delete_by_id(record_id) + except Exception as e: + logger.warning( + "[ForesightCleanup] Failed to delete from Milvus id=%s: %s", + record_id, + e, + ) + try: + await foresight_es_repo.delete_by_filters(filters={"_id": record_id}) + except Exception as e: + logger.warning( + "[ForesightCleanup] Failed to delete from ES id=%s: %s", + record_id, + e, + ) + + # 3. Delete from MongoDB + deleted_count = 0 + try: + foresight_repo = get_bean_by_type(ForesightRecordRawRepository) + for record in expired_records: + record_id = str(record.id) if record.id else None + if record_id: + result = await foresight_repo.delete_by_id(record_id) + if result: + deleted_count += 1 + except Exception as e: + logger.error( + "[ForesightCleanup] Failed to delete expired foresights from MongoDB: %s", + e, + ) + + logger.info( + "[ForesightCleanup] ✅ Cleanup complete: deleted %d/%d expired foresight records", + deleted_count, + len(expired_records), + ) + return deleted_count diff --git a/src/infra_layer/adapters/out/persistence/repository/episodic_memory_raw_repository.py b/src/infra_layer/adapters/out/persistence/repository/episodic_memory_raw_repository.py index 7240c403..c5f17b36 100644 --- a/src/infra_layer/adapters/out/persistence/repository/episodic_memory_raw_repository.py +++ b/src/infra_layer/adapters/out/persistence/repository/episodic_memory_raw_repository.py @@ -318,6 +318,42 @@ async def delete_by_user_id( logger.error("❌ Failed to delete episodic memories by user ID: %s", e) return 0 + async def delete_by_parent_id( + self, parent_id: str, session: Optional[AsyncClientSession] = None + ) -> int: + """ + Delete all episodic memories by parent ID (e.g. memcell event_id). + + Used for deduplication: when re-processing the same source, delete old + records before inserting new ones. + + Args: + parent_id: Parent memory ID + session: Optional MongoDB session for transaction support + + Returns: + Number of deleted records + """ + try: + result = await self.model.find( + {"parent_id": parent_id} + ).delete(session=session) + count = result.deleted_count if result else 0 + if count > 0: + logger.info( + "✅ Deleted %d episodic memories by parent_id=%s", + count, + parent_id, + ) + return count + except Exception as e: + logger.error( + "❌ Failed to delete episodic memories by parent_id=%s: %s", + parent_id, + e, + ) + return 0 + async def find_by_filter_paginated( self, query_filter: Optional[Dict[str, Any]] = None, diff --git a/tests/test_write_pipeline_dedup.py b/tests/test_write_pipeline_dedup.py new file mode 100644 index 00000000..f80898e1 --- /dev/null +++ b/tests/test_write_pipeline_dedup.py @@ -0,0 +1,322 @@ +""" +Tests for write pipeline deduplication and foresight expiry cleanup. + +Covers: +- save_memory_docs: deletes old episodic records before inserting new ones +- save_memory_docs: deletes old event_log records before inserting new ones +- cleanup_expired_foresights: removes expired foresight records +""" + +import asyncio +import pytest +from datetime import datetime, timedelta +from unittest.mock import AsyncMock, MagicMock, patch +from collections import defaultdict +from typing import List, Dict, Any + +from api_specs.memory_models import MemoryType + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_episodic_doc(parent_id: str = "memcell_001", episode: str = "test episode"): + """Create a mock episodic memory document.""" + doc = MagicMock() + doc.parent_id = parent_id + doc.episode = episode + doc.vector = [0.1, 0.2, 0.3] + doc.id = "ep_001" + doc.event_id = "ep_001" + return doc + + +def _make_event_log_doc(parent_id: str = "memcell_001"): + """Create a mock event log document.""" + doc = MagicMock() + doc.parent_id = parent_id + doc.atomic_fact = "test fact" + doc.id = "el_001" + return doc + + +def _make_foresight_doc(parent_id: str = "memcell_001"): + """Create a mock foresight document.""" + doc = MagicMock() + doc.parent_id = parent_id + doc.content = "test foresight" + doc.id = "fs_001" + return doc + + +# --------------------------------------------------------------------------- +# Test: Episodic dedup in save_memory_docs +# --------------------------------------------------------------------------- + +class TestEpisodicDedup: + """Verify that save_memory_docs deletes old episodic records before insert.""" + + @pytest.mark.asyncio + async def test_dedup_deletes_old_records_before_insert(self): + """ + When saving episodic docs, old records with the same parent_id + should be deleted from MongoDB, ES, and Milvus before new insert. + """ + # Arrange + mock_episodic_repo = AsyncMock() + mock_episodic_repo.append_episodic_memory = AsyncMock( + side_effect=lambda doc: doc + ) + mock_episodic_repo.delete_by_parent_id = AsyncMock(return_value=1) + + mock_es_repo = AsyncMock() + mock_es_repo.create = AsyncMock() + mock_es_repo.delete_by_filters = AsyncMock() + + mock_milvus_repo = AsyncMock() + mock_milvus_repo.insert = AsyncMock() + mock_milvus_repo.delete_by_filters = AsyncMock() + + doc = _make_episodic_doc(parent_id="mc_123") + from biz_layer.mem_memorize import MemoryDocPayload + + payloads = [MemoryDocPayload(MemoryType.EPISODIC_MEMORY, doc)] + + with patch("biz_layer.mem_memorize.get_bean_by_type") as mock_get_bean, \ + patch("biz_layer.mem_memorize.EpisodicMemoryConverter") as mock_converter, \ + patch("biz_layer.mem_memorize.EpisodicMemoryMilvusConverter") as mock_milvus_converter: + + def _bean_router(cls): + from infra_layer.adapters.out.persistence.repository.episodic_memory_raw_repository import ( + EpisodicMemoryRawRepository, + ) + from infra_layer.adapters.out.search.repository.episodic_memory_es_repository import ( + EpisodicMemoryEsRepository, + ) + from infra_layer.adapters.out.search.repository.episodic_memory_milvus_repository import ( + EpisodicMemoryMilvusRepository, + ) + + repo_map = { + EpisodicMemoryRawRepository: mock_episodic_repo, + EpisodicMemoryEsRepository: mock_es_repo, + EpisodicMemoryMilvusRepository: mock_milvus_repo, + } + return repo_map.get(cls, AsyncMock()) + + mock_get_bean.side_effect = _bean_router + mock_converter.from_mongo.return_value = MagicMock() + mock_milvus_converter.from_mongo.return_value = {"vector": [0.1]} + + # Act + from biz_layer.mem_memorize import save_memory_docs + result = await save_memory_docs(payloads) + + # Assert: delete was called before insert + mock_episodic_repo.delete_by_parent_id.assert_called_once_with("mc_123") + mock_es_repo.delete_by_filters.assert_called_once_with( + filters={"parent_id": "mc_123"} + ) + mock_milvus_repo.delete_by_filters.assert_called_once_with( + filters={"parent_id": "mc_123"} + ) + # And insert still happened + mock_episodic_repo.append_episodic_memory.assert_called_once() + + @pytest.mark.asyncio + async def test_dedup_failure_does_not_block_insert(self): + """If dedup delete fails, insert should still proceed.""" + mock_episodic_repo = AsyncMock() + mock_episodic_repo.append_episodic_memory = AsyncMock( + side_effect=lambda doc: doc + ) + mock_episodic_repo.delete_by_parent_id = AsyncMock( + side_effect=Exception("DB error") + ) + + mock_es_repo = AsyncMock() + mock_milvus_repo = AsyncMock() + + doc = _make_episodic_doc(parent_id="mc_fail") + from biz_layer.mem_memorize import MemoryDocPayload + + payloads = [MemoryDocPayload(MemoryType.EPISODIC_MEMORY, doc)] + + with patch("biz_layer.mem_memorize.get_bean_by_type") as mock_get_bean, \ + patch("biz_layer.mem_memorize.EpisodicMemoryConverter") as mock_converter, \ + patch("biz_layer.mem_memorize.EpisodicMemoryMilvusConverter") as mock_milvus_converter: + + def _bean_router(cls): + from infra_layer.adapters.out.persistence.repository.episodic_memory_raw_repository import ( + EpisodicMemoryRawRepository, + ) + from infra_layer.adapters.out.search.repository.episodic_memory_es_repository import ( + EpisodicMemoryEsRepository, + ) + from infra_layer.adapters.out.search.repository.episodic_memory_milvus_repository import ( + EpisodicMemoryMilvusRepository, + ) + repo_map = { + EpisodicMemoryRawRepository: mock_episodic_repo, + EpisodicMemoryEsRepository: mock_es_repo, + EpisodicMemoryMilvusRepository: mock_milvus_repo, + } + return repo_map.get(cls, AsyncMock()) + + mock_get_bean.side_effect = _bean_router + mock_converter.from_mongo.return_value = MagicMock() + mock_milvus_converter.from_mongo.return_value = {"vector": [0.1]} + + from biz_layer.mem_memorize import save_memory_docs + result = await save_memory_docs(payloads) + + # Insert still happened despite dedup failure + mock_episodic_repo.append_episodic_memory.assert_called_once() + + +# --------------------------------------------------------------------------- +# Test: Event log dedup in save_memory_docs +# --------------------------------------------------------------------------- + +class TestEventLogDedup: + """Verify that save_memory_docs deletes old event_log records before insert.""" + + @pytest.mark.asyncio + async def test_dedup_deletes_old_event_logs(self): + mock_event_log_repo = AsyncMock() + mock_event_log_repo.create_batch = AsyncMock(return_value=[]) + mock_event_log_repo.delete_by_parent_id = AsyncMock(return_value=2) + + mock_milvus_repo = AsyncMock() + mock_milvus_repo.delete_by_parent_id = AsyncMock(return_value=True) + + mock_sync_service = AsyncMock() + + doc = _make_event_log_doc(parent_id="mc_el_001") + from biz_layer.mem_memorize import MemoryDocPayload + + payloads = [MemoryDocPayload(MemoryType.EVENT_LOG, doc)] + + with patch("biz_layer.mem_memorize.get_bean_by_type") as mock_get_bean: + def _bean_router(cls): + from infra_layer.adapters.out.persistence.repository.event_log_record_raw_repository import ( + EventLogRecordRawRepository, + ) + from infra_layer.adapters.out.search.repository.event_log_milvus_repository import ( + EventLogMilvusRepository, + ) + from biz_layer.mem_sync import MemorySyncService + + repo_map = { + EventLogRecordRawRepository: mock_event_log_repo, + EventLogMilvusRepository: mock_milvus_repo, + MemorySyncService: mock_sync_service, + } + return repo_map.get(cls, AsyncMock()) + + mock_get_bean.side_effect = _bean_router + + from biz_layer.mem_memorize import save_memory_docs + await save_memory_docs(payloads) + + # Verify delete was called + mock_event_log_repo.delete_by_parent_id.assert_called_once_with("mc_el_001") + mock_milvus_repo.delete_by_parent_id.assert_called_once_with("mc_el_001") + # And batch create still happened + mock_event_log_repo.create_batch.assert_called_once() + + +# --------------------------------------------------------------------------- +# Test: Foresight expiry cleanup +# --------------------------------------------------------------------------- + +class TestForesightCleanup: + """Verify cleanup_expired_foresights removes expired records.""" + + @pytest.mark.asyncio + async def test_cleanup_removes_expired_records(self): + """Expired foresight records should be deleted from all stores.""" + mock_record_1 = MagicMock() + mock_record_1.id = "fs_expired_001" + mock_record_2 = MagicMock() + mock_record_2.id = "fs_expired_002" + + mock_foresight_repo = AsyncMock() + mock_foresight_repo.delete_by_id = AsyncMock(return_value=True) + + mock_es_repo = AsyncMock() + mock_es_repo.delete_by_filters = AsyncMock() + + mock_milvus_repo = AsyncMock() + mock_milvus_repo.delete_by_id = AsyncMock(return_value=True) + + with patch( + "biz_layer.mem_memorize.ForesightRecord" + ) as MockForesightRecord, patch( + "biz_layer.mem_memorize.get_bean_by_type" + ) as mock_get_bean, patch( + "biz_layer.mem_memorize.get_now_with_timezone" + ) as mock_now, patch( + "biz_layer.mem_memorize.to_date_str" + ) as mock_to_date_str: + + # Setup: 2 expired records + mock_find = MagicMock() + mock_find.to_list = AsyncMock( + return_value=[mock_record_1, mock_record_2] + ) + MockForesightRecord.find.return_value = mock_find + + mock_now.return_value = datetime(2026, 3, 2) + mock_to_date_str.return_value = "2026-03-02" + + def _bean_router(cls): + from infra_layer.adapters.out.persistence.repository.foresight_record_repository import ( + ForesightRecordRawRepository, + ) + from infra_layer.adapters.out.search.repository.foresight_es_repository import ( + ForesightEsRepository, + ) + from infra_layer.adapters.out.search.repository.foresight_milvus_repository import ( + ForesightMilvusRepository, + ) + repo_map = { + ForesightRecordRawRepository: mock_foresight_repo, + ForesightEsRepository: mock_es_repo, + ForesightMilvusRepository: mock_milvus_repo, + } + return repo_map.get(cls, AsyncMock()) + + mock_get_bean.side_effect = _bean_router + + from biz_layer.mem_memorize import cleanup_expired_foresights + count = await cleanup_expired_foresights() + + # Should have deleted 2 records + assert count == 2 + assert mock_foresight_repo.delete_by_id.call_count == 2 + assert mock_milvus_repo.delete_by_id.call_count == 2 + assert mock_es_repo.delete_by_filters.call_count == 2 + + @pytest.mark.asyncio + async def test_cleanup_returns_zero_when_no_expired(self): + """When there are no expired records, cleanup should return 0.""" + with patch( + "biz_layer.mem_memorize.ForesightRecord" + ) as MockForesightRecord, patch( + "biz_layer.mem_memorize.get_now_with_timezone" + ) as mock_now, patch( + "biz_layer.mem_memorize.to_date_str" + ) as mock_to_date_str: + + mock_find = MagicMock() + mock_find.to_list = AsyncMock(return_value=[]) + MockForesightRecord.find.return_value = mock_find + mock_now.return_value = datetime(2026, 3, 2) + mock_to_date_str.return_value = "2026-03-02" + + from biz_layer.mem_memorize import cleanup_expired_foresights + count = await cleanup_expired_foresights() + + assert count == 0