Skip to content

TerminateOnNaN callback triggers on_train_end(), causing unintended side effects in other callbacks #21771

@PhilLord0000

Description

@PhilLord0000

The TerminateOnNaN callback currently terminates training by setting self.model.stop_training = True when NaN or Inf loss is detected. Whilst this gracefully stops the training loop, it also triggers the on_train_end() method for all other callbacks, which can have unintended negative consequences.
Examples of problematic interactions:
BackupAndRestore: Deletes its backup directory in on_train_end(), preventing recovery from the last good epoch when NaN occurs
EarlyStopping: May restore best weights based on an incomplete training run

There might be more then that, I initially ran into this issue when using BackupAndRestore. I've been using this custom callback to stop the training without running the other callbacks.

class HardTerminateOnNaN(tf.keras.callbacks.Callback):
    def on_batch_end(self, batch, logs=None):
        loss = logs.get("loss")
        if loss is not None and (tf.math.is_nan(loss) or tf.math.is_inf(loss)):
            print(f"\nNaN detected at batch {batch}. Terminating immediately.")
            raise RuntimeError("NaN loss encountered.")

I propose either adding a new callback, similar to the one above or adding an option to the existing TerminateOnNaN. I'd be happy to open a PR and implement this based on what you think would be the better option.

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions