diff --git a/changes/3547.misc.md b/changes/3547.misc.md new file mode 100644 index 0000000000..771bfe8861 --- /dev/null +++ b/changes/3547.misc.md @@ -0,0 +1 @@ +Moved concurrency limits to a global per-event loop setting instead of per-array call. \ No newline at end of file diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index d41c457b4e..69d6c3082e 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from abc import abstractmethod from collections.abc import Mapping from typing import TYPE_CHECKING, Generic, TypeGuard, TypeVar @@ -8,8 +9,7 @@ from zarr.abc.metadata import Metadata from zarr.core.buffer import Buffer, NDBuffer -from zarr.core.common import NamedConfig, concurrent_map -from zarr.core.config import config +from zarr.core.common import NamedConfig if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Iterable @@ -225,11 +225,8 @@ async def decode_partial( ------- Iterable[NDBuffer | None] """ - return await concurrent_map( - list(batch_info), - self._decode_partial_single, - config.get("async.concurrency"), - ) + # Store handles concurrency limiting internally + return await asyncio.gather(*[self._decode_partial_single(*info) for info in batch_info]) class ArrayBytesCodecPartialEncodeMixin: @@ -262,11 +259,8 @@ async def encode_partial( The ByteSetter is used to write the necessary bytes and fetch bytes for existing chunk data. The chunk spec contains information about the chunk. """ - await concurrent_map( - list(batch_info), - self._encode_partial_single, - config.get("async.concurrency"), - ) + # Store handles concurrency limiting internally + await asyncio.gather(*[self._encode_partial_single(*info) for info in batch_info]) class CodecPipeline: @@ -464,11 +458,8 @@ async def _batching_helper( func: Callable[[CodecInput, ArraySpec], Awaitable[CodecOutput | None]], batch_info: Iterable[tuple[CodecInput | None, ArraySpec]], ) -> list[CodecOutput | None]: - return await concurrent_map( - list(batch_info), - _noop_for_none(func), - config.get("async.concurrency"), - ) + # Store handles concurrency limiting internally + return await asyncio.gather(*[_noop_for_none(func)(chunk, spec) for chunk, spec in batch_info]) def _noop_for_none( diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 87df89a683..49f0e90ace 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -670,13 +670,8 @@ async def getsize_prefix(self, prefix: str) -> int: # improve tail latency and might reduce memory pressure (since not all keys # would be in memory at once). - # avoid circular import - from zarr.core.common import concurrent_map - from zarr.core.config import config - - keys = [(x,) async for x in self.list_prefix(prefix)] - limit = config.get("async.concurrency") - sizes = await concurrent_map(keys, self.getsize, limit=limit) + keys = [x async for x in self.list_prefix(prefix)] + sizes = await asyncio.gather(*[self.getsize(key) for key in keys]) return sum(sizes) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index a3a2aff250..64f79e68cd 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import json import warnings from asyncio import gather @@ -22,7 +23,6 @@ import numpy as np from typing_extensions import deprecated -import zarr from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec from zarr.abc.numcodec import Numcodec, _is_numcodec from zarr.codecs._v2 import V2Codec @@ -60,7 +60,6 @@ _default_zarr_format, _warn_order_kwarg, ceildiv, - concurrent_map, parse_shapelike, product, ) @@ -1895,13 +1894,12 @@ async def resize(self, new_shape: ShapeLike, delete_outside_chunks: bool = True) async def _delete_key(key: str) -> None: await (self.store_path / key).delete() - await concurrent_map( - [ - (self.metadata.encode_chunk_key(chunk_coords),) + # Store handles concurrency limiting internally + await asyncio.gather( + *[ + _delete_key(self.metadata.encode_chunk_key(chunk_coords)) for chunk_coords in old_chunk_coords.difference(new_chunk_coords) - ], - _delete_key, - zarr_config.get("async.concurrency"), + ] ) # Write new metadata @@ -4625,10 +4623,9 @@ async def _copy_array_region( await result.setitem(chunk_coords, arr) # Stream data from the source array to the new array - await concurrent_map( - [(region, data) for region in result._iter_shard_regions()], - _copy_array_region, - zarr.core.config.config.get("async.concurrency"), + # Store handles concurrency limiting internally + await asyncio.gather( + *[_copy_array_region(region, data) for region in result._iter_shard_regions()] ) else: @@ -4636,10 +4633,9 @@ async def _copy_arraylike_region(chunk_coords: slice, _data: NDArrayLike) -> Non await result.setitem(chunk_coords, _data[chunk_coords]) # Stream data from the source array to the new array - await concurrent_map( - [(region, data) for region in result._iter_shard_regions()], - _copy_arraylike_region, - zarr.core.config.config.get("async.concurrency"), + # Store handles concurrency limiting internally + await asyncio.gather( + *[_copy_arraylike_region(region, data) for region in result._iter_shard_regions()] ) return result diff --git a/src/zarr/core/codec_pipeline.py b/src/zarr/core/codec_pipeline.py index fd557ac43e..0f8350f7ea 100644 --- a/src/zarr/core/codec_pipeline.py +++ b/src/zarr/core/codec_pipeline.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from dataclasses import dataclass from itertools import islice, pairwise from typing import TYPE_CHECKING, Any, TypeVar @@ -14,7 +15,6 @@ Codec, CodecPipeline, ) -from zarr.core.common import concurrent_map from zarr.core.config import config from zarr.core.indexing import SelectorTuple, is_scalar from zarr.errors import ZarrUserWarning @@ -267,10 +267,12 @@ async def read_batch( else: out[out_selection] = fill_value_or_default(chunk_spec) else: - chunk_bytes_batch = await concurrent_map( - [(byte_getter, array_spec.prototype) for byte_getter, array_spec, *_ in batch_info], - lambda byte_getter, prototype: byte_getter.get(prototype), - config.get("async.concurrency"), + # Store handles concurrency limiting internally + chunk_bytes_batch = await asyncio.gather( + *[ + byte_getter.get(array_spec.prototype) + for byte_getter, array_spec, *_ in batch_info + ] ) chunk_array_batch = await self.decode_batch( [ @@ -368,16 +370,15 @@ async def _read_key( return await byte_setter.get(prototype=prototype) chunk_bytes_batch: Iterable[Buffer | None] - chunk_bytes_batch = await concurrent_map( - [ - ( + # Store handles concurrency limiting internally + chunk_bytes_batch = await asyncio.gather( + *[ + _read_key( None if is_complete_chunk else byte_setter, chunk_spec.prototype, ) for byte_setter, chunk_spec, chunk_selection, _, is_complete_chunk in batch_info - ], - _read_key, - config.get("async.concurrency"), + ] ) chunk_array_decoded = await self.decode_batch( [ @@ -435,15 +436,14 @@ async def _write_key(byte_setter: ByteSetter, chunk_bytes: Buffer | None) -> Non else: await byte_setter.set(chunk_bytes) - await concurrent_map( - [ - (byte_setter, chunk_bytes) + # Store handles concurrency limiting internally + await asyncio.gather( + *[ + _write_key(byte_setter, chunk_bytes) for chunk_bytes, (byte_setter, *_) in zip( chunk_bytes_batch, batch_info, strict=False ) - ], - _write_key, - config.get("async.concurrency"), + ] ) async def decode( @@ -470,13 +470,12 @@ async def read( out: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: - await concurrent_map( - [ - (single_batch_info, out, drop_axes) + # Process mini-batches concurrently - stores handle I/O concurrency internally + await asyncio.gather( + *[ + self.read_batch(single_batch_info, out, drop_axes) for single_batch_info in batched(batch_info, self.batch_size) - ], - self.read_batch, - config.get("async.concurrency"), + ] ) async def write( @@ -485,13 +484,12 @@ async def write( value: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: - await concurrent_map( - [ - (single_batch_info, value, drop_axes) + # Process mini-batches concurrently - stores handle I/O concurrency internally + await asyncio.gather( + *[ + self.write_batch(single_batch_info, value, drop_axes) for single_batch_info in batched(batch_info, self.batch_size) - ], - self.write_batch, - config.get("async.concurrency"), + ] ) diff --git a/src/zarr/core/common.py b/src/zarr/core/common.py index d38949657e..e45c256310 100644 --- a/src/zarr/core/common.py +++ b/src/zarr/core/common.py @@ -1,13 +1,11 @@ from __future__ import annotations -import asyncio import functools import math import operator import warnings from collections.abc import Iterable, Mapping, Sequence from enum import Enum -from itertools import starmap from typing import ( TYPE_CHECKING, Any, @@ -27,7 +25,7 @@ from zarr.errors import ZarrRuntimeWarning if TYPE_CHECKING: - from collections.abc import Awaitable, Callable, Iterator + from collections.abc import Iterator ZARR_JSON = "zarr.json" @@ -94,28 +92,6 @@ def ceildiv(a: float, b: float) -> int: return math.ceil(a / b) -T = TypeVar("T", bound=tuple[Any, ...]) -V = TypeVar("V") - - -async def concurrent_map( - items: Iterable[T], - func: Callable[..., Awaitable[V]], - limit: int | None = None, -) -> list[V]: - if limit is None: - return await asyncio.gather(*list(starmap(func, items))) - - else: - sem = asyncio.Semaphore(limit) - - async def run(item: tuple[Any]) -> V: - async with sem: - return await func(*item) - - return await asyncio.gather(*[asyncio.ensure_future(run(item)) for item in items]) - - E = TypeVar("E", bound=Enum) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 9b5fee275b..658de7ef81 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -1440,13 +1440,10 @@ async def _members( ) raise ValueError(msg) - # enforce a concurrency limit by passing a semaphore to all the recursive functions - semaphore = asyncio.Semaphore(config.get("async.concurrency")) async for member in _iter_members_deep( self, max_depth=max_depth, skip_keys=skip_keys, - semaphore=semaphore, use_consolidated_for_children=use_consolidated_for_children, ): yield member @@ -3323,14 +3320,11 @@ async def create_nodes( The created nodes in the order they are created. """ - # Note: the only way to alter this value is via the config. If that's undesirable for some reason, - # then we should consider adding a keyword argument this this function - semaphore = asyncio.Semaphore(config.get("async.concurrency")) create_tasks: list[Coroutine[None, None, str]] = [] for key, value in nodes.items(): # make the key absolute - create_tasks.extend(_persist_metadata(store, key, value, semaphore=semaphore)) + create_tasks.extend(_persist_metadata(store, key, value)) created_object_keys = [] @@ -3476,28 +3470,16 @@ def _ensure_consistent_zarr_format( ) -async def _getitem_semaphore( - node: AsyncGroup, key: str, semaphore: asyncio.Semaphore | None -) -> AnyAsyncArray | AsyncGroup: +async def _getitem(node: AsyncGroup, key: str) -> AnyAsyncArray | AsyncGroup: """ - Wrap Group.getitem with an optional semaphore. - - If the semaphore parameter is an - asyncio.Semaphore instance, then the getitem operation is performed inside an async context - manager provided by that semaphore. If the semaphore parameter is None, then getitem is invoked - without a context manager. + Fetch a child node from a group by key. """ - if semaphore is not None: - async with semaphore: - return await node.getitem(key) - else: - return await node.getitem(key) + return await node.getitem(key) async def _iter_members( node: AsyncGroup, skip_keys: tuple[str, ...], - semaphore: asyncio.Semaphore | None, ) -> AsyncGenerator[tuple[str, AnyAsyncArray | AsyncGroup], None]: """ Iterate over the arrays and groups contained in a group. @@ -3508,8 +3490,6 @@ async def _iter_members( The group to traverse. skip_keys : tuple[str, ...] A tuple of keys to skip when iterating over the possible members of the group. - semaphore : asyncio.Semaphore | None - An optional semaphore to use for concurrency control. Yields ------ @@ -3520,10 +3500,7 @@ async def _iter_members( keys = [key async for key in node.store.list_dir(node.path)] keys_filtered = tuple(filter(lambda v: v not in skip_keys, keys)) - node_tasks = tuple( - asyncio.create_task(_getitem_semaphore(node, key, semaphore), name=key) - for key in keys_filtered - ) + node_tasks = tuple(asyncio.create_task(_getitem(node, key), name=key) for key in keys_filtered) for fetched_node_coro in asyncio.as_completed(node_tasks): try: @@ -3550,7 +3527,6 @@ async def _iter_members_deep( *, max_depth: int | None, skip_keys: tuple[str, ...], - semaphore: asyncio.Semaphore | None = None, use_consolidated_for_children: bool = True, ) -> AsyncGenerator[tuple[str, AnyAsyncArray | AsyncGroup], None]: """ @@ -3565,8 +3541,6 @@ async def _iter_members_deep( The maximum depth of recursion. skip_keys : tuple[str, ...] A tuple of keys to skip when iterating over the possible members of the group. - semaphore : asyncio.Semaphore | None - An optional semaphore to use for concurrency control. use_consolidated_for_children : bool, default True Whether to use the consolidated metadata of child groups loaded from the store. Note that this only affects groups loaded from the @@ -3585,7 +3559,7 @@ async def _iter_members_deep( new_depth = None else: new_depth = max_depth - 1 - async for name, node in _iter_members(group, skip_keys=skip_keys, semaphore=semaphore): + async for name, node in _iter_members(group, skip_keys=skip_keys): is_group = isinstance(node, AsyncGroup) if ( is_group @@ -3599,9 +3573,7 @@ async def _iter_members_deep( yield name, node if is_group and do_recursion: node = cast("AsyncGroup", node) - to_recurse[name] = _iter_members_deep( - node, max_depth=new_depth, skip_keys=skip_keys, semaphore=semaphore - ) + to_recurse[name] = _iter_members_deep(node, max_depth=new_depth, skip_keys=skip_keys) for prefix, subgroup_iter in to_recurse.items(): async for name, node in subgroup_iter: @@ -3811,9 +3783,7 @@ async def get_node(store: Store, path: str, zarr_format: ZarrFormat) -> AnyAsync raise ValueError(f"Unexpected zarr format: {zarr_format}") # pragma: no cover -async def _set_return_key( - *, store: Store, key: str, value: Buffer, semaphore: asyncio.Semaphore | None = None -) -> str: +async def _set_return_key(*, store: Store, key: str, value: Buffer) -> str: """ Write a value to storage at the given key. The key is returned. Useful when saving values via routines that return results in execution order, @@ -3828,15 +3798,8 @@ async def _set_return_key( The key to save the value to. value : Buffer The value to save. - semaphore : asyncio.Semaphore | None - An optional semaphore to use to limit the number of concurrent writes. """ - - if semaphore is not None: - async with semaphore: - await store.set(key, value) - else: - await store.set(key, value) + await store.set(key, value) return key @@ -3844,7 +3807,6 @@ def _persist_metadata( store: Store, path: str, metadata: ArrayV2Metadata | ArrayV3Metadata | GroupMetadata, - semaphore: asyncio.Semaphore | None = None, ) -> tuple[Coroutine[None, None, str], ...]: """ Prepare to save a metadata document to storage, returning a tuple of coroutines that must be awaited. @@ -3852,7 +3814,7 @@ def _persist_metadata( to_save = metadata.to_buffer_dict(default_buffer_prototype()) return tuple( - _set_return_key(store=store, key=_join_paths([path, key]), value=value, semaphore=semaphore) + _set_return_key(store=store, key=_join_paths([path, key]), value=value) for key, value in to_save.items() ) diff --git a/src/zarr/storage/_fsspec.py b/src/zarr/storage/_fsspec.py index f9e4ed375d..21f96d87d5 100644 --- a/src/zarr/storage/_fsspec.py +++ b/src/zarr/storage/_fsspec.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import json import warnings from contextlib import suppress @@ -17,6 +18,7 @@ from zarr.core.buffer import Buffer from zarr.errors import ZarrUserWarning from zarr.storage._common import _dereference_path +from zarr.storage._utils import with_concurrency_limit if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable @@ -82,6 +84,9 @@ class FsspecStore(Store): filesystem scheme. allowed_exceptions : tuple[type[Exception], ...] When fetching data, these cases will be deemed to correspond to missing keys. + concurrency_limit : int, optional + Maximum number of concurrent I/O operations. Default is 50. + Set to None for unlimited concurrency. Attributes ---------- @@ -117,18 +122,24 @@ class FsspecStore(Store): fs: AsyncFileSystem allowed_exceptions: tuple[type[Exception], ...] path: str + _semaphore: asyncio.Semaphore | None def __init__( self, fs: AsyncFileSystem, + *, read_only: bool = False, path: str = "/", allowed_exceptions: tuple[type[Exception], ...] = ALLOWED_EXCEPTIONS, + concurrency_limit: int | None = 50, ) -> None: super().__init__(read_only=read_only) self.fs = fs self.path = path self.allowed_exceptions = allowed_exceptions + self._semaphore = ( + asyncio.Semaphore(concurrency_limit) if concurrency_limit is not None else None + ) if not self.fs.async_impl: raise TypeError("Filesystem needs to support async operations.") @@ -273,6 +284,7 @@ def __eq__(self, other: object) -> bool: and self.fs == other.fs ) + @with_concurrency_limit() async def get( self, key: str, @@ -315,6 +327,7 @@ async def get( else: return value + @with_concurrency_limit() async def set( self, key: str, @@ -335,6 +348,27 @@ async def set( raise NotImplementedError await self.fs._pipe_file(path, value.to_bytes()) + async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None: + # Override to avoid deadlock from calling decorated set() method + if not self._is_open: + await self._open() + self._check_writable() + + async def _set_with_limit(key: str, value: Buffer) -> None: + if not isinstance(value, Buffer): + raise TypeError( + f"FsspecStore.set(): `value` must be a Buffer instance. Got an instance of {type(value)} instead." + ) + path = _dereference_path(self.path, key) + if self._semaphore: + async with self._semaphore: + await self.fs._pipe_file(path, value.to_bytes()) + else: + await self.fs._pipe_file(path, value.to_bytes()) + + await asyncio.gather(*[_set_with_limit(key, value) for key, value in values]) + + @with_concurrency_limit() async def delete(self, key: str) -> None: # docstring inherited self._check_writable() diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index 80233a112d..842cab41ca 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -19,12 +19,13 @@ ) from zarr.core.buffer import Buffer from zarr.core.buffer.core import default_buffer_prototype -from zarr.core.common import AccessModeLiteral, concurrent_map +from zarr.storage._utils import with_concurrency_limit if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable, Iterator from zarr.core.buffer import BufferPrototype + from zarr.core.common import AccessModeLiteral def _get(path: Path, prototype: BufferPrototype, byte_range: ByteRequest | None) -> Buffer: @@ -95,6 +96,9 @@ class LocalStore(Store): Directory to use as root of store. read_only : bool Whether the store is read-only + concurrency_limit : int, optional + Maximum number of concurrent I/O operations. Default is 100. + Set to None for unlimited concurrency. Attributes ---------- @@ -109,8 +113,15 @@ class LocalStore(Store): supports_listing: bool = True root: Path + _semaphore: asyncio.Semaphore | None - def __init__(self, root: Path | str, *, read_only: bool = False) -> None: + def __init__( + self, + root: Path | str, + *, + read_only: bool = False, + concurrency_limit: int | None = 100, + ) -> None: super().__init__(read_only=read_only) if isinstance(root, str): root = Path(root) @@ -119,12 +130,17 @@ def __init__(self, root: Path | str, *, read_only: bool = False) -> None: f"'root' must be a string or Path instance. Got an instance of {type(root)} instead." ) self.root = root + self._semaphore = ( + asyncio.Semaphore(concurrency_limit) if concurrency_limit is not None else None + ) def with_read_only(self, read_only: bool = False) -> Self: # docstring inherited + concurrency_limit = self._semaphore._value if self._semaphore else None return type(self)( root=self.root, read_only=read_only, + concurrency_limit=concurrency_limit, ) @classmethod @@ -187,6 +203,7 @@ def __repr__(self) -> str: def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) and self.root == other.root + @with_concurrency_limit() async def get( self, key: str, @@ -212,12 +229,23 @@ async def get_partial_values( key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited - args = [] - for key, byte_range in key_ranges: - assert isinstance(key, str) + # Note: We directly call the I/O functions here, wrapped with semaphore + # to avoid deadlock from calling the decorated get() method + + async def _get_with_limit(key: str, byte_range: ByteRequest | None) -> Buffer | None: path = self.root / key - args.append((_get, path, prototype, byte_range)) - return await concurrent_map(args, asyncio.to_thread, limit=None) # TODO: fix limit + try: + if self._semaphore: + async with self._semaphore: + return await asyncio.to_thread(_get, path, prototype, byte_range) + else: + return await asyncio.to_thread(_get, path, prototype, byte_range) + except (FileNotFoundError, IsADirectoryError, NotADirectoryError): + return None + + return await asyncio.gather( + *[_get_with_limit(key, byte_range) for key, byte_range in key_ranges] + ) async def set(self, key: str, value: Buffer) -> None: # docstring inherited @@ -230,6 +258,7 @@ async def set_if_not_exists(self, key: str, value: Buffer) -> None: except FileExistsError: pass + @with_concurrency_limit() async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None: if not self._is_open: await self._open() @@ -242,6 +271,7 @@ async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None: path = self.root / key await asyncio.to_thread(_put, path, value, exclusive=exclusive) + @with_concurrency_limit() async def delete(self, key: str) -> None: """ Remove a key from the store. diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index e6f9b7a512..bb0f81d2f5 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -1,12 +1,12 @@ from __future__ import annotations +import asyncio from logging import getLogger from typing import TYPE_CHECKING, Any, Self from zarr.abc.store import ByteRequest, Store from zarr.core.buffer import Buffer, gpu from zarr.core.buffer.core import default_buffer_prototype -from zarr.core.common import concurrent_map from zarr.storage._utils import _normalize_byte_range_index if TYPE_CHECKING: @@ -102,12 +102,10 @@ async def get_partial_values( key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited - - # All the key-ranges arguments goes with the same prototype - async def _get(key: str, byte_range: ByteRequest | None) -> Buffer | None: - return await self.get(key, prototype=prototype, byte_range=byte_range) - - return await concurrent_map(key_ranges, _get, limit=None) + # In-memory operations are fast and don't need concurrency limiting + return await asyncio.gather( + *[self.get(key, prototype, byte_range) for key, byte_range in key_ranges] + ) async def exists(self, key: str) -> bool: # docstring inherited diff --git a/src/zarr/storage/_obstore.py b/src/zarr/storage/_obstore.py index 5c2197ecf6..697f51ddb0 100644 --- a/src/zarr/storage/_obstore.py +++ b/src/zarr/storage/_obstore.py @@ -4,7 +4,7 @@ import contextlib import pickle from collections import defaultdict -from typing import TYPE_CHECKING, Generic, Self, TypedDict, TypeVar +from typing import TYPE_CHECKING, Generic, Self, TypeVar from zarr.abc.store import ( ByteRequest, @@ -13,14 +13,13 @@ Store, SuffixByteRequest, ) -from zarr.core.common import concurrent_map -from zarr.core.config import config +from zarr.storage._utils import with_concurrency_limit if TYPE_CHECKING: from collections.abc import AsyncGenerator, Coroutine, Iterable, Sequence from typing import Any - from obstore import ListResult, ListStream, ObjectMeta, OffsetRange, SuffixRange + from obstore import ListResult, ListStream, ObjectMeta from obstore.store import ObjectStore as _UpstreamObjectStore from zarr.core.buffer import Buffer, BufferPrototype @@ -47,6 +46,9 @@ class ObjectStore(Store, Generic[T_Store]): An obstore store instance that is set up with the proper credentials. read_only : bool Whether to open the store in read-only mode. + concurrency_limit : int, optional + Maximum number of concurrent I/O operations. Default is 50. + Set to None for unlimited concurrency. Warnings -------- @@ -56,6 +58,7 @@ class ObjectStore(Store, Generic[T_Store]): store: T_Store """The underlying obstore instance.""" + _semaphore: asyncio.Semaphore | None def __eq__(self, value: object) -> bool: if not isinstance(value, ObjectStore): @@ -66,17 +69,28 @@ def __eq__(self, value: object) -> bool: return self.store == value.store # type: ignore[no-any-return] - def __init__(self, store: T_Store, *, read_only: bool = False) -> None: + def __init__( + self, + store: T_Store, + *, + read_only: bool = False, + concurrency_limit: int | None = 50, + ) -> None: if not store.__class__.__module__.startswith("obstore"): raise TypeError(f"expected ObjectStore class, got {store!r}") super().__init__(read_only=read_only) self.store = store + self._semaphore = ( + asyncio.Semaphore(concurrency_limit) if concurrency_limit is not None else None + ) def with_read_only(self, read_only: bool = False) -> Self: # docstring inherited + concurrency_limit = self._semaphore._value if self._semaphore else None return type(self)( store=self.store, read_only=read_only, + concurrency_limit=concurrency_limit, ) def __str__(self) -> str: @@ -94,6 +108,7 @@ def __setstate__(self, state: dict[Any, Any]) -> None: state["store"] = pickle.loads(state["store"]) self.__dict__.update(state) + @with_concurrency_limit() async def get( self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None ) -> Buffer | None: @@ -101,41 +116,7 @@ async def get( import obstore as obs try: - if byte_range is None: - resp = await obs.get_async(self.store, key) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] - elif isinstance(byte_range, RangeByteRequest): - bytes = await obs.get_range_async( - self.store, key, start=byte_range.start, end=byte_range.end - ) - return prototype.buffer.from_bytes(bytes) # type: ignore[arg-type] - elif isinstance(byte_range, OffsetByteRequest): - resp = await obs.get_async( - self.store, key, options={"range": {"offset": byte_range.offset}} - ) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] - elif isinstance(byte_range, SuffixByteRequest): - # some object stores (Azure) don't support suffix requests. In this - # case, our workaround is to first get the length of the object and then - # manually request the byte range at the end. - try: - resp = await obs.get_async( - self.store, key, options={"range": {"suffix": byte_range.suffix}} - ) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] - except obs.exceptions.NotSupportedError: - head_resp = await obs.head_async(self.store, key) - file_size = head_resp["size"] - suffix_len = byte_range.suffix - buffer = await obs.get_range_async( - self.store, - key, - start=file_size - suffix_len, - length=suffix_len, - ) - return prototype.buffer.from_bytes(buffer) # type: ignore[arg-type] - else: - raise ValueError(f"Unexpected byte_range, got {byte_range}") + return await self._get_impl(key, prototype, byte_range, obs) except _ALLOWED_EXCEPTIONS: return None @@ -145,7 +126,96 @@ async def get_partial_values( key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited - return await _get_partial_values(self.store, prototype=prototype, key_ranges=key_ranges) + # We override to: + # 1. Avoid deadlock from calling the decorated get() method + # 2. Batch RangeByteRequests per-file using get_ranges_async for performance + import obstore as obs + + key_ranges = list(key_ranges) + # Group bounded range requests by path for batched fetching + per_file_bounded: dict[str, list[tuple[int, RangeByteRequest]]] = defaultdict(list) + other_requests: list[tuple[int, str, ByteRequest | None]] = [] + + for idx, (path, byte_range) in enumerate(key_ranges): + if isinstance(byte_range, RangeByteRequest): + per_file_bounded[path].append((idx, byte_range)) + else: + other_requests.append((idx, path, byte_range)) + + buffers: list[Buffer | None] = [None] * len(key_ranges) + + async def _fetch_ranges(path: str, requests: list[tuple[int, RangeByteRequest]]) -> None: + """Batch multiple range requests for the same file using get_ranges_async.""" + starts = [r.start for _, r in requests] + ends = [r.end for _, r in requests] + if self._semaphore: + async with self._semaphore: + responses = await obs.get_ranges_async( + self.store, path=path, starts=starts, ends=ends + ) + else: + responses = await obs.get_ranges_async( + self.store, path=path, starts=starts, ends=ends + ) + for (idx, _), response in zip(requests, responses, strict=True): + buffers[idx] = prototype.buffer.from_bytes(response) # type: ignore[arg-type] + + async def _fetch_one(idx: int, path: str, byte_range: ByteRequest | None) -> None: + """Fetch a single non-range request with semaphore limiting.""" + try: + if self._semaphore: + async with self._semaphore: + buffers[idx] = await self._get_impl(path, prototype, byte_range, obs) + else: + buffers[idx] = await self._get_impl(path, prototype, byte_range, obs) + except _ALLOWED_EXCEPTIONS: + pass # buffers[idx] stays None + + futs: list[Coroutine[Any, Any, None]] = [] + for path, requests in per_file_bounded.items(): + futs.append(_fetch_ranges(path, requests)) + for idx, path, byte_range in other_requests: + futs.append(_fetch_one(idx, path, byte_range)) + + await asyncio.gather(*futs) + return buffers + + async def _get_impl( + self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None, obs: Any + ) -> Buffer: + """Implementation of get without semaphore decoration.""" + if byte_range is None: + resp = await obs.get_async(self.store, key) + return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] + elif isinstance(byte_range, RangeByteRequest): + bytes = await obs.get_range_async( + self.store, key, start=byte_range.start, end=byte_range.end + ) + return prototype.buffer.from_bytes(bytes) # type: ignore[arg-type] + elif isinstance(byte_range, OffsetByteRequest): + resp = await obs.get_async( + self.store, key, options={"range": {"offset": byte_range.offset}} + ) + return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] + elif isinstance(byte_range, SuffixByteRequest): + try: + resp = await obs.get_async( + self.store, key, options={"range": {"suffix": byte_range.suffix}} + ) + return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] + except obs.exceptions.NotSupportedError: + head_resp = await obs.head_async(self.store, key) + file_size = head_resp["size"] + suffix_len = byte_range.suffix + buffer = await obs.get_range_async( + self.store, + key, + start=file_size - suffix_len, + length=suffix_len, + ) + return prototype.buffer.from_bytes(buffer) # type: ignore[arg-type] + else: + raise ValueError(f"Unexpected byte_range, got {byte_range}") async def exists(self, key: str) -> bool: # docstring inherited @@ -163,6 +233,7 @@ def supports_writes(self) -> bool: # docstring inherited return True + @with_concurrency_limit() async def set(self, key: str, value: Buffer) -> None: # docstring inherited import obstore as obs @@ -172,20 +243,43 @@ async def set(self, key: str, value: Buffer) -> None: buf = value.as_buffer_like() await obs.put_async(self.store, key, buf) + async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None: + # Override to avoid deadlock from calling decorated set() method + import obstore as obs + + self._check_writable() + + async def _set_with_limit(key: str, value: Buffer) -> None: + buf = value.as_buffer_like() + if self._semaphore: + async with self._semaphore: + await obs.put_async(self.store, key, buf) + else: + await obs.put_async(self.store, key, buf) + + await asyncio.gather(*[_set_with_limit(key, value) for key, value in values]) + async def set_if_not_exists(self, key: str, value: Buffer) -> None: # docstring inherited + # Note: Not decorated to avoid deadlock when called in batch via gather() import obstore as obs self._check_writable() buf = value.as_buffer_like() - with contextlib.suppress(obs.exceptions.AlreadyExistsError): - await obs.put_async(self.store, key, buf, mode="create") + if self._semaphore: + async with self._semaphore: + with contextlib.suppress(obs.exceptions.AlreadyExistsError): + await obs.put_async(self.store, key, buf, mode="create") + else: + with contextlib.suppress(obs.exceptions.AlreadyExistsError): + await obs.put_async(self.store, key, buf, mode="create") @property def supports_deletes(self) -> bool: # docstring inherited return True + @with_concurrency_limit() async def delete(self, key: str) -> None: # docstring inherited import obstore as obs @@ -208,8 +302,18 @@ async def delete_dir(self, prefix: str) -> None: prefix += "/" metas = await obs.list(self.store, prefix).collect_async() - keys = [(m["path"],) for m in metas] - await concurrent_map(keys, self.delete, limit=config.get("async.concurrency")) + + # Delete with semaphore limiting to avoid deadlock + async def _delete_with_limit(path: str) -> None: + if self._semaphore: + async with self._semaphore: + with contextlib.suppress(FileNotFoundError): + await obs.delete_async(self.store, path) + else: + with contextlib.suppress(FileNotFoundError): + await obs.delete_async(self.store, path) + + await asyncio.gather(*[_delete_with_limit(m["path"]) for m in metas]) @property def supports_listing(self) -> bool: @@ -267,242 +371,3 @@ async def _transform_list_dir( objects = [obj["path"].removeprefix(prefix).lstrip("/") for obj in list_result["objects"]] for item in prefixes + objects: yield item - - -class _BoundedRequest(TypedDict): - """Range request with a known start and end byte. - - These requests can be multiplexed natively on the Rust side with - `obstore.get_ranges_async`. - """ - - original_request_index: int - """The positional index in the original key_ranges input""" - - start: int - """Start byte offset.""" - - end: int - """End byte offset.""" - - -class _OtherRequest(TypedDict): - """Offset or suffix range requests. - - These requests cannot be concurrent on the Rust side, and each need their own call - to `obstore.get_async`, passing in the `range` parameter. - """ - - original_request_index: int - """The positional index in the original key_ranges input""" - - path: str - """The path to request from.""" - - range: OffsetRange | None - # Note: suffix requests are handled separately because some object stores (Azure) - # don't support them - """The range request type.""" - - -class _SuffixRequest(TypedDict): - """Offset or suffix range requests. - - These requests cannot be concurrent on the Rust side, and each need their own call - to `obstore.get_async`, passing in the `range` parameter. - """ - - original_request_index: int - """The positional index in the original key_ranges input""" - - path: str - """The path to request from.""" - - range: SuffixRange - """The suffix range.""" - - -class _Response(TypedDict): - """A response buffer associated with the original index that it should be restored to.""" - - original_request_index: int - """The positional index in the original key_ranges input""" - - buffer: Buffer - """The buffer returned from obstore's range request.""" - - -async def _make_bounded_requests( - store: _UpstreamObjectStore, - path: str, - requests: list[_BoundedRequest], - prototype: BufferPrototype, - semaphore: asyncio.Semaphore, -) -> list[_Response]: - """Make all bounded requests for a specific file. - - `obstore.get_ranges_async` allows for making concurrent requests for multiple ranges - within a single file, and will e.g. merge concurrent requests. This only uses one - single Python coroutine. - """ - import obstore as obs - - starts = [r["start"] for r in requests] - ends = [r["end"] for r in requests] - async with semaphore: - responses = await obs.get_ranges_async(store, path=path, starts=starts, ends=ends) - - buffer_responses: list[_Response] = [] - for request, response in zip(requests, responses, strict=True): - buffer_responses.append( - { - "original_request_index": request["original_request_index"], - "buffer": prototype.buffer.from_bytes(response), # type: ignore[arg-type] - } - ) - - return buffer_responses - - -async def _make_other_request( - store: _UpstreamObjectStore, - request: _OtherRequest, - prototype: BufferPrototype, - semaphore: asyncio.Semaphore, -) -> list[_Response]: - """Make offset or full-file requests. - - We return a `list[_Response]` for symmetry with `_make_bounded_requests` so that all - futures can be gathered together. - """ - import obstore as obs - - async with semaphore: - if request["range"] is None: - resp = await obs.get_async(store, request["path"]) - else: - resp = await obs.get_async(store, request["path"], options={"range": request["range"]}) - buffer = await resp.bytes_async() - - return [ - { - "original_request_index": request["original_request_index"], - "buffer": prototype.buffer.from_bytes(buffer), # type: ignore[arg-type] - } - ] - - -async def _make_suffix_request( - store: _UpstreamObjectStore, - request: _SuffixRequest, - prototype: BufferPrototype, - semaphore: asyncio.Semaphore, -) -> list[_Response]: - """Make suffix requests. - - This is separated out from `_make_other_request` because some object stores (Azure) - don't support suffix requests. In this case, our workaround is to first get the - length of the object and then manually request the byte range at the end. - - We return a `list[_Response]` for symmetry with `_make_bounded_requests` so that all - futures can be gathered together. - """ - import obstore as obs - - async with semaphore: - try: - resp = await obs.get_async(store, request["path"], options={"range": request["range"]}) - buffer = await resp.bytes_async() - except obs.exceptions.NotSupportedError: - head_resp = await obs.head_async(store, request["path"]) - file_size = head_resp["size"] - suffix_len = request["range"]["suffix"] - buffer = await obs.get_range_async( - store, - request["path"], - start=file_size - suffix_len, - length=suffix_len, - ) - - return [ - { - "original_request_index": request["original_request_index"], - "buffer": prototype.buffer.from_bytes(buffer), # type: ignore[arg-type] - } - ] - - -async def _get_partial_values( - store: _UpstreamObjectStore, - prototype: BufferPrototype, - key_ranges: Iterable[tuple[str, ByteRequest | None]], -) -> list[Buffer | None]: - """Make multiple range requests. - - ObjectStore has a `get_ranges` method that will additionally merge nearby ranges, - but it's _per_ file. So we need to split these key_ranges into **per-file** key - ranges, and then reassemble the results in the original order. - - We separate into different requests: - - - One call to `obstore.get_ranges_async` **per target file** - - One call to `obstore.get_async` for each other request. - """ - key_ranges = list(key_ranges) - per_file_bounded_requests: dict[str, list[_BoundedRequest]] = defaultdict(list) - other_requests: list[_OtherRequest] = [] - suffix_requests: list[_SuffixRequest] = [] - - for idx, (path, byte_range) in enumerate(key_ranges): - if byte_range is None: - other_requests.append( - { - "original_request_index": idx, - "path": path, - "range": None, - } - ) - elif isinstance(byte_range, RangeByteRequest): - per_file_bounded_requests[path].append( - {"original_request_index": idx, "start": byte_range.start, "end": byte_range.end} - ) - elif isinstance(byte_range, OffsetByteRequest): - other_requests.append( - { - "original_request_index": idx, - "path": path, - "range": {"offset": byte_range.offset}, - } - ) - elif isinstance(byte_range, SuffixByteRequest): - suffix_requests.append( - { - "original_request_index": idx, - "path": path, - "range": {"suffix": byte_range.suffix}, - } - ) - else: - raise ValueError(f"Unsupported range input: {byte_range}") - - semaphore = asyncio.Semaphore(config.get("async.concurrency")) - - futs: list[Coroutine[Any, Any, list[_Response]]] = [] - for path, bounded_ranges in per_file_bounded_requests.items(): - futs.append( - _make_bounded_requests(store, path, bounded_ranges, prototype, semaphore=semaphore) - ) - - for request in other_requests: - futs.append(_make_other_request(store, request, prototype, semaphore=semaphore)) # noqa: PERF401 - - for suffix_request in suffix_requests: - futs.append(_make_suffix_request(store, suffix_request, prototype, semaphore=semaphore)) # noqa: PERF401 - - buffers: list[Buffer | None] = [None] * len(key_ranges) - - for responses in await asyncio.gather(*futs): - for resp in responses: - buffers[resp["original_request_index"]] = resp["buffer"] - - return buffers diff --git a/src/zarr/storage/_utils.py b/src/zarr/storage/_utils.py index 10ac395b36..80bce250e9 100644 --- a/src/zarr/storage/_utils.py +++ b/src/zarr/storage/_utils.py @@ -1,17 +1,87 @@ from __future__ import annotations +import functools import re from pathlib import Path -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar from zarr.abc.store import OffsetByteRequest, RangeByteRequest, SuffixByteRequest if TYPE_CHECKING: - from collections.abc import Iterable, Mapping + import asyncio + from collections.abc import Callable, Coroutine, Iterable, Mapping from zarr.abc.store import ByteRequest from zarr.core.buffer import Buffer +P = ParamSpec("P") +T_co = TypeVar("T_co", covariant=True) + + +def with_concurrency_limit( + semaphore_attr: str = "_semaphore", +) -> Callable[[Callable[P, Coroutine[Any, Any, T_co]]], Callable[P, Coroutine[Any, Any, T_co]]]: + """ + Decorator that applies a semaphore-based concurrency limit to an async method. + + This decorator is designed for Store methods that need to limit concurrent operations. + The store instance should have a `_semaphore` attribute (or custom attribute name) + that is either an asyncio.Semaphore or None (for unlimited concurrency). + + Parameters + ---------- + semaphore_attr : str, optional + Name of the semaphore attribute on the class instance. Default is "_semaphore". + + Returns + ------- + Callable + The decorated async function with concurrency limiting applied. + + Examples + -------- + ```python + class MyStore(Store): + def __init__(self, concurrency_limit: int = 100): + self._semaphore = asyncio.Semaphore(concurrency_limit) if concurrency_limit else None + + @with_concurrency_limit() + async def get(self, key: str) -> Buffer | None: + # This will only run when semaphore permits + return await expensive_io_operation(key) + ``` + """ + + def decorator( + func: Callable[P, Coroutine[Any, Any, T_co]], + ) -> Callable[P, Coroutine[Any, Any, T_co]]: + """ + This decorator wraps the invocation of `func` in an `async with semaphore` context manager. + The semaphore object is resolved by getting the `semaphor_attr` attribute from the first + argument to func. When this decorator is used on a method of a class, that first argument + is a reference to the class instance (`self`). + """ + + @functools.wraps(func) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T_co: + # First arg should be 'self' + if not args: + raise TypeError(f"{func.__name__} requires at least one argument (self)") + + self = args[0] + + semaphore: asyncio.Semaphore | None = getattr(self, semaphore_attr) + + if semaphore is not None: + async with semaphore: + return await func(*args, **kwargs) + else: + return await func(*args, **kwargs) + + return wrapper + + return decorator + def normalize_path(path: str | bytes | Path | None) -> str: if path is None: diff --git a/src/zarr/testing/store_concurrency.py b/src/zarr/testing/store_concurrency.py new file mode 100644 index 0000000000..0dd6dcff17 --- /dev/null +++ b/src/zarr/testing/store_concurrency.py @@ -0,0 +1,249 @@ +"""Base test class for store concurrency limiting behavior.""" + +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Generic, TypeVar + +import pytest + +from zarr.core.buffer import Buffer, default_buffer_prototype + +if TYPE_CHECKING: + from zarr.abc.store import Store + +__all__ = ["StoreConcurrencyTests"] + + +S = TypeVar("S", bound="Store") +B = TypeVar("B", bound="Buffer") + + +class StoreConcurrencyTests(Generic[S, B]): + """Base class for testing store concurrency limiting behavior. + + This mixin provides tests for verifying that stores correctly implement + concurrency limiting. + + Subclasses should set: + - store_cls: The store class being tested + - buffer_cls: The buffer class to use (e.g., cpu.Buffer) + - expected_concurrency_limit: Expected default concurrency limit (or None for unlimited) + """ + + store_cls: type[S] + buffer_cls: type[B] + expected_concurrency_limit: int | None + + @pytest.fixture + async def store(self, store_kwargs: dict) -> S: + """Create and open a store instance.""" + return await self.store_cls.open(**store_kwargs) + + def test_concurrency_limit_default(self, store: S) -> None: + """Test that store has the expected default concurrency limit.""" + if hasattr(store, "_semaphore"): + if self.expected_concurrency_limit is None: + assert store._semaphore is None, "Expected no concurrency limit" + else: + assert store._semaphore is not None, "Expected concurrency limit to be set" + assert store._semaphore._value == self.expected_concurrency_limit, ( + f"Expected limit {self.expected_concurrency_limit}, got {store._semaphore._value}" + ) + + def test_concurrency_limit_custom(self, store_kwargs: dict) -> None: + """Test that custom concurrency limits can be set.""" + if "concurrency_limit" not in self.store_cls.__init__.__code__.co_varnames: + pytest.skip("Store does not support custom concurrency limits") + + # Test with custom limit + store = self.store_cls(**store_kwargs, concurrency_limit=42) + if hasattr(store, "_semaphore"): + assert store._semaphore is not None + assert store._semaphore._value == 42 + + # Test with None (unlimited) + store = self.store_cls(**store_kwargs, concurrency_limit=None) + if hasattr(store, "_semaphore"): + assert store._semaphore is None + + async def test_concurrency_limit_enforced(self, store: S) -> None: + """Test that the concurrency limit is actually enforced during execution. + + This test verifies that when many operations are submitted concurrently, + only up to the concurrency limit are actually executing at once. + """ + if not hasattr(store, "_semaphore") or store._semaphore is None: + pytest.skip("Store has no concurrency limit") + + limit = store._semaphore._value + + # We'll monitor the semaphore's available count + # When it reaches 0, that means `limit` operations are running + min_available = limit + + async def monitored_operation(key: str, value: B) -> None: + nonlocal min_available + # Check semaphore state right after we're scheduled + await asyncio.sleep(0) # Yield to ensure we're in the queue + available = store._semaphore._value + min_available = min(min_available, available) + + # Now do the actual operation (which will acquire the semaphore) + await store.set(key, value) + + # Launch more operations than the limit to ensure contention + num_ops = limit * 2 + items = [ + (f"limit_test_key_{i}", self.buffer_cls.from_bytes(f"value_{i}".encode())) + for i in range(num_ops) + ] + + await asyncio.gather(*[monitored_operation(k, v) for k, v in items]) + + # The semaphore should have been fully utilized (reached 0 or close to it) + # This indicates that `limit` operations were running concurrently + assert min_available < limit, ( + f"Semaphore was never fully utilized. " + f"Min available: {min_available}, Limit: {limit}. " + f"This suggests operations aren't running concurrently." + ) + + # Ideally it should reach 0, but allow some slack for timing + assert min_available <= 5, ( + f"Semaphore only reached {min_available} available slots. " + f"Expected close to 0 with limit {limit}." + ) + + async def test_batch_write_no_deadlock(self, store: S) -> None: + """Test that batch writes don't deadlock when exceeding concurrency limit.""" + # Create more items than any reasonable concurrency limit + num_items = 200 + items = [ + (f"test_key_{i}", self.buffer_cls.from_bytes(f"test_value_{i}".encode())) + for i in range(num_items) + ] + + # This should complete without deadlock, even if num_items > concurrency_limit + await asyncio.wait_for(store._set_many(items), timeout=30.0) + + # Verify all items were written correctly + for key, expected_value in items: + result = await store.get(key, default_buffer_prototype()) + assert result is not None + assert result.to_bytes() == expected_value.to_bytes() + + async def test_batch_read_no_deadlock(self, store: S) -> None: + """Test that batch reads don't deadlock when exceeding concurrency limit.""" + # Write test data + num_items = 200 + test_data = { + f"test_key_{i}": self.buffer_cls.from_bytes(f"test_value_{i}".encode()) + for i in range(num_items) + } + + for key, value in test_data.items(): + await store.set(key, value) + + # Read all items concurrently - should not deadlock + keys_and_ranges = [(key, None) for key in test_data] + results = await asyncio.wait_for( + store.get_partial_values(default_buffer_prototype(), keys_and_ranges), + timeout=30.0, + ) + + # Verify results + assert len(results) == num_items + for result, (_key, expected_value) in zip(results, test_data.items(), strict=True): + assert result is not None + assert result.to_bytes() == expected_value.to_bytes() + + async def test_batch_delete_no_deadlock(self, store: S) -> None: + """Test that batch deletes don't deadlock when exceeding concurrency limit.""" + if not store.supports_deletes: + pytest.skip("Store does not support deletes") + + # Write test data + num_items = 200 + keys = [f"test_key_{i}" for i in range(num_items)] + for key in keys: + await store.set(key, self.buffer_cls.from_bytes(b"test_value")) + + # Delete all items concurrently - should not deadlock + await asyncio.wait_for(asyncio.gather(*[store.delete(key) for key in keys]), timeout=30.0) + + # Verify all items were deleted + for key in keys: + result = await store.get(key, default_buffer_prototype()) + assert result is None + + async def test_concurrent_operations_correctness(self, store: S) -> None: + """Test that concurrent operations produce correct results.""" + num_operations = 100 + + # Mix of reads and writes + write_keys = [f"write_key_{i}" for i in range(num_operations)] + write_values = [ + self.buffer_cls.from_bytes(f"value_{i}".encode()) for i in range(num_operations) + ] + + # Write all concurrently + await asyncio.gather( + *[store.set(k, v) for k, v in zip(write_keys, write_values, strict=True)] + ) + + # Read all concurrently + results = await asyncio.gather( + *[store.get(k, default_buffer_prototype()) for k in write_keys] + ) + + # Verify correctness + for result, expected in zip(results, write_values, strict=True): + assert result is not None + assert result.to_bytes() == expected.to_bytes() + + @pytest.mark.parametrize("batch_size", [1, 10, 50, 100]) + async def test_various_batch_sizes(self, store: S, batch_size: int) -> None: + """Test that various batch sizes work correctly.""" + items = [ + (f"batch_key_{i}", self.buffer_cls.from_bytes(f"batch_value_{i}".encode())) + for i in range(batch_size) + ] + + # Should complete without issues for any batch size + await asyncio.wait_for(store._set_many(items), timeout=10.0) + + # Verify + for key, expected_value in items: + result = await store.get(key, default_buffer_prototype()) + assert result is not None + assert result.to_bytes() == expected_value.to_bytes() + + async def test_empty_batch_operations(self, store: S) -> None: + """Test that empty batch operations don't cause issues.""" + # Empty batch should not raise + await store._set_many([]) + + # Empty read batch + results = await store.get_partial_values(default_buffer_prototype(), []) + assert results == [] + + async def test_mixed_success_failure_batch(self, store: S) -> None: + """Test batch operations with mix of successful and failing items.""" + # Write some initial data + await store.set("existing_key", self.buffer_cls.from_bytes(b"existing_value")) + + # Try to read mix of existing and non-existing keys + key_ranges = [ + ("existing_key", None), + ("non_existing_key_1", None), + ("non_existing_key_2", None), + ] + + results = await store.get_partial_values(default_buffer_prototype(), key_ranges) + + # First should exist, others should be None + assert results[0] is not None + assert results[0].to_bytes() == b"existing_value" + assert results[1] is None + assert results[2] is None diff --git a/tests/test_common.py b/tests/test_common.py index 0944c3375a..9484d15ca3 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -31,10 +31,6 @@ def test_access_modes() -> None: assert set(ANY_ACCESS_MODE) == set(get_args(AccessModeLiteral)) -# todo: test -def test_concurrent_map() -> None: ... - - # todo: test def test_to_thread() -> None: ... diff --git a/tests/test_global_concurrency.py b/tests/test_global_concurrency.py new file mode 100644 index 0000000000..3cfca5052c --- /dev/null +++ b/tests/test_global_concurrency.py @@ -0,0 +1,43 @@ +""" +Tests for store-level concurrency limiting through the array API. +""" + +import numpy as np + +import zarr + + +class TestStoreConcurrencyThroughArrayAPI: + """Tests that store-level concurrency limiting works through the array API.""" + + def test_array_operations_with_store_concurrency(self) -> None: + """Test that array read/write works correctly with store-level concurrency limits.""" + store = zarr.storage.MemoryStore() + arr = zarr.create( + shape=(20, 20), + chunks=(10, 10), + dtype="i4", + store=store, + zarr_format=3, + ) + arr[:] = 42 + + data = arr[:] + + assert np.all(data == 42) + + def test_array_operations_with_local_store_concurrency(self, tmp_path: object) -> None: + """Test that array read/write works correctly with LocalStore concurrency limits.""" + store = zarr.storage.LocalStore(str(tmp_path), concurrency_limit=10) + arr = zarr.create( + shape=(20, 20), + chunks=(10, 10), + dtype="i4", + store=store, + zarr_format=3, + ) + arr[:] = 42 + + data = arr[:] + + assert np.all(data == 42) diff --git a/tests/test_group.py b/tests/test_group.py index 6f1f4e68fa..9f25036298 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -23,7 +23,6 @@ from zarr.core import sync_group from zarr.core._info import GroupInfo from zarr.core.buffer import default_buffer_prototype -from zarr.core.config import config as zarr_config from zarr.core.dtype.common import unpack_dtype_json from zarr.core.dtype.npy.int import UInt8 from zarr.core.group import ( @@ -1738,29 +1737,6 @@ async def test_create_nodes( assert node_spec == {k: v.metadata for k, v in observed_nodes.items()} -@pytest.mark.parametrize("store", ["memory"], indirect=True) -def test_create_nodes_concurrency_limit(store: MemoryStore) -> None: - """ - Test that the execution time of create_nodes can be constrained by the async concurrency - configuration setting. - """ - set_latency = 0.02 - num_groups = 10 - groups = {str(idx): GroupMetadata() for idx in range(num_groups)} - - latency_store = LatencyStore(store, set_latency=set_latency) - - # check how long it takes to iterate over the groups - # if create_nodes is sensitive to IO latency, - # this should take (num_groups * get_latency) seconds - # otherwise, it should take only marginally more than get_latency seconds - with zarr_config.set({"async.concurrency": 1}): - start = time.time() - _ = tuple(sync_group.create_nodes(store=latency_store, nodes=groups)) - elapsed = time.time() - start - assert elapsed > num_groups * set_latency - - @pytest.mark.parametrize( ("a_func", "b_func"), [ @@ -2250,38 +2226,6 @@ def test_group_members_performance(store: Store) -> None: assert elapsed < (num_groups * get_latency) -@pytest.mark.parametrize("store", ["memory"], indirect=True) -def test_group_members_concurrency_limit(store: MemoryStore) -> None: - """ - Test that the execution time of Group.members can be constrained by the async concurrency - configuration setting. - """ - get_latency = 0.02 - - # use the input store to create some groups - group_create = zarr.group(store=store) - num_groups = 10 - - # Create some groups - for i in range(num_groups): - group_create.create_group(f"group{i}") - - latency_store = LatencyStore(store, get_latency=get_latency) - # create a group with some latency on get operations - group_read = zarr.group(store=latency_store) - - # check how long it takes to iterate over the groups - # if .members is sensitive to IO latency, - # this should take (num_groups * get_latency) seconds - # otherwise, it should take only marginally more than get_latency seconds - with zarr_config.set({"async.concurrency": 1}): - start = time.time() - _ = group_read.members() - elapsed = time.time() - start - - assert elapsed > num_groups * get_latency - - @pytest.mark.parametrize("option", ["array", "group", "invalid"]) def test_build_metadata_v3(option: Literal["array", "group", "invalid"]) -> None: """ diff --git a/tests/test_store/test_local.py b/tests/test_store/test_local.py index bdc9b48121..ca9759bd9a 100644 --- a/tests/test_store/test_local.py +++ b/tests/test_store/test_local.py @@ -15,6 +15,7 @@ from zarr.storage import LocalStore from zarr.storage._local import _atomic_write from zarr.testing.store import StoreTests +from zarr.testing.store_concurrency import StoreConcurrencyTests from zarr.testing.utils import assert_bytes_equal if TYPE_CHECKING: @@ -204,3 +205,15 @@ def test_atomic_write_exclusive_preexisting(tmp_path: pathlib.Path) -> None: f.write(b"abc") assert path.read_bytes() == b"xyz" assert list(path.parent.iterdir()) == [path] # no temp files + + +class TestLocalStoreConcurrency(StoreConcurrencyTests[LocalStore, cpu.Buffer]): + """Test LocalStore concurrency limiting behavior.""" + + store_cls = LocalStore + buffer_cls = cpu.Buffer + expected_concurrency_limit = 100 # LocalStore default + + @pytest.fixture + def store_kwargs(self, tmpdir: str) -> dict[str, str]: + return {"root": str(tmpdir)} diff --git a/tests/test_store/test_memory.py b/tests/test_store/test_memory.py index 03c8b24271..1004ca20bb 100644 --- a/tests/test_store/test_memory.py +++ b/tests/test_store/test_memory.py @@ -14,6 +14,7 @@ from zarr.errors import ZarrUserWarning from zarr.storage import GpuMemoryStore, MemoryStore from zarr.testing.store import StoreTests +from zarr.testing.store_concurrency import StoreConcurrencyTests from zarr.testing.utils import gpu_test if TYPE_CHECKING: @@ -181,3 +182,15 @@ def test_from_dict(self) -> None: result = GpuMemoryStore.from_dict(d) for v in result._store_dict.values(): assert type(v) is gpu.Buffer + + +class TestMemoryStoreConcurrency(StoreConcurrencyTests[MemoryStore, cpu.Buffer]): + """Test MemoryStore concurrency limiting behavior.""" + + store_cls = MemoryStore + buffer_cls = cpu.Buffer + expected_concurrency_limit = None # MemoryStore has no limit (fast in-memory ops) + + @pytest.fixture + def store_kwargs(self) -> dict[str, Any]: + return {"store_dict": None}