Skip to content

Add Muon optimizer implementation and integration #39541

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
4 changes: 4 additions & 0 deletions docs/source/en/main_classes/optimizer_schedules.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ The `.optimization` module provides:

[[autodoc]] Adafactor

## Muon (PyTorch)

[[autodoc]] Muon

## AdamWeightDecay (TensorFlow)

[[autodoc]] AdamWeightDecay
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/optimizers.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ rendered properly in your Markdown viewer.

# Optimizers

Transformers offers two native optimizers, AdamW and AdaFactor. It also provides integrations for more specialized optimizers. Install the library that offers the optimizer and drop it in the `optim` parameter in [`TrainingArguments`].
Transformers offers three native optimizers: AdamW, AdaFactor, and Muon. It also provides integrations for more specialized optimizers. Install the library that offers the optimizer and drop it in the `optim` parameter in [`TrainingArguments`].

This guide will show you how to use these optimizers with [`Trainer`] using [`TrainingArguments`] shown below.

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@
_import_structure["masking_utils"] = ["AttentionMaskInterface"]
_import_structure["optimization"] = [
"Adafactor",
"Muon",
"get_constant_schedule",
"get_constant_schedule_with_warmup",
"get_cosine_schedule_with_warmup",
Expand Down Expand Up @@ -788,6 +789,7 @@
# Optimization
from .optimization import (
Adafactor,
Muon,
get_constant_schedule,
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
Expand Down
164 changes: 164 additions & 0 deletions src/transformers/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,3 +971,167 @@ def get_adafactor_schedule(optimizer, initial_lr=0.0):

"""
return AdafactorSchedule(optimizer, initial_lr)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need these changes here, just import the optimizer in trainer.py as you did for muon. We can add more choices like muon_adam when specifying optim value

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I deleted the _check_muon_available function and I added the other variations as well.


def zeropower_via_newtonschulz5(G, steps: int):
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert (
G.ndim >= 2
) # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT

# Ensure spectral norm is at most 1
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
# Perform the NS iterations
for _ in range(steps):
A = X @ X.mT
B = (
b * A + c * A @ A
) # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
X = a * X + B @ X

if G.size(-2) > G.size(-1):
X = X.mT
return X


def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
momentum.lerp_(grad, 1 - beta)
update = grad.lerp_(momentum, beta) if nesterov else momentum

# Only apply orthogonalization to 2D+ parameters (matrices)
if update.ndim >= 2:
if update.ndim == 4: # for the case of conv filters
update = update.view(len(update), -1)
update = zeropower_via_newtonschulz5(update, steps=ns_steps)
update *= max(1, grad.size(-2) / grad.size(-1)) ** 0.5

return update


class Muon(Optimizer):
"""
Muon - MomentUm Orthogonalized by Newton-schulz

https://kellerjordan.github.io/posts/muon/

Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. For efficient orthogonalization we use a Newton-Schulz iteration, which has the
advantage that it can be stably run in bfloat16 on the GPU.

Muon should only be used for hidden weight layers. The input embedding, final output layer,
and any internal gains or biases should be optimized using a standard method such as AdamW.
Hidden convolutional weights can be trained using Muon by viewing them as 2D and then
collapsing their last 3 dimensions.

Arguments:
params (`Iterable[nn.parameter.Parameter]`):
Iterable of parameters to optimize or dictionaries defining parameter groups.
lr (`float`, *optional*, defaults to 0.02):
The learning rate, in units of spectral norm per update.
weight_decay (`float`, *optional*, defaults to 0.0):
The AdamW-style weight decay.
momentum (`float`, *optional*, defaults to 0.95):
The momentum. A value of 0.95 here is usually fine.
ns_steps (`int`, *optional*, defaults to 5):
Number of Newton-Schulz iteration steps for orthogonalization.
nesterov (`bool`, *optional*, defaults to `True`):
Whether to use Nesterov momentum.

Example:

```python
# Optimizer for hidden weight layers only
hidden_params = [p for n, p in model.named_parameters() if p.ndim >= 2 and "embed" not in n and "head" not in n]
optimizer = Muon(hidden_params, lr=0.02, momentum=0.95)
```
"""

def __init__(
self,
params,
lr: float = 0.02,
weight_decay: float = 0.0,
momentum: float = 0.95,
ns_steps: int = 5,
nesterov: bool = True,
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if not 0.0 <= momentum <= 1.0:
raise ValueError(f"Invalid momentum value: {momentum}")
if not isinstance(ns_steps, int) or ns_steps < 1:
raise ValueError(f"Invalid ns_steps value: {ns_steps}")

defaults = {
"lr": lr,
"weight_decay": weight_decay,
"momentum": momentum,
"ns_steps": ns_steps,
"nesterov": nesterov,
}
super().__init__(params, defaults)

@torch.no_grad()
def step(self, closure=None):
"""
Performs a single optimization step.

Args:
closure (`Callable`, *optional*):
A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue

grad = p.grad
if grad.dtype in (torch.float16, torch.bfloat16):
grad = grad.float()

state = self.state[p]

# State initialization
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)

# Apply weight decay
if group["weight_decay"] != 0:
p.mul_(1 - group["lr"] * group["weight_decay"])

# Get momentum buffer
momentum_buffer = state["momentum_buffer"]

# Compute update
update = muon_update(
grad,
momentum_buffer,
beta=group["momentum"],
ns_steps=group["ns_steps"],
nesterov=group["nesterov"],
)

# Apply update
p.add_(update.reshape(p.shape), alpha=-group["lr"])

return loss
5 changes: 5 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,6 +1416,11 @@ def optimizer_hook(param):
if args.optim == OptimizerNames.ADAFACTOR:
optimizer_cls = Adafactor
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
elif args.optim == OptimizerNames.MUON:
from .optimization import Muon

optimizer_cls = Muon
optimizer_kwargs.update({"momentum": 0.95})
elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]:
from torch.optim import AdamW

Expand Down
1 change: 1 addition & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ class OptimizerNames(ExplicitEnum):
APOLLO_ADAMW = "apollo_adamw"
APOLLO_ADAMW_LAYERWISE = "apollo_adamw_layerwise"
STABLE_ADAMW = "stable_adamw"
MUON = "muon"


def _convert_str_dict(passed_value: dict):
Expand Down