Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3700.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CacheStore, LoggingStore and LatencyStore now support with_read_only.
132 changes: 78 additions & 54 deletions src/zarr/experimental/cache_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import time
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Literal, Self

from zarr.abc.store import ByteRequest, Store
from zarr.storage._wrapper import WrapperStore
Expand All @@ -15,6 +15,16 @@
from zarr.core.buffer.core import Buffer, BufferPrototype


class _CacheState:
_cache_order: OrderedDict[str, None] # Track access order for LRU
_current_size: int # Track current cache size
_key_sizes: dict[str, int] # Track size of each cached key
_lock: asyncio.Lock
_hits: int # Cache hit counter
_misses: int # Cache miss counter
_evictions: int # Cache eviction counter


class CacheStore(WrapperStore[Store]):
"""
A dual-store caching implementation for Zarr stores.
Expand Down Expand Up @@ -71,13 +81,7 @@ class CacheStore(WrapperStore[Store]):
max_size: int | None
key_insert_times: dict[str, float]
cache_set_data: bool
_cache_order: OrderedDict[str, None] # Track access order for LRU
_current_size: int # Track current cache size
_key_sizes: dict[str, int] # Track size of each cached key
_lock: asyncio.Lock
_hits: int # Cache hit counter
_misses: int # Cache miss counter
_evictions: int # Cache eviction counter
_state: _CacheState

def __init__(
self,
Expand Down Expand Up @@ -107,18 +111,38 @@ def __init__(
else:
self.max_age_seconds = max_age_seconds
self.max_size = max_size
self.cache_set_data = cache_set_data
self._state = _CacheState()

if key_insert_times is None:
self.key_insert_times = {}
else:
self.key_insert_times = key_insert_times
self.cache_set_data = cache_set_data
self._cache_order = OrderedDict()
self._current_size = 0
self._key_sizes = {}
self._lock = asyncio.Lock()
self._hits = 0
self._misses = 0
self._evictions = 0
self._state._cache_order = OrderedDict()
self._state._current_size = 0
self._state._key_sizes = {}
self._state._lock = asyncio.Lock()
self._state._hits = 0
self._state._misses = 0
self._state._evictions = 0

def _with_store(self, store: Store) -> Self:
# Cannot support this operation because it would share a cache, but have a new store
# So cache keys would conflict
raise NotImplementedError("CacheStore does not support this operation.")

def with_read_only(self, read_only: bool = False) -> Self:
# Create a new cache store that shares the same cache and mutable state
store = type(self)(
store=self._store.with_read_only(read_only),
cache_store=self._cache,
max_age_seconds=self.max_age_seconds,
max_size=self.max_size,
key_insert_times=self.key_insert_times,
cache_set_data=self.cache_set_data,
)
store._state = self._state
return store

def _is_key_fresh(self, key: str) -> bool:
"""Check if a cached key is still fresh based on max_age_seconds.
Expand All @@ -140,9 +164,9 @@ async def _accommodate_value(self, value_size: int) -> None:
return

# Remove least recently used items until we have enough space
while self._current_size + value_size > self.max_size and self._cache_order:
while self._state._current_size + value_size > self.max_size and self._state._cache_order:
# Get the least recently used key (first in OrderedDict)
lru_key = next(iter(self._cache_order))
lru_key = next(iter(self._state._cache_order))
await self._evict_key(lru_key)

async def _evict_key(self, key: str) -> None:
Expand All @@ -152,15 +176,15 @@ async def _evict_key(self, key: str) -> None:
Updates size tracking atomically with deletion.
"""
try:
key_size = self._key_sizes.get(key, 0)
key_size = self._state._key_sizes.get(key, 0)

# Delete from cache store
await self._cache.delete(key)

# Update tracking after successful deletion
self._remove_from_tracking(key)
self._current_size = max(0, self._current_size - key_size)
self._evictions += 1
self._state._current_size = max(0, self._state._current_size - key_size)
self._state._evictions += 1

logger.debug("_evict_key: evicted key %s, freed %d bytes", key, key_size)
except Exception:
Expand All @@ -183,39 +207,39 @@ async def _cache_value(self, key: str, value: Buffer) -> None:
)
return

async with self._lock:
async with self._state._lock:
# If key already exists, subtract old size first
if key in self._key_sizes:
old_size = self._key_sizes[key]
self._current_size -= old_size
if key in self._state._key_sizes:
old_size = self._state._key_sizes[key]
self._state._current_size -= old_size
logger.debug("_cache_value: updating existing key %s, old size %d", key, old_size)

# Make room for the new value (this calls _evict_key_locked internally)
await self._accommodate_value(value_size)

# Update tracking atomically
self._cache_order[key] = None # OrderedDict to track access order
self._current_size += value_size
self._key_sizes[key] = value_size
self._state._cache_order[key] = None # OrderedDict to track access order
self._state._current_size += value_size
self._state._key_sizes[key] = value_size
self.key_insert_times[key] = time.monotonic()

logger.debug("_cache_value: cached key %s with size %d bytes", key, value_size)

async def _update_access_order(self, key: str) -> None:
"""Update the access order for LRU tracking."""
if key in self._cache_order:
async with self._lock:
if key in self._state._cache_order:
async with self._state._lock:
# Move to end (most recently used)
self._cache_order.move_to_end(key)
self._state._cache_order.move_to_end(key)

def _remove_from_tracking(self, key: str) -> None:
"""Remove a key from all tracking structures.

Must be called while holding self._lock.
Must be called while holding self._state._lock.
"""
self._cache_order.pop(key, None)
self._state._cache_order.pop(key, None)
self.key_insert_times.pop(key, None)
self._key_sizes.pop(key, None)
self._state._key_sizes.pop(key, None)

async def _get_try_cache(
self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None
Expand All @@ -224,20 +248,20 @@ async def _get_try_cache(
maybe_cached_result = await self._cache.get(key, prototype, byte_range)
if maybe_cached_result is not None:
logger.debug("_get_try_cache: key %s found in cache (HIT)", key)
self._hits += 1
self._state._hits += 1
# Update access order for LRU
await self._update_access_order(key)
return maybe_cached_result
else:
logger.debug(
"_get_try_cache: key %s not found in cache (MISS), fetching from store", key
)
self._misses += 1
self._state._misses += 1
maybe_fresh_result = await super().get(key, prototype, byte_range)
if maybe_fresh_result is None:
# Key doesn't exist in source store
await self._cache.delete(key)
async with self._lock:
async with self._state._lock:
self._remove_from_tracking(key)
else:
# Cache the newly fetched value
Expand All @@ -249,12 +273,12 @@ async def _get_no_cache(
self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None
) -> Buffer | None:
"""Get data directly from source store and update cache."""
self._misses += 1
self._state._misses += 1
maybe_fresh_result = await super().get(key, prototype, byte_range)
if maybe_fresh_result is None:
# Key doesn't exist in source, remove from cache and tracking
await self._cache.delete(key)
async with self._lock:
async with self._state._lock:
self._remove_from_tracking(key)
else:
logger.debug("_get_no_cache: key %s found in store, setting in cache", key)
Expand Down Expand Up @@ -312,7 +336,7 @@ async def set(self, key: str, value: Buffer) -> None:
else:
logger.debug("set: deleting key %s from cache", key)
await self._cache.delete(key)
async with self._lock:
async with self._state._lock:
self._remove_from_tracking(key)

async def delete(self, key: str) -> None:
Expand All @@ -328,7 +352,7 @@ async def delete(self, key: str) -> None:
await super().delete(key)
logger.debug("delete: deleting key %s from cache", key)
await self._cache.delete(key)
async with self._lock:
async with self._state._lock:
self._remove_from_tracking(key)

def cache_info(self) -> dict[str, Any]:
Expand All @@ -339,20 +363,20 @@ def cache_info(self) -> dict[str, Any]:
if self.max_age_seconds == "infinity"
else self.max_age_seconds,
"max_size": self.max_size,
"current_size": self._current_size,
"current_size": self._state._current_size,
"cache_set_data": self.cache_set_data,
"tracked_keys": len(self.key_insert_times),
"cached_keys": len(self._cache_order),
"cached_keys": len(self._state._cache_order),
}

def cache_stats(self) -> dict[str, Any]:
"""Return cache performance statistics."""
total_requests = self._hits + self._misses
hit_rate = self._hits / total_requests if total_requests > 0 else 0.0
total_requests = self._state._hits + self._state._misses
hit_rate = self._state._hits / total_requests if total_requests > 0 else 0.0
return {
"hits": self._hits,
"misses": self._misses,
"evictions": self._evictions,
"hits": self._state._hits,
"misses": self._state._misses,
"evictions": self._state._evictions,
"total_requests": total_requests,
"hit_rate": hit_rate,
}
Expand All @@ -364,11 +388,11 @@ async def clear_cache(self) -> None:
await self._cache.clear()

# Reset tracking
async with self._lock:
async with self._state._lock:
self.key_insert_times.clear()
self._cache_order.clear()
self._key_sizes.clear()
self._current_size = 0
self._state._cache_order.clear()
self._state._key_sizes.clear()
self._state._current_size = 0
logger.debug("clear_cache: cleared all cache data")

def __repr__(self) -> str:
Expand All @@ -379,6 +403,6 @@ def __repr__(self) -> str:
f"cache_store={self._cache!r}, "
f"max_age_seconds={self.max_age_seconds}, "
f"max_size={self.max_size}, "
f"current_size={self._current_size}, "
f"cached_keys={len(self._cache_order)})"
f"current_size={self._state._current_size}, "
f"cached_keys={len(self._state._cache_order)})"
)
3 changes: 3 additions & 0 deletions src/zarr/storage/_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def _default_handler(self) -> logging.Handler:
)
return handler

def _with_store(self, store: T_Store) -> Self:
return type(self)(store=store, log_level=self.log_level, log_handler=self.log_handler)

@contextmanager
def log(self, hint: Any = "") -> Generator[None, None, None]:
"""Context manager to log method calls
Expand Down
13 changes: 11 additions & 2 deletions src/zarr/storage/_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Generic, TypeVar
from typing import TYPE_CHECKING, Generic, TypeVar, cast

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, AsyncIterator, Iterable
Expand Down Expand Up @@ -31,14 +31,23 @@ class WrapperStore(Store, Generic[T_Store]):
def __init__(self, store: T_Store) -> None:
self._store = store

def _with_store(self, store: T_Store) -> Self:
"""
Constructs a new instance of the wrapper store with the same details but a new store.
"""
return type(self)(store=store)

@classmethod
async def open(cls: type[Self], store_cls: type[T_Store], *args: Any, **kwargs: Any) -> Self:
store = store_cls(*args, **kwargs)
await store._open()
return cls(store=store)

def with_read_only(self, read_only: bool = False) -> Self:
return self._with_store(cast(T_Store, self._store.with_read_only(read_only)))

def __enter__(self) -> Self:
return type(self)(self._store.__enter__())
return self._with_store(self._store.__enter__())

def __exit__(
self,
Expand Down
9 changes: 6 additions & 3 deletions src/zarr/testing/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import pickle
from abc import abstractmethod
from typing import TYPE_CHECKING, Generic, TypeVar
from typing import TYPE_CHECKING, Generic, Self, TypeVar

from zarr.storage import WrapperStore

Expand Down Expand Up @@ -578,10 +578,13 @@ class LatencyStore(WrapperStore[Store]):
get_latency: float
set_latency: float

def __init__(self, cls: Store, *, get_latency: float = 0, set_latency: float = 0) -> None:
def __init__(self, store: Store, *, get_latency: float = 0, set_latency: float = 0) -> None:
self.get_latency = float(get_latency)
self.set_latency = float(set_latency)
self._store = cls
self._store = store

def _with_store(self, store: Store) -> Self:
return type(self)(store, get_latency=self.get_latency, set_latency=self.set_latency)

async def set(self, key: str, value: Buffer) -> None:
"""
Expand Down
Loading