Skip to content

Add Muon optimizer implementation and integration #39541

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
4 changes: 4 additions & 0 deletions docs/source/en/main_classes/optimizer_schedules.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ The `.optimization` module provides:

[[autodoc]] Adafactor

## Muon (PyTorch)

[[autodoc]] Muon

## AdamWeightDecay (TensorFlow)

[[autodoc]] AdamWeightDecay
Expand Down
27 changes: 26 additions & 1 deletion docs/source/en/optimizers.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ rendered properly in your Markdown viewer.

# Optimizers

Transformers offers two native optimizers, AdamW and AdaFactor. It also provides integrations for more specialized optimizers. Install the library that offers the optimizer and drop it in the `optim` parameter in [`TrainingArguments`].
Transformers offers three native optimizers: AdamW, AdaFactor, and Muon. It also provides integrations for more specialized optimizers. Install the library that offers the optimizer and drop it in the `optim` parameter in [`TrainingArguments`].

This guide will show you how to use these optimizers with [`Trainer`] using [`TrainingArguments`] shown below.

Expand Down Expand Up @@ -199,4 +199,29 @@ args = TrainingArguments(
save_strategy="no",
run_name="stable-adamw",
)
```

## Muon

```bash
pip install git+https://github.com/KellerJordan/Muon.git
```

[Muon](https://kellerjordan.github.io/posts/muon/) (MomentUm Orthogonalized by Newton-schulz) runs standard SGD-momentum and then performs an orthogonalization post-processing step, replacing each 2D parameter's update with the nearest orthogonal matrix. For efficient orthogonalization, it uses Newton-Schulz iteration that can be stably run in bfloat16 on GPU.

> [!TIP]
> Muon should only be used for hidden weight layers. The input embedding, final output layer, and any internal gains or biases should be optimized using a standard method such as AdamW.

```diff
args = TrainingArguments(
output_dir="./test-muon",
max_steps=1000,
per_device_train_batch_size=4,
+ optim="muon",
logging_strategy="steps",
logging_steps=1,
learning_rate=2e-2,
save_strategy="no",
run_name="muon",
)
```
131 changes: 131 additions & 0 deletions src/transformers/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,3 +971,134 @@ def get_adafactor_schedule(optimizer, initial_lr=0.0):

"""
return AdafactorSchedule(optimizer, initial_lr)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need these changes here, just import the optimizer in trainer.py as you did for muon. We can add more choices like muon_adam when specifying optim value

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I deleted the _check_muon_available function and I added the other variations as well.


class MuonWithAuxAdam:
"""
Distributed Muon variant that can be used for all parameters in the network, since it runs an
internal AdamW for the parameters that are not compatible with Muon. The user must manually
specify which parameters shall be optimized with Muon and which with Adam by passing in a
list of param_groups with the `use_muon` flag set.

The point of this class is to allow the user to have a single optimizer in their code, rather
than having both a Muon and an Adam which each need to be stepped.

Example usage:

```python
from transformers.optimization import MuonWithAuxAdam

hidden_weights = [p for p in model.body.parameters() if p.ndim >= 2]
hidden_gains_biases = [p for p in model.body.parameters() if p.ndim < 2]
nonhidden_params = [*model.head.parameters(), *model.embed.parameters()]
param_groups = [
dict(params=hidden_weights, use_muon=True,
lr=0.02, weight_decay=0.01),
dict(params=hidden_gains_biases+nonhidden_params, use_muon=False,
lr=3e-4, betas=(0.9, 0.95), weight_decay=0.01),
]
optimizer = MuonWithAuxAdam(param_groups)
```

Note: This requires installing the Muon package first:
`pip install git+https://github.com/KellerJordan/Muon.git`
"""

def __new__(cls, *args, **kwargs):
from muon import MuonWithAuxAdam as _MuonWithAuxAdam

return _MuonWithAuxAdam(*args, **kwargs)


class SingleDeviceMuon:
"""
Single-device variant of Muon optimizer.

Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. For efficient orthogonalization we use a Newton-Schulz iteration, which has the
advantage that it can be stably run in bfloat16 on the GPU.

This single-device variant is optimized for training on a single GPU/device.

Example usage:

```python
from transformers.optimization import SingleDeviceMuon

# Optimizer for hidden weight layers only
hidden_params = [p for n, p in model.named_parameters() if p.ndim >= 2 and "embed" not in n and "head" not in n]
optimizer = SingleDeviceMuon(hidden_params, lr=0.02, momentum=0.95)
```

Note: This requires installing the Muon package first:
`pip install git+https://github.com/KellerJordan/Muon.git`
"""

def __new__(cls, *args, **kwargs):
from muon import SingleDeviceMuon as _SingleDeviceMuon

return _SingleDeviceMuon(*args, **kwargs)


class SingleDeviceMuonWithAuxAdam:
"""
Non-distributed variant of MuonWithAuxAdam.

Example usage:

```python
from transformers.optimization import SingleDeviceMuonWithAuxAdam

hidden_weights = [p for p in model.body.parameters() if p.ndim >= 2]
hidden_gains_biases = [p for p in model.body.parameters() if p.ndim < 2]
nonhidden_params = [*model.head.parameters(), *model.embed.parameters()]
param_groups = [
dict(params=hidden_weights, use_muon=True,
lr=0.02, weight_decay=0.01),
dict(params=hidden_gains_biases+nonhidden_params, use_muon=False,
lr=3e-4, betas=(0.9, 0.95), weight_decay=0.01),
]
optimizer = SingleDeviceMuonWithAuxAdam(param_groups)
```

Note: This requires installing the Muon package first:
`pip install git+https://github.com/KellerJordan/Muon.git`
"""

def __new__(cls, *args, **kwargs):
from muon import SingleDeviceMuonWithAuxAdam as _SingleDeviceMuonWithAuxAdam

return _SingleDeviceMuonWithAuxAdam(*args, **kwargs)


class MuonOptimizer:
"""
Muon - MomentUm Orthogonalized by Newton-schulz

Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. For efficient orthogonalization we use a Newton-Schulz iteration, which has the
advantage that it can be stably run in bfloat16 on the GPU.

Muon should only be used for hidden weight layers. The input embedding, final output layer,
and any internal gains or biases should be optimized using a standard method such as AdamW.

Example usage:

```python
from transformers.optimization import MuonOptimizer

# Optimizer for hidden weight layers only
hidden_params = [p for n, p in model.named_parameters() if p.ndim >= 2 and "embed" not in n and "head" not in n]
optimizer = MuonOptimizer(hidden_params, lr=0.02, momentum=0.95)
```

Note: This requires installing the Muon package first:
`pip install git+https://github.com/KellerJordan/Muon.git`
"""

def __new__(cls, *args, **kwargs):
from muon import Muon as _Muon

return _Muon(*args, **kwargs)
9 changes: 9 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
is_liger_kernel_available,
is_lomo_available,
is_mistral_common_available,
is_muon_available,
is_natten_available,
is_nltk_available,
is_onnx_available,
Expand Down Expand Up @@ -389,6 +390,14 @@ def require_torch_optimi(test_case):
return unittest.skipUnless(is_torch_optimi_available(), "test requires torch-optimi")(test_case)


def require_muon(test_case):
"""
Decorator marking a test that requires muon. These tests are skipped when muon isn't installed.
https://github.com/KellerJordan/Muon
"""
return unittest.skipUnless(is_muon_available(), "test requires muon")(test_case)


def require_lomo(test_case):
"""
Decorator marking a test that requires LOMO. These tests are skipped when LOMO-optim isn't installed.
Expand Down
41 changes: 41 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@
is_in_notebook,
is_liger_kernel_available,
is_lomo_available,
is_muon_available,
is_peft_available,
is_safetensors_available,
is_sagemaker_dp_enabled,
Expand Down Expand Up @@ -1416,6 +1417,46 @@ def optimizer_hook(param):
if args.optim == OptimizerNames.ADAFACTOR:
optimizer_cls = Adafactor
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
elif args.optim == OptimizerNames.MUON:
if not is_muon_available():
raise ImportError(
"You need to install `muon` in order to use muon optimizers. "
"Install it with `pip install git+https://github.com/KellerJordan/Muon.git`."
)
from muon import Muon

optimizer_cls = Muon
optimizer_kwargs.update({"momentum": 0.95})
Comment on lines +1420 to +1429
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the import logic should live here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a function called is_muon_available in the trainer code. Should I make changes here? If I need to, can you share an example?

elif args.optim == OptimizerNames.MUON_ADAM:
if not is_muon_available():
raise ImportError(
"You need to install `muon` in order to use muon optimizers. "
"Install it with `pip install git+https://github.com/KellerJordan/Muon.git`."
)
from muon import MuonWithAuxAdam

optimizer_cls = MuonWithAuxAdam
optimizer_kwargs.update({"lr": 3e-4, "betas": (0.9, 0.95)})
elif args.optim == OptimizerNames.MUON_SINGLE:
if not is_muon_available():
raise ImportError(
"You need to install `muon` in order to use muon optimizers. "
"Install it with `pip install git+https://github.com/KellerJordan/Muon.git`."
)
from muon import SingleDeviceMuon

optimizer_cls = SingleDeviceMuon
optimizer_kwargs.update({"momentum": 0.95})
elif args.optim == OptimizerNames.MUON_SINGLE_ADAM:
if not is_muon_available():
raise ImportError(
"You need to install `muon` in order to use muon optimizers. "
"Install it with `pip install git+https://github.com/KellerJordan/Muon.git`."
)
from muon import SingleDeviceMuonWithAuxAdam

optimizer_cls = SingleDeviceMuonWithAuxAdam
optimizer_kwargs.update({"lr": 3e-4, "betas": (0.9, 0.95)})
elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]:
from torch.optim import AdamW

Expand Down
4 changes: 4 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,10 @@ class OptimizerNames(ExplicitEnum):
APOLLO_ADAMW = "apollo_adamw"
APOLLO_ADAMW_LAYERWISE = "apollo_adamw_layerwise"
STABLE_ADAMW = "stable_adamw"
MUON = "muon"
MUON_ADAM = "muon_adam"
MUON_SINGLE = "muon_single"
MUON_SINGLE_ADAM = "muon_single_adam"


def _convert_str_dict(passed_value: dict):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@
is_matplotlib_available,
is_mistral_common_available,
is_mlx_available,
is_muon_available,
is_natten_available,
is_ninja_available,
is_nltk_available,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_grokadamw_available = _is_package_available("grokadamw")
_schedulefree_available, _schedulefree_version = _is_package_available("schedulefree", return_version=True)
_torch_optimi_available = importlib.util.find_spec("optimi") is not None
_muon_available = importlib.util.find_spec("muon") is not None
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
_bs4_available = importlib.util.find_spec("bs4") is not None
_coloredlogs_available = _is_package_available("coloredlogs")
Expand Down Expand Up @@ -479,6 +480,10 @@ def is_torch_optimi_available():
return _torch_optimi_available


def is_muon_available():
return _muon_available


def is_lomo_available():
return _lomo_available

Expand Down
Loading