From 305c5d4114717aefeb17ac01482bcd291f76e547 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Tue, 16 Sep 2025 13:18:58 +0530 Subject: [PATCH 1/5] Add accumulate option to Trainer.predict for memory control Introduces an 'accumulate' argument to the predict method in all backend trainers and the base Trainer class. When set to False, predictions are not accumulated in memory and must be handled via callbacks, helping to avoid memory issues with large datasets. Updates method signatures, docstrings, and internal logic accordingly. --- keras/src/backend/jax/trainer.py | 12 ++++++++---- keras/src/backend/numpy/trainer.py | 12 ++++++++---- keras/src/backend/openvino/trainer.py | 12 ++++++++---- keras/src/backend/tensorflow/trainer.py | 18 +++++++++++------- keras/src/backend/torch/trainer.py | 14 +++++++++----- keras/src/trainers/trainer.py | 8 ++++++-- 6 files changed, 50 insertions(+), 26 deletions(-) diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index 5f01505c2d47..46fbd900a694 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -638,7 +638,7 @@ def evaluate( @traceback_utils.filter_traceback def predict( - self, x, batch_size=None, verbose="auto", steps=None, callbacks=None + self, x, batch_size=None, verbose="auto", steps=None, callbacks=None, accumulate=True ): # Create an iterator that yields batches of input data. epoch_iterator = JAXEpochIterator( @@ -694,7 +694,7 @@ def append_to_outputs(batch_outputs, outputs): return outputs self._jax_state_synced = True - outputs = None + outputs = None if accumulate else [] non_trainable_variables = None with epoch_iterator.catch_stop_iteration(): for begin_step, end_step, iterator in epoch_iterator: @@ -718,7 +718,8 @@ def append_to_outputs(batch_outputs, outputs): # during predict(), but it's allowed. "non_trainable_variables": non_trainable_variables, } - outputs = append_to_outputs(batch_outputs, outputs) + if accumulate: + outputs = append_to_outputs(batch_outputs, outputs) # Dispatch callbacks. This takes care of async dispatch. callbacks.on_predict_batch_end( @@ -731,7 +732,10 @@ def append_to_outputs(batch_outputs, outputs): self.jax_state_sync() callbacks.on_predict_end() self._jax_state = None - return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) + if accumulate: + return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) + else: + return None def train_on_batch( self, diff --git a/keras/src/backend/numpy/trainer.py b/keras/src/backend/numpy/trainer.py index fd8c276a86d2..155bb16f17cc 100644 --- a/keras/src/backend/numpy/trainer.py +++ b/keras/src/backend/numpy/trainer.py @@ -170,7 +170,7 @@ def fit( @traceback_utils.filter_traceback def predict( - self, x, batch_size=None, verbose="auto", steps=None, callbacks=None + self, x, batch_size=None, verbose="auto", steps=None, callbacks=None, accumulate=True ): # Create an iterator that yields batches of input data. epoch_iterator = EpochIterator( @@ -210,16 +210,20 @@ def append_to_outputs(batch_outputs, outputs): self.make_predict_function() self.stop_predicting = False callbacks.on_predict_begin() - outputs = None + outputs = None if accumulate else [] for begin_step, end_step, data in epoch_iterator: callbacks.on_predict_batch_begin(begin_step) batch_outputs = self.predict_function(data) - outputs = append_to_outputs(batch_outputs, outputs) + if accumulate: + outputs = append_to_outputs(batch_outputs, outputs) callbacks.on_predict_batch_end(end_step, {"outputs": batch_outputs}) if self.stop_predicting: break callbacks.on_predict_end() - return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) + if accumulate: + return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) + else: + return None @traceback_utils.filter_traceback def evaluate( diff --git a/keras/src/backend/openvino/trainer.py b/keras/src/backend/openvino/trainer.py index ac2e64a8060c..52bbedbb3e2d 100644 --- a/keras/src/backend/openvino/trainer.py +++ b/keras/src/backend/openvino/trainer.py @@ -171,7 +171,7 @@ def fit( @traceback_utils.filter_traceback def predict( - self, x, batch_size=None, verbose="auto", steps=None, callbacks=None + self, x, batch_size=None, verbose="auto", steps=None, callbacks=None, accumulate=True ): # Create an iterator that yields batches of input data. epoch_iterator = EpochIterator( @@ -212,16 +212,20 @@ def append_to_outputs(batch_outputs, outputs): self.make_predict_function() self.stop_predicting = False callbacks.on_predict_begin() - outputs = None + outputs = None if accumulate else [] for begin_step, end_step, data in epoch_iterator.enumerate_epoch(): callbacks.on_predict_batch_begin(begin_step) batch_outputs = self.predict_function(data) - outputs = append_to_outputs(batch_outputs, outputs) + if accumulate: + outputs = append_to_outputs(batch_outputs, outputs) callbacks.on_predict_batch_end(end_step, {"outputs": batch_outputs}) if self.stop_predicting: break callbacks.on_predict_end() - return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) + if accumulate: + return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) + else: + return None @traceback_utils.filter_traceback def evaluate( diff --git a/keras/src/backend/tensorflow/trainer.py b/keras/src/backend/tensorflow/trainer.py index cd6410999dd2..7b50e4e41bf1 100644 --- a/keras/src/backend/tensorflow/trainer.py +++ b/keras/src/backend/tensorflow/trainer.py @@ -521,7 +521,7 @@ def evaluate( @traceback_utils.filter_traceback def predict( - self, x, batch_size=None, verbose="auto", steps=None, callbacks=None + self, x, batch_size=None, verbose="auto", steps=None, callbacks=None, accumulate=True ): # Create an iterator that yields batches of input data. epoch_iterator = TFEpochIterator( @@ -580,23 +580,27 @@ def get_data(iterator): self.make_predict_function() self.stop_predicting = False callbacks.on_predict_begin() - outputs = None + outputs = None if accumulate else [] with epoch_iterator.catch_stop_iteration(): for begin_step, end_step, iterator in epoch_iterator: callbacks.on_predict_batch_begin(begin_step) data = get_data(iterator) batch_outputs = self.predict_function(data) - outputs = append_to_outputs(batch_outputs, outputs) + if accumulate: + outputs = append_to_outputs(batch_outputs, outputs) callbacks.on_predict_batch_end( end_step, {"outputs": batch_outputs} ) if self.stop_predicting: break callbacks.on_predict_end() - outputs = tree.map_structure_up_to( - batch_outputs, potentially_ragged_concat, outputs - ) - return tree.map_structure(convert_to_np_if_not_ragged, outputs) + if accumulate: + outputs = tree.map_structure_up_to( + batch_outputs, potentially_ragged_concat, outputs + ) + return tree.map_structure(convert_to_np_if_not_ragged, outputs) + else: + return None def train_on_batch( self, diff --git a/keras/src/backend/torch/trainer.py b/keras/src/backend/torch/trainer.py index ad68c2f3a7ec..6e7e267ddab6 100644 --- a/keras/src/backend/torch/trainer.py +++ b/keras/src/backend/torch/trainer.py @@ -395,7 +395,7 @@ def evaluate( @traceback_utils.filter_traceback def predict( - self, x, batch_size=None, verbose="auto", steps=None, callbacks=None + self, x, batch_size=None, verbose="auto", steps=None, callbacks=None, accumulate=True ): # Create an iterator that yields batches of input data. epoch_iterator = TorchEpochIterator( @@ -438,17 +438,21 @@ def append_to_outputs(batch_outputs, outputs): self.make_predict_function() self.stop_predicting = False callbacks.on_predict_begin() - outputs = None + outputs = None if accumulate else [] for begin_step, end_step, data in epoch_iterator: callbacks.on_predict_batch_begin(begin_step) batch_outputs = self.predict_function(data) - outputs = append_to_outputs(batch_outputs, outputs) + if accumulate: + outputs = append_to_outputs(batch_outputs, outputs) callbacks.on_predict_batch_end(end_step, {"outputs": batch_outputs}) if self.stop_predicting: break callbacks.on_predict_end() - outputs = tree.map_structure(backend.convert_to_numpy, outputs) - return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) + if accumulate: + outputs = tree.map_structure(backend.convert_to_numpy, outputs) + return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) + else: + return None def train_on_batch( self, diff --git a/keras/src/trainers/trainer.py b/keras/src/trainers/trainer.py index bac422db249c..5429c3989edd 100644 --- a/keras/src/trainers/trainer.py +++ b/keras/src/trainers/trainer.py @@ -807,7 +807,7 @@ def evaluate( raise NotImplementedError def predict( - self, x, batch_size=None, verbose="auto", steps=None, callbacks=None + self, x, batch_size=None, verbose="auto", steps=None, callbacks=None, accumulate=True ): """Generates output predictions for the input samples. @@ -858,9 +858,13 @@ def predict( repeating dataset, it will run indefinitely. callbacks: List of `keras.callbacks.Callback` instances. List of callbacks to apply during prediction. + accumulate: Boolean. Whether to accumulate predictions in memory. + If `False`, predictions are not returned and must be handled + via callbacks to avoid memory issues with large datasets. + Defaults to `True`. Returns: - NumPy array(s) of predictions. + NumPy array(s) of predictions if `accumulate=True`, otherwise `None`. """ raise NotImplementedError From 87d6dbdf66eab827838e65b15936fa648f50ccba Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Wed, 17 Sep 2025 09:35:14 +0530 Subject: [PATCH 2/5] Refactor predict output handling in trainer backends Standardizes the handling of outputs in the predict methods across all backend trainers by always initializing outputs as None and returning None when no outputs are accumulated. This simplifies the logic and ensures consistent behavior when accumulate is False or when no predictions are made. --- keras/src/backend/jax/trainer.py | 7 ++++--- keras/src/backend/numpy/trainer.py | 7 ++++--- keras/src/backend/openvino/trainer.py | 17 ++++++++++------- keras/src/backend/tensorflow/trainer.py | 7 ++++--- keras/src/backend/torch/trainer.py | 7 ++++--- 5 files changed, 26 insertions(+), 19 deletions(-) diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index 46fbd900a694..c8dc414e9a59 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -694,7 +694,7 @@ def append_to_outputs(batch_outputs, outputs): return outputs self._jax_state_synced = True - outputs = None if accumulate else [] + outputs = None non_trainable_variables = None with epoch_iterator.catch_stop_iteration(): for begin_step, end_step, iterator in epoch_iterator: @@ -733,9 +733,10 @@ def append_to_outputs(batch_outputs, outputs): callbacks.on_predict_end() self._jax_state = None if accumulate: + if outputs is None: + return None return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) - else: - return None + return outputs def train_on_batch( self, diff --git a/keras/src/backend/numpy/trainer.py b/keras/src/backend/numpy/trainer.py index 155bb16f17cc..fc9640e7866d 100644 --- a/keras/src/backend/numpy/trainer.py +++ b/keras/src/backend/numpy/trainer.py @@ -210,7 +210,7 @@ def append_to_outputs(batch_outputs, outputs): self.make_predict_function() self.stop_predicting = False callbacks.on_predict_begin() - outputs = None if accumulate else [] + outputs = None for begin_step, end_step, data in epoch_iterator: callbacks.on_predict_batch_begin(begin_step) batch_outputs = self.predict_function(data) @@ -221,9 +221,10 @@ def append_to_outputs(batch_outputs, outputs): break callbacks.on_predict_end() if accumulate: + if outputs is None: + return None return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) - else: - return None + return outputs @traceback_utils.filter_traceback def evaluate( diff --git a/keras/src/backend/openvino/trainer.py b/keras/src/backend/openvino/trainer.py index 52bbedbb3e2d..4f1baad5d482 100644 --- a/keras/src/backend/openvino/trainer.py +++ b/keras/src/backend/openvino/trainer.py @@ -212,20 +212,23 @@ def append_to_outputs(batch_outputs, outputs): self.make_predict_function() self.stop_predicting = False callbacks.on_predict_begin() - outputs = None if accumulate else [] - for begin_step, end_step, data in epoch_iterator.enumerate_epoch(): - callbacks.on_predict_batch_begin(begin_step) - batch_outputs = self.predict_function(data) + outputs = None + for _, iterator in epoch_iterator: + callbacks.on_predict_batch_begin(epoch_iterator.current_step) + batch_outputs = self.predict_function(iterator) if accumulate: outputs = append_to_outputs(batch_outputs, outputs) - callbacks.on_predict_batch_end(end_step, {"outputs": batch_outputs}) + callbacks.on_predict_batch_end( + epoch_iterator.current_step, {"outputs": batch_outputs} + ) if self.stop_predicting: break callbacks.on_predict_end() if accumulate: + if outputs is None: + return None return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) - else: - return None + return outputs @traceback_utils.filter_traceback def evaluate( diff --git a/keras/src/backend/tensorflow/trainer.py b/keras/src/backend/tensorflow/trainer.py index 7b50e4e41bf1..423ae4242f21 100644 --- a/keras/src/backend/tensorflow/trainer.py +++ b/keras/src/backend/tensorflow/trainer.py @@ -580,7 +580,7 @@ def get_data(iterator): self.make_predict_function() self.stop_predicting = False callbacks.on_predict_begin() - outputs = None if accumulate else [] + outputs = None with epoch_iterator.catch_stop_iteration(): for begin_step, end_step, iterator in epoch_iterator: callbacks.on_predict_batch_begin(begin_step) @@ -595,12 +595,13 @@ def get_data(iterator): break callbacks.on_predict_end() if accumulate: + if outputs is None: + return None outputs = tree.map_structure_up_to( batch_outputs, potentially_ragged_concat, outputs ) return tree.map_structure(convert_to_np_if_not_ragged, outputs) - else: - return None + return outputs def train_on_batch( self, diff --git a/keras/src/backend/torch/trainer.py b/keras/src/backend/torch/trainer.py index 6e7e267ddab6..fc97203ccc0c 100644 --- a/keras/src/backend/torch/trainer.py +++ b/keras/src/backend/torch/trainer.py @@ -438,7 +438,7 @@ def append_to_outputs(batch_outputs, outputs): self.make_predict_function() self.stop_predicting = False callbacks.on_predict_begin() - outputs = None if accumulate else [] + outputs = None for begin_step, end_step, data in epoch_iterator: callbacks.on_predict_batch_begin(begin_step) batch_outputs = self.predict_function(data) @@ -449,10 +449,11 @@ def append_to_outputs(batch_outputs, outputs): break callbacks.on_predict_end() if accumulate: + if outputs is None: + return None outputs = tree.map_structure(backend.convert_to_numpy, outputs) return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) - else: - return None + return outputs def train_on_batch( self, From c3c12405d41a334b9cf8e90edbab12cbc9e23a35 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Wed, 17 Sep 2025 09:55:31 +0530 Subject: [PATCH 3/5] Format predict method signatures for readability Refactored the predict method signatures in all backend trainer classes and the base Trainer to use one argument per line. Also reformatted long return statements for better readability. No functional changes were made. --- keras/src/backend/jax/trainer.py | 12 ++++++++++-- keras/src/backend/numpy/trainer.py | 12 ++++++++++-- keras/src/backend/openvino/trainer.py | 12 ++++++++++-- keras/src/backend/tensorflow/trainer.py | 8 +++++++- keras/src/backend/torch/trainer.py | 12 ++++++++++-- keras/src/trainers/trainer.py | 11 +++++++++-- 6 files changed, 56 insertions(+), 11 deletions(-) diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index c8dc414e9a59..a68ef15444cf 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -638,7 +638,13 @@ def evaluate( @traceback_utils.filter_traceback def predict( - self, x, batch_size=None, verbose="auto", steps=None, callbacks=None, accumulate=True + self, + x, + batch_size=None, + verbose="auto", + steps=None, + callbacks=None, + accumulate=True, ): # Create an iterator that yields batches of input data. epoch_iterator = JAXEpochIterator( @@ -735,7 +741,9 @@ def append_to_outputs(batch_outputs, outputs): if accumulate: if outputs is None: return None - return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) + return tree.map_structure_up_to( + batch_outputs, np.concatenate, outputs + ) return outputs def train_on_batch( diff --git a/keras/src/backend/numpy/trainer.py b/keras/src/backend/numpy/trainer.py index fc9640e7866d..852a1d35e60a 100644 --- a/keras/src/backend/numpy/trainer.py +++ b/keras/src/backend/numpy/trainer.py @@ -170,7 +170,13 @@ def fit( @traceback_utils.filter_traceback def predict( - self, x, batch_size=None, verbose="auto", steps=None, callbacks=None, accumulate=True + self, + x, + batch_size=None, + verbose="auto", + steps=None, + callbacks=None, + accumulate=True, ): # Create an iterator that yields batches of input data. epoch_iterator = EpochIterator( @@ -223,7 +229,9 @@ def append_to_outputs(batch_outputs, outputs): if accumulate: if outputs is None: return None - return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) + return tree.map_structure_up_to( + batch_outputs, np.concatenate, outputs + ) return outputs @traceback_utils.filter_traceback diff --git a/keras/src/backend/openvino/trainer.py b/keras/src/backend/openvino/trainer.py index 4f1baad5d482..c6a62568a978 100644 --- a/keras/src/backend/openvino/trainer.py +++ b/keras/src/backend/openvino/trainer.py @@ -171,7 +171,13 @@ def fit( @traceback_utils.filter_traceback def predict( - self, x, batch_size=None, verbose="auto", steps=None, callbacks=None, accumulate=True + self, + x, + batch_size=None, + verbose="auto", + steps=None, + callbacks=None, + accumulate=True, ): # Create an iterator that yields batches of input data. epoch_iterator = EpochIterator( @@ -227,7 +233,9 @@ def append_to_outputs(batch_outputs, outputs): if accumulate: if outputs is None: return None - return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) + return tree.map_structure_up_to( + batch_outputs, np.concatenate, outputs + ) return outputs @traceback_utils.filter_traceback diff --git a/keras/src/backend/tensorflow/trainer.py b/keras/src/backend/tensorflow/trainer.py index 423ae4242f21..c77bdac1a130 100644 --- a/keras/src/backend/tensorflow/trainer.py +++ b/keras/src/backend/tensorflow/trainer.py @@ -521,7 +521,13 @@ def evaluate( @traceback_utils.filter_traceback def predict( - self, x, batch_size=None, verbose="auto", steps=None, callbacks=None, accumulate=True + self, + x, + batch_size=None, + verbose="auto", + steps=None, + callbacks=None, + accumulate=True, ): # Create an iterator that yields batches of input data. epoch_iterator = TFEpochIterator( diff --git a/keras/src/backend/torch/trainer.py b/keras/src/backend/torch/trainer.py index fc97203ccc0c..71b1a33c07cc 100644 --- a/keras/src/backend/torch/trainer.py +++ b/keras/src/backend/torch/trainer.py @@ -395,7 +395,13 @@ def evaluate( @traceback_utils.filter_traceback def predict( - self, x, batch_size=None, verbose="auto", steps=None, callbacks=None, accumulate=True + self, + x, + batch_size=None, + verbose="auto", + steps=None, + callbacks=None, + accumulate=True, ): # Create an iterator that yields batches of input data. epoch_iterator = TorchEpochIterator( @@ -452,7 +458,9 @@ def append_to_outputs(batch_outputs, outputs): if outputs is None: return None outputs = tree.map_structure(backend.convert_to_numpy, outputs) - return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) + return tree.map_structure_up_to( + batch_outputs, np.concatenate, outputs + ) return outputs def train_on_batch( diff --git a/keras/src/trainers/trainer.py b/keras/src/trainers/trainer.py index 5429c3989edd..8b5178891f09 100644 --- a/keras/src/trainers/trainer.py +++ b/keras/src/trainers/trainer.py @@ -807,7 +807,13 @@ def evaluate( raise NotImplementedError def predict( - self, x, batch_size=None, verbose="auto", steps=None, callbacks=None, accumulate=True + self, + x, + batch_size=None, + verbose="auto", + steps=None, + callbacks=None, + accumulate=True, ): """Generates output predictions for the input samples. @@ -864,7 +870,8 @@ def predict( Defaults to `True`. Returns: - NumPy array(s) of predictions if `accumulate=True`, otherwise `None`. + NumPy array(s) of predictions if `accumulate=True`, + otherwise `None`. """ raise NotImplementedError From 5a79a5582c69987e5abb168805dab9aa433b3b4f Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Wed, 17 Sep 2025 10:01:45 +0530 Subject: [PATCH 4/5] Update trainer.py --- keras/src/backend/openvino/trainer.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/keras/src/backend/openvino/trainer.py b/keras/src/backend/openvino/trainer.py index c6a62568a978..c3d5c81d26af 100644 --- a/keras/src/backend/openvino/trainer.py +++ b/keras/src/backend/openvino/trainer.py @@ -219,14 +219,12 @@ def append_to_outputs(batch_outputs, outputs): self.stop_predicting = False callbacks.on_predict_begin() outputs = None - for _, iterator in epoch_iterator: - callbacks.on_predict_batch_begin(epoch_iterator.current_step) + for begin_step, end_step, iterator in epoch_iterator: + callbacks.on_predict_batch_begin(begin_step) batch_outputs = self.predict_function(iterator) if accumulate: outputs = append_to_outputs(batch_outputs, outputs) - callbacks.on_predict_batch_end( - epoch_iterator.current_step, {"outputs": batch_outputs} - ) + callbacks.on_predict_batch_end(end_step, {"outputs": batch_outputs}) if self.stop_predicting: break callbacks.on_predict_end() From 6b32547c97c838abb507c11351f40c82fe582195 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Wed, 17 Sep 2025 10:40:05 +0530 Subject: [PATCH 5/5] Update trainer_test.py --- keras/src/trainers/trainer_test.py | 87 ++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/keras/src/trainers/trainer_test.py b/keras/src/trainers/trainer_test.py index 05e910aa6038..d0014e59d384 100644 --- a/keras/src/trainers/trainer_test.py +++ b/keras/src/trainers/trainer_test.py @@ -2207,6 +2207,93 @@ def test_predict_dropout(self): out3 = model.predict_on_batch(np.ones((2, 20))) self.assertGreater(5, np.sum(np.abs(out2 - out3))) + def test_predict_accumulate_parameter(self): + # Test that `predict` with accumulate=True/False works correctly + model = ExampleModel(units=3) + x = np.ones((10, 4)) + + # Test accumulate=True (default behavior) + outputs_accumulated = model.predict(x, batch_size=2, accumulate=True) + self.assertIsInstance(outputs_accumulated, np.ndarray) + self.assertEqual(outputs_accumulated.shape, (10, 3)) + self.assertAllClose(outputs_accumulated, 4 * np.ones((10, 3))) + + # Test accumulate=False with callback to capture outputs + class OutputCaptureCallback(Callback): + def __init__(self): + super().__init__() + self.outputs = [] + + def on_predict_batch_end(self, batch, logs=None): + if logs and "outputs" in logs: + self.outputs.append(logs["outputs"]) + + callback = OutputCaptureCallback() + outputs_none = model.predict( + x, batch_size=2, accumulate=False, callbacks=[callback] + ) + + # Verify accumulate=False returns None + self.assertIsNone(outputs_none) + + # Verify callback captured the correct number of batches + self.assertEqual( + len(callback.outputs), 5 + ) # 10 samples / 2 batch_size = 5 batches + + # Verify callback outputs match accumulated outputs when concatenated + concatenated_outputs = np.concatenate(callback.outputs, axis=0) + self.assertAllClose(outputs_accumulated, concatenated_outputs) + + def test_predict_accumulate_parameter_multi_output(self): + # Test accumulate parameter with multi-output model + inputs = layers.Input((4,)) + output1 = layers.Dense(3, name="out1")(inputs) + output2 = layers.Dense(2, name="out2")(inputs) + model = models.Model(inputs=inputs, outputs=[output1, output2]) + + x = np.ones((8, 4)) + + # Test accumulate=True (default behavior) + outputs_accumulated = model.predict(x, batch_size=2, accumulate=True) + self.assertIsInstance(outputs_accumulated, list) + self.assertEqual(len(outputs_accumulated), 2) + self.assertEqual(outputs_accumulated[0].shape, (8, 3)) + self.assertEqual(outputs_accumulated[1].shape, (8, 2)) + + # Test accumulate=False with callback + class OutputCaptureCallback(Callback): + def __init__(self): + super().__init__() + self.outputs = [] + + def on_predict_batch_end(self, batch, logs=None): + if logs and "outputs" in logs: + self.outputs.append(logs["outputs"]) + + callback = OutputCaptureCallback() + outputs_none = model.predict( + x, batch_size=2, accumulate=False, callbacks=[callback] + ) + + # Verify accumulate=False returns None + self.assertIsNone(outputs_none) + + # Verify callback captured the correct outputs + self.assertEqual( + len(callback.outputs), 4 + ) # 8 samples / 2 batch_size = 4 batches + + # Verify callback outputs match accumulated outputs when concatenated + concatenated_outputs_1 = np.concatenate( + [out[0] for out in callback.outputs], axis=0 + ) + concatenated_outputs_2 = np.concatenate( + [out[1] for out in callback.outputs], axis=0 + ) + self.assertAllClose(outputs_accumulated[0], concatenated_outputs_1) + self.assertAllClose(outputs_accumulated[1], concatenated_outputs_2) + @pytest.mark.requires_trainable_backend def test_recompile(self): model = ExampleModel(units=3)