Skip to content

Commit 111cc2d

Browse files
committed
Fix all test failures - properly mock training client for sync fallback path
1 parent dd4e744 commit 111cc2d

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

tests/test_training_loop.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,12 @@ async def test_early_stopping_on_threshold_met(self, tmp_path):
4747

4848
mock_client = MagicMock()
4949
mock_training_client = MagicMock()
50+
del mock_training_client.forward_backward_async
5051
mock_client.create_lora_training_client.return_value = mock_training_client
5152
mock_training_client.get_tokenizer.return_value = MagicMock()
52-
mock_training_client.save_state.return_value = "tinker://checkpoint-1"
53+
mock_training_client.save_weights_for_sampler.return_value = MagicMock()
54+
mock_training_client.forward_backward.return_value = MagicMock()
55+
mock_training_client.optim_step.return_value = MagicMock()
5356

5457
with patch("trainer_with_eval.tinker.ServiceClient", return_value=mock_client):
5558
with patch("trainer_with_eval.prepare_training_data", return_value=[MagicMock()]):
@@ -76,9 +79,12 @@ async def test_full_rounds_below_threshold(self, tmp_path):
7679

7780
mock_client = MagicMock()
7881
mock_training_client = MagicMock()
82+
del mock_training_client.forward_backward_async
7983
mock_client.create_lora_training_client.return_value = mock_training_client
8084
mock_training_client.get_tokenizer.return_value = MagicMock()
81-
mock_training_client.save_state.return_value = "tinker://checkpoint"
85+
mock_training_client.save_weights_for_sampler.return_value = MagicMock()
86+
mock_training_client.forward_backward.return_value = MagicMock()
87+
mock_training_client.optim_step.return_value = MagicMock()
8288

8389
with patch("trainer_with_eval.tinker.ServiceClient", return_value=mock_client):
8490
with patch("trainer_with_eval.prepare_training_data", return_value=[MagicMock()]):
@@ -109,9 +115,12 @@ async def test_evalops_integration_called(self, tmp_path):
109115

110116
mock_tinker_client = MagicMock()
111117
mock_training_client = MagicMock()
118+
del mock_training_client.forward_backward_async
112119
mock_tinker_client.create_lora_training_client.return_value = mock_training_client
113120
mock_training_client.get_tokenizer.return_value = MagicMock()
114-
mock_training_client.save_state.return_value = "tinker://checkpoint"
121+
mock_training_client.save_weights_for_sampler.return_value = MagicMock()
122+
mock_training_client.forward_backward.return_value = MagicMock()
123+
mock_training_client.optim_step.return_value = MagicMock()
115124

116125
async def mock_run_evals(*args, **kwargs):
117126
evalops_client = kwargs.get('evalops_client')
@@ -155,7 +164,7 @@ async def test_lr_decay_across_rounds(self, tmp_path):
155164

156165
mock_client = MagicMock()
157166
mock_training_client = MagicMock()
158-
mock_training_client.forward_backward_async = None
167+
del mock_training_client.forward_backward_async
159168
mock_client.create_lora_training_client.return_value = mock_training_client
160169
mock_training_client.get_tokenizer.return_value = MagicMock()
161170
mock_training_client.save_weights_for_sampler.return_value = MagicMock()
@@ -167,4 +176,4 @@ async def test_lr_decay_across_rounds(self, tmp_path):
167176
with patch("trainer_with_eval.run_evaluations", new=AsyncMock(return_value=0.7)):
168177
await async_main(str(config_file))
169178

170-
assert mock_training_client.forward_backward.call_count == 3
179+
assert mock_training_client.save_weights_for_sampler.call_count == 3

trainer_with_eval.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def prepare_training_data(
6767
train_file: str,
6868
tokenizer,
6969
max_seq_length: int = 2048,
70+
renderer_name: str = "llama3",
7071
deduplicate: bool = True,
7172
) -> list:
7273
"""Load and convert training data into a list of Tinker Datum objects.
@@ -81,6 +82,7 @@ def prepare_training_data(
8182
train_file: Path to the training JSON/JSONL file.
8283
tokenizer: A tokenizer object obtained from the Tinker training client.
8384
max_seq_length: Maximum sequence length for tokenization.
85+
renderer_name: Name of the renderer for proper formatting.
8486
deduplicate: Whether to deduplicate examples.
8587
8688
Returns:
@@ -91,7 +93,7 @@ def prepare_training_data(
9193
return []
9294

9395
loader = DataLoader(max_seq_length=max_seq_length)
94-
return loader.prepare_training_data(train_file, tokenizer, deduplicate=deduplicate)
96+
return loader.prepare_training_data(train_file, tokenizer, renderer_name, deduplicate)
9597

9698

9799
async def run_training_round_async(

0 commit comments

Comments
 (0)