From 41643b4be9f71de92754da54016348ecedf9a046 Mon Sep 17 00:00:00 2001 From: Nithin Tatikonda Date: Fri, 13 Feb 2026 10:24:11 -0800 Subject: [PATCH] Internal PiperOrigin-RevId: 869789772 --- grain/_src/python/BUILD | 4 +- .../dataset/transformations/prefetch.py | 61 ++++++++++--------- .../dataset/transformations/prefetch_test.py | 54 ++++++++-------- grain/_src/python/options.py | 26 +++++--- 4 files changed, 83 insertions(+), 62 deletions(-) diff --git a/grain/_src/python/BUILD b/grain/_src/python/BUILD index 45cb690fe..4fe4f017b 100644 --- a/grain/_src/python/BUILD +++ b/grain/_src/python/BUILD @@ -222,7 +222,9 @@ py_library( name = "options", srcs = ["options.py"], srcs_version = "PY3", - deps = ["@abseil-py//absl/logging"], + deps = [ + "@abseil-py//absl/logging", + ], ) py_test( diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index 46a756eab..b2c309a92 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -26,6 +26,7 @@ import typing from typing import Any, Optional, Protocol, TypeVar +from absl import logging from concurrent import futures from grain._src.core import monitoring as grain_monitoring from grain._src.python import options as grain_options @@ -144,14 +145,16 @@ def __init__( self._next_buffered_index = 0 self._buffer = collections.deque() self._lock = threading.Lock() - self._prefetch_buffer_size = ( - read_options.prefetch_buffer_size if read_options.num_threads > 0 else 0 - ) - self._num_threads = read_options.num_threads + + assert isinstance(read_options.num_threads, int) + assert isinstance(read_options.prefetch_buffer_size, int) + self._target_num_threads = read_options.num_threads + self._target_prefetch_buffer_size = read_options.prefetch_buffer_size + self._allow_nones = allow_nones - if self._prefetch_buffer_size > 0: + if self._target_prefetch_buffer_size > 0 and self._target_num_threads > 0: self._executor = futures.ThreadPoolExecutor( - self._num_threads, thread_name_prefix="grain-prefetch" + self._target_num_threads, thread_name_prefix="grain-prefetch" ) def _initialize_stats( @@ -195,7 +198,10 @@ def __next__(self) -> T: if self._next_returned_index == self._dataset_length: break with self._lock, timer: - if self._prefetch_buffer_size > 0: + if ( + self._target_prefetch_buffer_size > 0 + and self._target_num_threads > 0 + ): if not self._buffer: # Fill the buffer on the first iteration. self._fill_buffer() @@ -237,11 +243,11 @@ def set_state(self, state): f"Checkpoint `next_index` {self._next_returned_index} is out of" f" range for dataset of length {self._dataset_length}." ) - if self._prefetch_buffer_size > 0: - # Cancel all pending futures in the buffer. - while self._buffer: - future = self._buffer.popleft() - future.cancel() + + # Cancel all pending futures in the buffer. + while self._buffer: + future = self._buffer.popleft() + future.cancel() def _get_next_index(self) -> int: return self._next_returned_index @@ -255,34 +261,33 @@ def __str__(self) -> str: f" allow_nones={self._allow_nones})" ) - def set_prefetch_buffer_size(self, buffer_size: int): - self._prefetch_buffer_size = buffer_size + def _set_prefetch_buffer_size(self, buffer_size: int): + self._target_prefetch_buffer_size = buffer_size # The executor is created in the constructor only if the prefetch buffer # size is greater than 0. If the user changes the prefetch buffer size, we # need to create or destroy the executor accordingly. - if self._prefetch_buffer_size > 0 and not hasattr(self, "_executor"): - if self._num_threads == 0: - raise ValueError( - "num_threads must be greater than 0 when prefetch buffer size is" - " greater than 0." - ) + if ( + self._target_prefetch_buffer_size > 0 + and self._target_num_threads > 0 + and not hasattr(self, "_executor") + ): self._executor = futures.ThreadPoolExecutor( - self._num_threads, thread_name_prefix="grain-prefetch" + self._target_num_threads, thread_name_prefix="grain-prefetch" ) - elif self._prefetch_buffer_size == 0 and hasattr(self, "_executor"): + elif self._target_prefetch_buffer_size == 0 and hasattr(self, "_executor"): self._executor.shutdown() delattr(self, "_executor") - def set_num_threads(self, num_threads: int) -> None: - self._num_threads = num_threads + def _set_num_threads(self, num_threads: int) -> None: + self._target_num_threads = num_threads old_executor = None # Accounts for the case where the executor does not exit. This can # happen if the prefetch buffer size is set to 0. if hasattr(self, "_executor"): old_executor = self._executor - if self._num_threads > 0: + if self._target_num_threads > 0 and self._target_prefetch_buffer_size > 0: self._executor = futures.ThreadPoolExecutor( - self._num_threads, thread_name_prefix="grain-prefetch" + self._target_num_threads, thread_name_prefix="grain-prefetch" ) else: delattr(self, "_executor") @@ -293,7 +298,7 @@ def set_num_threads(self, num_threads: int) -> None: def _fill_buffer(self): while ( - len(self._buffer) < self._prefetch_buffer_size + len(self._buffer) < self._target_prefetch_buffer_size and self._next_buffered_index < self._dataset_length ): # Note that we trigger creation of `_stats` in this (single) thread, it is @@ -307,7 +312,7 @@ def _fill_buffer(self): self._next_buffered_index += 1 def start_prefetch(self): - if self._prefetch_buffer_size > 0: + if self._target_prefetch_buffer_size > 0 and self._target_num_threads > 0: self._fill_buffer() def close(self) -> None: diff --git a/grain/_src/python/dataset/transformations/prefetch_test.py b/grain/_src/python/dataset/transformations/prefetch_test.py index 0855b0b4b..80ae0a723 100644 --- a/grain/_src/python/dataset/transformations/prefetch_test.py +++ b/grain/_src/python/dataset/transformations/prefetch_test.py @@ -158,12 +158,12 @@ def test_set_prefetch_buffer_size_0_to_positive(self): # With prefetch_buffer_size=0, executor is not created. self.assertFalse(hasattr(ds_iter, '_executor')) - self.assertEqual(ds_iter._prefetch_buffer_size, 0) + self.assertEqual(ds_iter._target_prefetch_buffer_size, 0) self.assertEqual(next(ds_iter), 0) # Setting prefetch_buffer_size to 2. - ds_iter.set_prefetch_buffer_size(2) - self.assertEqual(ds_iter._prefetch_buffer_size, 2) + ds_iter._set_prefetch_buffer_size(2) + self.assertEqual(ds_iter._target_prefetch_buffer_size, 2) self.assertEqual(next(ds_iter), 1) self.assertTrue(hasattr(ds_iter, '_executor')) self.assertLen(ds_iter._buffer, 2) @@ -178,13 +178,13 @@ def test_set_prefetch_buffer_size_positive_to_0(self): self.assertIsInstance(ds_iter, prefetch.PrefetchDatasetIterator) ds_iter = cast(prefetch.PrefetchDatasetIterator, ds_iter) - self.assertEqual(ds_iter._prefetch_buffer_size, 2) + self.assertEqual(ds_iter._target_prefetch_buffer_size, 2) self.assertEqual(next(ds_iter), 0) self.assertLen(ds_iter._buffer, 2) # Setting prefetch_buffer_size to 0. - ds_iter.set_prefetch_buffer_size(0) - self.assertEqual(ds_iter._prefetch_buffer_size, 0) + ds_iter._set_prefetch_buffer_size(0) + self.assertEqual(ds_iter._target_prefetch_buffer_size, 0) # Should consume buffer first. self.assertEqual(next(ds_iter), 1) self.assertLen(ds_iter._buffer, 1) @@ -202,13 +202,13 @@ def test_set_prefetch_buffer_size_increase(self): self.assertIsInstance(ds_iter, prefetch.PrefetchDatasetIterator) ds_iter = cast(prefetch.PrefetchDatasetIterator, ds_iter) - self.assertEqual(ds_iter._prefetch_buffer_size, 1) + self.assertEqual(ds_iter._target_prefetch_buffer_size, 1) self.assertEqual(next(ds_iter), 0) self.assertLen(ds_iter._buffer, 1) # Setting prefetch_buffer_size to 2. - ds_iter.set_prefetch_buffer_size(2) - self.assertEqual(ds_iter._prefetch_buffer_size, 2) + ds_iter._set_prefetch_buffer_size(2) + self.assertEqual(ds_iter._target_prefetch_buffer_size, 2) self.assertEqual(next(ds_iter), 1) self.assertLen(ds_iter._buffer, 2) self.assertEqual(next(ds_iter), 2) @@ -222,13 +222,13 @@ def test_set_prefetch_buffer_size_decrease(self): self.assertIsInstance(ds_iter, prefetch.PrefetchDatasetIterator) ds_iter = cast(prefetch.PrefetchDatasetIterator, ds_iter) - self.assertEqual(ds_iter._prefetch_buffer_size, 2) + self.assertEqual(ds_iter._target_prefetch_buffer_size, 2) self.assertEqual(next(ds_iter), 0) self.assertLen(ds_iter._buffer, 2) # Setting prefetch_buffer_size to 1. - ds_iter.set_prefetch_buffer_size(1) - self.assertEqual(ds_iter._prefetch_buffer_size, 1) + ds_iter._set_prefetch_buffer_size(1) + self.assertEqual(ds_iter._target_prefetch_buffer_size, 1) self.assertEqual(next(ds_iter), 1) self.assertLen(ds_iter._buffer, 1) self.assertEqual(next(ds_iter), 2) @@ -321,15 +321,17 @@ def test_set_num_threads_decrease_threads(self): ds_iter = iter(self.prefetch_lazy_iter_ds) self.assertIsInstance(ds_iter, prefetch.PrefetchDatasetIterator) ds_iter = cast(prefetch.PrefetchDatasetIterator, ds_iter) - self.assertEqual(ds_iter._num_threads, options.ReadOptions().num_threads) + self.assertEqual( + ds_iter._target_num_threads, options.ReadOptions().num_threads + ) self.assertEqual( ds_iter._executor._max_workers, options.ReadOptions().num_threads ) self.assertEqual([next(ds_iter) for _ in range(5)], list(range(5))) # Decrease threads - ds_iter.set_num_threads(5) - self.assertEqual(ds_iter._num_threads, 5) + ds_iter._set_num_threads(5) + self.assertEqual(ds_iter._target_num_threads, 5) self.assertEqual(ds_iter._executor._max_workers, 5) self.assertEqual([next(ds_iter) for _ in range(15)], list(range(5, 20))) @@ -340,13 +342,13 @@ def test_set_num_threads_increase_threads(self): ds_iter = iter(ds) self.assertIsInstance(ds_iter, prefetch.PrefetchDatasetIterator) ds_iter = cast(prefetch.PrefetchDatasetIterator, ds_iter) - self.assertEqual(ds_iter._num_threads, 5) + self.assertEqual(ds_iter._target_num_threads, 5) self.assertEqual(ds_iter._executor._max_workers, 5) self.assertEqual([next(ds_iter) for _ in range(5)], list(range(5))) # Increase threads - ds_iter.set_num_threads(10) - self.assertEqual(ds_iter._num_threads, 10) + ds_iter._set_num_threads(10) + self.assertEqual(ds_iter._target_num_threads, 10) self.assertEqual(ds_iter._executor._max_workers, 10) self.assertEqual([next(ds_iter) for _ in range(15)], list(range(5, 20))) @@ -354,14 +356,16 @@ def test_set_num_threads_decrease_to_zero(self): ds_iter = iter(self.prefetch_lazy_iter_ds) self.assertIsInstance(ds_iter, prefetch.PrefetchDatasetIterator) ds_iter = cast(prefetch.PrefetchDatasetIterator, ds_iter) - self.assertEqual(ds_iter._num_threads, options.ReadOptions().num_threads) + self.assertEqual( + ds_iter._target_num_threads, options.ReadOptions().num_threads + ) self.assertEqual( ds_iter._executor._max_workers, options.ReadOptions().num_threads ) self.assertEqual([next(ds_iter) for _ in range(5)], list(range(5))) # Decrease threads to 0 - ds_iter.set_num_threads(0) - self.assertEqual(ds_iter._num_threads, 0) + ds_iter._set_num_threads(0) + self.assertEqual(ds_iter._target_num_threads, 0) self.assertFalse(hasattr(ds_iter, '_executor')) self.assertEqual([next(ds_iter) for _ in range(15)], list(range(5, 20))) @@ -370,14 +374,14 @@ def test_set_num_threads_increase_from_zero(self): self.assertIsInstance(ds_iter, prefetch.PrefetchDatasetIterator) ds_iter = cast(prefetch.PrefetchDatasetIterator, ds_iter) self.assertEqual([next(ds_iter) for _ in range(5)], list(range(5))) - ds_iter.set_num_threads(0) - self.assertEqual(ds_iter._num_threads, 0) + ds_iter._set_num_threads(0) + self.assertEqual(ds_iter._target_num_threads, 0) self.assertFalse(hasattr(ds_iter, '_executor')) self.assertEqual([next(ds_iter) for _ in range(5)], list(range(5, 10))) # Increase threads from 0 - ds_iter.set_num_threads(5) - self.assertEqual(ds_iter._num_threads, 5) + ds_iter._set_num_threads(5) + self.assertEqual(ds_iter._target_num_threads, 5) self.assertEqual(ds_iter._executor._max_workers, 5) self.assertEqual([next(ds_iter) for _ in range(10)], list(range(10, 20))) diff --git a/grain/_src/python/options.py b/grain/_src/python/options.py index 293e3441a..f855aa932 100644 --- a/grain/_src/python/options.py +++ b/grain/_src/python/options.py @@ -12,9 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. """Dataclasses for holdings options.""" +from __future__ import annotations + import dataclasses from absl import logging +class AutotuneParameter: + + def __init__(self, *args, **kwargs): + raise NotImplementedError @dataclasses.dataclass(slots=True) @@ -41,25 +47,29 @@ class ReadOptions: # benchmarks reading from remote hard drives. # These values should work well for datasets with elements between 1 and # 10 KiB on disk. - num_threads: int = 16 - prefetch_buffer_size: int = 500 + num_threads: int | AutotuneParameter = 16 + prefetch_buffer_size: int | AutotuneParameter = 500 def __post_init__(self): - if self.num_threads < 0: + if isinstance(self.num_threads, int) and self.num_threads < 0: raise ValueError( f'num_threads must be non-negative, got {self.num_threads}' ) - if self.prefetch_buffer_size < 0: + + if ( + isinstance(self.prefetch_buffer_size, int) + and self.prefetch_buffer_size < 0 + ): raise ValueError( 'prefetch_buffer_size must be non-negative, got' f' {self.prefetch_buffer_size}' ) + # Avoid warning when setting prefetch_buffer_size=0, since this is commonly # used to disable prefetching. - if ( - self.prefetch_buffer_size < self.num_threads - and self.prefetch_buffer_size != 0 - ): + buffer_size = int(self.prefetch_buffer_size) + num_threads = int(self.num_threads) + if buffer_size < num_threads and buffer_size != 0: logging.warning( 'prefetch_buffer_size=%s is smaller than num_threads=%s. This will' ' limit the number of threads that can actually be used in parallel'