[Opt] Implement caching for NgramHashMapping creation#13
[Opt] Implement caching for NgramHashMapping creation#13yunkchen wants to merge 1 commit intodeepseek-ai:mainfrom
Conversation
Added caching mechanism for NgramHashMapping to optimize performance.
dino65-dev
left a comment
There was a problem hiding this comment.
Overall: Good optimization — caching NgramHashMapping avoids redundant tokenizer loading and hash computation across Engram layers. A few robustness issues to address:
Mutable reference bug (L317-338): layer_ids is passed by reference. Pass tuple(layer_ids) to avoid silent corruption if caller mutates the list.
Unbounded cache (L305-340): _HASH_MAPPING_CACHE can grow indefinitely, holding large tokenizers. Consider @lru_cache(maxsize=N) or add a clear_cache() helper.
No thread safety (L328-339): Race condition in check-then-set. Add threading.Lock or use @lru_cache (handles locking internally).
Suggested refactor (addresses all 3):
from functools import lru_cache
@lru_cache(maxsize=8)
def get_or_create_hash_mapping(
engram_vocab_size, # pass as tuple
max_ngram_size,
n_embed_per_ngram,
n_head_per_ngram,
layer_ids, # pass as tuple
tokenizer_name_or_path,
pad_id,
seed,
):
return NgramHashMapping(
engram_vocab_size=engram_vocab_size,
max_ngram_size=max_ngram_size,
n_embed_per_ngram=n_embed_per_ngram,
n_head_per_ngram=n_head_per_ngram,
layer_ids=layer_ids,
tokenizer_name_or_path=tokenizer_name_or_path,
pad_id=pad_id,
seed=seed,
)Callers need to pass tuples: tuple(engram_vocab_size), tuple(layer_ids).
| cache_key = ( | ||
| tuple(engram_vocab_size), | ||
| max_ngram_size, | ||
| n_embed_per_ngram, | ||
| n_head_per_ngram, | ||
| tuple(layer_ids), | ||
| tokenizer_name_or_path, | ||
| pad_id, | ||
| seed, | ||
| ) | ||
|
|
||
| if cache_key not in _HASH_MAPPING_CACHE: | ||
| _HASH_MAPPING_CACHE[cache_key] = NgramHashMapping( | ||
| engram_vocab_size=engram_vocab_size, | ||
| max_ngram_size=max_ngram_size, | ||
| n_embed_per_ngram=n_embed_per_ngram, | ||
| n_head_per_ngram=n_head_per_ngram, | ||
| layer_ids=layer_ids, | ||
| tokenizer_name_or_path=tokenizer_name_or_path, | ||
| pad_id=pad_id, | ||
| seed=seed, | ||
| ) |
There was a problem hiding this comment.
Issue: layer_ids is passed by reference to NgramHashMapping, but the cache key uses tuple(layer_ids). If the caller mutates the list later, the cached instance silently uses stale data.
Fix: Pass an immutable copy:
| cache_key = ( | |
| tuple(engram_vocab_size), | |
| max_ngram_size, | |
| n_embed_per_ngram, | |
| n_head_per_ngram, | |
| tuple(layer_ids), | |
| tokenizer_name_or_path, | |
| pad_id, | |
| seed, | |
| ) | |
| if cache_key not in _HASH_MAPPING_CACHE: | |
| _HASH_MAPPING_CACHE[cache_key] = NgramHashMapping( | |
| engram_vocab_size=engram_vocab_size, | |
| max_ngram_size=max_ngram_size, | |
| n_embed_per_ngram=n_embed_per_ngram, | |
| n_head_per_ngram=n_head_per_ngram, | |
| layer_ids=layer_ids, | |
| tokenizer_name_or_path=tokenizer_name_or_path, | |
| pad_id=pad_id, | |
| seed=seed, | |
| ) | |
| _HASH_MAPPING_CACHE[cache_key] = NgramHashMapping( | |
| engram_vocab_size=engram_vocab_size, | |
| max_ngram_size=max_ngram_size, | |
| n_embed_per_ngram=n_embed_per_ngram, | |
| n_head_per_ngram=n_head_per_ngram, | |
| layer_ids=tuple(layer_ids), # <- immutable copy | |
| tokenizer_name_or_path=tokenizer_name_or_path, | |
| pad_id=pad_id, | |
| seed=seed, | |
| ) |
| _HASH_MAPPING_CACHE = {} | ||
| # Ensures that an NgramHashMapping with identical configuration is created only once. | ||
| def get_or_create_hash_mapping( | ||
| engram_vocab_size, | ||
| max_ngram_size, | ||
| n_embed_per_ngram, | ||
| n_head_per_ngram, | ||
| layer_ids, | ||
| tokenizer_name_or_path, | ||
| pad_id, | ||
| seed, | ||
| ): | ||
| cache_key = ( | ||
| tuple(engram_vocab_size), | ||
| max_ngram_size, | ||
| n_embed_per_ngram, | ||
| n_head_per_ngram, | ||
| tuple(layer_ids), | ||
| tokenizer_name_or_path, | ||
| pad_id, | ||
| seed, | ||
| ) | ||
|
|
||
| if cache_key not in _HASH_MAPPING_CACHE: | ||
| _HASH_MAPPING_CACHE[cache_key] = NgramHashMapping( | ||
| engram_vocab_size=engram_vocab_size, | ||
| max_ngram_size=max_ngram_size, | ||
| n_embed_per_ngram=n_embed_per_ngram, | ||
| n_head_per_ngram=n_head_per_ngram, | ||
| layer_ids=layer_ids, | ||
| tokenizer_name_or_path=tokenizer_name_or_path, | ||
| pad_id=pad_id, | ||
| seed=seed, | ||
| ) | ||
|
|
||
| return _HASH_MAPPING_CACHE[cache_key] |
There was a problem hiding this comment.
Issue: _HASH_MAPPING_CACHE is unbounded. Each entry holds a HuggingFace tokenizer + lookup tables. In long-running processes or hyperparameter sweeps, this can grow indefinitely and OOM.
Fix: Use lru_cache with a size limit (also handles thread safety):
| _HASH_MAPPING_CACHE = {} | |
| # Ensures that an NgramHashMapping with identical configuration is created only once. | |
| def get_or_create_hash_mapping( | |
| engram_vocab_size, | |
| max_ngram_size, | |
| n_embed_per_ngram, | |
| n_head_per_ngram, | |
| layer_ids, | |
| tokenizer_name_or_path, | |
| pad_id, | |
| seed, | |
| ): | |
| cache_key = ( | |
| tuple(engram_vocab_size), | |
| max_ngram_size, | |
| n_embed_per_ngram, | |
| n_head_per_ngram, | |
| tuple(layer_ids), | |
| tokenizer_name_or_path, | |
| pad_id, | |
| seed, | |
| ) | |
| if cache_key not in _HASH_MAPPING_CACHE: | |
| _HASH_MAPPING_CACHE[cache_key] = NgramHashMapping( | |
| engram_vocab_size=engram_vocab_size, | |
| max_ngram_size=max_ngram_size, | |
| n_embed_per_ngram=n_embed_per_ngram, | |
| n_head_per_ngram=n_head_per_ngram, | |
| layer_ids=layer_ids, | |
| tokenizer_name_or_path=tokenizer_name_or_path, | |
| pad_id=pad_id, | |
| seed=seed, | |
| ) | |
| return _HASH_MAPPING_CACHE[cache_key] | |
| from functools import lru_cache | |
| @lru_cache(maxsize=8) | |
| def get_or_create_hash_mapping( | |
| engram_vocab_size, # must be tuple, not list | |
| max_ngram_size, | |
| n_embed_per_ngram, | |
| n_head_per_ngram, | |
| layer_ids, # must be tuple, not list | |
| tokenizer_name_or_path, | |
| pad_id, | |
| seed, | |
| ): | |
| return NgramHashMapping( | |
| engram_vocab_size=engram_vocab_size, | |
| max_ngram_size=max_ngram_size, | |
| n_embed_per_ngram=n_embed_per_ngram, | |
| n_head_per_ngram=n_head_per_ngram, | |
| layer_ids=layer_ids, | |
| tokenizer_name_or_path=tokenizer_name_or_path, | |
| pad_id=pad_id, | |
| seed=seed, | |
| ) |
Callers must pass tuples instead of lists. You can also, keep the manual cache but add a clear_hash_mapping_cache() helper.
| if cache_key not in _HASH_MAPPING_CACHE: | ||
| _HASH_MAPPING_CACHE[cache_key] = NgramHashMapping( | ||
| engram_vocab_size=engram_vocab_size, | ||
| max_ngram_size=max_ngram_size, | ||
| n_embed_per_ngram=n_embed_per_ngram, | ||
| n_head_per_ngram=n_head_per_ngram, | ||
| layer_ids=layer_ids, | ||
| tokenizer_name_or_path=tokenizer_name_or_path, | ||
| pad_id=pad_id, | ||
| seed=seed, | ||
| ) | ||
|
|
There was a problem hiding this comment.
Issue: The check-then-set pattern isn't thread-safe. Multiple threads can race past if cache_key not in _HASH_MAPPING_CACHE and redundantly create expensive NgramHashMapping instances (tokenizer loading, prime computation).
Fix: Add a lock:
| if cache_key not in _HASH_MAPPING_CACHE: | |
| _HASH_MAPPING_CACHE[cache_key] = NgramHashMapping( | |
| engram_vocab_size=engram_vocab_size, | |
| max_ngram_size=max_ngram_size, | |
| n_embed_per_ngram=n_embed_per_ngram, | |
| n_head_per_ngram=n_head_per_ngram, | |
| layer_ids=layer_ids, | |
| tokenizer_name_or_path=tokenizer_name_or_path, | |
| pad_id=pad_id, | |
| seed=seed, | |
| ) | |
| import threading | |
| _HASH_MAPPING_CACHE = {} | |
| _HASH_MAPPING_LOCK = threading.Lock() | |
| def get_or_create_hash_mapping(...): | |
| cache_key = (...) | |
| with _HASH_MAPPING_LOCK: | |
| if cache_key not in _HASH_MAPPING_CACHE: | |
| _HASH_MAPPING_CACHE[cache_key] = NgramHashMapping(...) | |
| return _HASH_MAPPING_CACHE[cache_key] |
Or use @lru_cache which handles locking internally (also fixes the unbounded cache issue)
Added caching mechanism for NgramHashMapping to optimize performance.