@@ -29,6 +29,19 @@ def test_config(self):
2929 optimizer = LossScaleOptimizer (inner_optimizer )
3030 self .run_class_serialization_test (optimizer )
3131
32+ def test_apply_with_no_vars (self ):
33+ self ._skip_test_for_stateless (False )
34+
35+ inner_optimizer = SGD (learning_rate = 0.5 )
36+ optimizer = LossScaleOptimizer (inner_optimizer )
37+ grads = [ops .array ([1.0 , 6.0 , 7.0 , 2.0 ]) * optimizer .initial_scale ]
38+ vars = [backend .Variable ([1.0 , 2.0 , 3.0 , 4.0 ])]
39+ optimizer .build (vars )
40+ optimizer .apply (grads )
41+ self .assertAllClose (
42+ vars , [[0.5 , - 1.0 , - 0.5 , 3.0 ]], rtol = 1e-4 , atol = 1e-4
43+ )
44+
3245 @parameterized .named_parameters (("stateless" , True ), ("stateful" , False ))
3346 def test_finite_step (self , stateless ):
3447 self ._skip_test_for_stateless (stateless )
@@ -40,7 +53,9 @@ def test_finite_step(self, stateless):
4053 if stateless :
4154 optimizer .build (vars )
4255 vars , _ = optimizer .stateless_apply (
43- optimizer .variables , grads , vars
56+ [v .value for v in optimizer .variables ],
57+ grads ,
58+ [v .value for v in vars ],
4459 )
4560 else :
4661 optimizer .apply (grads , vars )
@@ -60,7 +75,9 @@ def test_finite_step_with_inner_loss_scale(self, stateless):
6075 if stateless :
6176 optimizer .build (vars )
6277 vars , _ = optimizer .stateless_apply (
63- optimizer .variables , grads , vars
78+ [v .value for v in optimizer .variables ],
79+ grads ,
80+ [v .value for v in vars ],
6481 )
6582 else :
6683 optimizer .apply (grads , vars )
@@ -79,7 +96,9 @@ def test_infinite_step(self, stateless):
7996 if stateless :
8097 optimizer .build (vars )
8198 vars , _ = optimizer .stateless_apply (
82- optimizer .variables , grads , vars
99+ [v .value for v in optimizer .variables ],
100+ grads ,
101+ [v .value for v in vars ],
83102 )
84103 else :
85104 optimizer .apply (grads , vars )
@@ -98,7 +117,9 @@ def test_finite_step_with_overwrite(self, stateless):
98117 if stateless :
99118 optimizer .build (vars )
100119 vars , _ = optimizer .stateless_apply (
101- optimizer .variables , grads , vars
120+ [v .value for v in optimizer .variables ],
121+ grads ,
122+ [v .value for v in vars ],
102123 )
103124 else :
104125 optimizer .apply (grads , vars )
@@ -112,12 +133,14 @@ def test_downscaling(self, stateless):
112133 optimizer = LossScaleOptimizer (inner_optimizer , initial_scale = 400.0 )
113134 vars = [backend .Variable ([1.0 , 2.0 , 3.0 , 4.0 ])]
114135 optimizer .build (vars )
115- opt_vars = optimizer .variables
136+ opt_var_values = [ v . value for v in optimizer .variables ]
116137 grads = [ops .array ([np .inf , np .inf , np .inf , np .inf ])]
117138 for _ in range (4 ):
118139 if stateless :
119- _ , opt_vars = optimizer .stateless_apply (opt_vars , grads , vars )
120- for ref_v , v in zip (optimizer .variables , opt_vars ):
140+ _ , opt_var_values = optimizer .stateless_apply (
141+ opt_var_values , grads , [v .value for v in vars ]
142+ )
143+ for ref_v , v in zip (optimizer .variables , opt_var_values ):
121144 ref_v .assign (v )
122145 else :
123146 optimizer .apply (grads , vars )
@@ -135,12 +158,14 @@ def test_upscaling(self, stateless):
135158 )
136159 vars = [backend .Variable ([1.0 , 2.0 , 3.0 , 4.0 ])]
137160 optimizer .build (vars )
138- opt_vars = optimizer .variables
161+ opt_var_values = [ v . value for v in optimizer .variables ]
139162 grads = [ops .array ([1.0 , 6.0 , 7.0 , 2.0 ])]
140163 for _ in range (8 ):
141164 if stateless :
142- _ , opt_vars = optimizer .stateless_apply (opt_vars , grads , vars )
143- for ref_v , v in zip (optimizer .variables , opt_vars ):
165+ _ , opt_var_values = optimizer .stateless_apply (
166+ opt_var_values , grads , [v .value for v in vars ]
167+ )
168+ for ref_v , v in zip (optimizer .variables , opt_var_values ):
144169 ref_v .assign (v )
145170 else :
146171 optimizer .apply (grads , vars )
@@ -154,16 +179,104 @@ def test_iterations_update(self, stateless):
154179 optimizer = LossScaleOptimizer (inner_optimizer )
155180 vars = [backend .Variable ([1.0 , 2.0 , 3.0 , 4.0 ])]
156181 optimizer .build (vars )
157- opt_vars = optimizer .variables
182+ opt_var_values = [ v . value for v in optimizer .variables ]
158183 grads = [ops .array ([1.0 , 6.0 , 7.0 , 2.0 ])]
159184
160185 self .assertEqual (optimizer .iterations .value , 0 )
161186
162187 for i in range (3 ):
163188 if stateless :
164- _ , opt_vars = optimizer .stateless_apply (opt_vars , grads , vars )
165- for ref_v , v in zip (optimizer .variables , opt_vars ):
189+ _ , opt_var_values = optimizer .stateless_apply (
190+ opt_var_values , grads , [v .value for v in vars ]
191+ )
192+ for ref_v , v in zip (optimizer .variables , opt_var_values ):
166193 ref_v .assign (v )
167194 else :
168195 optimizer .apply (grads , vars )
169196 self .assertEqual (optimizer .iterations .value , i + 1 )
197+
198+ def test_serialization (self ):
199+ inner_optimizer = SGD (learning_rate = 0.5 )
200+ optimizer = LossScaleOptimizer (
201+ inner_optimizer ,
202+ initial_scale = 3.0 ,
203+ dynamic_growth_steps = 2 ,
204+ name = "test_opt" ,
205+ )
206+ config = optimizer .get_config ()
207+ self .assertLen (config , 4 )
208+ self .assertEqual (config ["name" ], "test_opt" )
209+ self .assertEqual (config ["initial_scale" ], 3.0 )
210+ self .assertEqual (config ["dynamic_growth_steps" ], 2 )
211+ self .assertIn ("inner_optimizer" , config )
212+ LossScaleOptimizer .from_config (config )
213+
214+ def test_init_dynamic_arg (self ):
215+ inner_optimizer = SGD (learning_rate = 0.5 )
216+
217+ # dynamic=True is supported
218+ LossScaleOptimizer (inner_optimizer , dynamic = True )
219+
220+ # dynamic=False is not supported
221+ with self .assertRaisesRegex (ValueError , "set `loss_scale_factor`" ):
222+ LossScaleOptimizer (inner_optimizer , dynamic = False )
223+
224+ def test_init_unsupported_arg (self ):
225+ inner_optimizer = SGD (learning_rate = 0.5 )
226+ with self .assertRaisesRegex (ValueError , "arguments: `foo`, `bar`" ):
227+ LossScaleOptimizer (inner_optimizer , foo = True , bar = 3 )
228+
229+ @parameterized .named_parameters (
230+ ("weight_decay" , "weight_decay" , 0.5 ),
231+ ("clipnorm" , "clipnorm" , 0.5 ),
232+ ("global_clipnorm" , "global_clipnorm" , 0.5 ),
233+ ("clipvalue" , "clipvalue" , 0.5 ),
234+ ("use_ema" , "use_ema" , True ),
235+ ("ema_momentum" , "ema_momentum" , 0.5 ),
236+ ("ema_overwrite_frequency" , "ema_overwrite_frequency" , 2 ),
237+ ("loss_scale_factor" , "loss_scale_factor" , 0.5 ),
238+ ("gradient_accumulation_steps" , "gradient_accumulation_steps" , 2 ),
239+ )
240+ def test_init_base_optimizer_unsupported_args (self , arg_name , arg_value ):
241+ inner_optimizer = SGD (learning_rate = 0.5 )
242+ with self .assertRaisesRegex (ValueError , "on the `inner_optimizer`" ):
243+ LossScaleOptimizer (inner_optimizer , ** {arg_name : arg_value })
244+
245+ def test_deserialization_backwards_compatibility (self ):
246+ # Test deserializing with a config that has all the unsupported
247+ # arguments from the base optimizer (which are no longer serialized)
248+ config = {
249+ "name" : "loss_scale_optimizer" ,
250+ "weight_decay" : None ,
251+ "clipnorm" : None ,
252+ "global_clipnorm" : None ,
253+ "clipvalue" : None ,
254+ "use_ema" : False ,
255+ "ema_momentum" : 0.99 ,
256+ "ema_overwrite_frequency" : None ,
257+ "loss_scale_factor" : None ,
258+ "gradient_accumulation_steps" : None ,
259+ "inner_optimizer" : {
260+ "module" : "keras.optimizers" ,
261+ "class_name" : "SGD" ,
262+ "config" : {
263+ "name" : "SGD" ,
264+ "learning_rate" : 0.5 ,
265+ "weight_decay" : None ,
266+ "clipnorm" : None ,
267+ "global_clipnorm" : None ,
268+ "clipvalue" : None ,
269+ "use_ema" : False ,
270+ "ema_momentum" : 0.99 ,
271+ "ema_overwrite_frequency" : None ,
272+ "loss_scale_factor" : None ,
273+ "gradient_accumulation_steps" : None ,
274+ "momentum" : 0.0 ,
275+ "nesterov" : False ,
276+ },
277+ "registered_name" : None ,
278+ },
279+ "initial_scale" : 2.0 ,
280+ "dynamic_growth_steps" : 2 ,
281+ }
282+ LossScaleOptimizer .from_config (config )
0 commit comments