From fa350545df1b436dec661987f5ca5cb99648b5ca Mon Sep 17 00:00:00 2001 From: Claudio Fantacci Date: Tue, 21 Oct 2025 05:29:14 -0700 Subject: [PATCH] Increase profiling in packing PiperOrigin-RevId: 822071089 --- .../python/dataset/transformations/packing.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/grain/_src/python/dataset/transformations/packing.py b/grain/_src/python/dataset/transformations/packing.py index 51af29d92..97a58c85e 100644 --- a/grain/_src/python/dataset/transformations/packing.py +++ b/grain/_src/python/dataset/transformations/packing.py @@ -170,6 +170,7 @@ def _finalize_current_batch(self, element_for_shapes): def __next__(self): timer = dataset_stats.Timer() if self._packed_batch is not None: + # We have a packed batch, we emit one row at a time until exhausted. with self._stats.record_self_time(offset_ns=timer.value()): if self._shuffle_bins: next_row = self._shuffled_rows[self._next_row] @@ -181,6 +182,8 @@ def __next__(self): self._next_row += 1 self._counter += 1 if self._next_row >= self._packed_batch_num_bins: + # We've emitted all the rows in the packed batch, so we reset the + # state and start from scratch. self._packed_batch = None self._last_emitted_batch_parent_state = ( self._current_batch_parent_state @@ -190,6 +193,7 @@ def __next__(self): self._shuffled_rows = None return self._stats.record_output_spec(element) + # We don't have a packed batch, so we loop until we do. while True: prior_iterator_state = self._parent.get_state() assert prior_iterator_state is not None @@ -203,7 +207,8 @@ def __next__(self): self._current_batch_parent_state = prior_iterator_state return next(self) else: - # The inner iterator is exhausted and there is no current batch. + # The inner iterator is exhausted and there is no current batch, so + # the packed iterator is also exhausted. raise StopIteration() from e with timer: @@ -213,9 +218,8 @@ def __next__(self): ) if self._current_batch is None: - # Initialize the batch manager with the specific packer class: use - # `element` to set dtypes + trailing dimensions. We are not adding the - # element to the batch, just initializing it. + # Use `element` to set dtypes + trailing dimensions. + # We are not adding the element to the batch, just initializing it. self._current_batch = self._packer_cls( element, self._num_packing_bins, @@ -230,7 +234,8 @@ def __next__(self): # Try adding element to the current packed batch. failing_components = self._current_batch.try_add_to_batch(element) - # When the batch is full, yield the packed data and start a new batch. + # When we have a full batch, yield the current packed data, + # and then start a new batch with this element. if failing_components is not None: with timer: self._finalize_current_batch(element) @@ -238,14 +243,16 @@ def __next__(self): assert self._current_batch is not None if self._current_batch.try_add_to_batch(element) is not None: - # If a single example can't fit in an empty batch, it's an error. + # If we can't pack a single example into an empty batch then we + # can't continue at all. element_shape = tree_lib.map_structure(lambda x: x.shape, element) raise ValueError( "Could not add element to empty packed batch! Packed batch has" f" packing sequence_lengths: {self._length_struct} while" f" element has shape: {element_shape}" ) - # We now have a packed batch. + + # We now have packed batch. return next(self)