1+ import fnmatch
12import logging
23from itertools import islice
34from typing import Collection , Optional
1011_logger = logging .getLogger (__name__ )
1112
1213
14+ def _matches_pattern (name : str , patterns : Collection [str ]) -> bool :
15+ """Check if parameter name matches any pattern (supports wildcards)."""
16+ return any (fnmatch .fnmatch (name , pattern ) for pattern in patterns )
17+
18+
1319def param_groups_weight_decay (
1420 model : nn .Module ,
1521 weight_decay : float = 1e-5 ,
1622 no_weight_decay_list : Collection [str ] = (),
23+ simple_params_list : Collection [str ] = (),
1724):
18- no_weight_decay_list = set (no_weight_decay_list )
1925 decay = []
26+ decay_simple = []
2027 no_decay = []
28+ no_decay_simple = []
2129 for name , param in model .named_parameters ():
2230 if not param .requires_grad :
2331 continue
2432
25- if param .ndim <= 1 or name .endswith (".bias" ) or name in no_weight_decay_list :
26- no_decay .append (param )
33+ # Determine if this is a "simple" parameter for fallback optimizer (if available)
34+ is_simple = _matches_pattern (name , no_weight_decay_list )
35+
36+ # Determine weight decay
37+ matches_pattern = _matches_pattern (name , no_weight_decay_list )
38+ if param .ndim <= 1 or name .endswith (".bias" ) or matches_pattern :
39+ # No weight decay
40+ if is_simple :
41+ no_decay_simple .append (param )
42+ else :
43+ no_decay .append (param )
2744 else :
28- decay .append (param )
29-
30- return [
31- {'params' : no_decay , 'weight_decay' : 0. },
32- {'params' : decay , 'weight_decay' : weight_decay }]
33-
45+ # With weight decay
46+ if is_simple :
47+ decay_simple .append (param )
48+ else :
49+ decay .append (param )
50+
51+ groups = []
52+ if decay :
53+ groups .append ({'params' : decay , 'weight_decay' : weight_decay })
54+ if decay_simple :
55+ groups .append ({'params' : decay_simple , 'weight_decay' : weight_decay , 'simple' : True })
56+ if no_decay :
57+ groups .append ({'params' : no_decay , 'weight_decay' : 0. })
58+ if no_decay_simple :
59+ groups .append ({'params' : no_decay_simple , 'weight_decay' : 0. , 'simple' : True })
60+
61+ return groups
3462
3563def _group (it , size ):
3664 it = iter (it )
@@ -70,9 +98,9 @@ def param_groups_layer_decay(
7098 model : nn .Module ,
7199 weight_decay : float = 0.05 ,
72100 no_weight_decay_list : Collection [str ] = (),
101+ simple_params_list : Collection [str ] = (),
73102 weight_decay_exclude_1d : bool = True ,
74103 layer_decay : float = .75 ,
75- end_layer_decay : Optional [float ] = None ,
76104 min_scale : float = 0. ,
77105 no_opt_scale : Optional [float ] = None ,
78106 verbose : bool = False ,
@@ -81,7 +109,6 @@ def param_groups_layer_decay(
81109 Parameter groups for layer-wise lr decay & weight decay
82110 Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
83111 """
84- no_weight_decay_list = set (no_weight_decay_list )
85112 param_group_names = {} # NOTE for debugging
86113 param_groups = {}
87114
@@ -99,8 +126,12 @@ def param_groups_layer_decay(
99126 if not param .requires_grad :
100127 continue
101128
102- # no decay: all 1D parameters and model specific ones
103- if (weight_decay_exclude_1d and param .ndim <= 1 ) or name in no_weight_decay_list :
129+ # Determine if this is a "simple" parameter for fallback optimizer (if available)
130+ is_simple = _matches_pattern (name , simple_params_list )
131+
132+ # Determine weight decay
133+ if (weight_decay_exclude_1d and param .ndim <= 1 ) or _matches_pattern (name , no_weight_decay_list ):
134+ # no weight decay for 1D parameters and model specific ones
104135 g_decay = "no_decay"
105136 this_decay = 0.
106137 else :
@@ -114,18 +145,23 @@ def param_groups_layer_decay(
114145 param .requires_grad = False
115146 continue
116147
117- group_name = "layer_%d_%s" % (layer_id , g_decay )
148+ simple_suffix = "_simple" if is_simple else ""
149+ group_name = "layer_%d_%s%s" % (layer_id , g_decay , simple_suffix )
150+
118151 if group_name not in param_groups :
119152 param_group_names [group_name ] = {
120153 "lr_scale" : this_scale ,
121154 "weight_decay" : this_decay ,
155+ "simple" : is_simple ,
122156 "param_names" : [],
123157 }
124158 param_groups [group_name ] = {
125159 "lr_scale" : this_scale ,
126160 "weight_decay" : this_decay ,
127161 "params" : [],
128162 }
163+ if is_simple :
164+ param_groups [group_name ]["simple" ] = True
129165
130166 param_group_names [group_name ]["param_names" ].append (name )
131167 param_groups [group_name ]["params" ].append (param )
0 commit comments