Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import tempfile
from unittest import mock

import pytest

from wbdata import cache


def test_get_cache_returns_working_cache():
"""Test that get_cache creates a working cache."""
with tempfile.TemporaryDirectory() as tmpdir:
test_cache = cache.get_cache(path=f"{tmpdir}/test_cache")
test_cache["key"] = "value"
assert test_cache["key"] == "value"
test_cache.close()


def test_get_cache_recovers_from_corrupted_cache():
"""Test that get_cache recovers gracefully from a corrupted cache file."""
with tempfile.TemporaryDirectory() as tmpdir:
cache_path = f"{tmpdir}/test_cache"

# Create and populate a cache
test_cache = cache.get_cache(path=cache_path)
test_cache["key"] = "value"
test_cache.close()

# Create a mock that raises SystemError on first call, then works on second
call_count = [0]
original_getattr = cache.shelved_cache.PersistentCache.__getattr__

def mock_getattr(self, item):
# Only intercept expire calls on the first cache instance
if item == "expire":
call_count[0] += 1
if call_count[0] == 1:
raise SystemError("Negative size passed to PyBytes")
return original_getattr(self, item)

# Mock __getattr__ to simulate corruption when expire() is called
with mock.patch.object(
cache.shelved_cache.PersistentCache, "__getattr__", mock_getattr
):
# This should recover and create a new cache
new_cache = cache.get_cache(path=cache_path)

# The new cache should work (but the old value will be lost)
new_cache["new_key"] = "new_value"
assert new_cache["new_key"] == "new_value"
new_cache.close()


def test_remove_cache_files():
"""Test that _remove_cache_files removes all cache-related files."""
with tempfile.TemporaryDirectory() as tmpdir:
cache_path = f"{tmpdir}/test_cache"

# Create a cache which will create files
test_cache = cache.get_cache(path=cache_path)
test_cache["key"] = "value"
test_cache.close()

# Remove the cache files
cache._remove_cache_files(cache_path)

# Verify cache files are removed (new cache should be empty)
new_cache = cache.get_cache(path=cache_path)
with pytest.raises(KeyError):
_ = new_cache["key"]
new_cache.close()
48 changes: 40 additions & 8 deletions wbdata/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import datetime as dt
import glob
import logging
import os
from pathlib import Path
Expand All @@ -16,6 +17,23 @@

log = logging.getLogger(__name__)


def _remove_cache_files(path: str | Path) -> None:
"""Remove all files associated with a shelve cache.

Shelve databases can create files with various extensions depending on
the underlying dbm implementation (.db, .dir, .bak, .dat, etc.).
"""
path_str = str(path)
# Remove files with extensions that shelve might create
for pattern in [f"{path_str}", f"{path_str}.*"]:
for filepath in glob.glob(pattern):
try:
os.remove(filepath)
log.debug(f"Removed corrupted cache file: {filepath}")
except OSError as e:
log.warning(f"Failed to remove cache file {filepath}: {e}")

CACHE_PATH = os.getenv(
"WBDATA_CACHE_PATH",
os.path.join(
Expand Down Expand Up @@ -69,12 +87,26 @@ def get_cache(
Path(path).parent.mkdir(parents=True, exist_ok=True)
ttl_days = ttl_days or TTL_DAYS
max_size = max_size or MAX_SIZE
cache = shelved_cache.PersistentCache(
cachetools.TTLCache,
filename=str(path),
maxsize=max_size,
ttl=dt.timedelta(days=ttl_days),
timer=dt.datetime.now,
)
cache.expire()

def _create_cache() -> shelved_cache.PersistentCache:
return shelved_cache.PersistentCache(
cachetools.TTLCache,
filename=str(path),
maxsize=max_size,
ttl=dt.timedelta(days=ttl_days),
timer=dt.datetime.now,
)

cache = _create_cache()
try:
cache.expire()
except SystemError:
# Cache file is corrupted, remove it and create a new one
log.warning(
f"Cache at {path} appears to be corrupted. Removing and recreating."
)
cache.close()
_remove_cache_files(path)
cache = _create_cache()
cache.expire()
return cache
Loading