Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
6 changes: 2 additions & 4 deletions kv_cache_benchmark/kv_cache/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,8 @@ def read(self, key: str) -> Tuple[np.ndarray, StorageBackend.IOTiming]:

def delete(self, key: str):
path = self._get_path(key)
if path.exists():
path.unlink()
if key in self.metadata:
del self.metadata[key]
path.unlink(missing_ok=True)
self.metadata.pop(key, None)

def clear(self):
"""Deletes all .npy files from the cache directory."""
Expand Down
9 changes: 9 additions & 0 deletions kv_cache_benchmark/kv_cache/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,7 @@ def _run_preconditioning(self):
state = {'written_bytes': 0, 'seq': 0, 'last_report': 0}

def worker():
consecutive_failures = 0
while True:
with lock:
if state['written_bytes'] >= target_bytes:
Expand All @@ -803,6 +804,7 @@ def worker():
success, tier, latency = self.cache.allocate_cache(key, tokens_per_entry)

if success:
consecutive_failures = 0
entry = self.cache.cache_entries.get(key)
if entry:
with lock:
Expand All @@ -811,6 +813,13 @@ def worker():
if gb_written - state['last_report'] >= 10:
print(f" Preconditioning progress: {gb_written:.1f} / {target_gb:.1f} GB")
state['last_report'] = gb_written
else:
consecutive_failures += 1
if consecutive_failures > 50:
with lock:
print(f" WARNING: Preconditioning stalled at {state['written_bytes']/1024**3:.1f} GB — filesystem full. Continuing.")
return
time.sleep(0.1)

with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = [executor.submit(worker) for _ in range(num_threads)]
Expand Down
175 changes: 140 additions & 35 deletions kv_cache_benchmark/kv_cache/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
and MultiTierCache (3-tier LRU cache with waterfall eviction).
"""

import os
import time
import hashlib
import shutil
import logging
import threading
from typing import Dict, List, Optional, Tuple
Expand Down Expand Up @@ -137,7 +137,8 @@ def __init__(self,
else:
try:
nvme_base = self.backends['nvme'].base_path
self.nvme_memory_limit = float(shutil.disk_usage(nvme_base).free)
st = os.statvfs(str(nvme_base))
self.nvme_memory_limit = float(st.f_bavail * st.f_frsize) * 0.95
except Exception:
self.nvme_memory_limit = float('inf')

Expand Down Expand Up @@ -322,88 +323,190 @@ def _ensure_space_in_tier(self, tier: str, required_bytes: int, recursion_depth:
if next_tier is None and tier != 'nvme':
return False

# When NVMe is the terminal tier (no tier after it), the entry MUST
# be written here — relax capacity guards and evict to full limit.
is_last_tier = (next_tier is None)

limit = self._get_tier_limit(tier)
target_usage_ratio = cfg('eviction', 'target_usage_ratio', default=0.8)
target_usage = limit * target_usage_ratio

large_entry_limit_ratio = cfg('eviction', 'large_entry_limit_ratio', default=0.95)
if required_bytes > limit * large_entry_limit_ratio:
# Only reject oversized entries on non-terminal tiers (they can cascade).
# On the last tier, we must accommodate the entry regardless of size.
if not is_last_tier and required_bytes > limit * large_entry_limit_ratio:
return False

entries_in_tier = len(self._get_lru_entries_in_tier(tier))
# On the last tier, evict to full capacity (not 80%) since there's
# no next tier that needs a buffer for cascading entries.
effective_target = limit if is_last_tier else target_usage

# ────────────────────────────────────────────────────────────────
# SNAPSHOT-BASED LRU EVICTION
#
# Performance context:
# _get_lru_entries_in_tier() copies every entry in cache_entries
# that belongs to this tier, then sorts by last_access time.
# At 15 TB with 60k entries, that's ~60k dict copies + sort.
#
# Old behavior (O(n²)):
# The while loop called _get_lru_entries_in_tier() on EVERY
# iteration, but only used lru_entries[0] — the single oldest
# entry. Evicting 100 entries meant 100 full scans.
#
# New behavior (O(n)):
# Take ONE sorted snapshot before the loop. Walk through it
# with an index. Each entry is either:
# - Still valid → evict it (delete or demote)
# - Already gone (another thread got it) → skip, advance index
# If we exhaust the snapshot without freeing enough space,
# refresh it ONCE (new entries may have been written since the
# snapshot). Worst case: 2 scans instead of thousands.
#
# Why stale snapshots are safe:
# - DELETE path: the existence check under metadata_lock already
# skips entries that another thread evicted. A stale snapshot
# just means we hit more skips — no double-decrement.
# - DEMOTE path: _demote_entry() checks that the entry still
# exists in from_tier before moving it. If it's gone, it
# returns False and we advance to the next entry.
# - New entries added after the snapshot are NEWER than
# everything in it (higher last_access time), so LRU order
# says evict them last. Not including them is correct.
#
# Impact on MLPerf metrics:
# Storage device latencies (write_device_p50, read_device_p50)
# are timed INSIDE the backend — after eviction has already
# freed space. This optimization only reduces the untimed CPU
# overhead between I/O operations. Throughput (req/s) improves
# because the benchmark can push I/O faster; device-level
# numbers stay the same.
# ────────────────────────────────────────────────────────────────

lru_entries = self._get_lru_entries_in_tier(tier)
lru_idx = 0

max_evictions_hard_cap = cfg('eviction', 'max_evictions_hard_cap', default=5000)
max_evictions_min = cfg('eviction', 'max_evictions_min', default=1000)
max_evictions_per_call = min(max_evictions_hard_cap, max(max_evictions_min, entries_in_tier + 100))
max_evictions_per_call = min(max_evictions_hard_cap, max(max_evictions_min, len(lru_entries) + 100))
eviction_count = 0

while eviction_count < max_evictions_per_call:
# ── Check 1: Is there already enough space? ──
with self.memory_lock:
current_usage = self._get_tier_usage(tier)
if current_usage + required_bytes <= target_usage:
if current_usage + required_bytes <= effective_target:
self._update_tier_usage(tier, required_bytes)
return True

if current_usage < limit * 0.05 and required_bytes <= limit * large_entry_limit_ratio:
# Near-empty tier: usage tracking may have drifted from
# accumulated rounding. Trust it and allow the write.
if current_usage < limit * 0.05:
self._update_tier_usage(tier, required_bytes)
return True

lru_entries = self._get_lru_entries_in_tier(tier)

if not lru_entries:
with self.metadata_lock:
actual_usage = sum(
entry['size'] for entry in self.cache_entries.values()
if entry['location'] == tier
)
with self.memory_lock:
if tier == 'gpu':
self.gpu_memory_used = actual_usage
elif tier == 'cpu':
self.cpu_memory_used = actual_usage
elif tier == 'nvme':
self.nvme_memory_used = actual_usage
# ── Check 2: Advance through the LRU snapshot ──
# If we've walked past the end of the snapshot, try one
# refresh — concurrent threads may have evicted most of our
# snapshot, or new entries may have landed in this tier.
if lru_idx >= len(lru_entries):
lru_entries = self._get_lru_entries_in_tier(tier)
lru_idx = 0

if not lru_entries:
# Tier is truly empty. Recount actual usage from
# cache_entries to correct any drift, then decide.
with self.metadata_lock:
actual_usage = sum(
entry['size'] for entry in self.cache_entries.values()
if entry['location'] == tier
)
with self.memory_lock:
if tier == 'gpu':
self.gpu_memory_used = actual_usage
elif tier == 'cpu':
self.cpu_memory_used = actual_usage
elif tier == 'nvme':
self.nvme_memory_used = actual_usage

with self.memory_lock:
current_usage = self._get_tier_usage(tier)
if current_usage + required_bytes <= target_usage:
self._update_tier_usage(tier, required_bytes)
with self.memory_lock:
current_usage = self._get_tier_usage(tier)
if current_usage + required_bytes <= effective_target:
self._update_tier_usage(tier, required_bytes)
return True

# Last tier with nothing left to evict — allow the
# write and let the OS enforce disk space.
if is_last_tier:
with self.memory_lock:
self._update_tier_usage(tier, required_bytes)
return True

return False
return False

total_size_in_tier = sum(e['size'] for _, e in lru_entries)
if total_size_in_tier < limit * 0.2 and required_bytes > target_usage * 0.5:
return False
# On non-terminal tiers, bail out if there's little data to
# evict relative to what we need. On the last tier, keep
# going — there's nowhere else to send the entry.
# (Only check on first pass through the snapshot to avoid
# re-summing on every iteration.)
if lru_idx == 0 and not is_last_tier:
total_size_in_tier = sum(e['size'] for _, e in lru_entries)
if total_size_in_tier < limit * 0.2 and required_bytes > target_usage * 0.5:
return False

lru_key, lru_entry = lru_entries[0]
# ── Pick the next LRU entry from the snapshot ──
lru_key, lru_entry = lru_entries[lru_idx]
lru_size = lru_entry['size']
lru_idx += 1

# ── Evict: DELETE (terminal tier) or DEMOTE (non-terminal) ──
if next_tier is None and tier == 'nvme':
# Terminal tier: delete the .npy file from disk.
# The existence check prevents double-decrementing when
# multiple threads race on the same stale snapshot entry.
entry_lock = self._get_entry_lock(lru_key)
with entry_lock:
with self.metadata_lock:
existing = self.cache_entries.get(lru_key)
if existing is None or existing['location'] != 'nvme':
# Another thread already evicted this entry.
# Safe to skip — just advance to the next one.
eviction_count += 1
continue
actual_size = existing['size']
del self.cache_entries[lru_key]
self.entry_locks.pop(lru_key, None)
try:
self.backends['nvme'].delete(lru_key)
except Exception as e:
logger.warning(f"Failed to delete NVMe entry {lru_key}: {e}")
with self.metadata_lock:
self.cache_entries.pop(lru_key, None)
with self.memory_lock:
self.nvme_memory_used = max(0, self.nvme_memory_used - lru_size)
self.nvme_memory_used = max(0, self.nvme_memory_used - actual_size)
with self.stats_lock:
self.stats['evictions'] += 1
else:
# Non-terminal tier: demote entry to the next tier down.
# Recursively ensure space in next_tier first.
if not self._ensure_space_in_tier(next_tier, lru_size, recursion_depth + 1):
logger.warning(f"Could not make space in {next_tier} for demotion")
return False

success, _ = self._demote_entry(lru_key, tier, next_tier)
if not success:
# Entry may have been deleted/moved by another thread; skip to next
# Entry was deleted/moved by another thread between
# the snapshot and now. Skip to the next one.
eviction_count += 1
continue

eviction_count += 1

# Exhausted eviction budget. On the last tier, allow the write
# anyway — we've freed as much as we can.
if is_last_tier:
with self.memory_lock:
self._update_tier_usage(tier, required_bytes)
return True

return False

def allocate_cache(self, key: str, num_tokens: int, phase: InferencePhase = InferencePhase.PREFILL) -> Tuple[bool, str, float]:
Expand Down Expand Up @@ -451,6 +554,8 @@ def _allocate_cache_inner(self, key: str, num_tokens: int, phase: InferencePhase
if allocated_tier is None:
logger.warning("All tiers full — eviction could not free space, forcing write to NVMe")
allocated_tier = 'nvme'
with self.memory_lock:
self._update_tier_usage('nvme', size_bytes)

try:
if allocated_tier == 'gpu':
Expand Down
Loading
Loading