-
Notifications
You must be signed in to change notification settings - Fork 418
Description
Feature or Model Request
What problem are you trying to solve?
The cross_entropy_with_logits
function in src/MaxText/max_utils.py
already includes an implementation for z_loss
regularization (weight * log(z)^2
), a feature designed to improve model stability by penalizing large logits.
However, the loss_fn
in src/MaxText/train.py
currently hardcodes the z_loss
argument to 0.0
, effectively disabling this feature.
Line 146 in a55e18a
xent, _ = max_utils.cross_entropy_with_logits(logits, one_hot_targets, 0.0) |
Furthermore, the code currently ignores the second return value (_
) from cross_entropy_with_logits
, which is the z_loss
component itself. This prevents the z_loss
value from being monitored or logged, even if it were enabled.
Why is this problem important?
z_loss
is an important regularization technique that can impact model stability and performance. Hardcoding this value prevents users from tuning this hyperparameter to optimize their model.
Enabling this feature via a config flag (like z_loss_weight
) and logging its value are essential for experimentation and monitoring the impact of this regularization during training.
Describe your requested feature or solution.
We request the following functional changes to enable and monitor the z_loss
feature:
-
Expose Configuration: Add a new flag named
z_loss_weight
to the configuration (e.g., insrc/MaxText/configs/base.yml
). The default value should be0.0
to maintain existing behavior. -
Enable in Loss Function: In
train.py
'sloss_fn
, passconfig.z_loss_weight
tomax_utils.cross_entropy_with_logits
instead of the hardcoded0.0
. -
Capture and Log Metrics:
- Capture the second return value (the z_loss component) from
cross_entropy_with_logits
into az_loss
variable (i.e., replace_
withz_loss
). - Ensure this
z_loss
value is correctly masked and normalized (averaged), similar to how the mainxent
loss is handled. - Add this normalized
z_loss
value to theaux
dictionary inloss_fn
(e.g.,aux["z_loss"] = ...
). - This will allow the
z_loss
value to be propagated throughtrain_step
andeval_step
to be logged as metrics (e.g.,learning/z_loss
andevaluation/z_loss
) for monitoring in tools like TensorBoard.
- Capture the second return value (the z_loss component) from
Additional Context
No response