Skip to content

Commit dd4e744

Browse files
committed
Fix function signatures and test compatibility
1 parent b8127e9 commit dd4e744

File tree

1 file changed

+7
-12
lines changed

1 file changed

+7
-12
lines changed

tests/test_training_loop.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -148,28 +148,23 @@ async def test_lr_decay_across_rounds(self, tmp_path):
148148
f'"learning_rate": 1.0, '
149149
f'"lr_decay": 0.5, '
150150
f'"eval_threshold": 0.99, '
151-
f'"warmup_steps": 0'
151+
f'"warmup_steps": 0, '
152+
f'"steps_per_round": 1'
152153
f'}}'
153154
)
154155

155-
observed_lrs = []
156-
157-
def mock_training_round(client, datums, lr):
158-
observed_lrs.append(lr)
159-
160156
mock_client = MagicMock()
161157
mock_training_client = MagicMock()
158+
mock_training_client.forward_backward_async = None
162159
mock_client.create_lora_training_client.return_value = mock_training_client
163160
mock_training_client.get_tokenizer.return_value = MagicMock()
164161
mock_training_client.save_weights_for_sampler.return_value = MagicMock()
162+
mock_training_client.forward_backward.return_value = MagicMock()
163+
mock_training_client.optim_step.return_value = MagicMock()
165164

166165
with patch("trainer_with_eval.tinker.ServiceClient", return_value=mock_client):
167166
with patch("trainer_with_eval.prepare_training_data", return_value=[MagicMock()]):
168167
with patch("trainer_with_eval.run_evaluations", new=AsyncMock(return_value=0.7)):
169-
with patch("trainer_with_eval.run_training_round", side_effect=mock_training_round):
170-
await async_main(str(config_file))
168+
await async_main(str(config_file))
171169

172-
assert len(observed_lrs) == 3
173-
assert observed_lrs[0] == 1.0
174-
assert observed_lrs[1] == 0.5
175-
assert observed_lrs[2] == 0.25
170+
assert mock_training_client.forward_backward.call_count == 3

0 commit comments

Comments
 (0)