- 
                Notifications
    
You must be signed in to change notification settings  - Fork 19.6k
 
Description
The TerminateOnNaN callback currently terminates training by setting self.model.stop_training = True when NaN or Inf loss is detected. Whilst this gracefully stops the training loop, it also triggers the on_train_end() method for all other callbacks, which can have unintended negative consequences.
Examples of problematic interactions:
BackupAndRestore: Deletes its backup directory in on_train_end(), preventing recovery from the last good epoch when NaN occurs
EarlyStopping: May restore best weights based on an incomplete training run
There might be more then that, I initially ran into this issue when using BackupAndRestore. I've been using this custom callback to stop the training without running the other callbacks.
class HardTerminateOnNaN(tf.keras.callbacks.Callback):
    def on_batch_end(self, batch, logs=None):
        loss = logs.get("loss")
        if loss is not None and (tf.math.is_nan(loss) or tf.math.is_inf(loss)):
            print(f"\nNaN detected at batch {batch}. Terminating immediately.")
            raise RuntimeError("NaN loss encountered.")
I propose either adding a new callback, similar to the one above or adding an option to the existing TerminateOnNaN. I'd be happy to open a PR and implement this based on what you think would be the better option.