Skip to content

Conversation

ved1beta
Copy link
Contributor

What does this PR do?

model checkpoint tracking and rotation in the Trainer when save_steps and eval_steps are misaligned, addressing issue #39476

Before submitting

  • Did you write any new necessary tests?

@SunMarc

@SunMarc
Copy link
Member

SunMarc commented Jul 22, 2025

Mind checking if this works @KrisTHL181 ?

@KrisTHL181
Copy link

KrisTHL181 commented Jul 22, 2025

I ran a minimal reproduction with save_steps=3 and eval_steps=2, and observed that the checkpoint marked as best_model_checkpoint (e.g., checkpoint-6) was later deleted during checkpoint rotation, despite setting load_best_model_at_end=True.

At the end of training, this caused a FileNotFoundError when attempting to load the best model. From real-time monitoring of the output directory, I confirmed that the checkpoint did exist temporarily but was removed by the cleanup process before training completed.

The current logic in _sorted_checkpoints correctly identifies the best model and attempts to preserve it by adjusting the order of checkpoints. However, the method used — swapping elements iteratively toward the end — may not fully move the best checkpoint to the tail of the list, especially if it’s far from the end. As a result, when _rotate_checkpoints applies save_total_limit, the best model can still be included in the deletion list.

This suggests that while the best model is tracked correctly, its protection is incomplete due to insufficient handling in _sorted_checkpoints. The interaction between save_total_limit and older, non-consecutive best checkpoints may lead to premature deletion — even when metrics confirm a best model was found.

This behavior could result in failures during model recovery and warrants further investigation.

The full error traceback is:

Traceback (most recent call last):
  File "/tmp/transformers/example_", line 81, in <module>
    trainer.train()
  File "/tmp/transformers/src/transformers/trainer.py", line 2209, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/tmp/transformers/src/transformers/trainer.py", line 2720, in _inner_training_loop
    if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen genericpath>", line 113, in samefile
FileNotFoundError: [Errno 2] No such file or directory: './test-fix-pr/checkpoint-6'

@ved1beta
Copy link
Contributor Author

On it 🫡, thanks for the detailed review ❤️

@ved1beta
Copy link
Contributor Author

ved1beta commented Jul 23, 2025

@KrisTHL181 please can you share the reproduction

@KrisTHL181
Copy link

KrisTHL181 commented Jul 23, 2025

Great fix! It looks like the best_model_checkpoint is indeed protected!
However, there's an issue with the current code: the best model isn't necessarily saved, which ultimately results in best_model_checkpoint being None. In my previous test, the FileNotFoundError occurred because the best model had been saved but later deleted — your latest code fixes that problem. However, it doesn't force the Trainer to save a checkpoint when the best model is found during evaluation.

Here's the code:

import os
import tempfile
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
    set_seed
)

set_seed(42)

dataset = load_dataset("imdb", split="train[:100]").shuffle(seed=42)
dataset = dataset.map(lambda x: {"label": int(x["label"])}, batched=False)

model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=64)

tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"])

training_args = TrainingArguments(
    output_dir="./pr-test",
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    save_steps=3,                  # misaligned |
    eval_steps=2,                  # misaligned |
    save_total_limit=2,            # only keep 2 model
    load_best_model_at_end=True,
    metric_for_best_model="loss",
    greater_is_better=True, # NOTE: ONLY FOR TEST!
    eval_strategy="steps",
    logging_dir=os.path.join("pr-test", "logs"),
    logging_steps=1,
    report_to="none",
    preserve_best_model=True
)

# We intentionally set greater_is_better=True with 'loss' to force 
# an early checkpoint to be marked as best_model_checkpoint.
# This tests whether the Trainer correctly preserves it even when save_total_limit
# would otherwise delete older checkpoints.

import numpy as np
from sklearn.metrics import accuracy_score

def compute_metrics(pred):
    logits, labels = pred
    predictions = np.argmax(logits, axis=-1)
    return {"accuracy": accuracy_score(labels, predictions)}

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets,
    eval_dataset=tokenized_datasets,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)


trainer.train()

# check
best_ckpt = trainer.state.best_model_checkpoint
assert best_ckpt is not None, "[FAIL] best_ckpt is None"
print(f"best_ckpt = {best_ckpt}")
assert os.path.exists(best_ckpt), f"[FAIL] best_ckpt folder not found"

To illustrate this, here's what happens in practice when eval_steps=2 and save_steps=3:

┌──(kris㉿KrisTHL181)-[/media/kris/SwapTemp/transformers]
└─$ head pr-test/checkpoint-13/trainer_state.json 
{
  "best_global_step": 2,
  "best_metric": 0.4929685592651367,
- "best_model_checkpoint": null,
+ // No checkpoint saved at step 2 → best model lost
  "epoch": 1.0,
  "eval_steps": 2,
  "global_step": 13,
  "is_hyper_param_search": false,
  "is_local_process_zero": true,
  "is_world_process_zero": true,
                                                                                                                                                                      
┌──(kris㉿KrisTHL181)-[/media/kris/SwapTemp/transformers]
└─$ ls pr-test           
checkpoint-12  checkpoint-13

@KrisTHL181
Copy link

KrisTHL181 commented Jul 23, 2025

Quick update: I manually created checkpoint-2 (copied from checkpoint-12) and set best_model_checkpoint to point to it.
Test Result: the model was not deleted by checkpoint rotation, and load_best_model_at_end=True successfully loaded it at the end. Moreover, no assertion errors occurred.

@KrisTHL181
Copy link

@ved1beta This looks like it solves the issue! It's working as expected :D
The best model is now saved and preserved even when save_steps and eval_steps are misaligned. 👍

┌──(kris㉿KrisTHL181)-[/media/kris/SwapTemp/transformers/pr-test]
└─$  ls -1    
checkpoint-13
+ checkpoint-2
                                                                                                                                                                      
┌──(kris㉿KrisTHL181)-[/media/kris/SwapTemp/transformers/pr-test]
└─$ head checkpoint-13/trainer_state.json 
{
  "best_global_step": 2,
  "best_metric": 0.4929685592651367,
+  "best_model_checkpoint": "./pr-test/checkpoint-2",
  "epoch": 1.0,
  "eval_steps": 2,
  "global_step": 13,
  "is_hyper_param_search": false,
  "is_local_process_zero": true,
  "is_world_process_zero": true,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants