Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions grain/_src/python/dataset/transformations/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def _finalize_current_batch(self, element_for_shapes):
def __next__(self):
timer = dataset_stats.Timer()
if self._packed_batch is not None:
# We have a packed batch, we emit one row at a time until exhausted.
with self._stats.record_self_time(offset_ns=timer.value()):
if self._shuffle_bins:
next_row = self._shuffled_rows[self._next_row]
Expand All @@ -181,6 +182,8 @@ def __next__(self):
self._next_row += 1
self._counter += 1
if self._next_row >= self._packed_batch_num_bins:
# We've emitted all the rows in the packed batch, so we reset the
# state and start from scratch.
self._packed_batch = None
self._last_emitted_batch_parent_state = (
self._current_batch_parent_state
Expand All @@ -190,6 +193,7 @@ def __next__(self):
self._shuffled_rows = None
return self._stats.record_output_spec(element)

# We don't have a packed batch, so we loop until we do.
while True:
prior_iterator_state = self._parent.get_state()
assert prior_iterator_state is not None
Expand All @@ -203,7 +207,8 @@ def __next__(self):
self._current_batch_parent_state = prior_iterator_state
return next(self)
else:
# The inner iterator is exhausted and there is no current batch.
# The inner iterator is exhausted and there is no current batch, so
# the packed iterator is also exhausted.
raise StopIteration() from e

with timer:
Expand All @@ -213,9 +218,8 @@ def __next__(self):
)

if self._current_batch is None:
# Initialize the batch manager with the specific packer class: use
# `element` to set dtypes + trailing dimensions. We are not adding the
# element to the batch, just initializing it.
# Use `element` to set dtypes + trailing dimensions.
# We are not adding the element to the batch, just initializing it.
self._current_batch = self._packer_cls(
element,
self._num_packing_bins,
Expand All @@ -230,22 +234,25 @@ def __next__(self):
# Try adding element to the current packed batch.
failing_components = self._current_batch.try_add_to_batch(element)

# When the batch is full, yield the packed data and start a new batch.
# When we have a full batch, yield the current packed data,
# and then start a new batch with this element.
if failing_components is not None:
with timer:
self._finalize_current_batch(element)
self._current_batch_parent_state = prior_iterator_state
assert self._current_batch is not None

if self._current_batch.try_add_to_batch(element) is not None:
# If a single example can't fit in an empty batch, it's an error.
# If we can't pack a single example into an empty batch then we
# can't continue at all.
element_shape = tree_lib.map_structure(lambda x: x.shape, element)
raise ValueError(
"Could not add element to empty packed batch! Packed batch has"
f" packing sequence_lengths: {self._length_struct} while"
f" element has shape: {element_shape}"
)
# We now have a packed batch.

# We now have packed batch.
return next(self)


Expand Down