From 6c4e53f84d9efc9fa08870b58f2b67276b0d3e95 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 24 Oct 2025 21:31:05 +0200 Subject: [PATCH 1/7] add global concurrency limit instead of per-routine concurrency limits --- src/zarr/abc/codec.py | 4 -- src/zarr/abc/store.py | 4 +- src/zarr/core/array.py | 4 -- src/zarr/core/codec_pipeline.py | 5 -- src/zarr/core/common.py | 115 +++++++++++++++++++++++++++++++- src/zarr/core/group.py | 10 +-- src/zarr/storage/_local.py | 3 +- src/zarr/storage/_memory.py | 3 +- src/zarr/storage/_obstore.py | 8 +-- 9 files changed, 128 insertions(+), 28 deletions(-) diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index d41c457b4e..b5f7819a91 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -9,7 +9,6 @@ from zarr.abc.metadata import Metadata from zarr.core.buffer import Buffer, NDBuffer from zarr.core.common import NamedConfig, concurrent_map -from zarr.core.config import config if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Iterable @@ -228,7 +227,6 @@ async def decode_partial( return await concurrent_map( list(batch_info), self._decode_partial_single, - config.get("async.concurrency"), ) @@ -265,7 +263,6 @@ async def encode_partial( await concurrent_map( list(batch_info), self._encode_partial_single, - config.get("async.concurrency"), ) @@ -467,7 +464,6 @@ async def _batching_helper( return await concurrent_map( list(batch_info), _noop_for_none(func), - config.get("async.concurrency"), ) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 4b3edf78d1..30602edf34 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -464,11 +464,9 @@ async def getsize_prefix(self, prefix: str) -> int: # avoid circular import from zarr.core.common import concurrent_map - from zarr.core.config import config keys = [(x,) async for x in self.list_prefix(prefix)] - limit = config.get("async.concurrency") - sizes = await concurrent_map(keys, self.getsize, limit=limit) + sizes = await concurrent_map(keys, self.getsize) return sum(sizes) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 8bd8be40b2..2f42836fc2 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -22,7 +22,6 @@ import numpy as np from typing_extensions import deprecated -import zarr from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec from zarr.abc.numcodec import Numcodec, _is_numcodec from zarr.codecs._v2 import V2Codec @@ -1853,7 +1852,6 @@ async def _delete_key(key: str) -> None: for chunk_coords in old_chunk_coords.difference(new_chunk_coords) ], _delete_key, - zarr_config.get("async.concurrency"), ) # Write new metadata @@ -4530,7 +4528,6 @@ async def _copy_array_region( await concurrent_map( [(region, data) for region in result._iter_shard_regions()], _copy_array_region, - zarr.core.config.config.get("async.concurrency"), ) else: @@ -4541,7 +4538,6 @@ async def _copy_arraylike_region(chunk_coords: slice, _data: NDArrayLike) -> Non await concurrent_map( [(region, data) for region in result._iter_shard_regions()], _copy_arraylike_region, - zarr.core.config.config.get("async.concurrency"), ) return result diff --git a/src/zarr/core/codec_pipeline.py b/src/zarr/core/codec_pipeline.py index 63fcda7065..e6864c607e 100644 --- a/src/zarr/core/codec_pipeline.py +++ b/src/zarr/core/codec_pipeline.py @@ -270,7 +270,6 @@ async def read_batch( chunk_bytes_batch = await concurrent_map( [(byte_getter, array_spec.prototype) for byte_getter, array_spec, *_ in batch_info], lambda byte_getter, prototype: byte_getter.get(prototype), - config.get("async.concurrency"), ) chunk_array_batch = await self.decode_batch( [ @@ -375,7 +374,6 @@ async def _read_key( for byte_setter, chunk_spec, chunk_selection, _, is_complete_chunk in batch_info ], _read_key, - config.get("async.concurrency"), ) chunk_array_decoded = await self.decode_batch( [ @@ -441,7 +439,6 @@ async def _write_key(byte_setter: ByteSetter, chunk_bytes: Buffer | None) -> Non ) ], _write_key, - config.get("async.concurrency"), ) async def decode( @@ -474,7 +471,6 @@ async def read( for single_batch_info in batched(batch_info, self.batch_size) ], self.read_batch, - config.get("async.concurrency"), ) async def write( @@ -489,7 +485,6 @@ async def write( for single_batch_info in batched(batch_info, self.batch_size) ], self.write_batch, - config.get("async.concurrency"), ) diff --git a/src/zarr/core/common.py b/src/zarr/core/common.py index 651ebd72f3..8f6f899cf8 100644 --- a/src/zarr/core/common.py +++ b/src/zarr/core/common.py @@ -4,7 +4,9 @@ import functools import math import operator +import threading import warnings +import weakref from collections.abc import Iterable, Mapping, Sequence from enum import Enum from itertools import starmap @@ -82,15 +84,126 @@ def ceildiv(a: float, b: float) -> int: V = TypeVar("V") +# Global semaphore management for per-process concurrency limiting +# Use WeakKeyDictionary to automatically clean up semaphores when event loops are garbage collected +_global_semaphores: weakref.WeakKeyDictionary[asyncio.AbstractEventLoop, asyncio.Semaphore] = ( + weakref.WeakKeyDictionary() +) +# Use threading.Lock instead of asyncio.Lock to coordinate across event loops +_global_semaphore_lock = threading.Lock() + + +def get_global_semaphore() -> asyncio.Semaphore: + """ + Get the global semaphore for the current event loop. + + This ensures that all concurrent operations across the process share the same + concurrency limit, preventing excessive concurrent task creation when multiple + arrays or operations are running simultaneously. + + The semaphore is lazily created per event loop and uses the configured + `async.concurrency` value from zarr config. The semaphore is cached per event + loop, so subsequent calls return the same semaphore instance. + + Note: Config changes after the first call will not affect the semaphore limit. + To apply new config values, use :func:`reset_global_semaphores` to clear the cache. + + Returns + ------- + asyncio.Semaphore + The global semaphore for this event loop. + + Raises + ------ + RuntimeError + If called outside of an async context (no running event loop). + + See Also + -------- + reset_global_semaphores : Clear the global semaphore cache + """ + loop = asyncio.get_running_loop() + + # Acquire lock FIRST to prevent TOCTOU race condition + with _global_semaphore_lock: + if loop not in _global_semaphores: + limit = zarr_config.get("async.concurrency") + _global_semaphores[loop] = asyncio.Semaphore(limit) + return _global_semaphores[loop] + + +def reset_global_semaphores() -> None: + """ + Clear all cached global semaphores. + + This is useful when you want config changes to take effect, or for testing. + The next call to :func:`get_global_semaphore` will create a new semaphore + using the current configuration. + + Warning: This should only be called when no async operations are in progress, + as it will invalidate all existing semaphore references. + + Examples + -------- + >>> import zarr + >>> zarr.config.set({"async.concurrency": 50}) + >>> reset_global_semaphores() # Apply new config + """ + with _global_semaphore_lock: + _global_semaphores.clear() + + async def concurrent_map( items: Iterable[T], func: Callable[..., Awaitable[V]], limit: int | None = None, + *, + use_global_semaphore: bool = True, ) -> list[V]: - if limit is None: + """ + Execute an async function concurrently over multiple items with concurrency limiting. + + Parameters + ---------- + items : Iterable[T] + Items to process, where each item is a tuple of arguments to pass to func. + func : Callable[..., Awaitable[V]] + Async function to execute for each item. + limit : int | None, optional + If provided and use_global_semaphore is False, creates a local semaphore + with this limit. If None, no concurrency limiting is applied. + use_global_semaphore : bool, default True + If True, uses the global per-process semaphore for concurrency limiting, + ensuring all concurrent operations share the same limit. If False, uses + the `limit` parameter for local limiting (legacy behavior). + + Returns + ------- + list[V] + Results from executing func on all items. + """ + if use_global_semaphore: + if limit is not None: + raise ValueError( + "Cannot specify both use_global_semaphore=True and a limit value. " + "Either use the global semaphore (use_global_semaphore=True, limit=None) " + "or specify a local limit (use_global_semaphore=False, limit=)." + ) + # Use the global semaphore for process-wide concurrency limiting + sem = get_global_semaphore() + + async def run(item: tuple[Any]) -> V: + async with sem: + return await func(*item) + + return await asyncio.gather(*[asyncio.ensure_future(run(item)) for item in items]) + + elif limit is None: + # No concurrency limiting return await asyncio.gather(*list(starmap(func, items))) else: + # Legacy mode: create local semaphore with specified limit sem = asyncio.Semaphore(limit) async def run(item: tuple[Any]) -> V: diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 26aed4fd60..2f381431a3 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -44,6 +44,7 @@ NodeType, ShapeLike, ZarrFormat, + get_global_semaphore, parse_shapelike, ) from zarr.core.config import config @@ -1441,8 +1442,8 @@ async def _members( ) raise ValueError(msg) - # enforce a concurrency limit by passing a semaphore to all the recursive functions - semaphore = asyncio.Semaphore(config.get("async.concurrency")) + # Use global semaphore for process-wide concurrency limiting + semaphore = get_global_semaphore() async for member in _iter_members_deep( self, max_depth=max_depth, @@ -3338,9 +3339,8 @@ async def create_nodes( The created nodes in the order they are created. """ - # Note: the only way to alter this value is via the config. If that's undesirable for some reason, - # then we should consider adding a keyword argument this this function - semaphore = asyncio.Semaphore(config.get("async.concurrency")) + # Use global semaphore for process-wide concurrency limiting + semaphore = get_global_semaphore() create_tasks: list[Coroutine[None, None, str]] = [] for key, value in nodes.items(): diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index f64da71bb4..d6f10be862 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -217,7 +217,8 @@ async def get_partial_values( assert isinstance(key, str) path = self.root / key args.append((_get, path, prototype, byte_range)) - return await concurrent_map(args, asyncio.to_thread, limit=None) # TODO: fix limit + # Use global semaphore to limit concurrent thread spawning + return await concurrent_map(args, asyncio.to_thread) async def set(self, key: str, value: Buffer) -> None: # docstring inherited diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index a3fd058680..12d7424185 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -104,7 +104,8 @@ async def get_partial_values( async def _get(key: str, byte_range: ByteRequest | None) -> Buffer | None: return await self.get(key, prototype=prototype, byte_range=byte_range) - return await concurrent_map(key_ranges, _get, limit=None) + # In-memory operations are fast and don't benefit from concurrency limiting + return await concurrent_map(key_ranges, _get, use_global_semaphore=False) async def exists(self, key: str) -> bool: # docstring inherited diff --git a/src/zarr/storage/_obstore.py b/src/zarr/storage/_obstore.py index 5c2197ecf6..e1d1bde672 100644 --- a/src/zarr/storage/_obstore.py +++ b/src/zarr/storage/_obstore.py @@ -13,8 +13,7 @@ Store, SuffixByteRequest, ) -from zarr.core.common import concurrent_map -from zarr.core.config import config +from zarr.core.common import concurrent_map, get_global_semaphore if TYPE_CHECKING: from collections.abc import AsyncGenerator, Coroutine, Iterable, Sequence @@ -209,7 +208,7 @@ async def delete_dir(self, prefix: str) -> None: metas = await obs.list(self.store, prefix).collect_async() keys = [(m["path"],) for m in metas] - await concurrent_map(keys, self.delete, limit=config.get("async.concurrency")) + await concurrent_map(keys, self.delete) @property def supports_listing(self) -> bool: @@ -485,7 +484,8 @@ async def _get_partial_values( else: raise ValueError(f"Unsupported range input: {byte_range}") - semaphore = asyncio.Semaphore(config.get("async.concurrency")) + # Use global semaphore for process-wide concurrency limiting + semaphore = get_global_semaphore() futs: list[Coroutine[Any, Any, list[_Response]]] = [] for path, bounded_ranges in per_file_bounded_requests.items(): From e98e6c0a1f4f731bfb76eaa8ffdbe050b8b2c7e6 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 24 Oct 2025 21:32:17 +0200 Subject: [PATCH 2/7] add test --- tests/test_global_concurrency.py | 330 +++++++++++++++++++++++++++++++ 1 file changed, 330 insertions(+) create mode 100644 tests/test_global_concurrency.py diff --git a/tests/test_global_concurrency.py b/tests/test_global_concurrency.py new file mode 100644 index 0000000000..5df1d68a39 --- /dev/null +++ b/tests/test_global_concurrency.py @@ -0,0 +1,330 @@ +""" +Tests for global per-process concurrency limiting. +""" + +import asyncio +from typing import TYPE_CHECKING, Any + +import numpy as np +import pytest + +import zarr +from zarr.core.common import get_global_semaphore, reset_global_semaphores +from zarr.core.config import config + +if TYPE_CHECKING: + from numpy.typing import NDArray + + +class TestGlobalSemaphore: + """Tests for the global semaphore management.""" + + async def test_get_global_semaphore_creates_per_loop(self) -> None: + """Test that each event loop gets its own semaphore.""" + sem1 = get_global_semaphore() + assert sem1 is not None + assert isinstance(sem1, asyncio.Semaphore) + + # Getting it again should return the same instance + sem2 = get_global_semaphore() + assert sem1 is sem2 + + async def test_global_semaphore_uses_config_limit(self) -> None: + """Test that the global semaphore respects the configured limit.""" + # Set a custom concurrency limit + original_limit: Any = config.get("async.concurrency") + try: + config.set({"async.concurrency": 5}) + + # Clear existing semaphores to force recreation + reset_global_semaphores() + + sem = get_global_semaphore() + + # The semaphore should have the configured limit + # We can verify this by acquiring all tokens and checking the semaphore is locked + for i in range(5): + await sem.acquire() + if i < 4: + assert not sem.locked() # Should still have capacity + else: + assert sem.locked() # All tokens acquired, semaphore is now locked + + # Release all tokens + for _ in range(5): + sem.release() + + finally: + # Restore original config + config.set({"async.concurrency": original_limit}) + # Clear semaphores again to reset state + reset_global_semaphores() + + async def test_global_semaphore_shared_across_operations(self) -> None: + """Test that multiple concurrent operations share the same semaphore.""" + # Track the maximum number of concurrent tasks + max_concurrent = 0 + current_concurrent = 0 + lock = asyncio.Lock() + + async def tracked_operation() -> None: + """An operation that tracks concurrency.""" + nonlocal max_concurrent, current_concurrent + + async with lock: + current_concurrent += 1 + max_concurrent = max(max_concurrent, current_concurrent) + + # Small delay to ensure overlap + await asyncio.sleep(0.01) + + async with lock: + current_concurrent -= 1 + + # Set a low concurrency limit to make the test observable + original_limit: Any = config.get("async.concurrency") + try: + config.set({"async.concurrency": 5}) + + # Clear existing semaphores + reset_global_semaphores() + + # Get the global semaphore + sem = get_global_semaphore() + + # Create many tasks that use the semaphore + async def task_with_semaphore() -> None: + async with sem: + await tracked_operation() + + # Launch 20 tasks (4x the limit) + tasks = [task_with_semaphore() for _ in range(20)] + await asyncio.gather(*tasks) + + # Maximum concurrent should respect the limit + assert max_concurrent <= 5, f"Max concurrent was {max_concurrent}, expected <= 5" + assert max_concurrent >= 3, ( + f"Max concurrent was {max_concurrent}, expected some concurrency" + ) + + finally: + config.set({"async.concurrency": original_limit}) + reset_global_semaphores() + + async def test_semaphore_reuse_across_calls(self) -> None: + """Test that repeated calls to get_global_semaphore return the same instance.""" + reset_global_semaphores() + + # Call multiple times and verify we get the same instance + sem1 = get_global_semaphore() + sem2 = get_global_semaphore() + sem3 = get_global_semaphore() + + assert sem1 is sem2 is sem3, "Should return same semaphore instance on repeated calls" + + # Verify it's still the same after using it + async with sem1: + sem4 = get_global_semaphore() + assert sem1 is sem4 + + def test_config_change_after_creation(self) -> None: + """Test and document that config changes don't affect existing semaphores.""" + original_limit: Any = config.get("async.concurrency") + try: + # Set initial config + config.set({"async.concurrency": 5}) + + async def check_limit() -> None: + reset_global_semaphores() + + # Create semaphore with limit=5 + sem1 = get_global_semaphore() + initial_capacity: int = sem1._value + + # Change config + config.set({"async.concurrency": 50}) + + # Get semaphore again - should be same instance with old limit + sem2 = get_global_semaphore() + assert sem1 is sem2, "Should return same semaphore instance" + assert sem2._value == initial_capacity, ( + f"Semaphore limit changed from {initial_capacity} to {sem2._value}. " + "Config changes should not affect existing semaphores." + ) + + # Clean up + reset_global_semaphores() + + asyncio.run(check_limit()) + + finally: + config.set({"async.concurrency": original_limit}) + + +class TestArrayConcurrency: + """Tests that array operations use global concurrency limiting.""" + + @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") + async def test_multiple_arrays_share_concurrency_limit(self) -> None: + """Test that reading from multiple arrays shares the global concurrency limit.""" + from zarr.core.common import concurrent_map + + # Track concurrent task executions + max_concurrent_tasks = 0 + current_concurrent_tasks = 0 + task_lock = asyncio.Lock() + + async def tracked_chunk_operation(chunk_id: int) -> int: + """Simulate a chunk operation with tracking.""" + nonlocal max_concurrent_tasks, current_concurrent_tasks + + async with task_lock: + current_concurrent_tasks += 1 + max_concurrent_tasks = max(max_concurrent_tasks, current_concurrent_tasks) + + # Small delay to simulate I/O + await asyncio.sleep(0.001) + + async with task_lock: + current_concurrent_tasks -= 1 + + return chunk_id + + # Set a low concurrency limit + original_limit: Any = config.get("async.concurrency") + try: + config.set({"async.concurrency": 10}) + + # Clear existing semaphores + reset_global_semaphores() + + # Simulate reading many chunks using concurrent_map (which uses the global semaphore) + # This simulates what happens when reading from multiple arrays + chunk_ids = [(i,) for i in range(100)] + await concurrent_map(chunk_ids, tracked_chunk_operation) + + # The maximum concurrent tasks should respect the global limit + assert max_concurrent_tasks <= 10, ( + f"Max concurrent tasks was {max_concurrent_tasks}, expected <= 10" + ) + + assert max_concurrent_tasks >= 5, ( + f"Max concurrent tasks was {max_concurrent_tasks}, " + f"expected at least some concurrency" + ) + + finally: + config.set({"async.concurrency": original_limit}) + # Note: We don't reset_global_semaphores() here because doing so while + # many tasks are still cleaning up can trigger ResourceWarnings from + # asyncio internals. The semaphore will be reused by subsequent tests. + + def test_sync_api_uses_global_concurrency(self) -> None: + """Test that synchronous API also benefits from global concurrency limiting.""" + # This test verifies that the sync API (which wraps async) uses global limiting + + # Set a low concurrency limit + original_limit: Any = config.get("async.concurrency") + try: + config.set({"async.concurrency": 8}) + + # Create a small array - the key is that zarr internally uses + # concurrent_map which now uses the global semaphore + store = zarr.storage.MemoryStore() + arr = zarr.create( + shape=(20, 20), + chunks=(10, 10), + dtype="i4", + store=store, + zarr_format=3, + ) + arr[:] = 42 + + # Read data (synchronously) + data: NDArray[Any] = arr[:] + + # Verify we got the right data + assert np.all(data == 42) + + # The test passes if no errors occurred + # The concurrency limiting is happening under the hood + + finally: + config.set({"async.concurrency": original_limit}) + + +class TestConcurrentMapGlobal: + """Tests for concurrent_map using global semaphore.""" + + async def test_concurrent_map_uses_global_by_default(self) -> None: + """Test that concurrent_map uses global semaphore by default.""" + from zarr.core.common import concurrent_map + + # Track concurrent executions + max_concurrent = 0 + current_concurrent = 0 + lock = asyncio.Lock() + + async def tracked_task(x: int) -> int: + nonlocal max_concurrent, current_concurrent + + async with lock: + current_concurrent += 1 + max_concurrent = max(max_concurrent, current_concurrent) + + await asyncio.sleep(0.01) + + async with lock: + current_concurrent -= 1 + + return x * 2 + + # Set a low limit + original_limit: Any = config.get("async.concurrency") + try: + config.set({"async.concurrency": 5}) + + # Clear existing semaphores + reset_global_semaphores() + + # Use concurrent_map with default settings (use_global_semaphore=True) + items = [(i,) for i in range(20)] + results = await concurrent_map(items, tracked_task) + + assert len(results) == 20 + assert max_concurrent <= 5 + assert max_concurrent >= 3 # Should have some concurrency + + finally: + config.set({"async.concurrency": original_limit}) + reset_global_semaphores() + + async def test_concurrent_map_legacy_mode(self) -> None: + """Test that concurrent_map legacy mode still works.""" + from zarr.core.common import concurrent_map + + async def simple_task(x: int) -> int: + await asyncio.sleep(0.001) + return x * 2 + + # Use legacy mode with local limit + items = [(i,) for i in range(10)] + results = await concurrent_map(items, simple_task, limit=3, use_global_semaphore=False) + + assert len(results) == 10 + assert results == [i * 2 for i in range(10)] + + async def test_concurrent_map_parameter_validation(self) -> None: + """Test that concurrent_map validates conflicting parameters.""" + from zarr.core.common import concurrent_map + + async def simple_task(x: int) -> int: + return x * 2 + + items = [(i,) for i in range(10)] + + # Should raise ValueError when both limit and use_global_semaphore=True + with pytest.raises( + ValueError, match="Cannot specify both use_global_semaphore=True and a limit" + ): + await concurrent_map(items, simple_task, limit=5, use_global_semaphore=True) From 735ee8e68232da2ddeeb74c89ea902926869ae3c Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 24 Oct 2025 21:51:14 +0200 Subject: [PATCH 3/7] lint --- tests/test_global_concurrency.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/test_global_concurrency.py b/tests/test_global_concurrency.py index 5df1d68a39..f6366e3c53 100644 --- a/tests/test_global_concurrency.py +++ b/tests/test_global_concurrency.py @@ -3,7 +3,7 @@ """ import asyncio -from typing import TYPE_CHECKING, Any +from typing import Any import numpy as np import pytest @@ -12,9 +12,6 @@ from zarr.core.common import get_global_semaphore, reset_global_semaphores from zarr.core.config import config -if TYPE_CHECKING: - from numpy.typing import NDArray - class TestGlobalSemaphore: """Tests for the global semaphore management.""" @@ -241,7 +238,7 @@ def test_sync_api_uses_global_concurrency(self) -> None: arr[:] = 42 # Read data (synchronously) - data: NDArray[Any] = arr[:] + data = arr[:] # Verify we got the right data assert np.all(data == 42) From 3ef6cfba55f2e7203551424792e55005230cee90 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Mon, 27 Oct 2025 16:31:44 +0100 Subject: [PATCH 4/7] changelog --- changes/3547.misc.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changes/3547.misc.md diff --git a/changes/3547.misc.md b/changes/3547.misc.md new file mode 100644 index 0000000000..771bfe8861 --- /dev/null +++ b/changes/3547.misc.md @@ -0,0 +1 @@ +Moved concurrency limits to a global per-event loop setting instead of per-array call. \ No newline at end of file From 229e3b3eeab054de6fa506facd4466f3343fa8b4 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 18 Dec 2025 10:04:06 +0100 Subject: [PATCH 5/7] move concurrency limiting logic to stores --- src/zarr/abc/codec.py | 21 ++--- src/zarr/abc/store.py | 8 +- src/zarr/core/array.py | 24 ++--- src/zarr/core/codec_pipeline.py | 51 ++++++----- src/zarr/storage/_fsspec.py | 34 +++++++ src/zarr/storage/_local.py | 45 ++++++++-- src/zarr/storage/_memory.py | 13 ++- src/zarr/storage/_obstore.py | 154 +++++++++++++++++++++++--------- src/zarr/storage/_utils.py | 68 +++++++++++++- tests/test_group.py | 56 ------------ 10 files changed, 304 insertions(+), 170 deletions(-) diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index b5f7819a91..69d6c3082e 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from abc import abstractmethod from collections.abc import Mapping from typing import TYPE_CHECKING, Generic, TypeGuard, TypeVar @@ -8,7 +9,7 @@ from zarr.abc.metadata import Metadata from zarr.core.buffer import Buffer, NDBuffer -from zarr.core.common import NamedConfig, concurrent_map +from zarr.core.common import NamedConfig if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Iterable @@ -224,10 +225,8 @@ async def decode_partial( ------- Iterable[NDBuffer | None] """ - return await concurrent_map( - list(batch_info), - self._decode_partial_single, - ) + # Store handles concurrency limiting internally + return await asyncio.gather(*[self._decode_partial_single(*info) for info in batch_info]) class ArrayBytesCodecPartialEncodeMixin: @@ -260,10 +259,8 @@ async def encode_partial( The ByteSetter is used to write the necessary bytes and fetch bytes for existing chunk data. The chunk spec contains information about the chunk. """ - await concurrent_map( - list(batch_info), - self._encode_partial_single, - ) + # Store handles concurrency limiting internally + await asyncio.gather(*[self._encode_partial_single(*info) for info in batch_info]) class CodecPipeline: @@ -461,10 +458,8 @@ async def _batching_helper( func: Callable[[CodecInput, ArraySpec], Awaitable[CodecOutput | None]], batch_info: Iterable[tuple[CodecInput | None, ArraySpec]], ) -> list[CodecOutput | None]: - return await concurrent_map( - list(batch_info), - _noop_for_none(func), - ) + # Store handles concurrency limiting internally + return await asyncio.gather(*[_noop_for_none(func)(chunk, spec) for chunk, spec in batch_info]) def _noop_for_none( diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 30602edf34..4ccab1877f 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from abc import ABC, abstractmethod from asyncio import gather from dataclasses import dataclass @@ -462,11 +463,8 @@ async def getsize_prefix(self, prefix: str) -> int: # improve tail latency and might reduce memory pressure (since not all keys # would be in memory at once). - # avoid circular import - from zarr.core.common import concurrent_map - - keys = [(x,) async for x in self.list_prefix(prefix)] - sizes = await concurrent_map(keys, self.getsize) + keys = [x async for x in self.list_prefix(prefix)] + sizes = await asyncio.gather(*[self.getsize(key) for key in keys]) return sum(sizes) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index d036cd7974..01ff74f38f 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import json import warnings from asyncio import gather @@ -59,7 +60,6 @@ _default_zarr_format, _warn_order_kwarg, ceildiv, - concurrent_map, parse_shapelike, product, ) @@ -1847,12 +1847,12 @@ async def resize(self, new_shape: ShapeLike, delete_outside_chunks: bool = True) async def _delete_key(key: str) -> None: await (self.store_path / key).delete() - await concurrent_map( - [ - (self.metadata.encode_chunk_key(chunk_coords),) + # Store handles concurrency limiting internally + await asyncio.gather( + *[ + _delete_key(self.metadata.encode_chunk_key(chunk_coords)) for chunk_coords in old_chunk_coords.difference(new_chunk_coords) - ], - _delete_key, + ] ) # Write new metadata @@ -4533,9 +4533,9 @@ async def _copy_array_region( await result.setitem(chunk_coords, arr) # Stream data from the source array to the new array - await concurrent_map( - [(region, data) for region in result._iter_shard_regions()], - _copy_array_region, + # Store handles concurrency limiting internally + await asyncio.gather( + *[_copy_array_region(region, data) for region in result._iter_shard_regions()] ) else: @@ -4543,9 +4543,9 @@ async def _copy_arraylike_region(chunk_coords: slice, _data: NDArrayLike) -> Non await result.setitem(chunk_coords, _data[chunk_coords]) # Stream data from the source array to the new array - await concurrent_map( - [(region, data) for region in result._iter_shard_regions()], - _copy_arraylike_region, + # Store handles concurrency limiting internally + await asyncio.gather( + *[_copy_arraylike_region(region, data) for region in result._iter_shard_regions()] ) return result diff --git a/src/zarr/core/codec_pipeline.py b/src/zarr/core/codec_pipeline.py index e99080acec..0f8350f7ea 100644 --- a/src/zarr/core/codec_pipeline.py +++ b/src/zarr/core/codec_pipeline.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from dataclasses import dataclass from itertools import islice, pairwise from typing import TYPE_CHECKING, Any, TypeVar @@ -14,7 +15,6 @@ Codec, CodecPipeline, ) -from zarr.core.common import concurrent_map from zarr.core.config import config from zarr.core.indexing import SelectorTuple, is_scalar from zarr.errors import ZarrUserWarning @@ -267,9 +267,12 @@ async def read_batch( else: out[out_selection] = fill_value_or_default(chunk_spec) else: - chunk_bytes_batch = await concurrent_map( - [(byte_getter, array_spec.prototype) for byte_getter, array_spec, *_ in batch_info], - lambda byte_getter, prototype: byte_getter.get(prototype), + # Store handles concurrency limiting internally + chunk_bytes_batch = await asyncio.gather( + *[ + byte_getter.get(array_spec.prototype) + for byte_getter, array_spec, *_ in batch_info + ] ) chunk_array_batch = await self.decode_batch( [ @@ -367,15 +370,15 @@ async def _read_key( return await byte_setter.get(prototype=prototype) chunk_bytes_batch: Iterable[Buffer | None] - chunk_bytes_batch = await concurrent_map( - [ - ( + # Store handles concurrency limiting internally + chunk_bytes_batch = await asyncio.gather( + *[ + _read_key( None if is_complete_chunk else byte_setter, chunk_spec.prototype, ) for byte_setter, chunk_spec, chunk_selection, _, is_complete_chunk in batch_info - ], - _read_key, + ] ) chunk_array_decoded = await self.decode_batch( [ @@ -433,14 +436,14 @@ async def _write_key(byte_setter: ByteSetter, chunk_bytes: Buffer | None) -> Non else: await byte_setter.set(chunk_bytes) - await concurrent_map( - [ - (byte_setter, chunk_bytes) + # Store handles concurrency limiting internally + await asyncio.gather( + *[ + _write_key(byte_setter, chunk_bytes) for chunk_bytes, (byte_setter, *_) in zip( chunk_bytes_batch, batch_info, strict=False ) - ], - _write_key, + ] ) async def decode( @@ -467,12 +470,12 @@ async def read( out: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: - await concurrent_map( - [ - (single_batch_info, out, drop_axes) + # Process mini-batches concurrently - stores handle I/O concurrency internally + await asyncio.gather( + *[ + self.read_batch(single_batch_info, out, drop_axes) for single_batch_info in batched(batch_info, self.batch_size) - ], - self.read_batch, + ] ) async def write( @@ -481,12 +484,12 @@ async def write( value: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: - await concurrent_map( - [ - (single_batch_info, value, drop_axes) + # Process mini-batches concurrently - stores handle I/O concurrency internally + await asyncio.gather( + *[ + self.write_batch(single_batch_info, value, drop_axes) for single_batch_info in batched(batch_info, self.batch_size) - ], - self.write_batch, + ] ) diff --git a/src/zarr/storage/_fsspec.py b/src/zarr/storage/_fsspec.py index 7945fba467..e1ca718784 100644 --- a/src/zarr/storage/_fsspec.py +++ b/src/zarr/storage/_fsspec.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import json import warnings from contextlib import suppress @@ -17,6 +18,7 @@ from zarr.core.buffer import Buffer from zarr.errors import ZarrUserWarning from zarr.storage._common import _dereference_path +from zarr.storage._utils import with_concurrency_limit if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable @@ -82,6 +84,9 @@ class FsspecStore(Store): filesystem scheme. allowed_exceptions : tuple[type[Exception], ...] When fetching data, these cases will be deemed to correspond to missing keys. + concurrency_limit : int, optional + Maximum number of concurrent I/O operations. Default is 50. + Set to None for unlimited concurrency. Attributes ---------- @@ -117,18 +122,24 @@ class FsspecStore(Store): fs: AsyncFileSystem allowed_exceptions: tuple[type[Exception], ...] path: str + _semaphore: asyncio.Semaphore | None def __init__( self, fs: AsyncFileSystem, + *, read_only: bool = False, path: str = "/", allowed_exceptions: tuple[type[Exception], ...] = ALLOWED_EXCEPTIONS, + concurrency_limit: int | None = 50, ) -> None: super().__init__(read_only=read_only) self.fs = fs self.path = path self.allowed_exceptions = allowed_exceptions + self._semaphore = ( + asyncio.Semaphore(concurrency_limit) if concurrency_limit is not None else None + ) if not self.fs.async_impl: raise TypeError("Filesystem needs to support async operations.") @@ -273,6 +284,7 @@ def __eq__(self, other: object) -> bool: and self.fs == other.fs ) + @with_concurrency_limit() async def get( self, key: str, @@ -315,6 +327,7 @@ async def get( else: return value + @with_concurrency_limit() async def set( self, key: str, @@ -335,6 +348,27 @@ async def set( raise NotImplementedError await self.fs._pipe_file(path, value.to_bytes()) + async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None: + # Override to avoid deadlock from calling decorated set() method + if not self._is_open: + await self._open() + self._check_writable() + + async def _set_with_limit(key: str, value: Buffer) -> None: + if not isinstance(value, Buffer): + raise TypeError( + f"FsspecStore.set(): `value` must be a Buffer instance. Got an instance of {type(value)} instead." + ) + path = _dereference_path(self.path, key) + if self._semaphore: + async with self._semaphore: + await self.fs._pipe_file(path, value.to_bytes()) + else: + await self.fs._pipe_file(path, value.to_bytes()) + + await asyncio.gather(*[_set_with_limit(key, value) for key, value in values]) + + @with_concurrency_limit() async def delete(self, key: str) -> None: # docstring inherited self._check_writable() diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index d6f10be862..ea48c756d3 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -19,12 +19,13 @@ ) from zarr.core.buffer import Buffer from zarr.core.buffer.core import default_buffer_prototype -from zarr.core.common import AccessModeLiteral, concurrent_map +from zarr.storage._utils import with_concurrency_limit if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable, Iterator from zarr.core.buffer import BufferPrototype + from zarr.core.common import AccessModeLiteral def _get(path: Path, prototype: BufferPrototype, byte_range: ByteRequest | None) -> Buffer: @@ -95,6 +96,9 @@ class LocalStore(Store): Directory to use as root of store. read_only : bool Whether the store is read-only + concurrency_limit : int, optional + Maximum number of concurrent I/O operations. Default is 100. + Set to None for unlimited concurrency. Attributes ---------- @@ -109,8 +113,15 @@ class LocalStore(Store): supports_listing: bool = True root: Path + _semaphore: asyncio.Semaphore | None - def __init__(self, root: Path | str, *, read_only: bool = False) -> None: + def __init__( + self, + root: Path | str, + *, + read_only: bool = False, + concurrency_limit: int | None = 100, + ) -> None: super().__init__(read_only=read_only) if isinstance(root, str): root = Path(root) @@ -119,12 +130,17 @@ def __init__(self, root: Path | str, *, read_only: bool = False) -> None: f"'root' must be a string or Path instance. Got an instance of {type(root)} instead." ) self.root = root + self._semaphore = ( + asyncio.Semaphore(concurrency_limit) if concurrency_limit is not None else None + ) def with_read_only(self, read_only: bool = False) -> Self: # docstring inherited + concurrency_limit = self._semaphore._value if self._semaphore else None return type(self)( root=self.root, read_only=read_only, + concurrency_limit=concurrency_limit, ) @classmethod @@ -187,6 +203,7 @@ def __repr__(self) -> str: def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) and self.root == other.root + @with_concurrency_limit() async def get( self, key: str, @@ -212,13 +229,23 @@ async def get_partial_values( key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited - args = [] - for key, byte_range in key_ranges: - assert isinstance(key, str) + # Note: We directly call the I/O functions here, wrapped with semaphore + # to avoid deadlock from calling the decorated get() method + + async def _get_with_limit(key: str, byte_range: ByteRequest | None) -> Buffer | None: path = self.root / key - args.append((_get, path, prototype, byte_range)) - # Use global semaphore to limit concurrent thread spawning - return await concurrent_map(args, asyncio.to_thread) + try: + if self._semaphore: + async with self._semaphore: + return await asyncio.to_thread(_get, path, prototype, byte_range) + else: + return await asyncio.to_thread(_get, path, prototype, byte_range) + except (FileNotFoundError, IsADirectoryError, NotADirectoryError): + return None + + return await asyncio.gather( + *[_get_with_limit(key, byte_range) for key, byte_range in key_ranges] + ) async def set(self, key: str, value: Buffer) -> None: # docstring inherited @@ -231,6 +258,7 @@ async def set_if_not_exists(self, key: str, value: Buffer) -> None: except FileExistsError: pass + @with_concurrency_limit() async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None: if not self._is_open: await self._open() @@ -243,6 +271,7 @@ async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None: path = self.root / key await asyncio.to_thread(_put, path, value, exclusive=exclusive) + @with_concurrency_limit() async def delete(self, key: str) -> None: """ Remove a key from the store. diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index e968e3cd26..be222c96b7 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -1,12 +1,12 @@ from __future__ import annotations +import asyncio from logging import getLogger from typing import TYPE_CHECKING, Self from zarr.abc.store import ByteRequest, Store from zarr.core.buffer import Buffer, gpu from zarr.core.buffer.core import default_buffer_prototype -from zarr.core.common import concurrent_map from zarr.storage._utils import _normalize_byte_range_index if TYPE_CHECKING: @@ -102,13 +102,10 @@ async def get_partial_values( key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited - - # All the key-ranges arguments goes with the same prototype - async def _get(key: str, byte_range: ByteRequest | None) -> Buffer | None: - return await self.get(key, prototype=prototype, byte_range=byte_range) - - # In-memory operations are fast and don't benefit from concurrency limiting - return await concurrent_map(key_ranges, _get, use_global_semaphore=False) + # In-memory operations are fast and don't need concurrency limiting + return await asyncio.gather( + *[self.get(key, prototype, byte_range) for key, byte_range in key_ranges] + ) async def exists(self, key: str) -> bool: # docstring inherited diff --git a/src/zarr/storage/_obstore.py b/src/zarr/storage/_obstore.py index e1d1bde672..223142d371 100644 --- a/src/zarr/storage/_obstore.py +++ b/src/zarr/storage/_obstore.py @@ -13,7 +13,8 @@ Store, SuffixByteRequest, ) -from zarr.core.common import concurrent_map, get_global_semaphore +from zarr.core.common import get_global_semaphore +from zarr.storage._utils import with_concurrency_limit if TYPE_CHECKING: from collections.abc import AsyncGenerator, Coroutine, Iterable, Sequence @@ -46,6 +47,9 @@ class ObjectStore(Store, Generic[T_Store]): An obstore store instance that is set up with the proper credentials. read_only : bool Whether to open the store in read-only mode. + concurrency_limit : int, optional + Maximum number of concurrent I/O operations. Default is 50. + Set to None for unlimited concurrency. Warnings -------- @@ -55,6 +59,7 @@ class ObjectStore(Store, Generic[T_Store]): store: T_Store """The underlying obstore instance.""" + _semaphore: asyncio.Semaphore | None def __eq__(self, value: object) -> bool: if not isinstance(value, ObjectStore): @@ -65,17 +70,28 @@ def __eq__(self, value: object) -> bool: return self.store == value.store # type: ignore[no-any-return] - def __init__(self, store: T_Store, *, read_only: bool = False) -> None: + def __init__( + self, + store: T_Store, + *, + read_only: bool = False, + concurrency_limit: int | None = 50, + ) -> None: if not store.__class__.__module__.startswith("obstore"): raise TypeError(f"expected ObjectStore class, got {store!r}") super().__init__(read_only=read_only) self.store = store + self._semaphore = ( + asyncio.Semaphore(concurrency_limit) if concurrency_limit is not None else None + ) def with_read_only(self, read_only: bool = False) -> Self: # docstring inherited + concurrency_limit = self._semaphore._value if self._semaphore else None return type(self)( store=self.store, read_only=read_only, + concurrency_limit=concurrency_limit, ) def __str__(self) -> str: @@ -93,6 +109,7 @@ def __setstate__(self, state: dict[Any, Any]) -> None: state["store"] = pickle.loads(state["store"]) self.__dict__.update(state) + @with_concurrency_limit() async def get( self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None ) -> Buffer | None: @@ -100,41 +117,7 @@ async def get( import obstore as obs try: - if byte_range is None: - resp = await obs.get_async(self.store, key) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] - elif isinstance(byte_range, RangeByteRequest): - bytes = await obs.get_range_async( - self.store, key, start=byte_range.start, end=byte_range.end - ) - return prototype.buffer.from_bytes(bytes) # type: ignore[arg-type] - elif isinstance(byte_range, OffsetByteRequest): - resp = await obs.get_async( - self.store, key, options={"range": {"offset": byte_range.offset}} - ) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] - elif isinstance(byte_range, SuffixByteRequest): - # some object stores (Azure) don't support suffix requests. In this - # case, our workaround is to first get the length of the object and then - # manually request the byte range at the end. - try: - resp = await obs.get_async( - self.store, key, options={"range": {"suffix": byte_range.suffix}} - ) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] - except obs.exceptions.NotSupportedError: - head_resp = await obs.head_async(self.store, key) - file_size = head_resp["size"] - suffix_len = byte_range.suffix - buffer = await obs.get_range_async( - self.store, - key, - start=file_size - suffix_len, - length=suffix_len, - ) - return prototype.buffer.from_bytes(buffer) # type: ignore[arg-type] - else: - raise ValueError(f"Unexpected byte_range, got {byte_range}") + return await self._get_impl(key, prototype, byte_range, obs) except _ALLOWED_EXCEPTIONS: return None @@ -144,7 +127,60 @@ async def get_partial_values( key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited - return await _get_partial_values(self.store, prototype=prototype, key_ranges=key_ranges) + # Note: We directly call obs operations here, wrapped with semaphore + # to avoid deadlock from calling the decorated get() method + import obstore as obs + + async def _get_with_limit(key: str, byte_range: ByteRequest | None) -> Buffer | None: + try: + if self._semaphore: + async with self._semaphore: + return await self._get_impl(key, prototype, byte_range, obs) + else: + return await self._get_impl(key, prototype, byte_range, obs) + except _ALLOWED_EXCEPTIONS: + return None + + return await asyncio.gather( + *[_get_with_limit(key, byte_range) for key, byte_range in key_ranges] + ) + + async def _get_impl( + self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None, obs: Any + ) -> Buffer: + """Implementation of get without semaphore decoration.""" + if byte_range is None: + resp = await obs.get_async(self.store, key) + return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] + elif isinstance(byte_range, RangeByteRequest): + bytes = await obs.get_range_async( + self.store, key, start=byte_range.start, end=byte_range.end + ) + return prototype.buffer.from_bytes(bytes) # type: ignore[arg-type] + elif isinstance(byte_range, OffsetByteRequest): + resp = await obs.get_async( + self.store, key, options={"range": {"offset": byte_range.offset}} + ) + return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] + elif isinstance(byte_range, SuffixByteRequest): + try: + resp = await obs.get_async( + self.store, key, options={"range": {"suffix": byte_range.suffix}} + ) + return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] + except obs.exceptions.NotSupportedError: + head_resp = await obs.head_async(self.store, key) + file_size = head_resp["size"] + suffix_len = byte_range.suffix + buffer = await obs.get_range_async( + self.store, + key, + start=file_size - suffix_len, + length=suffix_len, + ) + return prototype.buffer.from_bytes(buffer) # type: ignore[arg-type] + else: + raise ValueError(f"Unexpected byte_range, got {byte_range}") async def exists(self, key: str) -> bool: # docstring inherited @@ -162,6 +198,7 @@ def supports_writes(self) -> bool: # docstring inherited return True + @with_concurrency_limit() async def set(self, key: str, value: Buffer) -> None: # docstring inherited import obstore as obs @@ -171,20 +208,43 @@ async def set(self, key: str, value: Buffer) -> None: buf = value.as_buffer_like() await obs.put_async(self.store, key, buf) + async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None: + # Override to avoid deadlock from calling decorated set() method + import obstore as obs + + self._check_writable() + + async def _set_with_limit(key: str, value: Buffer) -> None: + buf = value.as_buffer_like() + if self._semaphore: + async with self._semaphore: + await obs.put_async(self.store, key, buf) + else: + await obs.put_async(self.store, key, buf) + + await asyncio.gather(*[_set_with_limit(key, value) for key, value in values]) + async def set_if_not_exists(self, key: str, value: Buffer) -> None: # docstring inherited + # Note: Not decorated to avoid deadlock when called in batch via gather() import obstore as obs self._check_writable() buf = value.as_buffer_like() - with contextlib.suppress(obs.exceptions.AlreadyExistsError): - await obs.put_async(self.store, key, buf, mode="create") + if self._semaphore: + async with self._semaphore: + with contextlib.suppress(obs.exceptions.AlreadyExistsError): + await obs.put_async(self.store, key, buf, mode="create") + else: + with contextlib.suppress(obs.exceptions.AlreadyExistsError): + await obs.put_async(self.store, key, buf, mode="create") @property def supports_deletes(self) -> bool: # docstring inherited return True + @with_concurrency_limit() async def delete(self, key: str) -> None: # docstring inherited import obstore as obs @@ -207,8 +267,18 @@ async def delete_dir(self, prefix: str) -> None: prefix += "/" metas = await obs.list(self.store, prefix).collect_async() - keys = [(m["path"],) for m in metas] - await concurrent_map(keys, self.delete) + + # Delete with semaphore limiting to avoid deadlock + async def _delete_with_limit(path: str) -> None: + if self._semaphore: + async with self._semaphore: + with contextlib.suppress(FileNotFoundError): + await obs.delete_async(self.store, path) + else: + with contextlib.suppress(FileNotFoundError): + await obs.delete_async(self.store, path) + + await asyncio.gather(*[_delete_with_limit(m["path"]) for m in metas]) @property def supports_listing(self) -> bool: diff --git a/src/zarr/storage/_utils.py b/src/zarr/storage/_utils.py index 39c28d44c3..9ce01c2d99 100644 --- a/src/zarr/storage/_utils.py +++ b/src/zarr/storage/_utils.py @@ -1,17 +1,81 @@ from __future__ import annotations +import functools import re from pathlib import Path -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar from zarr.abc.store import OffsetByteRequest, RangeByteRequest, SuffixByteRequest if TYPE_CHECKING: - from collections.abc import Iterable, Mapping + import asyncio + from collections.abc import Callable, Coroutine, Iterable, Mapping from zarr.abc.store import ByteRequest from zarr.core.buffer import Buffer +P = ParamSpec("P") +T_co = TypeVar("T_co", covariant=True) + + +def with_concurrency_limit( + semaphore_attr: str = "_semaphore", +) -> Callable[[Callable[P, Coroutine[Any, Any, T_co]]], Callable[P, Coroutine[Any, Any, T_co]]]: + """ + Decorator that applies a semaphore-based concurrency limit to an async method. + + This decorator is designed for Store methods that need to limit concurrent operations. + The store instance should have a `_semaphore` attribute (or custom attribute name) + that is either an asyncio.Semaphore or None (for unlimited concurrency). + + Parameters + ---------- + semaphore_attr : str, optional + Name of the semaphore attribute on the class instance. Default is "_semaphore". + + Returns + ------- + Callable + The decorated async function with concurrency limiting applied. + + Examples + -------- + ```python + class MyStore(Store): + def __init__(self, concurrency_limit: int = 100): + self._semaphore = asyncio.Semaphore(concurrency_limit) if concurrency_limit else None + + @with_concurrency_limit() + async def get(self, key: str) -> Buffer | None: + # This will only run when semaphore permits + return await expensive_io_operation(key) + ``` + """ + + def decorator( + func: Callable[P, Coroutine[Any, Any, T_co]], + ) -> Callable[P, Coroutine[Any, Any, T_co]]: + @functools.wraps(func) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T_co: + # First arg should be 'self' + if not args: + raise TypeError(f"{func.__name__} requires at least one argument (self)") + + self = args[0] + semaphore: asyncio.Semaphore | None = getattr(self, semaphore_attr, None) + + if semaphore is None: + # No concurrency limit - run directly + return await func(*args, **kwargs) + else: + # Apply concurrency limit + async with semaphore: + return await func(*args, **kwargs) + + return wrapper + + return decorator + def normalize_path(path: str | bytes | Path | None) -> str: if path is None: diff --git a/tests/test_group.py b/tests/test_group.py index 6f1f4e68fa..9f25036298 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -23,7 +23,6 @@ from zarr.core import sync_group from zarr.core._info import GroupInfo from zarr.core.buffer import default_buffer_prototype -from zarr.core.config import config as zarr_config from zarr.core.dtype.common import unpack_dtype_json from zarr.core.dtype.npy.int import UInt8 from zarr.core.group import ( @@ -1738,29 +1737,6 @@ async def test_create_nodes( assert node_spec == {k: v.metadata for k, v in observed_nodes.items()} -@pytest.mark.parametrize("store", ["memory"], indirect=True) -def test_create_nodes_concurrency_limit(store: MemoryStore) -> None: - """ - Test that the execution time of create_nodes can be constrained by the async concurrency - configuration setting. - """ - set_latency = 0.02 - num_groups = 10 - groups = {str(idx): GroupMetadata() for idx in range(num_groups)} - - latency_store = LatencyStore(store, set_latency=set_latency) - - # check how long it takes to iterate over the groups - # if create_nodes is sensitive to IO latency, - # this should take (num_groups * get_latency) seconds - # otherwise, it should take only marginally more than get_latency seconds - with zarr_config.set({"async.concurrency": 1}): - start = time.time() - _ = tuple(sync_group.create_nodes(store=latency_store, nodes=groups)) - elapsed = time.time() - start - assert elapsed > num_groups * set_latency - - @pytest.mark.parametrize( ("a_func", "b_func"), [ @@ -2250,38 +2226,6 @@ def test_group_members_performance(store: Store) -> None: assert elapsed < (num_groups * get_latency) -@pytest.mark.parametrize("store", ["memory"], indirect=True) -def test_group_members_concurrency_limit(store: MemoryStore) -> None: - """ - Test that the execution time of Group.members can be constrained by the async concurrency - configuration setting. - """ - get_latency = 0.02 - - # use the input store to create some groups - group_create = zarr.group(store=store) - num_groups = 10 - - # Create some groups - for i in range(num_groups): - group_create.create_group(f"group{i}") - - latency_store = LatencyStore(store, get_latency=get_latency) - # create a group with some latency on get operations - group_read = zarr.group(store=latency_store) - - # check how long it takes to iterate over the groups - # if .members is sensitive to IO latency, - # this should take (num_groups * get_latency) seconds - # otherwise, it should take only marginally more than get_latency seconds - with zarr_config.set({"async.concurrency": 1}): - start = time.time() - _ = group_read.members() - elapsed = time.time() - start - - assert elapsed > num_groups * get_latency - - @pytest.mark.parametrize("option", ["array", "group", "invalid"]) def test_build_metadata_v3(option: Literal["array", "group", "invalid"]) -> None: """ From 05d191ae3672cbf0f61bb7b08b1f56cc045eee05 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 19 Dec 2025 15:59:54 +0100 Subject: [PATCH 6/7] add store concurrency tests --- src/zarr/storage/_utils.py | 18 +- src/zarr/testing/store_concurrency.py | 247 ++++++++++++++++++++++++++ tests/test_store/test_local.py | 13 ++ tests/test_store/test_memory.py | 13 ++ 4 files changed, 284 insertions(+), 7 deletions(-) create mode 100644 src/zarr/testing/store_concurrency.py diff --git a/src/zarr/storage/_utils.py b/src/zarr/storage/_utils.py index 9ce01c2d99..d156a06891 100644 --- a/src/zarr/storage/_utils.py +++ b/src/zarr/storage/_utils.py @@ -55,6 +55,13 @@ async def get(self, key: str) -> Buffer | None: def decorator( func: Callable[P, Coroutine[Any, Any, T_co]], ) -> Callable[P, Coroutine[Any, Any, T_co]]: + """ + This decorator wraps the invocation of `func` in an `async with semaphore` context manager. + The semaphore object is resolved by getting the `semaphor_attr` attribute from the first + argument to func. When this decorator is used on a method of a class, that first argument + is a reference to the class instance (`self`). + """ + @functools.wraps(func) async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T_co: # First arg should be 'self' @@ -62,15 +69,12 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T_co: raise TypeError(f"{func.__name__} requires at least one argument (self)") self = args[0] - semaphore: asyncio.Semaphore | None = getattr(self, semaphore_attr, None) - if semaphore is None: - # No concurrency limit - run directly + semaphore: asyncio.Semaphore = getattr(self, semaphore_attr) + + # Apply concurrency limit + async with semaphore: return await func(*args, **kwargs) - else: - # Apply concurrency limit - async with semaphore: - return await func(*args, **kwargs) return wrapper diff --git a/src/zarr/testing/store_concurrency.py b/src/zarr/testing/store_concurrency.py new file mode 100644 index 0000000000..06cf23857d --- /dev/null +++ b/src/zarr/testing/store_concurrency.py @@ -0,0 +1,247 @@ +"""Base test class for store concurrency limiting behavior.""" + +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Generic, TypeVar + +import pytest + +from zarr.core.buffer import Buffer, default_buffer_prototype + +if TYPE_CHECKING: + from zarr.abc.store import Store + +__all__ = ["StoreConcurrencyTests"] + + +S = TypeVar("S", bound="Store") +B = TypeVar("B", bound="Buffer") + + +class StoreConcurrencyTests(Generic[S, B]): + """Base class for testing store concurrency limiting behavior. + + This mixin provides tests for verifying that stores correctly implement + concurrency limiting. + + Subclasses should set: + - store_cls: The store class being tested + - buffer_cls: The buffer class to use (e.g., cpu.Buffer) + - expected_concurrency_limit: Expected default concurrency limit (or None for unlimited) + """ + + store_cls: type[S] + buffer_cls: type[B] + expected_concurrency_limit: int | None + + @pytest.fixture + async def store(self, store_kwargs: dict) -> S: + """Create and open a store instance.""" + return await self.store_cls.open(**store_kwargs) + + def test_concurrency_limit_default(self, store: S) -> None: + """Test that store has the expected default concurrency limit.""" + if hasattr(store, "_semaphore"): + if self.expected_concurrency_limit is None: + assert store._semaphore is None, "Expected no concurrency limit" + else: + assert store._semaphore is not None, "Expected concurrency limit to be set" + assert store._semaphore._value == self.expected_concurrency_limit, ( + f"Expected limit {self.expected_concurrency_limit}, got {store._semaphore._value}" + ) + + def test_concurrency_limit_custom(self, store_kwargs: dict) -> None: + """Test that custom concurrency limits can be set.""" + if "concurrency_limit" not in self.store_cls.__init__.__code__.co_varnames: + pytest.skip("Store does not support custom concurrency limits") + + # Test with custom limit + store = self.store_cls(**store_kwargs, concurrency_limit=42) + if hasattr(store, "_semaphore"): + assert store._semaphore is not None + assert store._semaphore._value == 42 + + # Test with None (unlimited) + store = self.store_cls(**store_kwargs, concurrency_limit=None) + if hasattr(store, "_semaphore"): + assert store._semaphore is None + + async def test_concurrency_limit_enforced(self, store: S) -> None: + """Test that the concurrency limit is actually enforced during execution. + + This test verifies that when many operations are submitted concurrently, + only up to the concurrency limit are actually executing at once. + """ + if not hasattr(store, "_semaphore") or store._semaphore is None: + pytest.skip("Store has no concurrency limit") + + limit = store._semaphore._value + + # We'll monitor the semaphore's available count + # When it reaches 0, that means `limit` operations are running + min_available = limit + + async def monitored_operation(key: str, value: B) -> None: + nonlocal min_available + # Check semaphore state right after we're scheduled + await asyncio.sleep(0) # Yield to ensure we're in the queue + available = store._semaphore._value + min_available = min(min_available, available) + + # Now do the actual operation (which will acquire the semaphore) + await store.set(key, value) + + # Launch more operations than the limit to ensure contention + num_ops = limit * 2 + items = [ + (f"limit_test_key_{i}", self.buffer_cls.from_bytes(f"value_{i}".encode())) + for i in range(num_ops) + ] + + await asyncio.gather(*[monitored_operation(k, v) for k, v in items]) + + # The semaphore should have been fully utilized (reached 0 or close to it) + # This indicates that `limit` operations were running concurrently + assert min_available < limit, ( + f"Semaphore was never fully utilized. " + f"Min available: {min_available}, Limit: {limit}. " + f"This suggests operations aren't running concurrently." + ) + + # Ideally it should reach 0, but allow some slack for timing + assert min_available <= 5, ( + f"Semaphore only reached {min_available} available slots. " + f"Expected close to 0 with limit {limit}." + ) + + async def test_batch_write_no_deadlock(self, store: S) -> None: + """Test that batch writes don't deadlock when exceeding concurrency limit.""" + # Create more items than any reasonable concurrency limit + num_items = 200 + items = [ + (f"test_key_{i}", self.buffer_cls.from_bytes(f"test_value_{i}".encode())) + for i in range(num_items) + ] + + # This should complete without deadlock, even if num_items > concurrency_limit + await asyncio.wait_for(store._set_many(items), timeout=30.0) + + # Verify all items were written correctly + for key, expected_value in items: + result = await store.get(key, default_buffer_prototype()) + assert result is not None + assert result.to_bytes() == expected_value.to_bytes() + + async def test_batch_read_no_deadlock(self, store: S) -> None: + """Test that batch reads don't deadlock when exceeding concurrency limit.""" + # Write test data + num_items = 200 + test_data = { + f"test_key_{i}": self.buffer_cls.from_bytes(f"test_value_{i}".encode()) + for i in range(num_items) + } + + for key, value in test_data.items(): + await store.set(key, value) + + # Read all items concurrently - should not deadlock + keys_and_ranges = [(key, None) for key in test_data] + results = await asyncio.wait_for( + store.get_partial_values(default_buffer_prototype(), keys_and_ranges), + timeout=30.0, + ) + + # Verify results + assert len(results) == num_items + for result, (key, expected_value) in zip(results, test_data.items()): + assert result is not None + assert result.to_bytes() == expected_value.to_bytes() + + async def test_batch_delete_no_deadlock(self, store: S) -> None: + """Test that batch deletes don't deadlock when exceeding concurrency limit.""" + if not store.supports_deletes: + pytest.skip("Store does not support deletes") + + # Write test data + num_items = 200 + keys = [f"test_key_{i}" for i in range(num_items)] + for key in keys: + await store.set(key, self.buffer_cls.from_bytes(b"test_value")) + + # Delete all items concurrently - should not deadlock + await asyncio.wait_for(asyncio.gather(*[store.delete(key) for key in keys]), timeout=30.0) + + # Verify all items were deleted + for key in keys: + result = await store.get(key, default_buffer_prototype()) + assert result is None + + async def test_concurrent_operations_correctness(self, store: S) -> None: + """Test that concurrent operations produce correct results.""" + num_operations = 100 + + # Mix of reads and writes + write_keys = [f"write_key_{i}" for i in range(num_operations)] + write_values = [ + self.buffer_cls.from_bytes(f"value_{i}".encode()) for i in range(num_operations) + ] + + # Write all concurrently + await asyncio.gather(*[store.set(k, v) for k, v in zip(write_keys, write_values)]) + + # Read all concurrently + results = await asyncio.gather( + *[store.get(k, default_buffer_prototype()) for k in write_keys] + ) + + # Verify correctness + for result, expected in zip(results, write_values): + assert result is not None + assert result.to_bytes() == expected.to_bytes() + + @pytest.mark.parametrize("batch_size", [1, 10, 50, 100]) + async def test_various_batch_sizes(self, store: S, batch_size: int) -> None: + """Test that various batch sizes work correctly.""" + items = [ + (f"batch_key_{i}", self.buffer_cls.from_bytes(f"batch_value_{i}".encode())) + for i in range(batch_size) + ] + + # Should complete without issues for any batch size + await asyncio.wait_for(store._set_many(items), timeout=10.0) + + # Verify + for key, expected_value in items: + result = await store.get(key, default_buffer_prototype()) + assert result is not None + assert result.to_bytes() == expected_value.to_bytes() + + async def test_empty_batch_operations(self, store: S) -> None: + """Test that empty batch operations don't cause issues.""" + # Empty batch should not raise + await store._set_many([]) + + # Empty read batch + results = await store.get_partial_values(default_buffer_prototype(), []) + assert results == [] + + async def test_mixed_success_failure_batch(self, store: S) -> None: + """Test batch operations with mix of successful and failing items.""" + # Write some initial data + await store.set("existing_key", self.buffer_cls.from_bytes(b"existing_value")) + + # Try to read mix of existing and non-existing keys + key_ranges = [ + ("existing_key", None), + ("non_existing_key_1", None), + ("non_existing_key_2", None), + ] + + results = await store.get_partial_values(default_buffer_prototype(), key_ranges) + + # First should exist, others should be None + assert results[0] is not None + assert results[0].to_bytes() == b"existing_value" + assert results[1] is None + assert results[2] is None diff --git a/tests/test_store/test_local.py b/tests/test_store/test_local.py index 6756bc83d9..73eec991f8 100644 --- a/tests/test_store/test_local.py +++ b/tests/test_store/test_local.py @@ -12,6 +12,7 @@ from zarr.storage import LocalStore from zarr.storage._local import _atomic_write from zarr.testing.store import StoreTests +from zarr.testing.store_concurrency import StoreConcurrencyTests from zarr.testing.utils import assert_bytes_equal @@ -150,3 +151,15 @@ def test_atomic_write_exclusive_preexisting(tmp_path: pathlib.Path) -> None: f.write(b"abc") assert path.read_bytes() == b"xyz" assert list(path.parent.iterdir()) == [path] # no temp files + + +class TestLocalStoreConcurrency(StoreConcurrencyTests[LocalStore, cpu.Buffer]): + """Test LocalStore concurrency limiting behavior.""" + + store_cls = LocalStore + buffer_cls = cpu.Buffer + expected_concurrency_limit = 100 # LocalStore default + + @pytest.fixture + def store_kwargs(self, tmpdir: str) -> dict[str, str]: + return {"root": str(tmpdir)} diff --git a/tests/test_store/test_memory.py b/tests/test_store/test_memory.py index 29fa9b2964..2222905745 100644 --- a/tests/test_store/test_memory.py +++ b/tests/test_store/test_memory.py @@ -12,6 +12,7 @@ from zarr.errors import ZarrUserWarning from zarr.storage import GpuMemoryStore, MemoryStore from zarr.testing.store import StoreTests +from zarr.testing.store_concurrency import StoreConcurrencyTests from zarr.testing.utils import gpu_test if TYPE_CHECKING: @@ -130,3 +131,15 @@ def test_from_dict(self) -> None: result = GpuMemoryStore.from_dict(d) for v in result._store_dict.values(): assert type(v) is gpu.Buffer + + +class TestMemoryStoreConcurrency(StoreConcurrencyTests[MemoryStore, cpu.Buffer]): + """Test MemoryStore concurrency limiting behavior.""" + + store_cls = MemoryStore + buffer_cls = cpu.Buffer + expected_concurrency_limit = None # MemoryStore has no limit (fast in-memory ops) + + @pytest.fixture + def store_kwargs(self) -> dict[str, Any]: + return {"store_dict": None} From 21fc1d547a31c067800227c6dabb095ee6e9ee94 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Sun, 8 Feb 2026 20:52:32 +0100 Subject: [PATCH 7/7] remove more references to the global concurrency limit --- src/zarr/core/common.py | 139 +--------- src/zarr/core/group.py | 58 +---- src/zarr/storage/_obstore.py | 299 ++++------------------ src/zarr/storage/_utils.py | 8 +- src/zarr/testing/store_concurrency.py | 8 +- tests/test_common.py | 4 - tests/test_global_concurrency.py | 354 +++----------------------- 7 files changed, 103 insertions(+), 767 deletions(-) diff --git a/src/zarr/core/common.py b/src/zarr/core/common.py index 88f1388091..e45c256310 100644 --- a/src/zarr/core/common.py +++ b/src/zarr/core/common.py @@ -1,15 +1,11 @@ from __future__ import annotations -import asyncio import functools import math import operator -import threading import warnings -import weakref from collections.abc import Iterable, Mapping, Sequence from enum import Enum -from itertools import starmap from typing import ( TYPE_CHECKING, Any, @@ -29,7 +25,7 @@ from zarr.errors import ZarrRuntimeWarning if TYPE_CHECKING: - from collections.abc import Awaitable, Callable, Iterator + from collections.abc import Iterator ZARR_JSON = "zarr.json" @@ -96,139 +92,6 @@ def ceildiv(a: float, b: float) -> int: return math.ceil(a / b) -T = TypeVar("T", bound=tuple[Any, ...]) -V = TypeVar("V") - - -# Global semaphore management for per-process concurrency limiting -# Use WeakKeyDictionary to automatically clean up semaphores when event loops are garbage collected -_global_semaphores: weakref.WeakKeyDictionary[asyncio.AbstractEventLoop, asyncio.Semaphore] = ( - weakref.WeakKeyDictionary() -) -# Use threading.Lock instead of asyncio.Lock to coordinate across event loops -_global_semaphore_lock = threading.Lock() - - -def get_global_semaphore() -> asyncio.Semaphore: - """ - Get the global semaphore for the current event loop. - - This ensures that all concurrent operations across the process share the same - concurrency limit, preventing excessive concurrent task creation when multiple - arrays or operations are running simultaneously. - - The semaphore is lazily created per event loop and uses the configured - `async.concurrency` value from zarr config. The semaphore is cached per event - loop, so subsequent calls return the same semaphore instance. - - Note: Config changes after the first call will not affect the semaphore limit. - To apply new config values, use :func:`reset_global_semaphores` to clear the cache. - - Returns - ------- - asyncio.Semaphore - The global semaphore for this event loop. - - Raises - ------ - RuntimeError - If called outside of an async context (no running event loop). - - See Also - -------- - reset_global_semaphores : Clear the global semaphore cache - """ - loop = asyncio.get_running_loop() - - # Acquire lock FIRST to prevent TOCTOU race condition - with _global_semaphore_lock: - if loop not in _global_semaphores: - limit = zarr_config.get("async.concurrency") - _global_semaphores[loop] = asyncio.Semaphore(limit) - return _global_semaphores[loop] - - -def reset_global_semaphores() -> None: - """ - Clear all cached global semaphores. - - This is useful when you want config changes to take effect, or for testing. - The next call to :func:`get_global_semaphore` will create a new semaphore - using the current configuration. - - Warning: This should only be called when no async operations are in progress, - as it will invalidate all existing semaphore references. - - Examples - -------- - >>> import zarr - >>> zarr.config.set({"async.concurrency": 50}) - >>> reset_global_semaphores() # Apply new config - """ - with _global_semaphore_lock: - _global_semaphores.clear() - - -async def concurrent_map( - items: Iterable[T], - func: Callable[..., Awaitable[V]], - limit: int | None = None, - *, - use_global_semaphore: bool = True, -) -> list[V]: - """ - Execute an async function concurrently over multiple items with concurrency limiting. - - Parameters - ---------- - items : Iterable[T] - Items to process, where each item is a tuple of arguments to pass to func. - func : Callable[..., Awaitable[V]] - Async function to execute for each item. - limit : int | None, optional - If provided and use_global_semaphore is False, creates a local semaphore - with this limit. If None, no concurrency limiting is applied. - use_global_semaphore : bool, default True - If True, uses the global per-process semaphore for concurrency limiting, - ensuring all concurrent operations share the same limit. If False, uses - the `limit` parameter for local limiting (legacy behavior). - - Returns - ------- - list[V] - Results from executing func on all items. - """ - if use_global_semaphore: - if limit is not None: - raise ValueError( - "Cannot specify both use_global_semaphore=True and a limit value. " - "Either use the global semaphore (use_global_semaphore=True, limit=None) " - "or specify a local limit (use_global_semaphore=False, limit=)." - ) - # Use the global semaphore for process-wide concurrency limiting - sem = get_global_semaphore() - - async def run(item: tuple[Any]) -> V: - async with sem: - return await func(*item) - - return await asyncio.gather(*[asyncio.ensure_future(run(item)) for item in items]) - - elif limit is None: - # No concurrency limiting - return await asyncio.gather(*list(starmap(func, items))) - - else: - # Legacy mode: create local semaphore with specified limit - sem = asyncio.Semaphore(limit) - - async def run(item: tuple[Any]) -> V: - async with sem: - return await func(*item) - - return await asyncio.gather(*[asyncio.ensure_future(run(item)) for item in items]) - - E = TypeVar("E", bound=Enum) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 50b57a569f..658de7ef81 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -44,7 +44,6 @@ NodeType, ShapeLike, ZarrFormat, - get_global_semaphore, parse_shapelike, ) from zarr.core.config import config @@ -1441,13 +1440,10 @@ async def _members( ) raise ValueError(msg) - # Use global semaphore for process-wide concurrency limiting - semaphore = get_global_semaphore() async for member in _iter_members_deep( self, max_depth=max_depth, skip_keys=skip_keys, - semaphore=semaphore, use_consolidated_for_children=use_consolidated_for_children, ): yield member @@ -3324,13 +3320,11 @@ async def create_nodes( The created nodes in the order they are created. """ - # Use global semaphore for process-wide concurrency limiting - semaphore = get_global_semaphore() create_tasks: list[Coroutine[None, None, str]] = [] for key, value in nodes.items(): # make the key absolute - create_tasks.extend(_persist_metadata(store, key, value, semaphore=semaphore)) + create_tasks.extend(_persist_metadata(store, key, value)) created_object_keys = [] @@ -3476,28 +3470,16 @@ def _ensure_consistent_zarr_format( ) -async def _getitem_semaphore( - node: AsyncGroup, key: str, semaphore: asyncio.Semaphore | None -) -> AnyAsyncArray | AsyncGroup: +async def _getitem(node: AsyncGroup, key: str) -> AnyAsyncArray | AsyncGroup: """ - Wrap Group.getitem with an optional semaphore. - - If the semaphore parameter is an - asyncio.Semaphore instance, then the getitem operation is performed inside an async context - manager provided by that semaphore. If the semaphore parameter is None, then getitem is invoked - without a context manager. + Fetch a child node from a group by key. """ - if semaphore is not None: - async with semaphore: - return await node.getitem(key) - else: - return await node.getitem(key) + return await node.getitem(key) async def _iter_members( node: AsyncGroup, skip_keys: tuple[str, ...], - semaphore: asyncio.Semaphore | None, ) -> AsyncGenerator[tuple[str, AnyAsyncArray | AsyncGroup], None]: """ Iterate over the arrays and groups contained in a group. @@ -3508,8 +3490,6 @@ async def _iter_members( The group to traverse. skip_keys : tuple[str, ...] A tuple of keys to skip when iterating over the possible members of the group. - semaphore : asyncio.Semaphore | None - An optional semaphore to use for concurrency control. Yields ------ @@ -3520,10 +3500,7 @@ async def _iter_members( keys = [key async for key in node.store.list_dir(node.path)] keys_filtered = tuple(filter(lambda v: v not in skip_keys, keys)) - node_tasks = tuple( - asyncio.create_task(_getitem_semaphore(node, key, semaphore), name=key) - for key in keys_filtered - ) + node_tasks = tuple(asyncio.create_task(_getitem(node, key), name=key) for key in keys_filtered) for fetched_node_coro in asyncio.as_completed(node_tasks): try: @@ -3550,7 +3527,6 @@ async def _iter_members_deep( *, max_depth: int | None, skip_keys: tuple[str, ...], - semaphore: asyncio.Semaphore | None = None, use_consolidated_for_children: bool = True, ) -> AsyncGenerator[tuple[str, AnyAsyncArray | AsyncGroup], None]: """ @@ -3565,8 +3541,6 @@ async def _iter_members_deep( The maximum depth of recursion. skip_keys : tuple[str, ...] A tuple of keys to skip when iterating over the possible members of the group. - semaphore : asyncio.Semaphore | None - An optional semaphore to use for concurrency control. use_consolidated_for_children : bool, default True Whether to use the consolidated metadata of child groups loaded from the store. Note that this only affects groups loaded from the @@ -3585,7 +3559,7 @@ async def _iter_members_deep( new_depth = None else: new_depth = max_depth - 1 - async for name, node in _iter_members(group, skip_keys=skip_keys, semaphore=semaphore): + async for name, node in _iter_members(group, skip_keys=skip_keys): is_group = isinstance(node, AsyncGroup) if ( is_group @@ -3599,9 +3573,7 @@ async def _iter_members_deep( yield name, node if is_group and do_recursion: node = cast("AsyncGroup", node) - to_recurse[name] = _iter_members_deep( - node, max_depth=new_depth, skip_keys=skip_keys, semaphore=semaphore - ) + to_recurse[name] = _iter_members_deep(node, max_depth=new_depth, skip_keys=skip_keys) for prefix, subgroup_iter in to_recurse.items(): async for name, node in subgroup_iter: @@ -3811,9 +3783,7 @@ async def get_node(store: Store, path: str, zarr_format: ZarrFormat) -> AnyAsync raise ValueError(f"Unexpected zarr format: {zarr_format}") # pragma: no cover -async def _set_return_key( - *, store: Store, key: str, value: Buffer, semaphore: asyncio.Semaphore | None = None -) -> str: +async def _set_return_key(*, store: Store, key: str, value: Buffer) -> str: """ Write a value to storage at the given key. The key is returned. Useful when saving values via routines that return results in execution order, @@ -3828,15 +3798,8 @@ async def _set_return_key( The key to save the value to. value : Buffer The value to save. - semaphore : asyncio.Semaphore | None - An optional semaphore to use to limit the number of concurrent writes. """ - - if semaphore is not None: - async with semaphore: - await store.set(key, value) - else: - await store.set(key, value) + await store.set(key, value) return key @@ -3844,7 +3807,6 @@ def _persist_metadata( store: Store, path: str, metadata: ArrayV2Metadata | ArrayV3Metadata | GroupMetadata, - semaphore: asyncio.Semaphore | None = None, ) -> tuple[Coroutine[None, None, str], ...]: """ Prepare to save a metadata document to storage, returning a tuple of coroutines that must be awaited. @@ -3852,7 +3814,7 @@ def _persist_metadata( to_save = metadata.to_buffer_dict(default_buffer_prototype()) return tuple( - _set_return_key(store=store, key=_join_paths([path, key]), value=value, semaphore=semaphore) + _set_return_key(store=store, key=_join_paths([path, key]), value=value) for key, value in to_save.items() ) diff --git a/src/zarr/storage/_obstore.py b/src/zarr/storage/_obstore.py index 223142d371..697f51ddb0 100644 --- a/src/zarr/storage/_obstore.py +++ b/src/zarr/storage/_obstore.py @@ -4,7 +4,7 @@ import contextlib import pickle from collections import defaultdict -from typing import TYPE_CHECKING, Generic, Self, TypedDict, TypeVar +from typing import TYPE_CHECKING, Generic, Self, TypeVar from zarr.abc.store import ( ByteRequest, @@ -13,14 +13,13 @@ Store, SuffixByteRequest, ) -from zarr.core.common import get_global_semaphore from zarr.storage._utils import with_concurrency_limit if TYPE_CHECKING: from collections.abc import AsyncGenerator, Coroutine, Iterable, Sequence from typing import Any - from obstore import ListResult, ListStream, ObjectMeta, OffsetRange, SuffixRange + from obstore import ListResult, ListStream, ObjectMeta from obstore.store import ObjectStore as _UpstreamObjectStore from zarr.core.buffer import Buffer, BufferPrototype @@ -127,23 +126,59 @@ async def get_partial_values( key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited - # Note: We directly call obs operations here, wrapped with semaphore - # to avoid deadlock from calling the decorated get() method + # We override to: + # 1. Avoid deadlock from calling the decorated get() method + # 2. Batch RangeByteRequests per-file using get_ranges_async for performance import obstore as obs - async def _get_with_limit(key: str, byte_range: ByteRequest | None) -> Buffer | None: + key_ranges = list(key_ranges) + # Group bounded range requests by path for batched fetching + per_file_bounded: dict[str, list[tuple[int, RangeByteRequest]]] = defaultdict(list) + other_requests: list[tuple[int, str, ByteRequest | None]] = [] + + for idx, (path, byte_range) in enumerate(key_ranges): + if isinstance(byte_range, RangeByteRequest): + per_file_bounded[path].append((idx, byte_range)) + else: + other_requests.append((idx, path, byte_range)) + + buffers: list[Buffer | None] = [None] * len(key_ranges) + + async def _fetch_ranges(path: str, requests: list[tuple[int, RangeByteRequest]]) -> None: + """Batch multiple range requests for the same file using get_ranges_async.""" + starts = [r.start for _, r in requests] + ends = [r.end for _, r in requests] + if self._semaphore: + async with self._semaphore: + responses = await obs.get_ranges_async( + self.store, path=path, starts=starts, ends=ends + ) + else: + responses = await obs.get_ranges_async( + self.store, path=path, starts=starts, ends=ends + ) + for (idx, _), response in zip(requests, responses, strict=True): + buffers[idx] = prototype.buffer.from_bytes(response) # type: ignore[arg-type] + + async def _fetch_one(idx: int, path: str, byte_range: ByteRequest | None) -> None: + """Fetch a single non-range request with semaphore limiting.""" try: if self._semaphore: async with self._semaphore: - return await self._get_impl(key, prototype, byte_range, obs) + buffers[idx] = await self._get_impl(path, prototype, byte_range, obs) else: - return await self._get_impl(key, prototype, byte_range, obs) + buffers[idx] = await self._get_impl(path, prototype, byte_range, obs) except _ALLOWED_EXCEPTIONS: - return None + pass # buffers[idx] stays None - return await asyncio.gather( - *[_get_with_limit(key, byte_range) for key, byte_range in key_ranges] - ) + futs: list[Coroutine[Any, Any, None]] = [] + for path, requests in per_file_bounded.items(): + futs.append(_fetch_ranges(path, requests)) + for idx, path, byte_range in other_requests: + futs.append(_fetch_one(idx, path, byte_range)) + + await asyncio.gather(*futs) + return buffers async def _get_impl( self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None, obs: Any @@ -336,243 +371,3 @@ async def _transform_list_dir( objects = [obj["path"].removeprefix(prefix).lstrip("/") for obj in list_result["objects"]] for item in prefixes + objects: yield item - - -class _BoundedRequest(TypedDict): - """Range request with a known start and end byte. - - These requests can be multiplexed natively on the Rust side with - `obstore.get_ranges_async`. - """ - - original_request_index: int - """The positional index in the original key_ranges input""" - - start: int - """Start byte offset.""" - - end: int - """End byte offset.""" - - -class _OtherRequest(TypedDict): - """Offset or suffix range requests. - - These requests cannot be concurrent on the Rust side, and each need their own call - to `obstore.get_async`, passing in the `range` parameter. - """ - - original_request_index: int - """The positional index in the original key_ranges input""" - - path: str - """The path to request from.""" - - range: OffsetRange | None - # Note: suffix requests are handled separately because some object stores (Azure) - # don't support them - """The range request type.""" - - -class _SuffixRequest(TypedDict): - """Offset or suffix range requests. - - These requests cannot be concurrent on the Rust side, and each need their own call - to `obstore.get_async`, passing in the `range` parameter. - """ - - original_request_index: int - """The positional index in the original key_ranges input""" - - path: str - """The path to request from.""" - - range: SuffixRange - """The suffix range.""" - - -class _Response(TypedDict): - """A response buffer associated with the original index that it should be restored to.""" - - original_request_index: int - """The positional index in the original key_ranges input""" - - buffer: Buffer - """The buffer returned from obstore's range request.""" - - -async def _make_bounded_requests( - store: _UpstreamObjectStore, - path: str, - requests: list[_BoundedRequest], - prototype: BufferPrototype, - semaphore: asyncio.Semaphore, -) -> list[_Response]: - """Make all bounded requests for a specific file. - - `obstore.get_ranges_async` allows for making concurrent requests for multiple ranges - within a single file, and will e.g. merge concurrent requests. This only uses one - single Python coroutine. - """ - import obstore as obs - - starts = [r["start"] for r in requests] - ends = [r["end"] for r in requests] - async with semaphore: - responses = await obs.get_ranges_async(store, path=path, starts=starts, ends=ends) - - buffer_responses: list[_Response] = [] - for request, response in zip(requests, responses, strict=True): - buffer_responses.append( - { - "original_request_index": request["original_request_index"], - "buffer": prototype.buffer.from_bytes(response), # type: ignore[arg-type] - } - ) - - return buffer_responses - - -async def _make_other_request( - store: _UpstreamObjectStore, - request: _OtherRequest, - prototype: BufferPrototype, - semaphore: asyncio.Semaphore, -) -> list[_Response]: - """Make offset or full-file requests. - - We return a `list[_Response]` for symmetry with `_make_bounded_requests` so that all - futures can be gathered together. - """ - import obstore as obs - - async with semaphore: - if request["range"] is None: - resp = await obs.get_async(store, request["path"]) - else: - resp = await obs.get_async(store, request["path"], options={"range": request["range"]}) - buffer = await resp.bytes_async() - - return [ - { - "original_request_index": request["original_request_index"], - "buffer": prototype.buffer.from_bytes(buffer), # type: ignore[arg-type] - } - ] - - -async def _make_suffix_request( - store: _UpstreamObjectStore, - request: _SuffixRequest, - prototype: BufferPrototype, - semaphore: asyncio.Semaphore, -) -> list[_Response]: - """Make suffix requests. - - This is separated out from `_make_other_request` because some object stores (Azure) - don't support suffix requests. In this case, our workaround is to first get the - length of the object and then manually request the byte range at the end. - - We return a `list[_Response]` for symmetry with `_make_bounded_requests` so that all - futures can be gathered together. - """ - import obstore as obs - - async with semaphore: - try: - resp = await obs.get_async(store, request["path"], options={"range": request["range"]}) - buffer = await resp.bytes_async() - except obs.exceptions.NotSupportedError: - head_resp = await obs.head_async(store, request["path"]) - file_size = head_resp["size"] - suffix_len = request["range"]["suffix"] - buffer = await obs.get_range_async( - store, - request["path"], - start=file_size - suffix_len, - length=suffix_len, - ) - - return [ - { - "original_request_index": request["original_request_index"], - "buffer": prototype.buffer.from_bytes(buffer), # type: ignore[arg-type] - } - ] - - -async def _get_partial_values( - store: _UpstreamObjectStore, - prototype: BufferPrototype, - key_ranges: Iterable[tuple[str, ByteRequest | None]], -) -> list[Buffer | None]: - """Make multiple range requests. - - ObjectStore has a `get_ranges` method that will additionally merge nearby ranges, - but it's _per_ file. So we need to split these key_ranges into **per-file** key - ranges, and then reassemble the results in the original order. - - We separate into different requests: - - - One call to `obstore.get_ranges_async` **per target file** - - One call to `obstore.get_async` for each other request. - """ - key_ranges = list(key_ranges) - per_file_bounded_requests: dict[str, list[_BoundedRequest]] = defaultdict(list) - other_requests: list[_OtherRequest] = [] - suffix_requests: list[_SuffixRequest] = [] - - for idx, (path, byte_range) in enumerate(key_ranges): - if byte_range is None: - other_requests.append( - { - "original_request_index": idx, - "path": path, - "range": None, - } - ) - elif isinstance(byte_range, RangeByteRequest): - per_file_bounded_requests[path].append( - {"original_request_index": idx, "start": byte_range.start, "end": byte_range.end} - ) - elif isinstance(byte_range, OffsetByteRequest): - other_requests.append( - { - "original_request_index": idx, - "path": path, - "range": {"offset": byte_range.offset}, - } - ) - elif isinstance(byte_range, SuffixByteRequest): - suffix_requests.append( - { - "original_request_index": idx, - "path": path, - "range": {"suffix": byte_range.suffix}, - } - ) - else: - raise ValueError(f"Unsupported range input: {byte_range}") - - # Use global semaphore for process-wide concurrency limiting - semaphore = get_global_semaphore() - - futs: list[Coroutine[Any, Any, list[_Response]]] = [] - for path, bounded_ranges in per_file_bounded_requests.items(): - futs.append( - _make_bounded_requests(store, path, bounded_ranges, prototype, semaphore=semaphore) - ) - - for request in other_requests: - futs.append(_make_other_request(store, request, prototype, semaphore=semaphore)) # noqa: PERF401 - - for suffix_request in suffix_requests: - futs.append(_make_suffix_request(store, suffix_request, prototype, semaphore=semaphore)) # noqa: PERF401 - - buffers: list[Buffer | None] = [None] * len(key_ranges) - - for responses in await asyncio.gather(*futs): - for resp in responses: - buffers[resp["original_request_index"]] = resp["buffer"] - - return buffers diff --git a/src/zarr/storage/_utils.py b/src/zarr/storage/_utils.py index 0deac52dd9..80bce250e9 100644 --- a/src/zarr/storage/_utils.py +++ b/src/zarr/storage/_utils.py @@ -70,10 +70,12 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T_co: self = args[0] - semaphore: asyncio.Semaphore = getattr(self, semaphore_attr) + semaphore: asyncio.Semaphore | None = getattr(self, semaphore_attr) - # Apply concurrency limit - async with semaphore: + if semaphore is not None: + async with semaphore: + return await func(*args, **kwargs) + else: return await func(*args, **kwargs) return wrapper diff --git a/src/zarr/testing/store_concurrency.py b/src/zarr/testing/store_concurrency.py index 06cf23857d..0dd6dcff17 100644 --- a/src/zarr/testing/store_concurrency.py +++ b/src/zarr/testing/store_concurrency.py @@ -154,7 +154,7 @@ async def test_batch_read_no_deadlock(self, store: S) -> None: # Verify results assert len(results) == num_items - for result, (key, expected_value) in zip(results, test_data.items()): + for result, (_key, expected_value) in zip(results, test_data.items(), strict=True): assert result is not None assert result.to_bytes() == expected_value.to_bytes() @@ -188,7 +188,9 @@ async def test_concurrent_operations_correctness(self, store: S) -> None: ] # Write all concurrently - await asyncio.gather(*[store.set(k, v) for k, v in zip(write_keys, write_values)]) + await asyncio.gather( + *[store.set(k, v) for k, v in zip(write_keys, write_values, strict=True)] + ) # Read all concurrently results = await asyncio.gather( @@ -196,7 +198,7 @@ async def test_concurrent_operations_correctness(self, store: S) -> None: ) # Verify correctness - for result, expected in zip(results, write_values): + for result, expected in zip(results, write_values, strict=True): assert result is not None assert result.to_bytes() == expected.to_bytes() diff --git a/tests/test_common.py b/tests/test_common.py index 0944c3375a..9484d15ca3 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -31,10 +31,6 @@ def test_access_modes() -> None: assert set(ANY_ACCESS_MODE) == set(get_args(AccessModeLiteral)) -# todo: test -def test_concurrent_map() -> None: ... - - # todo: test def test_to_thread() -> None: ... diff --git a/tests/test_global_concurrency.py b/tests/test_global_concurrency.py index f6366e3c53..3cfca5052c 100644 --- a/tests/test_global_concurrency.py +++ b/tests/test_global_concurrency.py @@ -1,327 +1,43 @@ """ -Tests for global per-process concurrency limiting. +Tests for store-level concurrency limiting through the array API. """ -import asyncio -from typing import Any - import numpy as np -import pytest import zarr -from zarr.core.common import get_global_semaphore, reset_global_semaphores -from zarr.core.config import config - - -class TestGlobalSemaphore: - """Tests for the global semaphore management.""" - - async def test_get_global_semaphore_creates_per_loop(self) -> None: - """Test that each event loop gets its own semaphore.""" - sem1 = get_global_semaphore() - assert sem1 is not None - assert isinstance(sem1, asyncio.Semaphore) - - # Getting it again should return the same instance - sem2 = get_global_semaphore() - assert sem1 is sem2 - - async def test_global_semaphore_uses_config_limit(self) -> None: - """Test that the global semaphore respects the configured limit.""" - # Set a custom concurrency limit - original_limit: Any = config.get("async.concurrency") - try: - config.set({"async.concurrency": 5}) - - # Clear existing semaphores to force recreation - reset_global_semaphores() - - sem = get_global_semaphore() - - # The semaphore should have the configured limit - # We can verify this by acquiring all tokens and checking the semaphore is locked - for i in range(5): - await sem.acquire() - if i < 4: - assert not sem.locked() # Should still have capacity - else: - assert sem.locked() # All tokens acquired, semaphore is now locked - - # Release all tokens - for _ in range(5): - sem.release() - - finally: - # Restore original config - config.set({"async.concurrency": original_limit}) - # Clear semaphores again to reset state - reset_global_semaphores() - - async def test_global_semaphore_shared_across_operations(self) -> None: - """Test that multiple concurrent operations share the same semaphore.""" - # Track the maximum number of concurrent tasks - max_concurrent = 0 - current_concurrent = 0 - lock = asyncio.Lock() - - async def tracked_operation() -> None: - """An operation that tracks concurrency.""" - nonlocal max_concurrent, current_concurrent - - async with lock: - current_concurrent += 1 - max_concurrent = max(max_concurrent, current_concurrent) - - # Small delay to ensure overlap - await asyncio.sleep(0.01) - - async with lock: - current_concurrent -= 1 - - # Set a low concurrency limit to make the test observable - original_limit: Any = config.get("async.concurrency") - try: - config.set({"async.concurrency": 5}) - - # Clear existing semaphores - reset_global_semaphores() - - # Get the global semaphore - sem = get_global_semaphore() - - # Create many tasks that use the semaphore - async def task_with_semaphore() -> None: - async with sem: - await tracked_operation() - - # Launch 20 tasks (4x the limit) - tasks = [task_with_semaphore() for _ in range(20)] - await asyncio.gather(*tasks) - - # Maximum concurrent should respect the limit - assert max_concurrent <= 5, f"Max concurrent was {max_concurrent}, expected <= 5" - assert max_concurrent >= 3, ( - f"Max concurrent was {max_concurrent}, expected some concurrency" - ) - - finally: - config.set({"async.concurrency": original_limit}) - reset_global_semaphores() - - async def test_semaphore_reuse_across_calls(self) -> None: - """Test that repeated calls to get_global_semaphore return the same instance.""" - reset_global_semaphores() - - # Call multiple times and verify we get the same instance - sem1 = get_global_semaphore() - sem2 = get_global_semaphore() - sem3 = get_global_semaphore() - - assert sem1 is sem2 is sem3, "Should return same semaphore instance on repeated calls" - - # Verify it's still the same after using it - async with sem1: - sem4 = get_global_semaphore() - assert sem1 is sem4 - - def test_config_change_after_creation(self) -> None: - """Test and document that config changes don't affect existing semaphores.""" - original_limit: Any = config.get("async.concurrency") - try: - # Set initial config - config.set({"async.concurrency": 5}) - - async def check_limit() -> None: - reset_global_semaphores() - - # Create semaphore with limit=5 - sem1 = get_global_semaphore() - initial_capacity: int = sem1._value - - # Change config - config.set({"async.concurrency": 50}) - - # Get semaphore again - should be same instance with old limit - sem2 = get_global_semaphore() - assert sem1 is sem2, "Should return same semaphore instance" - assert sem2._value == initial_capacity, ( - f"Semaphore limit changed from {initial_capacity} to {sem2._value}. " - "Config changes should not affect existing semaphores." - ) - - # Clean up - reset_global_semaphores() - - asyncio.run(check_limit()) - - finally: - config.set({"async.concurrency": original_limit}) - - -class TestArrayConcurrency: - """Tests that array operations use global concurrency limiting.""" - - @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") - async def test_multiple_arrays_share_concurrency_limit(self) -> None: - """Test that reading from multiple arrays shares the global concurrency limit.""" - from zarr.core.common import concurrent_map - - # Track concurrent task executions - max_concurrent_tasks = 0 - current_concurrent_tasks = 0 - task_lock = asyncio.Lock() - - async def tracked_chunk_operation(chunk_id: int) -> int: - """Simulate a chunk operation with tracking.""" - nonlocal max_concurrent_tasks, current_concurrent_tasks - - async with task_lock: - current_concurrent_tasks += 1 - max_concurrent_tasks = max(max_concurrent_tasks, current_concurrent_tasks) - - # Small delay to simulate I/O - await asyncio.sleep(0.001) - - async with task_lock: - current_concurrent_tasks -= 1 - - return chunk_id - - # Set a low concurrency limit - original_limit: Any = config.get("async.concurrency") - try: - config.set({"async.concurrency": 10}) - - # Clear existing semaphores - reset_global_semaphores() - - # Simulate reading many chunks using concurrent_map (which uses the global semaphore) - # This simulates what happens when reading from multiple arrays - chunk_ids = [(i,) for i in range(100)] - await concurrent_map(chunk_ids, tracked_chunk_operation) - - # The maximum concurrent tasks should respect the global limit - assert max_concurrent_tasks <= 10, ( - f"Max concurrent tasks was {max_concurrent_tasks}, expected <= 10" - ) - - assert max_concurrent_tasks >= 5, ( - f"Max concurrent tasks was {max_concurrent_tasks}, " - f"expected at least some concurrency" - ) - - finally: - config.set({"async.concurrency": original_limit}) - # Note: We don't reset_global_semaphores() here because doing so while - # many tasks are still cleaning up can trigger ResourceWarnings from - # asyncio internals. The semaphore will be reused by subsequent tests. - - def test_sync_api_uses_global_concurrency(self) -> None: - """Test that synchronous API also benefits from global concurrency limiting.""" - # This test verifies that the sync API (which wraps async) uses global limiting - - # Set a low concurrency limit - original_limit: Any = config.get("async.concurrency") - try: - config.set({"async.concurrency": 8}) - - # Create a small array - the key is that zarr internally uses - # concurrent_map which now uses the global semaphore - store = zarr.storage.MemoryStore() - arr = zarr.create( - shape=(20, 20), - chunks=(10, 10), - dtype="i4", - store=store, - zarr_format=3, - ) - arr[:] = 42 - - # Read data (synchronously) - data = arr[:] - - # Verify we got the right data - assert np.all(data == 42) - - # The test passes if no errors occurred - # The concurrency limiting is happening under the hood - - finally: - config.set({"async.concurrency": original_limit}) - - -class TestConcurrentMapGlobal: - """Tests for concurrent_map using global semaphore.""" - - async def test_concurrent_map_uses_global_by_default(self) -> None: - """Test that concurrent_map uses global semaphore by default.""" - from zarr.core.common import concurrent_map - - # Track concurrent executions - max_concurrent = 0 - current_concurrent = 0 - lock = asyncio.Lock() - - async def tracked_task(x: int) -> int: - nonlocal max_concurrent, current_concurrent - - async with lock: - current_concurrent += 1 - max_concurrent = max(max_concurrent, current_concurrent) - - await asyncio.sleep(0.01) - - async with lock: - current_concurrent -= 1 - - return x * 2 - - # Set a low limit - original_limit: Any = config.get("async.concurrency") - try: - config.set({"async.concurrency": 5}) - - # Clear existing semaphores - reset_global_semaphores() - - # Use concurrent_map with default settings (use_global_semaphore=True) - items = [(i,) for i in range(20)] - results = await concurrent_map(items, tracked_task) - - assert len(results) == 20 - assert max_concurrent <= 5 - assert max_concurrent >= 3 # Should have some concurrency - - finally: - config.set({"async.concurrency": original_limit}) - reset_global_semaphores() - - async def test_concurrent_map_legacy_mode(self) -> None: - """Test that concurrent_map legacy mode still works.""" - from zarr.core.common import concurrent_map - - async def simple_task(x: int) -> int: - await asyncio.sleep(0.001) - return x * 2 - - # Use legacy mode with local limit - items = [(i,) for i in range(10)] - results = await concurrent_map(items, simple_task, limit=3, use_global_semaphore=False) - - assert len(results) == 10 - assert results == [i * 2 for i in range(10)] - - async def test_concurrent_map_parameter_validation(self) -> None: - """Test that concurrent_map validates conflicting parameters.""" - from zarr.core.common import concurrent_map - - async def simple_task(x: int) -> int: - return x * 2 - items = [(i,) for i in range(10)] - # Should raise ValueError when both limit and use_global_semaphore=True - with pytest.raises( - ValueError, match="Cannot specify both use_global_semaphore=True and a limit" - ): - await concurrent_map(items, simple_task, limit=5, use_global_semaphore=True) +class TestStoreConcurrencyThroughArrayAPI: + """Tests that store-level concurrency limiting works through the array API.""" + + def test_array_operations_with_store_concurrency(self) -> None: + """Test that array read/write works correctly with store-level concurrency limits.""" + store = zarr.storage.MemoryStore() + arr = zarr.create( + shape=(20, 20), + chunks=(10, 10), + dtype="i4", + store=store, + zarr_format=3, + ) + arr[:] = 42 + + data = arr[:] + + assert np.all(data == 42) + + def test_array_operations_with_local_store_concurrency(self, tmp_path: object) -> None: + """Test that array read/write works correctly with LocalStore concurrency limits.""" + store = zarr.storage.LocalStore(str(tmp_path), concurrency_limit=10) + arr = zarr.create( + shape=(20, 20), + chunks=(10, 10), + dtype="i4", + store=store, + zarr_format=3, + ) + arr[:] = 42 + + data = arr[:] + + assert np.all(data == 42)