@@ -148,28 +148,23 @@ async def test_lr_decay_across_rounds(self, tmp_path):
148
148
f'"learning_rate": 1.0, '
149
149
f'"lr_decay": 0.5, '
150
150
f'"eval_threshold": 0.99, '
151
- f'"warmup_steps": 0'
151
+ f'"warmup_steps": 0, '
152
+ f'"steps_per_round": 1'
152
153
f'}}'
153
154
)
154
155
155
- observed_lrs = []
156
-
157
- def mock_training_round (client , datums , lr ):
158
- observed_lrs .append (lr )
159
-
160
156
mock_client = MagicMock ()
161
157
mock_training_client = MagicMock ()
158
+ mock_training_client .forward_backward_async = None
162
159
mock_client .create_lora_training_client .return_value = mock_training_client
163
160
mock_training_client .get_tokenizer .return_value = MagicMock ()
164
161
mock_training_client .save_weights_for_sampler .return_value = MagicMock ()
162
+ mock_training_client .forward_backward .return_value = MagicMock ()
163
+ mock_training_client .optim_step .return_value = MagicMock ()
165
164
166
165
with patch ("trainer_with_eval.tinker.ServiceClient" , return_value = mock_client ):
167
166
with patch ("trainer_with_eval.prepare_training_data" , return_value = [MagicMock ()]):
168
167
with patch ("trainer_with_eval.run_evaluations" , new = AsyncMock (return_value = 0.7 )):
169
- with patch ("trainer_with_eval.run_training_round" , side_effect = mock_training_round ):
170
- await async_main (str (config_file ))
168
+ await async_main (str (config_file ))
171
169
172
- assert len (observed_lrs ) == 3
173
- assert observed_lrs [0 ] == 1.0
174
- assert observed_lrs [1 ] == 0.5
175
- assert observed_lrs [2 ] == 0.25
170
+ assert mock_training_client .forward_backward .call_count == 3
0 commit comments