From 409563413de9863f1119ba48de7272c57ec7d213 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 1 Aug 2025 20:32:35 +0200 Subject: [PATCH 1/9] add class for custom cache --- chebifier/ensemble/_custom_cache.py | 76 +++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 chebifier/ensemble/_custom_cache.py diff --git a/chebifier/ensemble/_custom_cache.py b/chebifier/ensemble/_custom_cache.py new file mode 100644 index 0000000..3681b28 --- /dev/null +++ b/chebifier/ensemble/_custom_cache.py @@ -0,0 +1,76 @@ +import os +import pickle +import threading +from collections import OrderedDict +from typing import Any + + +class PerSmilesPerModelLRUCache: + def __init__(self, max_size: int = 100, persist_path: str | None = None): + self._cache = OrderedDict() + self._max_size = max_size + self._lock = threading.Lock() + self._persist_path = persist_path + + self.hits = 0 + self.misses = 0 + + if self._persist_path: + self._load_cache() + + def get(self, smiles: str, model_name: str) -> Any | None: + key = (smiles, model_name) + with self._lock: + if key in self._cache: + self._cache.move_to_end(key) + self.hits += 1 + return self._cache[key] + else: + self.misses += 1 + return None + + def set(self, smiles: str, model_name: str, value: Any) -> None: + key = (smiles, model_name) + with self._lock: + if key in self._cache: + self._cache.move_to_end(key) + self._cache[key] = value + if len(self._cache) > self._max_size: + self._cache.popitem(last=False) + + def clear(self) -> None: + self._save_cache() + with self._lock: + self._cache.clear() + self.hits = 0 + self.misses = 0 + if self._persist_path and os.path.exists(self._persist_path): + os.remove(self._persist_path) + + def stats(self) -> dict: + return {"hits": self.hits, "misses": self.misses} + + def _save_cache(self) -> None: + """Serialize the cache to disk.""" + if not self._persist_path: + try: + with open(self._persist_path, "wb") as f: + pickle.dump(self._cache, f) + except Exception as e: + print(f"[Cache Save Error] {e}") + + def _load_cache(self) -> None: + """Load the cache from disk.""" + if os.path.exists(self._persist_path): + try: + with open(self._persist_path, "rb") as f: + loaded = pickle.load(f) + if isinstance(loaded, OrderedDict): + self._cache = loaded + except Exception as e: + print(f"[Cache Load Error] {e}") + + +if __name__ == "__main__": + # Example usage + cache = PerSmilesPerModelLRUCache(max_size=100, persist_path="cache.pkl") From ed28dfed9f30a4aa82ccbc2cf4513d32de5eb84a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 1 Aug 2025 21:35:28 +0200 Subject: [PATCH 2/9] add decorator for custom cache --- chebifier/ensemble/_custom_cache.py | 102 ++++++++++++++++++++++++++-- 1 file changed, 98 insertions(+), 4 deletions(-) diff --git a/chebifier/ensemble/_custom_cache.py b/chebifier/ensemble/_custom_cache.py index 3681b28..44ef292 100644 --- a/chebifier/ensemble/_custom_cache.py +++ b/chebifier/ensemble/_custom_cache.py @@ -2,7 +2,9 @@ import pickle import threading from collections import OrderedDict -from typing import Any +from collections.abc import Iterable +from functools import wraps +from typing import Any, Callable class PerSmilesPerModelLRUCache: @@ -30,6 +32,7 @@ def get(self, smiles: str, model_name: str) -> Any | None: return None def set(self, smiles: str, model_name: str, value: Any) -> None: + assert value is not None, "Value must not be None" key = (smiles, model_name) with self._lock: if key in self._cache: @@ -50,9 +53,65 @@ def clear(self) -> None: def stats(self) -> dict: return {"hits": self.hits, "misses": self.misses} + def batch_decorator(self, func: Callable) -> Callable: + """Decorator for class methods that accept a batch of SMILES as a tuple, + and want caching per (smiles, model_name) combination. + """ + + @wraps(func) + def wrapper(instance, smiles_list: list[str]): + assert isinstance(smiles_list, list), "smiles_list must be a list." + model_name = getattr(instance, "model_name", None) + assert model_name is not None, "Instance must have a model_name attribute." + + results = [] + missing_smiles = [] + missing_indices = [] + + # First: try to fetch all from cache + for i, smiles in enumerate(smiles_list): + result = self.get(smiles=smiles, model_name=model_name) + if result is not None: + results.append((i, result)) # save index for reordering + else: + missing_smiles.append(smiles) + missing_indices.append(i) + + # If some are missing, call original function + if missing_smiles: + new_results = func(instance, tuple(missing_smiles)) + assert isinstance( + new_results, Iterable + ), "Function must return an Iterable." + # Save to cache and append + for smiles, prediction in zip(missing_smiles, new_results): + if prediction is not None: + self.set(smiles, model_name, prediction) + results.append((missing_indices.pop(0), prediction)) + + # Reorder results to match original indices + results.sort(key=lambda x: x[0]) # sort by index + ordered = [result for _, result in results] + return ordered + + return wrapper + + def __len__(self): + with self._lock: + return len(self._cache) + + def __repr__(self): + return self._cache.__repr__() + + def save(self): + self._save_cache() + + def load(self): + self._load_cache() + def _save_cache(self) -> None: """Serialize the cache to disk.""" - if not self._persist_path: + if self._persist_path: try: with open(self._persist_path, "wb") as f: pickle.dump(self._cache, f) @@ -72,5 +131,40 @@ def _load_cache(self) -> None: if __name__ == "__main__": - # Example usage - cache = PerSmilesPerModelLRUCache(max_size=100, persist_path="cache.pkl") + # cache will persist across runs in "cache.pkl" + cache = PerSmilesPerModelLRUCache(max_size=50) + + class ExamplePredictor: + model_name = "example_model" + + @cache.batch_decorator + def predict(self, smiles_list: tuple[str]) -> list[dict]: + # Simulate a prediction function + return [{"prediction": hash(smiles) % 100} for smiles in smiles_list] + + # Create an instance of the predictor + predictor = ExamplePredictor() + + # Prediction set 1 — new model, all should be cache misses + predictor.model_name = "example_model" + predictor.predict(["CCC", "C", "CCO", "CCN"]) # MISS × 4 + print("Cache Stats:", cache.stats()) + + # Prediction set 2 — same model, partial hit/miss + predictor.model_name = "example_model" + predictor.predict(["CCC", "CO", "CCO", "CN"]) # HIT: CCC, CCO — MISS: CO, CN + print("Cache Stats:", cache.stats()) + + # Prediction set 3 — new model, same SMILES — should all be misses (per-model caching) + predictor.model_name = "example_model_2" + predictor.predict(["CCC", "C", "CO", "CN"]) # MISS × 4 (new model) + print("Cache Stats:", cache.stats()) + + # Prediction set 4 — another model + predictor.model_name = "example_model_3" + predictor.predict(["CCCC", "CCCl", "CCBr", "C(C)C"]) # MISS × 4 + print("Cache Stats:", cache.stats()) + + from pprint import pprint + + pprint(cache) From ad14325f63edc76a53767bb94c9ed49984e4efb5 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 1 Aug 2025 23:24:25 +0200 Subject: [PATCH 3/9] decorate each predict method with cache --- chebifier/__init__.py | 6 ++++++ chebifier/{ensemble => }/_custom_cache.py | 3 +++ chebifier/prediction_models/base_predictor.py | 10 +++------- chebifier/prediction_models/c3p_predictor.py | 8 ++++---- chebifier/prediction_models/chebi_lookup.py | 17 +++++++++-------- .../prediction_models/chemlog_predictor.py | 12 +++++------- chebifier/prediction_models/nn_predictor.py | 8 ++++---- 7 files changed, 34 insertions(+), 30 deletions(-) rename chebifier/{ensemble => }/_custom_cache.py (97%) diff --git a/chebifier/__init__.py b/chebifier/__init__.py index e69de29..aa1e6ec 100644 --- a/chebifier/__init__.py +++ b/chebifier/__init__.py @@ -0,0 +1,6 @@ +# Note: The top-level package __init__.py runs only once, +# even if multiple subpackages are imported later. + +from ._custom_cache import PerSmilesPerModelLRUCache + +modelwise_smiles_lru_cache = PerSmilesPerModelLRUCache(max_size=100) diff --git a/chebifier/ensemble/_custom_cache.py b/chebifier/_custom_cache.py similarity index 97% rename from chebifier/ensemble/_custom_cache.py rename to chebifier/_custom_cache.py index 44ef292..42fcc2f 100644 --- a/chebifier/ensemble/_custom_cache.py +++ b/chebifier/_custom_cache.py @@ -92,6 +92,9 @@ def wrapper(instance, smiles_list: list[str]): # Reorder results to match original indices results.sort(key=lambda x: x[0]) # sort by index ordered = [result for _, result in results] + assert len(ordered) == len( + smiles_list + ), "Result length does not match input length." return ordered return wrapper diff --git a/chebifier/prediction_models/base_predictor.py b/chebifier/prediction_models/base_predictor.py index ba1412d..a175366 100644 --- a/chebifier/prediction_models/base_predictor.py +++ b/chebifier/prediction_models/base_predictor.py @@ -1,7 +1,7 @@ import json from abc import ABC -from functools import lru_cache +from chebifier import modelwise_smiles_lru_cache class BasePredictor(ABC): @@ -23,17 +23,13 @@ def __init__( self._description = kwargs.get("description", None) + @modelwise_smiles_lru_cache.batch_decorator def predict_smiles_list(self, smiles_list: list[str]) -> dict: - # list is not hashable, so we convert it to a tuple (useful for caching) - return self.predict_smiles_tuple(tuple(smiles_list)) - - @lru_cache(maxsize=100) - def predict_smiles_tuple(self, smiles_tuple: tuple[str]) -> dict: raise NotImplementedError() def predict_smiles(self, smiles: str) -> dict: # by default, use list-based prediction - return self.predict_smiles_tuple((smiles,))[0] + return self.predict_smiles_list([smiles])[0] @property def info_text(self): diff --git a/chebifier/prediction_models/c3p_predictor.py b/chebifier/prediction_models/c3p_predictor.py index 00c71f7..dc4704d 100644 --- a/chebifier/prediction_models/c3p_predictor.py +++ b/chebifier/prediction_models/c3p_predictor.py @@ -1,9 +1,9 @@ -from functools import lru_cache -from typing import Optional, List from pathlib import Path +from typing import List, Optional from c3p import classifier as c3p_classifier +from chebifier import modelwise_smiles_lru_cache from chebifier.prediction_models import BasePredictor @@ -24,8 +24,8 @@ def __init__( self.chemical_classes = chemical_classes self.chebi_graph = kwargs.get("chebi_graph", None) - @lru_cache(maxsize=100) - def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list: + @modelwise_smiles_lru_cache.batch_decorator + def predict_smiles_list(self, smiles_list: list[str]) -> list: result_list = c3p_classifier.classify( list(smiles_list), self.program_directory, diff --git a/chebifier/prediction_models/chebi_lookup.py b/chebifier/prediction_models/chebi_lookup.py index 2f6a7b0..d145e24 100644 --- a/chebifier/prediction_models/chebi_lookup.py +++ b/chebifier/prediction_models/chebi_lookup.py @@ -1,16 +1,16 @@ -from functools import lru_cache +import json +import os from typing import Optional -from chebifier.prediction_models import BasePredictor -import os import networkx as nx from rdkit import Chem -import json + +from chebifier import modelwise_smiles_lru_cache +from chebifier.prediction_models import BasePredictor from chebifier.utils import load_chebi_graph class ChEBILookupPredictor(BasePredictor): - def __init__( self, model_name: str, @@ -67,7 +67,6 @@ def build_smiles_lookup(self): ) return smiles_lookup - @lru_cache(maxsize=100) def predict_smiles(self, smiles: str) -> Optional[dict]: if not smiles: return None @@ -94,7 +93,8 @@ def predict_smiles(self, smiles: str) -> Optional[dict]: else: return None - def predict_smiles_tuple(self, smiles_list: list[str]) -> list: + @modelwise_smiles_lru_cache.batch_decorator + def predict_smiles_list(self, smiles_list: list[str]) -> list: predictions = [] for smiles in smiles_list: predictions.append(self.predict_smiles(smiles)) @@ -145,7 +145,8 @@ def explain_smiles(self, smiles: str) -> dict: # Example usage smiles_list = [ "CCO", - "C1=CC=CC=C1" "*C(=O)OC[C@H](COP(=O)([O-])OCC[N+](C)(C)C)OC(*)=O", + "C1=CC=CC=C1", + "*C(=O)OC[C@H](COP(=O)([O-])OCC[N+](C)(C)C)OC(*)=O", ] # SMILES with 251 matches in ChEBI predictions = predictor.predict_smiles_list(smiles_list) print(predictions) diff --git a/chebifier/prediction_models/chemlog_predictor.py b/chebifier/prediction_models/chemlog_predictor.py index 8232641..0cd5fa5 100644 --- a/chebifier/prediction_models/chemlog_predictor.py +++ b/chebifier/prediction_models/chemlog_predictor.py @@ -12,10 +12,11 @@ ) from chemlog.cli import CLASSIFIERS, _smiles_to_mol, strategy_call from chemlog_extra.alg_classification.by_element_classification import ( - XMolecularEntityClassifier, OrganoXCompoundClassifier, + XMolecularEntityClassifier, ) -from functools import lru_cache + +from chebifier import modelwise_smiles_lru_cache from .base_predictor import BasePredictor @@ -47,7 +48,6 @@ class ChemlogExtraPredictor(BasePredictor): - CHEMLOG_CLASSIFIER = None def __init__(self, model_name: str, **kwargs): @@ -72,12 +72,10 @@ def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list: class ChemlogXMolecularEntityPredictor(ChemlogExtraPredictor): - CHEMLOG_CLASSIFIER = XMolecularEntityClassifier class ChemlogOrganoXCompoundPredictor(ChemlogExtraPredictor): - CHEMLOG_CLASSIFIER = OrganoXCompoundClassifier @@ -97,7 +95,6 @@ def __init__(self, model_name: str, **kwargs): # fmt: on print(f"Initialised ChemLog model {self.model_name}") - @lru_cache(maxsize=100) def predict_smiles(self, smiles: str) -> Optional[dict]: mol = _smiles_to_mol(smiles) if mol is None: @@ -122,7 +119,8 @@ def predict_smiles(self, smiles: str) -> Optional[dict]: for label in self.peptide_labels + pos_labels } - def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list: + @modelwise_smiles_lru_cache.batch_decorator + def predict_smiles_list(self, smiles_list: list[str]) -> list: results = [] for i, smiles in tqdm.tqdm(enumerate(smiles_list)): results.append(self.predict_smiles(smiles)) diff --git a/chebifier/prediction_models/nn_predictor.py b/chebifier/prediction_models/nn_predictor.py index e7d72c9..79dcad9 100644 --- a/chebifier/prediction_models/nn_predictor.py +++ b/chebifier/prediction_models/nn_predictor.py @@ -1,10 +1,10 @@ -from functools import lru_cache - import numpy as np import torch import tqdm from rdkit import Chem +from chebifier import modelwise_smiles_lru_cache + from .base_predictor import BasePredictor @@ -52,8 +52,8 @@ def read_smiles(self, smiles): d = reader.to_data(dict(features=smiles, labels=None)) return d - @lru_cache(maxsize=100) - def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list: + @modelwise_smiles_lru_cache.batch_decorator + def predict_smiles_list(self, smiles_list: list[str]) -> list: """Returns a list with the length of smiles_list, each element is either None (=failure) or a dictionary Of classes and predicted values.""" token_dicts = [] From cdd7de9c5a8c6a6ecf98614e77942b44d293bb3b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 2 Aug 2025 00:23:07 +0200 Subject: [PATCH 4/9] tests for custom cache --- chebifier/_custom_cache.py | 46 ++----------- tests/test_cache.py | 128 +++++++++++++++++++++++++++++++++++++ 2 files changed, 132 insertions(+), 42 deletions(-) create mode 100644 tests/test_cache.py diff --git a/chebifier/_custom_cache.py b/chebifier/_custom_cache.py index 42fcc2f..967abc4 100644 --- a/chebifier/_custom_cache.py +++ b/chebifier/_custom_cache.py @@ -84,10 +84,12 @@ def wrapper(instance, smiles_list: list[str]): new_results, Iterable ), "Function must return an Iterable." # Save to cache and append - for smiles, prediction in zip(missing_smiles, new_results): + for smiles, prediction, missing_idx in zip( + missing_smiles, new_results, missing_indices + ): if prediction is not None: self.set(smiles, model_name, prediction) - results.append((missing_indices.pop(0), prediction)) + results.append((missing_idx, prediction)) # Reorder results to match original indices results.sort(key=lambda x: x[0]) # sort by index @@ -131,43 +133,3 @@ def _load_cache(self) -> None: self._cache = loaded except Exception as e: print(f"[Cache Load Error] {e}") - - -if __name__ == "__main__": - # cache will persist across runs in "cache.pkl" - cache = PerSmilesPerModelLRUCache(max_size=50) - - class ExamplePredictor: - model_name = "example_model" - - @cache.batch_decorator - def predict(self, smiles_list: tuple[str]) -> list[dict]: - # Simulate a prediction function - return [{"prediction": hash(smiles) % 100} for smiles in smiles_list] - - # Create an instance of the predictor - predictor = ExamplePredictor() - - # Prediction set 1 — new model, all should be cache misses - predictor.model_name = "example_model" - predictor.predict(["CCC", "C", "CCO", "CCN"]) # MISS × 4 - print("Cache Stats:", cache.stats()) - - # Prediction set 2 — same model, partial hit/miss - predictor.model_name = "example_model" - predictor.predict(["CCC", "CO", "CCO", "CN"]) # HIT: CCC, CCO — MISS: CO, CN - print("Cache Stats:", cache.stats()) - - # Prediction set 3 — new model, same SMILES — should all be misses (per-model caching) - predictor.model_name = "example_model_2" - predictor.predict(["CCC", "C", "CO", "CN"]) # MISS × 4 (new model) - print("Cache Stats:", cache.stats()) - - # Prediction set 4 — another model - predictor.model_name = "example_model_3" - predictor.predict(["CCCC", "CCCl", "CCBr", "C(C)C"]) # MISS × 4 - print("Cache Stats:", cache.stats()) - - from pprint import pprint - - pprint(cache) diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 0000000..0286b86 --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,128 @@ +import os +import tempfile +import unittest + +from chebifier import PerSmilesPerModelLRUCache + +g_cache = PerSmilesPerModelLRUCache(max_size=3) + + +class DummyPredictor: + def __init__(self, model_name): + self.model_name = model_name + + @g_cache.batch_decorator + def predict(self, smiles_list: tuple[str]): + # Simple predictable dummy function for tests + return [f"{self.model_name}{i}" for i in range(len(smiles_list))] + + +class TestPerSmilesPerModelLRUCache(unittest.TestCase): + def setUp(self): + # Create temp file for persistence tests + self.temp_file = tempfile.NamedTemporaryFile(delete=False) + self.temp_file.close() + self.cache = PerSmilesPerModelLRUCache( + max_size=3, persist_path=self.temp_file.name + ) + + def tearDown(self): + if os.path.exists(self.temp_file.name): + os.remove(self.temp_file.name) + + def test_cache_miss_and_set_get(self): + # Initially empty + self.assertEqual(len(self.cache), 0) + self.assertIsNone(self.cache.get("CCC", "model1")) + + # Set and get + self.cache.set("CCC", "model1", "result1") + self.assertEqual(self.cache.get("CCC", "model1"), "result1") + self.assertEqual(self.cache.hits, 1) + self.assertEqual(self.cache.misses, 1) # One miss from first get + + def test_cache_eviction(self): + self.cache.set("a", "m", "v1") + self.cache.set("b", "m", "v2") + self.cache.set("c", "m", "v3") + self.assertEqual(len(self.cache), 3) + # Adding one more triggers eviction of oldest + self.cache.set("d", "m", "v4") + self.assertEqual(len(self.cache), 3) + self.assertIsNone(self.cache.get("a", "m")) # 'a' evicted + self.assertIsNotNone(self.cache.get("d", "m")) # 'd' present + + def test_batch_decorator_hits_and_misses(self): + predictor = DummyPredictor("modelA") + predictor2 = DummyPredictor("modelB") + + # Clear cache before starting the test + g_cache.clear() + + smiles = ["AAA", "BBB", "CCC", "DDD", "EEE"] + # First call all misses + results1 = predictor.predict(smiles) + results1_model2 = predictor2.predict(smiles) + + # all prediction as retrived from actual prediction function and not from cache + self.assertListEqual( + results1, ["modelA_P0", "modelA_P1", "modelA_P2", "modelA_P3", "modelA_P4"] + ) + self.assertListEqual( + results1_model2, + ["modelB_P0", "modelB_P1", "modelB_P2", "modelB_P3", "modelB_P4"], + ) + stats_after_first = g_cache.stats() + self.assertEqual(stats_after_first["misses"], 3) + + # cache = {("AAA", "modelA"): "modelA_P0", ("BBB", "modelA"): "modelA_P1", ("CCC", "modelA"): "modelA_P2"} + # Second call with some hits and some misses + results2 = predictor.predict(["FFF", "DDD"]) + # AAA from cache + # FFF is not in cache, so it predicted, hence it has P0 as its the only one passed to prediction function + # and dummy predictor returns iterates over the smiles list and return P{idx} corresponding to the index + self.assertListEqual(results2, ["P3", "P0"]) + stats_after_second = g_cache.stats() + self.assertEqual(stats_after_second["hits"], 1) + self.assertEqual(stats_after_second["misses"], 4) + + # cache = {("AAA", "modelA"): "P0", ("BBB", "modelA"): "P1", ("CCC", "modelA"): "P2", + # ("DDD", "modelA"): "P3", ("EEE", "modelA"): "P4", ("FFF", "modelA"): "P0"} + + # Third call with some hits and some misses + results3 = predictor.predict(["EEE", "GGG", "DDD", "HHH", "BBB", "ZZZ"]) + # Here, predictions for [EEE, DDD, BBB] are retrived from cache, + # while [GGG, HHH, ZZZ] are not in cache and hence passe to the prediction function + self.assertListEqual(results3, ["P4", "P0", "P3", "P0", "P1", "P0"]) + stats_after_third = g_cache.stats() + self.assertEqual(stats_after_third["hits"], 1) + self.assertEqual(stats_after_third["misses"], 4) + + def test_persistence_save_and_load(self): + # Set some values + self.cache.set("sm1", "modelX", "val1") + self.cache.set("sm2", "modelX", "val2") + + # Save cache to file + self.cache.save() + + # Create new cache instance loading from file + new_cache = PerSmilesPerModelLRUCache( + max_size=3, persist_path=self.temp_file.name + ) + new_cache.load() + + self.assertEqual(new_cache.get("sm1", "modelX"), "val1") + self.assertEqual(new_cache.get("sm2", "modelX"), "val2") + + def test_clear_cache(self): + self.cache.set("x", "m", "v") + self.cache.save() + self.assertTrue(os.path.exists(self.temp_file.name)) + self.cache.clear() + self.assertEqual(len(self.cache), 0) + self.assertFalse(os.path.exists(self.temp_file.name)) + + +if __name__ == "__main__": + unittest.main() From 9b636db515425489c6a18e601c414681aadd404b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 2 Aug 2025 20:18:01 +0200 Subject: [PATCH 5/9] refine test for realistic scenario with 2 models --- chebifier/_custom_cache.py | 6 ++++- tests/test_cache.py | 52 +++++++++++++++++++++++++++----------- 2 files changed, 42 insertions(+), 16 deletions(-) diff --git a/chebifier/_custom_cache.py b/chebifier/_custom_cache.py index 967abc4..8d02c68 100644 --- a/chebifier/_custom_cache.py +++ b/chebifier/_custom_cache.py @@ -125,7 +125,11 @@ def _save_cache(self) -> None: def _load_cache(self) -> None: """Load the cache from disk.""" - if os.path.exists(self._persist_path): + if ( + self._persist_path + and os.path.exists(self._persist_path) + and os.path.getsize(self._persist_path) > 0 + ): try: with open(self._persist_path, "rb") as f: loaded = pickle.load(f) diff --git a/tests/test_cache.py b/tests/test_cache.py index 0286b86..72c7797 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -4,7 +4,7 @@ from chebifier import PerSmilesPerModelLRUCache -g_cache = PerSmilesPerModelLRUCache(max_size=3) +g_cache = PerSmilesPerModelLRUCache(max_size=100, persist_path=None) class DummyPredictor: @@ -14,7 +14,7 @@ def __init__(self, model_name): @g_cache.batch_decorator def predict(self, smiles_list: tuple[str]): # Simple predictable dummy function for tests - return [f"{self.model_name}{i}" for i in range(len(smiles_list))] + return [f"{self.model_name}_P{i}" for i in range(len(smiles_list))] class TestPerSmilesPerModelLRUCache(unittest.TestCase): @@ -73,30 +73,52 @@ def test_batch_decorator_hits_and_misses(self): ["modelB_P0", "modelB_P1", "modelB_P2", "modelB_P3", "modelB_P4"], ) stats_after_first = g_cache.stats() - self.assertEqual(stats_after_first["misses"], 3) + self.assertEqual( + stats_after_first["misses"], 10 + ) # 5 for modelA and 5 for modelB + self.assertEqual(stats_after_first["hits"], 0) + self.assertEqual(len(g_cache), 10) # 5 for each model + + # cache = {("AAA", "modelA"): "modelA_P0", ("BBB", "modelA"): "modelA_P1", ("CCC", "modelA"): "modelA_P2", + # ("DDD", "modelA"): "modelA_P3", ("EEE", "modelA"): "modelA_P4", + # ("AAA", "modelB"): "modelB_P0", ("BBB", "modelB"): "modelB_P1", ("CCC", "modelB"): "modelB_P2",} + # ("DDD", "modelB"): "modelB_P3", ("EEE", "modelB"): "modelB_P4"} - # cache = {("AAA", "modelA"): "modelA_P0", ("BBB", "modelA"): "modelA_P1", ("CCC", "modelA"): "modelA_P2"} # Second call with some hits and some misses results2 = predictor.predict(["FFF", "DDD"]) - # AAA from cache - # FFF is not in cache, so it predicted, hence it has P0 as its the only one passed to prediction function - # and dummy predictor returns iterates over the smiles list and return P{idx} corresponding to the index - self.assertListEqual(results2, ["P3", "P0"]) + # DDD from cache + # FFF is not in cache, so its predicted, hence it has P0 as its the only one passed to prediction function + # and dummy predictor iterates over the smiles list and returns P{idx} corresponding to the index + self.assertListEqual(results2, ["modelA_P0", "modelA_P3"]) stats_after_second = g_cache.stats() - self.assertEqual(stats_after_second["hits"], 1) - self.assertEqual(stats_after_second["misses"], 4) + self.assertEqual(stats_after_second["hits"], 1) # additional 1 hit for DDD + self.assertEqual(stats_after_second["misses"], 11) # 1 miss for FFF - # cache = {("AAA", "modelA"): "P0", ("BBB", "modelA"): "P1", ("CCC", "modelA"): "P2", - # ("DDD", "modelA"): "P3", ("EEE", "modelA"): "P4", ("FFF", "modelA"): "P0"} + # cache = {("AAA", "modelA"): "modelA_P0", ("BBB", "modelA"): "modelA_P1", ("CCC", "modelA"): "modelA_P2", + # ("DDD", "modelA"): "modelA_P3", ("EEE", "modelA"): "modelA_P4", ("FFF", "modelA"): "modelA_P0", ...} # Third call with some hits and some misses results3 = predictor.predict(["EEE", "GGG", "DDD", "HHH", "BBB", "ZZZ"]) # Here, predictions for [EEE, DDD, BBB] are retrived from cache, # while [GGG, HHH, ZZZ] are not in cache and hence passe to the prediction function - self.assertListEqual(results3, ["P4", "P0", "P3", "P0", "P1", "P0"]) + self.assertListEqual( + results3, + [ + "modelA_P4", # EEE from cache + "modelA_P0", # GGG not in cache, so it predicted, hence it has P0 as its the only one passed to prediction function + "modelA_P3", # DDD from cache + "modelA_P1", # HHH not in cache, so it predicted, hence it has P1 as its the only one passed to prediction function + "modelA_P1", # BBB from cache + "modelA_P2", # ZZZ not in cache, so it predicted, hence it has P2 as its the only one passed to prediction function + ], + ) stats_after_third = g_cache.stats() - self.assertEqual(stats_after_third["hits"], 1) - self.assertEqual(stats_after_third["misses"], 4) + self.assertEqual( + stats_after_third["hits"], 4 + ) # additional 3 hits for EEE, DDD, BBB + self.assertEqual( + stats_after_third["misses"], 14 + ) # additional 3 misses for GGG, HHH, ZZZ def test_persistence_save_and_load(self): # Set some values From 238be4e8a2749987f4f507bb8b35d9ec5ab8b6d7 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 2 Aug 2025 20:23:44 +0200 Subject: [PATCH 6/9] cache: docstrings and typehints --- chebifier/_custom_cache.py | 101 +++++++++++++++++++++++++++++++------ tests/test_cache.py | 50 ++++++++++++++---- 2 files changed, 127 insertions(+), 24 deletions(-) diff --git a/chebifier/_custom_cache.py b/chebifier/_custom_cache.py index 8d02c68..992cdcd 100644 --- a/chebifier/_custom_cache.py +++ b/chebifier/_custom_cache.py @@ -8,8 +8,20 @@ class PerSmilesPerModelLRUCache: + """ + A thread-safe, optionally persistent LRU cache for storing + (SMILES, model_name) → result mappings. + """ + def __init__(self, max_size: int = 100, persist_path: str | None = None): - self._cache = OrderedDict() + """ + Initialize the cache. + + Args: + max_size (int): Maximum number of items to keep in the cache. + persist_path (str | None): Optional path to persist cache using pickle. + """ + self._cache: OrderedDict[tuple[str, str], Any] = OrderedDict() self._max_size = max_size self._lock = threading.Lock() self._persist_path = persist_path @@ -21,6 +33,16 @@ def __init__(self, max_size: int = 100, persist_path: str | None = None): self._load_cache() def get(self, smiles: str, model_name: str) -> Any | None: + """ + Retrieve value from cache if present, otherwise return None. + + Args: + smiles (str): SMILES string key. + model_name (str): Model identifier. + + Returns: + Any | None: Cached value or None. + """ key = (smiles, model_name) with self._lock: if key in self._cache: @@ -32,6 +54,14 @@ def get(self, smiles: str, model_name: str) -> Any | None: return None def set(self, smiles: str, model_name: str, value: Any) -> None: + """ + Store value in cache under (smiles, model_name) key. + + Args: + smiles (str): SMILES string key. + model_name (str): Model identifier. + value (Any): Value to cache. + """ assert value is not None, "Value must not be None" key = (smiles, model_name) with self._lock: @@ -42,6 +72,9 @@ def set(self, smiles: str, model_name: str, value: Any) -> None: self._cache.popitem(last=False) def clear(self) -> None: + """ + Clear the cache and remove the persistence file if present. + """ self._save_cache() with self._lock: self._cache.clear() @@ -50,23 +83,38 @@ def clear(self) -> None: if self._persist_path and os.path.exists(self._persist_path): os.remove(self._persist_path) - def stats(self) -> dict: + def stats(self) -> dict[str, int]: + """ + Return cache hit/miss statistics. + + Returns: + dict[str, int]: Dictionary with 'hits' and 'misses' keys. + """ return {"hits": self.hits, "misses": self.misses} def batch_decorator(self, func: Callable) -> Callable: - """Decorator for class methods that accept a batch of SMILES as a tuple, - and want caching per (smiles, model_name) combination. + """ + Decorator for class methods that accept a batch of SMILES as a list, + and cache predictions per (smiles, model_name) key. + + The instance is expected to have a `model_name` attribute. + + Args: + func (Callable): The method to decorate. + + Returns: + Callable: The wrapped method. """ @wraps(func) - def wrapper(instance, smiles_list: list[str]): + def wrapper(instance, smiles_list: list[str]) -> list[Any]: assert isinstance(smiles_list, list), "smiles_list must be a list." model_name = getattr(instance, "model_name", None) assert model_name is not None, "Instance must have a model_name attribute." - results = [] - missing_smiles = [] - missing_indices = [] + results: list[tuple[int, Any]] = [] + missing_smiles: list[str] = [] + missing_indices: list[int] = [] # First: try to fetch all from cache for i, smiles in enumerate(smiles_list): @@ -82,7 +130,8 @@ def wrapper(instance, smiles_list: list[str]): new_results = func(instance, tuple(missing_smiles)) assert isinstance( new_results, Iterable - ), "Function must return an Iterable." + ), "Function must return an Iterable." + # Save to cache and append for smiles, prediction, missing_idx in zip( missing_smiles, new_results, missing_indices @@ -101,21 +150,41 @@ def wrapper(instance, smiles_list: list[str]): return wrapper - def __len__(self): + def __len__(self) -> int: + """ + Return number of items in the cache. + + Returns: + int: Number of entries in the cache. + """ with self._lock: return len(self._cache) - def __repr__(self): + def __repr__(self) -> str: + """ + String representation of the underlying cache. + + Returns: + str: String version of the OrderedDict. + """ return self._cache.__repr__() - def save(self): + def save(self) -> None: + """ + Save the cache to disk, if persistence is enabled. + """ self._save_cache() - def load(self): + def load(self) -> None: + """ + Load the cache from disk, if persistence is enabled. + """ self._load_cache() def _save_cache(self) -> None: - """Serialize the cache to disk.""" + """ + Serialize the cache to disk using pickle. + """ if self._persist_path: try: with open(self._persist_path, "wb") as f: @@ -124,7 +193,9 @@ def _save_cache(self) -> None: print(f"[Cache Save Error] {e}") def _load_cache(self) -> None: - """Load the cache from disk.""" + """ + Load the cache from disk, if the file exists and is non-empty. + """ if ( self._persist_path and os.path.exists(self._persist_path) diff --git a/tests/test_cache.py b/tests/test_cache.py index 72c7797..eaa98f4 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -8,17 +8,28 @@ class DummyPredictor: - def __init__(self, model_name): + def __init__(self, model_name: str): + """ + Dummy predictor for testing cache decorator. + :param model_name: Name of the model instance (used for key separation). + """ self.model_name = model_name @g_cache.batch_decorator - def predict(self, smiles_list: tuple[str]): + def predict(self, smiles_list: tuple[str]) -> list[str]: + """ + Dummy predict method to simulate model inference. + Returns list of predictions with predictable format. + """ # Simple predictable dummy function for tests return [f"{self.model_name}_P{i}" for i in range(len(smiles_list))] class TestPerSmilesPerModelLRUCache(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: + """ + Set up a temporary cache file and cache instance before each test. + """ # Create temp file for persistence tests self.temp_file = tempfile.NamedTemporaryFile(delete=False) self.temp_file.close() @@ -26,11 +37,17 @@ def setUp(self): max_size=3, persist_path=self.temp_file.name ) - def tearDown(self): + def tearDown(self) -> None: + """ + Clean up the temporary file after each test. + """ if os.path.exists(self.temp_file.name): os.remove(self.temp_file.name) - def test_cache_miss_and_set_get(self): + def test_cache_miss_and_set_get(self) -> None: + """ + Test cache miss on initial get, then set and confirm hit. + """ # Initially empty self.assertEqual(len(self.cache), 0) self.assertIsNone(self.cache.get("CCC", "model1")) @@ -41,7 +58,10 @@ def test_cache_miss_and_set_get(self): self.assertEqual(self.cache.hits, 1) self.assertEqual(self.cache.misses, 1) # One miss from first get - def test_cache_eviction(self): + def test_cache_eviction(self) -> None: + """ + Test LRU eviction when capacity is exceeded. + """ self.cache.set("a", "m", "v1") self.cache.set("b", "m", "v2") self.cache.set("c", "m", "v3") @@ -52,7 +72,13 @@ def test_cache_eviction(self): self.assertIsNone(self.cache.get("a", "m")) # 'a' evicted self.assertIsNotNone(self.cache.get("d", "m")) # 'd' present - def test_batch_decorator_hits_and_misses(self): + def test_batch_decorator_hits_and_misses(self) -> None: + """ + Test decorator behavior on batch prediction: + - first call (all misses) + - second call (mixed hits and misses) + - third call (more hits and misses) + """ predictor = DummyPredictor("modelA") predictor2 = DummyPredictor("modelB") @@ -120,7 +146,10 @@ def test_batch_decorator_hits_and_misses(self): stats_after_third["misses"], 14 ) # additional 3 misses for GGG, HHH, ZZZ - def test_persistence_save_and_load(self): + def test_persistence_save_and_load(self) -> None: + """ + Test that cache is properly saved to disk and reloaded. + """ # Set some values self.cache.set("sm1", "modelX", "val1") self.cache.set("sm2", "modelX", "val2") @@ -137,7 +166,10 @@ def test_persistence_save_and_load(self): self.assertEqual(new_cache.get("sm1", "modelX"), "val1") self.assertEqual(new_cache.get("sm2", "modelX"), "val2") - def test_clear_cache(self): + def test_clear_cache(self) -> None: + """ + Test clearing the cache and removing persisted file. + """ self.cache.set("x", "m", "v") self.cache.save() self.assertTrue(os.path.exists(self.temp_file.name)) From c8ba0612a68a8ed71b78a6bb74f2a605c512a917 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 2 Aug 2025 20:39:21 +0200 Subject: [PATCH 7/9] avoid sorting and re-iterating --- chebifier/_custom_cache.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/chebifier/_custom_cache.py b/chebifier/_custom_cache.py index 992cdcd..38b500f 100644 --- a/chebifier/_custom_cache.py +++ b/chebifier/_custom_cache.py @@ -112,18 +112,22 @@ def wrapper(instance, smiles_list: list[str]) -> list[Any]: model_name = getattr(instance, "model_name", None) assert model_name is not None, "Instance must have a model_name attribute." - results: list[tuple[int, Any]] = [] missing_smiles: list[str] = [] missing_indices: list[int] = [] + ordered_results: list[Any] = [None] * len(smiles_list) # First: try to fetch all from cache - for i, smiles in enumerate(smiles_list): - result = self.get(smiles=smiles, model_name=model_name) - if result is not None: - results.append((i, result)) # save index for reordering + for idx, smiles in enumerate(smiles_list): + prediction = self.get(smiles=smiles, model_name=model_name) + if prediction is not None: + # For debugging purposes, you can uncomment the print statement below + # print( + # f"[Cache Hit] Prediction for smiles: {smiles} and model: {model_name} are retrieved from cache." + # ) + ordered_results[idx] = prediction else: missing_smiles.append(smiles) - missing_indices.append(i) + missing_indices.append(idx) # If some are missing, call original function if missing_smiles: @@ -138,15 +142,9 @@ def wrapper(instance, smiles_list: list[str]) -> list[Any]: ): if prediction is not None: self.set(smiles, model_name, prediction) - results.append((missing_idx, prediction)) - - # Reorder results to match original indices - results.sort(key=lambda x: x[0]) # sort by index - ordered = [result for _, result in results] - assert len(ordered) == len( - smiles_list - ), "Result length does not match input length." - return ordered + ordered_results[missing_idx] = prediction + + return ordered_results return wrapper From c304092ed4b004835172b736eb60a940bf23eb61 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 11 Aug 2025 12:20:03 +0200 Subject: [PATCH 8/9] add predict_smiles_list for chemlog extra --- chebifier/prediction_models/chemlog_predictor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebifier/prediction_models/chemlog_predictor.py b/chebifier/prediction_models/chemlog_predictor.py index 0cd5fa5..99ca402 100644 --- a/chebifier/prediction_models/chemlog_predictor.py +++ b/chebifier/prediction_models/chemlog_predictor.py @@ -55,7 +55,7 @@ def __init__(self, model_name: str, **kwargs): self.chebi_graph = kwargs.get("chebi_graph", None) self.classifier = self.CHEMLOG_CLASSIFIER() - def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list: + def predict_smiles_list(self, smiles_list: list[str]) -> list: mol_list = [_smiles_to_mol(smiles) for smiles in smiles_list] res = self.classifier.classify(mol_list) if self.chebi_graph is not None: From 606ebda07e3a00db3479c12ad3e9fb8d11299664 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 11 Aug 2025 12:34:35 +0200 Subject: [PATCH 9/9] cache to chemlog extra --- chebifier/prediction_models/chemlog_predictor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/chebifier/prediction_models/chemlog_predictor.py b/chebifier/prediction_models/chemlog_predictor.py index 99ca402..99fa3b9 100644 --- a/chebifier/prediction_models/chemlog_predictor.py +++ b/chebifier/prediction_models/chemlog_predictor.py @@ -55,6 +55,7 @@ def __init__(self, model_name: str, **kwargs): self.chebi_graph = kwargs.get("chebi_graph", None) self.classifier = self.CHEMLOG_CLASSIFIER() + @modelwise_smiles_lru_cache.batch_decorator def predict_smiles_list(self, smiles_list: list[str]) -> list: mol_list = [_smiles_to_mol(smiles) for smiles in smiles_list] res = self.classifier.classify(mol_list)