Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change
specification inference.
* Adds support for changing `IterDataset.mix` components and weights after a
checkpoint.
* Adds experimental support for `get_next_index` and `set_next_index` to fetch
and advance a `grain.DatasetIterator` to the given produced element index.

## Grain 0.2.14 (October 30, 2025)

Expand Down
31 changes: 31 additions & 0 deletions grain/_src/python/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1574,6 +1574,27 @@ def _stats(self) -> dataset_stats.Stats:
self._ctx.dataset_options.execution_tracking_mode
)

def _set_next_index(self, index: int) -> None:
"""Sets the next index for the dataset iterator.

Note: This index is the index of the element that will be produced next by
the iterator. Implementations of this method should not process any data. If
advancing to the index without processing is not possible (e.g. for filter
and packing), then the implementation should raise a ValueError.

Args:
index: The index of the next element to be produced.
"""
raise NotImplementedError

def _get_next_index(self) -> int:
"""Returns the next index for the dataset iterator.

Note: This index is the index of the element that will be produced next by
the iterator.
"""
raise NotImplementedError

# pytype: enable=attribute-error
# pylint: enable=protected-access

Expand Down Expand Up @@ -1783,3 +1804,13 @@ def get_element_spec(
ds: `MapDataset` or `IterDataset` to get the element spec from.
"""
return ds._element_spec # pylint: disable=protected-access


def set_next_index(ds_iter: DatasetIterator, index: int) -> None:
"""Sets the next index for the dataset iterator."""
return ds_iter._set_next_index(index) # pylint: disable=protected-access


def get_next_index(ds_iter: DatasetIterator) -> int:
"""Returns the next index for the dataset iterator."""
return ds_iter._get_next_index() # pylint: disable=protected-access
8 changes: 8 additions & 0 deletions grain/_src/python/dataset/transformations/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,14 @@ def get_state(self):
def set_state(self, state):
self._parent.set_state(state)

def _get_next_index(self) -> int:
return (
dataset.get_next_index(self._parent) + self._batch_size - 1
) // self._batch_size

def _set_next_index(self, index: int) -> None:
dataset.set_next_index(self._parent, index * self._batch_size)

def __str__(self) -> str:
return (
f"BatchDatasetIterator(batch_size={self._batch_size},"
Expand Down
56 changes: 56 additions & 0 deletions grain/_src/python/dataset/transformations/batch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,62 @@ def test_element_spec_nested(self):
},
)

@parameterized.parameters(
dict(
batch_size=2,
drop_remainder=False,
expected=[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]],
),
dict(
batch_size=3,
drop_remainder=False,
expected=[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]],
),
dict(
batch_size=3,
drop_remainder=True,
expected=[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
),
)
def test_get_next_index(self, batch_size, drop_remainder, expected):
del expected
ds = dataset.MapDataset.range(0, 10).to_iter_dataset()
ds = batch.BatchIterDataset(
ds, batch_size=batch_size, drop_remainder=drop_remainder
)
ds_iter = ds.__iter__()
self.assertEqual(dataset.get_next_index(ds_iter), 0)
for i, _ in enumerate(ds_iter):
self.assertEqual(dataset.get_next_index(ds_iter), i + 1)

@parameterized.parameters(
dict(
batch_size=2,
drop_remainder=False,
expected=[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]],
),
dict(
batch_size=3,
drop_remainder=False,
expected=[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]],
),
dict(
batch_size=3,
drop_remainder=True,
expected=[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
),
)
def test_set_next_index(self, batch_size, drop_remainder, expected):
ds = dataset.MapDataset.range(0, 10).to_iter_dataset()
ds = batch.BatchIterDataset(
ds, batch_size=batch_size, drop_remainder=drop_remainder
)
for i in reversed(range(len(expected))):
ds_iter = ds.__iter__()
dataset.set_next_index(ds_iter, i)
actual = next(ds_iter)
np.testing.assert_allclose(actual, expected[i])


if __name__ == "__main__":
absltest.main()
24 changes: 24 additions & 0 deletions grain/_src/python/dataset/transformations/interleave.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,30 @@ def set_state(self, state):
self._exhausted_iterator_state[index_in_cycle] = it_state
self._iterators_in_use[index_in_cycle] = None

def _get_next_index(self) -> int:
if len(self._datasets) == 1:
it = self._iterators_in_use[0]
if it is None:
return 0
return dataset.get_next_index(it)
raise NotImplementedError(
"get_next_index is not supported for InterleaveDatasetIterator with"
" more than one dataset."
)

def _set_next_index(self, index: int) -> None:
if len(self._datasets) == 1:
# Ensure iterator is created by calling get_state.
_ = self.get_state()
it = self._iterators_in_use[0]
assert it is not None
dataset.set_next_index(it, index)
else:
raise NotImplementedError(
"set_next_index is not supported for InterleaveDatasetIterator with"
" more than one dataset."
)

def close(self) -> None:
"""Closes the iterator and shuts down the iterator prefetching."""
if self._closed:
Expand Down
39 changes: 39 additions & 0 deletions grain/_src/python/dataset/transformations/interleave_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,45 @@ def test_interleave_stats_with_mismatched_dataset_structures(self):
self.assertLen(node_names, 1)
self.assertIn("InterleaveDatasetIterator", node_names[0])

def test_get_next_index(self):
ds = dataset.MapDataset.range(10).to_iter_dataset()
ds = interleave.InterleaveIterDataset([ds], cycle_length=1)
ds_iter = ds.__iter__()
self.assertEqual(dataset.get_next_index(ds_iter), 0)
for i in range(10):
next(ds_iter)
self.assertEqual(dataset.get_next_index(ds_iter), i + 1)

def test_set_next_index(self):
ds = dataset.MapDataset.range(10).to_iter_dataset()
ds = interleave.InterleaveIterDataset([ds], cycle_length=1)
ds_iter = ds.__iter__()
for i in reversed(range(10)):
dataset.set_next_index(ds_iter, i)
self.assertEqual(next(ds_iter), i)

def test_get_next_index_with_multiple_datasets(self):
ds = dataset.MapDataset.range(10).to_iter_dataset()
ds = interleave.InterleaveIterDataset([ds, ds], cycle_length=2)
ds_iter = ds.__iter__()
with self.assertRaisesRegex(
NotImplementedError,
"get_next_index is not supported for InterleaveDatasetIterator with"
" more than one dataset.",
):
dataset.get_next_index(ds_iter)

def test_set_next_index_with_multiple_datasets(self):
ds = dataset.MapDataset.range(10).to_iter_dataset()
ds = interleave.InterleaveIterDataset([ds, ds], cycle_length=2)
ds_iter = ds.__iter__()
with self.assertRaisesRegex(
NotImplementedError,
"set_next_index is not supported for InterleaveDatasetIterator with"
" more than one dataset.",
):
dataset.set_next_index(ds_iter, 0)


if __name__ == "__main__":
absltest.main()
7 changes: 7 additions & 0 deletions grain/_src/python/dataset/transformations/limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ def set_state(self, state):
self._parent.set_state(state["parent"])
self._count_elements_read = state["count_elements_read"]

def _get_next_index(self) -> int:
return self._count_elements_read

def _set_next_index(self, index):
dataset.set_next_index(self._parent, index)
self._count_elements_read = index


class LimitIterDataset(dataset.IterDataset[T]):
"""Limits the number of elements in the dataset.
Expand Down
23 changes: 23 additions & 0 deletions grain/_src/python/dataset/transformations/limit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,29 @@ def test_element_spec(self):
self.assertEqual(spec.dtype, np.int64)
self.assertEqual(spec.shape, ())

def test_get_next_index(self):
ds = dataset.MapDataset.range(0, 20).batch(3).to_iter_dataset()
limited_ds = limit.LimitIterDataset(ds, count=2)
ds_iter = limited_ds.__iter__()
self.assertEqual(dataset.get_next_index(ds_iter), 0)
_ = next(ds_iter)
self.assertEqual(dataset.get_next_index(ds_iter), 1)
_ = next(ds_iter)
with self.assertRaises(StopIteration):
next(ds_iter)
self.assertEqual(dataset.get_next_index(ds_iter), 2)

def test_set_next_index(self):
ds = dataset.MapDataset.range(0, 20).batch(3).to_iter_dataset()
limited_ds = limit.LimitIterDataset(ds, count=2)
ds_iter = limited_ds.__iter__()
dataset.set_next_index(ds_iter, 1)
self.assertEqual(dataset.get_next_index(ds_iter), 1)
_ = next(ds_iter)
self.assertEqual(dataset.get_next_index(ds_iter), 2)
with self.assertRaises(StopIteration):
next(ds_iter)


if __name__ == "__main__":
absltest.main()
20 changes: 20 additions & 0 deletions grain/_src/python/dataset/transformations/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,12 @@ def get_state(self):
def set_state(self, state):
self._parent.set_state(state)

def _get_next_index(self) -> int:
return dataset.get_next_index(self._parent)

def _set_next_index(self, next_index: int):
dataset.set_next_index(self._parent, next_index)

def __str__(self) -> str:
return f"MapDatasetIterator(transform={self._transform_name})"

Expand Down Expand Up @@ -370,6 +376,13 @@ def set_state(self, state):
self._parent.set_state(state["parent"])
self._index_for_rng = state["index_for_rng"]

def _get_next_index(self) -> int:
return self._index_for_rng

def _set_next_index(self, next_index: int):
dataset.set_next_index(self._parent, next_index)
self._index_for_rng = next_index

def __str__(self) -> str:
return f"RandomMapDatasetIterator(transform={self._transform_name})"

Expand Down Expand Up @@ -408,6 +421,13 @@ def set_state(self, state):
self._parent.set_state(state["parent"])
self._counter = state["counter"]

def _get_next_index(self) -> int:
return self._counter

def _set_next_index(self, next_index: int):
dataset.set_next_index(self._parent, next_index)
self._counter = next_index

def __str__(self) -> str:
return f"MapWithIndexDatasetIterator(transform={self._transform_name})"

Expand Down
53 changes: 53 additions & 0 deletions grain/_src/python/dataset/transformations/map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,22 @@ def test_map_element_spec_inference_raises_error(self):
with self.assertRaisesRegex(ValueError, "does not implement `output_spec`"):
_ = ds._element_spec

def test_get_next_index(self):
ds = dataset.MapDataset.range(0, 20).to_iter_dataset()
mapped_ds = map_ds.MapIterDataset(ds, MapWithNoTransform())
ds_iter = mapped_ds.__iter__()
for i in range(20):
self.assertEqual(dataset.get_next_index(ds_iter), i)
_ = next(ds_iter)

def test_set_next_index(self):
ds = dataset.MapDataset.range(0, 20).to_iter_dataset()
mapped_ds = map_ds.MapIterDataset(ds, MapWithNoTransform())
ds_iter = mapped_ds.__iter__()
for i in reversed(range(20)):
dataset.set_next_index(ds_iter, i)
self.assertEqual(next(ds_iter), i)


class RandomMapIterDatasetTest(parameterized.TestCase):

Expand Down Expand Up @@ -450,6 +466,27 @@ def test_random_map_element_spec_inference_raises_error(self):
with self.assertRaisesRegex(ValueError, "does not implement `output_spec`"):
_ = ds._element_spec

def test_get_next_index(self):
ds = dataset.MapDataset.range(0, 20).to_iter_dataset()
mapped_ds = map_ds.RandomMapIterDataset(
ds, RandomMapWithTransform(), seed=0
)
ds_iter = mapped_ds.__iter__()
for i in range(20):
self.assertEqual(dataset.get_next_index(ds_iter), i)
_ = next(ds_iter)

def test_set_next_index(self):
ds = dataset.MapDataset.range(0, 20).to_iter_dataset()
mapped_ds = map_ds.RandomMapIterDataset(
ds, RandomMapWithTransform(), seed=0
)
expected = list(mapped_ds)
ds_iter = mapped_ds.__iter__()
for i in reversed(range(20)):
dataset.set_next_index(ds_iter, i)
self.assertEqual(next(ds_iter), expected[i])


class MapWithIndexMapDatasetTest(parameterized.TestCase):

Expand Down Expand Up @@ -542,6 +579,22 @@ def test_map_with_index_element_spec_inference_raises_error(self):
with self.assertRaisesRegex(ValueError, "does not implement `output_spec`"):
_ = ds._element_spec

def test_get_next_index(self):
ds = dataset.MapDataset.range(0, 20).to_iter_dataset()
mapped_ds = map_ds.MapWithIndexIterDataset(ds, AddIndexTransform())
ds_iter = mapped_ds.__iter__()
for i in range(20):
self.assertEqual(dataset.get_next_index(ds_iter), i)
_ = next(ds_iter)

def test_set_next_index(self):
ds = dataset.MapDataset.range(0, 20).to_iter_dataset()
mapped_ds = map_ds.MapWithIndexIterDataset(ds, AddIndexTransform())
ds_iter = mapped_ds.__iter__()
for i in reversed(range(20)):
dataset.set_next_index(ds_iter, i)
self.assertEqual(next(ds_iter), (i, i))


if __name__ == "__main__":
absltest.main()
Loading