Skip to content

Commit 220a7f8

Browse files
committed
Add impl of Muon optimizer. Fix #2580
1 parent 68bc434 commit 220a7f8

File tree

6 files changed

+611
-15
lines changed

6 files changed

+611
-15
lines changed

tests/test_optim.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,14 @@ def test_kron(optimizer):
394394
_test_model(optimizer, dict(lr=1e-3))
395395

396396

397+
@pytest.mark.parametrize('optimizer', ['muon'])
398+
def test_muon(optimizer):
399+
_test_rosenbrock(
400+
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
401+
)
402+
_test_model(optimizer, dict(lr=1e-3))
403+
404+
397405
@pytest.mark.parametrize('optimizer', ['adopt', 'adoptw'])
398406
def test_adopt(optimizer):
399407
_test_rosenbrock(

timm/optim/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .lookahead import Lookahead
1414
from .madgrad import MADGRAD
1515
from .mars import Mars
16+
from .muon import Muon
1617
from .nadam import NAdamLegacy
1718
from .nadamw import NAdamW
1819
from .nvnovograd import NvNovoGrad

timm/optim/_optim_factory.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from .lookahead import Lookahead
3232
from .madgrad import MADGRAD
3333
from .mars import Mars
34+
from .muon import Muon
3435
from .nadam import NAdamLegacy
3536
from .nadamw import NAdamW
3637
from .nvnovograd import NvNovoGrad
@@ -871,6 +872,14 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
871872
description='Unleashing the Power of Variance Reduction for Training Large Models',
872873
has_betas=True,
873874
),
875+
OptimInfo(
876+
name='muon',
877+
opt_class=Muon,
878+
description='MomentUm Orthogonalized by Newton-schulz with AdamW fallback for 1D params',
879+
has_momentum=True,
880+
has_eps=True,
881+
has_betas=True,
882+
),
874883
OptimInfo(
875884
name='novograd',
876885
opt_class=NvNovoGrad,

timm/optim/_param_groups.py

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import fnmatch
12
import logging
23
from itertools import islice
34
from typing import Collection, Optional
@@ -10,27 +11,54 @@
1011
_logger = logging.getLogger(__name__)
1112

1213

14+
def _matches_pattern(name: str, patterns: Collection[str]) -> bool:
15+
"""Check if parameter name matches any pattern (supports wildcards)."""
16+
return any(fnmatch.fnmatch(name, pattern) for pattern in patterns)
17+
18+
1319
def param_groups_weight_decay(
1420
model: nn.Module,
1521
weight_decay: float = 1e-5,
1622
no_weight_decay_list: Collection[str] = (),
23+
simple_params_list: Collection[str] = (),
1724
):
18-
no_weight_decay_list = set(no_weight_decay_list)
1925
decay = []
26+
decay_simple = []
2027
no_decay = []
28+
no_decay_simple = []
2129
for name, param in model.named_parameters():
2230
if not param.requires_grad:
2331
continue
2432

25-
if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
26-
no_decay.append(param)
33+
# Determine if this is a "simple" parameter for fallback optimizer (if available)
34+
is_simple = _matches_pattern(name, no_weight_decay_list)
35+
36+
# Determine weight decay
37+
matches_pattern = _matches_pattern(name, no_weight_decay_list)
38+
if param.ndim <= 1 or name.endswith(".bias") or matches_pattern:
39+
# No weight decay
40+
if is_simple:
41+
no_decay_simple.append(param)
42+
else:
43+
no_decay.append(param)
2744
else:
28-
decay.append(param)
29-
30-
return [
31-
{'params': no_decay, 'weight_decay': 0.},
32-
{'params': decay, 'weight_decay': weight_decay}]
33-
45+
# With weight decay
46+
if is_simple:
47+
decay_simple.append(param)
48+
else:
49+
decay.append(param)
50+
51+
groups = []
52+
if decay:
53+
groups.append({'params': decay, 'weight_decay': weight_decay})
54+
if decay_simple:
55+
groups.append({'params': decay_simple, 'weight_decay': weight_decay, 'simple': True})
56+
if no_decay:
57+
groups.append({'params': no_decay, 'weight_decay': 0.})
58+
if no_decay_simple:
59+
groups.append({'params': no_decay_simple, 'weight_decay': 0., 'simple': True})
60+
61+
return groups
3462

3563
def _group(it, size):
3664
it = iter(it)
@@ -70,9 +98,9 @@ def param_groups_layer_decay(
7098
model: nn.Module,
7199
weight_decay: float = 0.05,
72100
no_weight_decay_list: Collection[str] = (),
101+
simple_params_list: Collection[str] = (),
73102
weight_decay_exclude_1d: bool = True,
74103
layer_decay: float = .75,
75-
end_layer_decay: Optional[float] = None,
76104
min_scale: float = 0.,
77105
no_opt_scale: Optional[float] = None,
78106
verbose: bool = False,
@@ -81,7 +109,6 @@ def param_groups_layer_decay(
81109
Parameter groups for layer-wise lr decay & weight decay
82110
Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
83111
"""
84-
no_weight_decay_list = set(no_weight_decay_list)
85112
param_group_names = {} # NOTE for debugging
86113
param_groups = {}
87114

@@ -99,8 +126,12 @@ def param_groups_layer_decay(
99126
if not param.requires_grad:
100127
continue
101128

102-
# no decay: all 1D parameters and model specific ones
103-
if (weight_decay_exclude_1d and param.ndim <= 1) or name in no_weight_decay_list:
129+
# Determine if this is a "simple" parameter for fallback optimizer (if available)
130+
is_simple = _matches_pattern(name, simple_params_list)
131+
132+
# Determine weight decay
133+
if (weight_decay_exclude_1d and param.ndim <= 1) or _matches_pattern(name, no_weight_decay_list):
134+
# no weight decay for 1D parameters and model specific ones
104135
g_decay = "no_decay"
105136
this_decay = 0.
106137
else:
@@ -114,18 +145,23 @@ def param_groups_layer_decay(
114145
param.requires_grad = False
115146
continue
116147

117-
group_name = "layer_%d_%s" % (layer_id, g_decay)
148+
simple_suffix = "_simple" if is_simple else ""
149+
group_name = "layer_%d_%s%s" % (layer_id, g_decay, simple_suffix)
150+
118151
if group_name not in param_groups:
119152
param_group_names[group_name] = {
120153
"lr_scale": this_scale,
121154
"weight_decay": this_decay,
155+
"simple": is_simple,
122156
"param_names": [],
123157
}
124158
param_groups[group_name] = {
125159
"lr_scale": this_scale,
126160
"weight_decay": this_decay,
127161
"params": [],
128162
}
163+
if is_simple:
164+
param_groups[group_name]["simple"] = True
129165

130166
param_group_names[group_name]["param_names"].append(name)
131167
param_groups[group_name]["params"].append(param)

0 commit comments

Comments
 (0)