From 713505576f877bb6396e3bb56670f0bb3945d736 Mon Sep 17 00:00:00 2001 From: Nithin Tatikonda Date: Thu, 5 Feb 2026 11:02:30 -0800 Subject: [PATCH] Internal PiperOrigin-RevId: 866025320 --- .../dataset/transformations/interleave.py | 12 +++++++++- .../transformations/interleave_test.py | 18 ++++++++++++++ .../dataset/transformations/prefetch_test.py | 1 - .../transformations/process_prefetch_test.py | 24 ++++++++++++++++++- 4 files changed, 52 insertions(+), 3 deletions(-) diff --git a/grain/_src/python/dataset/transformations/interleave.py b/grain/_src/python/dataset/transformations/interleave.py index 9a387ae07..ab6a068ef 100644 --- a/grain/_src/python/dataset/transformations/interleave.py +++ b/grain/_src/python/dataset/transformations/interleave.py @@ -218,7 +218,7 @@ def set_state(self, state): iterator = _add_prefetch_and_make_iterator( self._datasets[index_in_datasets], interleave_iterator=weakref.ref(self), - start_prefetch=False, + start_prefetch=self._started, ) iterator.set_state(it_state) self._iterators_in_use[index_in_cycle] = iterator @@ -266,6 +266,13 @@ def set_keep_iterators_after_stop_iteration( # continuing iteration without recreating the iterators. self._keep_iterators_after_stop_iteration = keep_iterators + def start_prefetch(self) -> None: + self._prefetch_ds_iter.start_prefetch() + for iterator in self._iterators_in_use: + if iterator is not None: + iterator.start_prefetch() + self._started = True + def close(self) -> None: """Closes the iterator and shuts down the iterator prefetching.""" if self._closed: @@ -275,6 +282,9 @@ def close(self) -> None: for iterator in self._iterators_in_use: if iterator is not None: iterator.close() + for index_iterator_pair in self._exhausted_iterators: + if index_iterator_pair is not None: + index_iterator_pair[1].close() def _initialize_stats( self, execution_tracking_mode: base.ExecutionTrackingMode diff --git a/grain/_src/python/dataset/transformations/interleave_test.py b/grain/_src/python/dataset/transformations/interleave_test.py index c5d647307..af2b0eff0 100644 --- a/grain/_src/python/dataset/transformations/interleave_test.py +++ b/grain/_src/python/dataset/transformations/interleave_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import time + from absl.testing import absltest from absl.testing import flagsaver from absl.testing import parameterized @@ -291,6 +293,22 @@ def test_set_next_index_with_multiple_datasets(self): ): dataset.set_next_index(ds_iter, 0) + def test_start_prefetch(self): + count = 0 + + def map_fn(x): + nonlocal count + count += 1 + return x + + ds = dataset.MapDataset.range(10).to_iter_dataset() + ds = ds.map(map_fn) + ds = interleave.InterleaveIterDataset([ds], cycle_length=1) + ds_iter = ds.__iter__() + ds_iter.start_prefetch() + while count == 0: + time.sleep(0.1) + if __name__ == "__main__": absltest.main() diff --git a/grain/_src/python/dataset/transformations/prefetch_test.py b/grain/_src/python/dataset/transformations/prefetch_test.py index 0855b0b4b..a5b0ddefd 100644 --- a/grain/_src/python/dataset/transformations/prefetch_test.py +++ b/grain/_src/python/dataset/transformations/prefetch_test.py @@ -980,7 +980,6 @@ def map(self, features): @parameterized.parameters(0, 0.5, 30) def test_prefetch_but_no_read(self, sleep_s): ds = dataset.MapDataset.source([1, 2, 3]).repeat() - ds = ds.filter(lambda x: x > 3) ds = ds.to_iter_dataset() ds = prefetch.multithread_prefetch(ds, num_threads=1, buffer_size=1) it = ds.__iter__() diff --git a/grain/_src/python/dataset/transformations/process_prefetch_test.py b/grain/_src/python/dataset/transformations/process_prefetch_test.py index 935d10007..797d14df9 100644 --- a/grain/_src/python/dataset/transformations/process_prefetch_test.py +++ b/grain/_src/python/dataset/transformations/process_prefetch_test.py @@ -48,6 +48,16 @@ def filter(self, element: int) -> bool: return bool(element % 2) +@dataclasses.dataclass(frozen=True) +class WriteMarker(transforms.Map): + path: str + + def map(self, element: int) -> int: + with open(self.path, 'w') as f: + f.write(str(element)) + return element + + class ProcessPrefetchIterDatasetTest(parameterized.TestCase): def setUp(self): @@ -851,10 +861,22 @@ def map(self, features): if not start_prefetch_calls: self.assertGreater(time_to_fetch, 1) + def test_start_prefetch_prefetches_without_next_call(self): + marker_file = os.path.join(self.create_tempdir().full_path, 'marker') + ds = dataset.MapDataset.range(10) + ds = ds.map(WriteMarker(marker_file)) + ds = ds.to_iter_dataset() + ds = process_prefetch.multiprocess_prefetch(ds, num_workers=1) + it = ds.__iter__() + it.start_prefetch() + + # Wait for prefetch to happen. + while not os.path.exists(marker_file): + time.sleep(0.5) + @parameterized.parameters(0, 0.5, 30) def test_prefetch_but_no_read(self, sleep_s): ds = dataset.MapDataset.source([1, 2, 3]).repeat() - ds = ds.filter(lambda x: x > 3) ds = ds.to_iter_dataset() ds = process_prefetch.multiprocess_prefetch(ds, num_workers=1) it = ds.__iter__()