diff --git a/CHANGELOG.md b/CHANGELOG.md index aee025319..ddbce5b4c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,8 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change specification inference. * Adds support for changing `IterDataset.mix` components and weights after a checkpoint. + * Adds experimental support for `get_next_index` and `set_next_index` to fetch + and advance a `grain.DatasetIterator` to the given produced element index. ## Grain 0.2.14 (October 30, 2025) diff --git a/grain/_src/python/dataset/dataset.py b/grain/_src/python/dataset/dataset.py index ba04c95bf..8508f2b96 100644 --- a/grain/_src/python/dataset/dataset.py +++ b/grain/_src/python/dataset/dataset.py @@ -1574,6 +1574,27 @@ def _stats(self) -> dataset_stats.Stats: self._ctx.dataset_options.execution_tracking_mode ) + def _set_next_index(self, index: int) -> None: + """Sets the next index for the dataset iterator. + + Note: This index is the index of the element that will be produced next by + the iterator. Implementations of this method should not process any data. If + advancing to the index without processing is not possible (e.g. for filter + and packing), then the implementation should raise a ValueError. + + Args: + index: The index of the next element to be produced. + """ + raise NotImplementedError + + def _get_next_index(self) -> int: + """Returns the next index for the dataset iterator. + + Note: This index is the index of the element that will be produced next by + the iterator. + """ + raise NotImplementedError + # pytype: enable=attribute-error # pylint: enable=protected-access @@ -1783,3 +1804,13 @@ def get_element_spec( ds: `MapDataset` or `IterDataset` to get the element spec from. """ return ds._element_spec # pylint: disable=protected-access + + +def set_next_index(ds_iter: DatasetIterator, index: int) -> None: + """Sets the next index for the dataset iterator.""" + return ds_iter._set_next_index(index) # pylint: disable=protected-access + + +def get_next_index(ds_iter: DatasetIterator) -> int: + """Returns the next index for the dataset iterator.""" + return ds_iter._get_next_index() # pylint: disable=protected-access diff --git a/grain/_src/python/dataset/sources/tfrecord_dataset.py b/grain/_src/python/dataset/sources/tfrecord_dataset.py index 5f7561ef4..0ba2623e6 100644 --- a/grain/_src/python/dataset/sources/tfrecord_dataset.py +++ b/grain/_src/python/dataset/sources/tfrecord_dataset.py @@ -52,7 +52,7 @@ def __next__(self) -> bytes: f" {codecs.encode(buf, 'hex')}" ) length, _ = struct.unpack(" bytes: f" {codecs.encode(buf, 'hex')}" ) data, _ = struct.unpack("<%dsI" % length, buf) - # TODO: b/412697846 - Add CRC check for data mask mismatch. + # TODO: Add CRC check for data mask mismatch. return data def seek(self, offset: int): diff --git a/grain/_src/python/dataset/transformations/batch.py b/grain/_src/python/dataset/transformations/batch.py index 41f098f9c..2ba2c6903 100644 --- a/grain/_src/python/dataset/transformations/batch.py +++ b/grain/_src/python/dataset/transformations/batch.py @@ -249,6 +249,14 @@ def get_state(self): def set_state(self, state): self._parent.set_state(state) + def _get_next_index(self) -> int: + return ( + dataset.get_next_index(self._parent) + self._batch_size - 1 + ) // self._batch_size + + def _set_next_index(self, index: int) -> None: + dataset.set_next_index(self._parent, index * self._batch_size) + def __str__(self) -> str: return ( f"BatchDatasetIterator(batch_size={self._batch_size}," diff --git a/grain/_src/python/dataset/transformations/batch_test.py b/grain/_src/python/dataset/transformations/batch_test.py index c9b373e26..ff791489e 100644 --- a/grain/_src/python/dataset/transformations/batch_test.py +++ b/grain/_src/python/dataset/transformations/batch_test.py @@ -797,6 +797,62 @@ def test_element_spec_nested(self): }, ) + @parameterized.parameters( + dict( + batch_size=2, + drop_remainder=False, + expected=[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]], + ), + dict( + batch_size=3, + drop_remainder=False, + expected=[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]], + ), + dict( + batch_size=3, + drop_remainder=True, + expected=[[0, 1, 2], [3, 4, 5], [6, 7, 8]], + ), + ) + def test_get_next_index(self, batch_size, drop_remainder, expected): + del expected + ds = dataset.MapDataset.range(0, 10).to_iter_dataset() + ds = batch.BatchIterDataset( + ds, batch_size=batch_size, drop_remainder=drop_remainder + ) + ds_iter = ds.__iter__() + self.assertEqual(dataset.get_next_index(ds_iter), 0) + for i, _ in enumerate(ds_iter): + self.assertEqual(dataset.get_next_index(ds_iter), i + 1) + + @parameterized.parameters( + dict( + batch_size=2, + drop_remainder=False, + expected=[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]], + ), + dict( + batch_size=3, + drop_remainder=False, + expected=[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]], + ), + dict( + batch_size=3, + drop_remainder=True, + expected=[[0, 1, 2], [3, 4, 5], [6, 7, 8]], + ), + ) + def test_set_next_index(self, batch_size, drop_remainder, expected): + ds = dataset.MapDataset.range(0, 10).to_iter_dataset() + ds = batch.BatchIterDataset( + ds, batch_size=batch_size, drop_remainder=drop_remainder + ) + for i in reversed(range(len(expected))): + ds_iter = ds.__iter__() + dataset.set_next_index(ds_iter, i) + actual = next(ds_iter) + np.testing.assert_allclose(actual, expected[i]) + if __name__ == "__main__": absltest.main() diff --git a/grain/_src/python/dataset/transformations/interleave.py b/grain/_src/python/dataset/transformations/interleave.py index 722a4d38b..0ff30977b 100644 --- a/grain/_src/python/dataset/transformations/interleave.py +++ b/grain/_src/python/dataset/transformations/interleave.py @@ -206,6 +206,30 @@ def set_state(self, state): self._exhausted_iterator_state[index_in_cycle] = it_state self._iterators_in_use[index_in_cycle] = None + def _get_next_index(self) -> int: + if len(self._datasets) == 1: + it = self._iterators_in_use[0] + if it is None: + return 0 + return dataset.get_next_index(it) + raise NotImplementedError( + "get_next_index is not supported for InterleaveDatasetIterator with" + " more than one dataset." + ) + + def _set_next_index(self, index: int) -> None: + if len(self._datasets) == 1: + # Ensure iterator is created by calling get_state. + _ = self.get_state() + it = self._iterators_in_use[0] + assert it is not None + dataset.set_next_index(it, index) + else: + raise NotImplementedError( + "set_next_index is not supported for InterleaveDatasetIterator with" + " more than one dataset." + ) + def close(self) -> None: """Closes the iterator and shuts down the iterator prefetching.""" if self._closed: diff --git a/grain/_src/python/dataset/transformations/interleave_test.py b/grain/_src/python/dataset/transformations/interleave_test.py index 1740ee24f..0862db287 100644 --- a/grain/_src/python/dataset/transformations/interleave_test.py +++ b/grain/_src/python/dataset/transformations/interleave_test.py @@ -213,6 +213,45 @@ def test_interleave_stats_with_mismatched_dataset_structures(self): self.assertLen(node_names, 1) self.assertIn("InterleaveDatasetIterator", node_names[0]) + def test_get_next_index(self): + ds = dataset.MapDataset.range(10).to_iter_dataset() + ds = interleave.InterleaveIterDataset([ds], cycle_length=1) + ds_iter = ds.__iter__() + self.assertEqual(dataset.get_next_index(ds_iter), 0) + for i in range(10): + next(ds_iter) + self.assertEqual(dataset.get_next_index(ds_iter), i + 1) + + def test_set_next_index(self): + ds = dataset.MapDataset.range(10).to_iter_dataset() + ds = interleave.InterleaveIterDataset([ds], cycle_length=1) + ds_iter = ds.__iter__() + for i in reversed(range(10)): + dataset.set_next_index(ds_iter, i) + self.assertEqual(next(ds_iter), i) + + def test_get_next_index_with_multiple_datasets(self): + ds = dataset.MapDataset.range(10).to_iter_dataset() + ds = interleave.InterleaveIterDataset([ds, ds], cycle_length=2) + ds_iter = ds.__iter__() + with self.assertRaisesRegex( + NotImplementedError, + "get_next_index is not supported for InterleaveDatasetIterator with" + " more than one dataset.", + ): + dataset.get_next_index(ds_iter) + + def test_set_next_index_with_multiple_datasets(self): + ds = dataset.MapDataset.range(10).to_iter_dataset() + ds = interleave.InterleaveIterDataset([ds, ds], cycle_length=2) + ds_iter = ds.__iter__() + with self.assertRaisesRegex( + NotImplementedError, + "set_next_index is not supported for InterleaveDatasetIterator with" + " more than one dataset.", + ): + dataset.set_next_index(ds_iter, 0) + if __name__ == "__main__": absltest.main() diff --git a/grain/_src/python/dataset/transformations/limit.py b/grain/_src/python/dataset/transformations/limit.py index 1239df350..a49631897 100644 --- a/grain/_src/python/dataset/transformations/limit.py +++ b/grain/_src/python/dataset/transformations/limit.py @@ -53,6 +53,13 @@ def set_state(self, state): self._parent.set_state(state["parent"]) self._count_elements_read = state["count_elements_read"] + def _get_next_index(self) -> int: + return self._count_elements_read + + def _set_next_index(self, index): + dataset.set_next_index(self._parent, index) + self._count_elements_read = index + class LimitIterDataset(dataset.IterDataset[T]): """Limits the number of elements in the dataset. diff --git a/grain/_src/python/dataset/transformations/limit_test.py b/grain/_src/python/dataset/transformations/limit_test.py index 93c8b9a3c..847b822b0 100644 --- a/grain/_src/python/dataset/transformations/limit_test.py +++ b/grain/_src/python/dataset/transformations/limit_test.py @@ -84,6 +84,29 @@ def test_element_spec(self): self.assertEqual(spec.dtype, np.int64) self.assertEqual(spec.shape, ()) + def test_get_next_index(self): + ds = dataset.MapDataset.range(0, 20).batch(3).to_iter_dataset() + limited_ds = limit.LimitIterDataset(ds, count=2) + ds_iter = limited_ds.__iter__() + self.assertEqual(dataset.get_next_index(ds_iter), 0) + _ = next(ds_iter) + self.assertEqual(dataset.get_next_index(ds_iter), 1) + _ = next(ds_iter) + with self.assertRaises(StopIteration): + next(ds_iter) + self.assertEqual(dataset.get_next_index(ds_iter), 2) + + def test_set_next_index(self): + ds = dataset.MapDataset.range(0, 20).batch(3).to_iter_dataset() + limited_ds = limit.LimitIterDataset(ds, count=2) + ds_iter = limited_ds.__iter__() + dataset.set_next_index(ds_iter, 1) + self.assertEqual(dataset.get_next_index(ds_iter), 1) + _ = next(ds_iter) + self.assertEqual(dataset.get_next_index(ds_iter), 2) + with self.assertRaises(StopIteration): + next(ds_iter) + if __name__ == "__main__": absltest.main() diff --git a/grain/_src/python/dataset/transformations/map.py b/grain/_src/python/dataset/transformations/map.py index 48c0ff45b..2c6ef3597 100644 --- a/grain/_src/python/dataset/transformations/map.py +++ b/grain/_src/python/dataset/transformations/map.py @@ -320,6 +320,12 @@ def get_state(self): def set_state(self, state): self._parent.set_state(state) + def _get_next_index(self) -> int: + return dataset.get_next_index(self._parent) + + def _set_next_index(self, next_index: int): + dataset.set_next_index(self._parent, next_index) + def __str__(self) -> str: return f"MapDatasetIterator(transform={self._transform_name})" @@ -370,6 +376,13 @@ def set_state(self, state): self._parent.set_state(state["parent"]) self._index_for_rng = state["index_for_rng"] + def _get_next_index(self) -> int: + return self._index_for_rng + + def _set_next_index(self, next_index: int): + dataset.set_next_index(self._parent, next_index) + self._index_for_rng = next_index + def __str__(self) -> str: return f"RandomMapDatasetIterator(transform={self._transform_name})" @@ -408,6 +421,13 @@ def set_state(self, state): self._parent.set_state(state["parent"]) self._counter = state["counter"] + def _get_next_index(self) -> int: + return self._counter + + def _set_next_index(self, next_index: int): + dataset.set_next_index(self._parent, next_index) + self._counter = next_index + def __str__(self) -> str: return f"MapWithIndexDatasetIterator(transform={self._transform_name})" diff --git a/grain/_src/python/dataset/transformations/map_test.py b/grain/_src/python/dataset/transformations/map_test.py index 507c01c24..d9d354724 100644 --- a/grain/_src/python/dataset/transformations/map_test.py +++ b/grain/_src/python/dataset/transformations/map_test.py @@ -368,6 +368,22 @@ def test_map_element_spec_inference_raises_error(self): with self.assertRaisesRegex(ValueError, "does not implement `output_spec`"): _ = ds._element_spec + def test_get_next_index(self): + ds = dataset.MapDataset.range(0, 20).to_iter_dataset() + mapped_ds = map_ds.MapIterDataset(ds, MapWithNoTransform()) + ds_iter = mapped_ds.__iter__() + for i in range(20): + self.assertEqual(dataset.get_next_index(ds_iter), i) + _ = next(ds_iter) + + def test_set_next_index(self): + ds = dataset.MapDataset.range(0, 20).to_iter_dataset() + mapped_ds = map_ds.MapIterDataset(ds, MapWithNoTransform()) + ds_iter = mapped_ds.__iter__() + for i in reversed(range(20)): + dataset.set_next_index(ds_iter, i) + self.assertEqual(next(ds_iter), i) + class RandomMapIterDatasetTest(parameterized.TestCase): @@ -450,6 +466,27 @@ def test_random_map_element_spec_inference_raises_error(self): with self.assertRaisesRegex(ValueError, "does not implement `output_spec`"): _ = ds._element_spec + def test_get_next_index(self): + ds = dataset.MapDataset.range(0, 20).to_iter_dataset() + mapped_ds = map_ds.RandomMapIterDataset( + ds, RandomMapWithTransform(), seed=0 + ) + ds_iter = mapped_ds.__iter__() + for i in range(20): + self.assertEqual(dataset.get_next_index(ds_iter), i) + _ = next(ds_iter) + + def test_set_next_index(self): + ds = dataset.MapDataset.range(0, 20).to_iter_dataset() + mapped_ds = map_ds.RandomMapIterDataset( + ds, RandomMapWithTransform(), seed=0 + ) + expected = list(mapped_ds) + ds_iter = mapped_ds.__iter__() + for i in reversed(range(20)): + dataset.set_next_index(ds_iter, i) + self.assertEqual(next(ds_iter), expected[i]) + class MapWithIndexMapDatasetTest(parameterized.TestCase): @@ -542,6 +579,22 @@ def test_map_with_index_element_spec_inference_raises_error(self): with self.assertRaisesRegex(ValueError, "does not implement `output_spec`"): _ = ds._element_spec + def test_get_next_index(self): + ds = dataset.MapDataset.range(0, 20).to_iter_dataset() + mapped_ds = map_ds.MapWithIndexIterDataset(ds, AddIndexTransform()) + ds_iter = mapped_ds.__iter__() + for i in range(20): + self.assertEqual(dataset.get_next_index(ds_iter), i) + _ = next(ds_iter) + + def test_set_next_index(self): + ds = dataset.MapDataset.range(0, 20).to_iter_dataset() + mapped_ds = map_ds.MapWithIndexIterDataset(ds, AddIndexTransform()) + ds_iter = mapped_ds.__iter__() + for i in reversed(range(20)): + dataset.set_next_index(ds_iter, i) + self.assertEqual(next(ds_iter), (i, i)) + if __name__ == "__main__": absltest.main() diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index e7a6392fa..db28185ef 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -253,6 +253,12 @@ def set_state(self, state): future = self._buffer.popleft() future.cancel() + def _get_next_index(self) -> int: + return self._next_returned_index + + def _set_next_index(self, index: int) -> None: + self.set_state({"next_index": index}) + def __str__(self) -> str: return ( f"PrefetchDatasetIterator(read_options={self._read_options}," @@ -900,6 +906,7 @@ def __init__( self._prefetch_buffer_size = prefetch_buffer_size self._step_zero_state: StateT = parent.get_state() self._state: StateT | None = None + self._next_index: int | None = 0 self._prefetch_thread: threading.Thread | None = None self._prefetch_should_stop: threading.Event = threading.Event() @@ -970,6 +977,8 @@ def __next__(self): self._stop_prefetch() raise err self._state = state + if self._next_index is not None: + self._next_index += 1 with self._stats.record_self_time(offset_ns=timer.value()): element = self._stats.record_bytes_produced(element) return self._stats.record_output_spec(element) @@ -1008,6 +1017,25 @@ def set_state(self, state: StateT): self._stop_prefetch() self._maybe_nonnative_parent.set_state(state) self._state = self._maybe_nonnative_parent.get_state() + if isinstance(self._maybe_nonnative_parent, dataset.DatasetIterator): + try: + self._next_index = dataset.get_next_index(self._maybe_nonnative_parent) + except Exception: # pylint: disable=broad-except + self._next_index = None + else: + self._next_index = None + + def _get_next_index(self) -> int: + if self._next_index is not None: + return self._next_index + raise ValueError("Upstream iterator does not support get_next_index.") + + def _set_next_index(self, next_index: int): + assert isinstance(self._maybe_nonnative_parent, dataset.DatasetIterator) + self._stop_prefetch() + dataset.set_next_index(self._maybe_nonnative_parent, next_index) + self._next_index = next_index + self._state = self._maybe_nonnative_parent.get_state() def __str__(self) -> str: return ( diff --git a/grain/_src/python/dataset/transformations/prefetch_test.py b/grain/_src/python/dataset/transformations/prefetch_test.py index 3e45a8ad6..0b2068577 100644 --- a/grain/_src/python/dataset/transformations/prefetch_test.py +++ b/grain/_src/python/dataset/transformations/prefetch_test.py @@ -432,6 +432,18 @@ def test_element_spec(self): self.assertEqual(spec.shape, ()) self.assertEqual(spec.dtype, np.int64) + def test_get_next_index(self): + ds_iter = self.prefetch_lazy_iter_ds.__iter__() + for i in range(20): + self.assertEqual(dataset.get_next_index(ds_iter), i) + _ = next(ds_iter) + + def test_set_next_index(self): + ds_iter = self.prefetch_lazy_iter_ds.__iter__() + for i in reversed(range(20)): + dataset.set_next_index(ds_iter, i) + self.assertEqual(next(ds_iter), i) + class MultiprocessPrefetchIterDatasetTest(parameterized.TestCase): diff --git a/grain/experimental.py b/grain/experimental.py index c46862230..f927653d0 100644 --- a/grain/experimental.py +++ b/grain/experimental.py @@ -90,4 +90,6 @@ from grain._src.python.dataset.dataset import ( get_element_spec, + set_next_index, + get_next_index, )