From 558f26a4e77e081f438f2ce490634e8d130fdcfe Mon Sep 17 00:00:00 2001 From: yunkchen Date: Thu, 22 Jan 2026 15:20:55 +0800 Subject: [PATCH] Implement caching for NgramHashMapping creation Added caching mechanism for NgramHashMapping to optimize performance. --- engram_demo_v1.py | 41 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/engram_demo_v1.py b/engram_demo_v1.py index f3ce993..319f25d 100644 --- a/engram_demo_v1.py +++ b/engram_demo_v1.py @@ -302,6 +302,43 @@ def hash(self, input_ids): hash_ids_for_all_layers[layer_id] = self._get_ngram_hashes(input_ids, layer_id=layer_id) return hash_ids_for_all_layers +_HASH_MAPPING_CACHE = {} +# Ensures that an NgramHashMapping with identical configuration is created only once. +def get_or_create_hash_mapping( + engram_vocab_size, + max_ngram_size, + n_embed_per_ngram, + n_head_per_ngram, + layer_ids, + tokenizer_name_or_path, + pad_id, + seed, +): + cache_key = ( + tuple(engram_vocab_size), + max_ngram_size, + n_embed_per_ngram, + n_head_per_ngram, + tuple(layer_ids), + tokenizer_name_or_path, + pad_id, + seed, + ) + + if cache_key not in _HASH_MAPPING_CACHE: + _HASH_MAPPING_CACHE[cache_key] = NgramHashMapping( + engram_vocab_size=engram_vocab_size, + max_ngram_size=max_ngram_size, + n_embed_per_ngram=n_embed_per_ngram, + n_head_per_ngram=n_head_per_ngram, + layer_ids=layer_ids, + tokenizer_name_or_path=tokenizer_name_or_path, + pad_id=pad_id, + seed=seed, + ) + + return _HASH_MAPPING_CACHE[cache_key] + class MultiHeadEmbedding(nn.Module): def __init__(self, list_of_N: List[int], D: int): super().__init__() @@ -327,7 +364,7 @@ class Engram(nn.Module): def __init__(self, layer_id): super().__init__() self.layer_id = layer_id - self.hash_mapping = NgramHashMapping( + self.hash_mapping = get_or_create_hash_mapping( engram_vocab_size=engram_cfg.engram_vocab_size, max_ngram_size = engram_cfg.max_ngram_size, n_embed_per_ngram = engram_cfg.n_embed_per_ngram, @@ -420,4 +457,4 @@ def forward(self,input_ids,hidden_states): print("✅ Forward Complete!") print(f"{input_ids.shape=}\n{output.shape=}") - \ No newline at end of file +