@@ -47,9 +47,12 @@ async def test_early_stopping_on_threshold_met(self, tmp_path):
47
47
48
48
mock_client = MagicMock ()
49
49
mock_training_client = MagicMock ()
50
+ del mock_training_client .forward_backward_async
50
51
mock_client .create_lora_training_client .return_value = mock_training_client
51
52
mock_training_client .get_tokenizer .return_value = MagicMock ()
52
- mock_training_client .save_state .return_value = "tinker://checkpoint-1"
53
+ mock_training_client .save_weights_for_sampler .return_value = MagicMock ()
54
+ mock_training_client .forward_backward .return_value = MagicMock ()
55
+ mock_training_client .optim_step .return_value = MagicMock ()
53
56
54
57
with patch ("trainer_with_eval.tinker.ServiceClient" , return_value = mock_client ):
55
58
with patch ("trainer_with_eval.prepare_training_data" , return_value = [MagicMock ()]):
@@ -76,9 +79,12 @@ async def test_full_rounds_below_threshold(self, tmp_path):
76
79
77
80
mock_client = MagicMock ()
78
81
mock_training_client = MagicMock ()
82
+ del mock_training_client .forward_backward_async
79
83
mock_client .create_lora_training_client .return_value = mock_training_client
80
84
mock_training_client .get_tokenizer .return_value = MagicMock ()
81
- mock_training_client .save_state .return_value = "tinker://checkpoint"
85
+ mock_training_client .save_weights_for_sampler .return_value = MagicMock ()
86
+ mock_training_client .forward_backward .return_value = MagicMock ()
87
+ mock_training_client .optim_step .return_value = MagicMock ()
82
88
83
89
with patch ("trainer_with_eval.tinker.ServiceClient" , return_value = mock_client ):
84
90
with patch ("trainer_with_eval.prepare_training_data" , return_value = [MagicMock ()]):
@@ -109,9 +115,12 @@ async def test_evalops_integration_called(self, tmp_path):
109
115
110
116
mock_tinker_client = MagicMock ()
111
117
mock_training_client = MagicMock ()
118
+ del mock_training_client .forward_backward_async
112
119
mock_tinker_client .create_lora_training_client .return_value = mock_training_client
113
120
mock_training_client .get_tokenizer .return_value = MagicMock ()
114
- mock_training_client .save_state .return_value = "tinker://checkpoint"
121
+ mock_training_client .save_weights_for_sampler .return_value = MagicMock ()
122
+ mock_training_client .forward_backward .return_value = MagicMock ()
123
+ mock_training_client .optim_step .return_value = MagicMock ()
115
124
116
125
async def mock_run_evals (* args , ** kwargs ):
117
126
evalops_client = kwargs .get ('evalops_client' )
@@ -155,7 +164,7 @@ async def test_lr_decay_across_rounds(self, tmp_path):
155
164
156
165
mock_client = MagicMock ()
157
166
mock_training_client = MagicMock ()
158
- mock_training_client .forward_backward_async = None
167
+ del mock_training_client .forward_backward_async
159
168
mock_client .create_lora_training_client .return_value = mock_training_client
160
169
mock_training_client .get_tokenizer .return_value = MagicMock ()
161
170
mock_training_client .save_weights_for_sampler .return_value = MagicMock ()
@@ -167,4 +176,4 @@ async def test_lr_decay_across_rounds(self, tmp_path):
167
176
with patch ("trainer_with_eval.run_evaluations" , new = AsyncMock (return_value = 0.7 )):
168
177
await async_main (str (config_file ))
169
178
170
- assert mock_training_client .forward_backward .call_count == 3
179
+ assert mock_training_client .save_weights_for_sampler .call_count == 3
0 commit comments