Skip to content

[feature request] Make z_loss factor configurable #2352

@bzantium

Description

@bzantium

Feature or Model Request

What problem are you trying to solve?
The cross_entropy_with_logits function in src/MaxText/max_utils.py already includes an implementation for z_loss regularization (weight * log(z)^2), a feature designed to improve model stability by penalizing large logits.

However, the loss_fn in src/MaxText/train.py currently hardcodes the z_loss argument to 0.0, effectively disabling this feature.

xent, _ = max_utils.cross_entropy_with_logits(logits, one_hot_targets, 0.0)

Furthermore, the code currently ignores the second return value (_) from cross_entropy_with_logits, which is the z_loss component itself. This prevents the z_loss value from being monitored or logged, even if it were enabled.

Why is this problem important?
z_loss is an important regularization technique that can impact model stability and performance. Hardcoding this value prevents users from tuning this hyperparameter to optimize their model.

Enabling this feature via a config flag (like z_loss_weight) and logging its value are essential for experimentation and monitoring the impact of this regularization during training.

Describe your requested feature or solution.
We request the following functional changes to enable and monitor the z_loss feature:

  1. Expose Configuration: Add a new flag named z_loss_weight to the configuration (e.g., in src/MaxText/configs/base.yml). The default value should be 0.0 to maintain existing behavior.

  2. Enable in Loss Function: In train.py's loss_fn, pass config.z_loss_weight to max_utils.cross_entropy_with_logits instead of the hardcoded 0.0.

  3. Capture and Log Metrics:

    • Capture the second return value (the z_loss component) from cross_entropy_with_logits into a z_loss variable (i.e., replace _ with z_loss).
    • Ensure this z_loss value is correctly masked and normalized (averaged), similar to how the main xent loss is handled.
    • Add this normalized z_loss value to the aux dictionary in loss_fn (e.g., aux["z_loss"] = ...).
    • This will allow the z_loss value to be propagated through train_step and eval_step to be logged as metrics (e.g., learning/z_loss and evaluation/z_loss) for monitoring in tools like TensorBoard.

Additional Context

No response

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions