Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions engram_demo_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,43 @@ def hash(self, input_ids):
hash_ids_for_all_layers[layer_id] = self._get_ngram_hashes(input_ids, layer_id=layer_id)
return hash_ids_for_all_layers

_HASH_MAPPING_CACHE = {}
# Ensures that an NgramHashMapping with identical configuration is created only once.
def get_or_create_hash_mapping(
engram_vocab_size,
max_ngram_size,
n_embed_per_ngram,
n_head_per_ngram,
layer_ids,
tokenizer_name_or_path,
pad_id,
seed,
):
cache_key = (
tuple(engram_vocab_size),
max_ngram_size,
n_embed_per_ngram,
n_head_per_ngram,
tuple(layer_ids),
tokenizer_name_or_path,
pad_id,
seed,
)

if cache_key not in _HASH_MAPPING_CACHE:
_HASH_MAPPING_CACHE[cache_key] = NgramHashMapping(
engram_vocab_size=engram_vocab_size,
max_ngram_size=max_ngram_size,
n_embed_per_ngram=n_embed_per_ngram,
n_head_per_ngram=n_head_per_ngram,
layer_ids=layer_ids,
tokenizer_name_or_path=tokenizer_name_or_path,
pad_id=pad_id,
seed=seed,
)
Comment on lines +317 to +338

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: layer_ids is passed by reference to NgramHashMapping, but the cache key uses tuple(layer_ids). If the caller mutates the list later, the cached instance silently uses stale data.
Fix: Pass an immutable copy:

Suggested change
cache_key = (
tuple(engram_vocab_size),
max_ngram_size,
n_embed_per_ngram,
n_head_per_ngram,
tuple(layer_ids),
tokenizer_name_or_path,
pad_id,
seed,
)
if cache_key not in _HASH_MAPPING_CACHE:
_HASH_MAPPING_CACHE[cache_key] = NgramHashMapping(
engram_vocab_size=engram_vocab_size,
max_ngram_size=max_ngram_size,
n_embed_per_ngram=n_embed_per_ngram,
n_head_per_ngram=n_head_per_ngram,
layer_ids=layer_ids,
tokenizer_name_or_path=tokenizer_name_or_path,
pad_id=pad_id,
seed=seed,
)
_HASH_MAPPING_CACHE[cache_key] = NgramHashMapping(
engram_vocab_size=engram_vocab_size,
max_ngram_size=max_ngram_size,
n_embed_per_ngram=n_embed_per_ngram,
n_head_per_ngram=n_head_per_ngram,
layer_ids=tuple(layer_ids), # <- immutable copy
tokenizer_name_or_path=tokenizer_name_or_path,
pad_id=pad_id,
seed=seed,
)


Comment on lines +328 to +339

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: The check-then-set pattern isn't thread-safe. Multiple threads can race past if cache_key not in _HASH_MAPPING_CACHE and redundantly create expensive NgramHashMapping instances (tokenizer loading, prime computation).

Fix: Add a lock:

Suggested change
if cache_key not in _HASH_MAPPING_CACHE:
_HASH_MAPPING_CACHE[cache_key] = NgramHashMapping(
engram_vocab_size=engram_vocab_size,
max_ngram_size=max_ngram_size,
n_embed_per_ngram=n_embed_per_ngram,
n_head_per_ngram=n_head_per_ngram,
layer_ids=layer_ids,
tokenizer_name_or_path=tokenizer_name_or_path,
pad_id=pad_id,
seed=seed,
)
import threading
_HASH_MAPPING_CACHE = {}
_HASH_MAPPING_LOCK = threading.Lock()
def get_or_create_hash_mapping(...):
cache_key = (...)
with _HASH_MAPPING_LOCK:
if cache_key not in _HASH_MAPPING_CACHE:
_HASH_MAPPING_CACHE[cache_key] = NgramHashMapping(...)
return _HASH_MAPPING_CACHE[cache_key]

Or use @lru_cache which handles locking internally (also fixes the unbounded cache issue)

return _HASH_MAPPING_CACHE[cache_key]
Comment on lines +305 to +340

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: _HASH_MAPPING_CACHE is unbounded. Each entry holds a HuggingFace tokenizer + lookup tables. In long-running processes or hyperparameter sweeps, this can grow indefinitely and OOM.

Fix: Use lru_cache with a size limit (also handles thread safety):

Suggested change
_HASH_MAPPING_CACHE = {}
# Ensures that an NgramHashMapping with identical configuration is created only once.
def get_or_create_hash_mapping(
engram_vocab_size,
max_ngram_size,
n_embed_per_ngram,
n_head_per_ngram,
layer_ids,
tokenizer_name_or_path,
pad_id,
seed,
):
cache_key = (
tuple(engram_vocab_size),
max_ngram_size,
n_embed_per_ngram,
n_head_per_ngram,
tuple(layer_ids),
tokenizer_name_or_path,
pad_id,
seed,
)
if cache_key not in _HASH_MAPPING_CACHE:
_HASH_MAPPING_CACHE[cache_key] = NgramHashMapping(
engram_vocab_size=engram_vocab_size,
max_ngram_size=max_ngram_size,
n_embed_per_ngram=n_embed_per_ngram,
n_head_per_ngram=n_head_per_ngram,
layer_ids=layer_ids,
tokenizer_name_or_path=tokenizer_name_or_path,
pad_id=pad_id,
seed=seed,
)
return _HASH_MAPPING_CACHE[cache_key]
from functools import lru_cache
@lru_cache(maxsize=8)
def get_or_create_hash_mapping(
engram_vocab_size, # must be tuple, not list
max_ngram_size,
n_embed_per_ngram,
n_head_per_ngram,
layer_ids, # must be tuple, not list
tokenizer_name_or_path,
pad_id,
seed,
):
return NgramHashMapping(
engram_vocab_size=engram_vocab_size,
max_ngram_size=max_ngram_size,
n_embed_per_ngram=n_embed_per_ngram,
n_head_per_ngram=n_head_per_ngram,
layer_ids=layer_ids,
tokenizer_name_or_path=tokenizer_name_or_path,
pad_id=pad_id,
seed=seed,
)

Callers must pass tuples instead of lists. You can also, keep the manual cache but add a clear_hash_mapping_cache() helper.


class MultiHeadEmbedding(nn.Module):
def __init__(self, list_of_N: List[int], D: int):
super().__init__()
Expand All @@ -327,7 +364,7 @@ class Engram(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.hash_mapping = NgramHashMapping(
self.hash_mapping = get_or_create_hash_mapping(
engram_vocab_size=engram_cfg.engram_vocab_size,
max_ngram_size = engram_cfg.max_ngram_size,
n_embed_per_ngram = engram_cfg.n_embed_per_ngram,
Expand Down Expand Up @@ -420,4 +457,4 @@ def forward(self,input_ids,hidden_states):

print("✅ Forward Complete!")
print(f"{input_ids.shape=}\n{output.shape=}")