Skip to content

predict_step in Trainer should pass num_items_in_batch #41108

@pramodith

Description

@pramodith

Feature request

predict_step in Trainer.py doesn't currently pass the num_items_in_batch to the compute_loss function.

loss, outputs = self.compute_loss(model, inputs, return_outputs=True)

This seems to be misaligned because the training_step function does

loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)

Motivation

The Trainer's training_step function uses get_batch_samples to calculate the num_items_in_batch which is used to scale the loss.

However with the predict_step not passing this value, a user with a custom loss function has to account for num_items_in_batch being null at eval time but not train time, which is a bit confusing. Ensuring that both train and predict steps calculate num_items_in_batch the same way ensures accurate logging and comparison of loss metrics.

Your contribution

I'm happy to submit a PR for this unless there's a strong reason as to why the predict_step shouldn't be passing the num_items_in_batch to compute_loss

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions