|
| 1 | +""" |
| 2 | +Mock Tinker client for offline demos and testing. |
| 3 | +
|
| 4 | +Allows running the evaluation loop without cloud API access. |
| 5 | +""" |
| 6 | + |
| 7 | +import asyncio |
| 8 | +import json |
| 9 | +from pathlib import Path |
| 10 | +from typing import Any, Dict, List, Optional |
| 11 | +import numpy as np |
| 12 | + |
| 13 | + |
| 14 | +class MockTokenizer: |
| 15 | + """Mock tokenizer for offline mode.""" |
| 16 | + |
| 17 | + def encode(self, text: str, add_special_tokens: bool = True) -> List[int]: |
| 18 | + tokens = text.split() |
| 19 | + return list(range(len(tokens))) |
| 20 | + |
| 21 | + def decode(self, tokens: List[int]) -> str: |
| 22 | + return f"<decoded_{len(tokens)}_tokens>" |
| 23 | + |
| 24 | + |
| 25 | +class MockFuture: |
| 26 | + """Mock future for sync API.""" |
| 27 | + |
| 28 | + def __init__(self, value: Any): |
| 29 | + self.value = value |
| 30 | + |
| 31 | + def result(self): |
| 32 | + return self.value |
| 33 | + |
| 34 | + |
| 35 | +class MockSaveResult: |
| 36 | + """Mock save result with path.""" |
| 37 | + |
| 38 | + def __init__(self, name: str, checkpoint_dir: Path): |
| 39 | + self.path = f"mock://checkpoint/{name}" |
| 40 | + checkpoint_dir.mkdir(parents=True, exist_ok=True) |
| 41 | + checkpoint_file = checkpoint_dir / f"{name}.json" |
| 42 | + checkpoint_file.write_text(json.dumps({"name": name, "mock": True})) |
| 43 | + |
| 44 | + |
| 45 | +class MockTrainingClient: |
| 46 | + """Mock LoRA training client for offline demos.""" |
| 47 | + |
| 48 | + def __init__(self, base_model: str, rank: int, checkpoint_dir: Path): |
| 49 | + self.base_model = base_model |
| 50 | + self.rank = rank |
| 51 | + self.checkpoint_dir = checkpoint_dir |
| 52 | + self.step_count = 0 |
| 53 | + self.current_loss = 2.5 |
| 54 | + |
| 55 | + def get_tokenizer(self): |
| 56 | + return MockTokenizer() |
| 57 | + |
| 58 | + def forward_backward(self, datums: List[Any], loss_fn: str = "cross_entropy"): |
| 59 | + self.step_count += 1 |
| 60 | + self.current_loss *= 0.95 |
| 61 | + return MockFuture({"loss": self.current_loss}) |
| 62 | + |
| 63 | + async def forward_backward_async(self, datums: List[Any], loss_fn: str = "cross_entropy"): |
| 64 | + await asyncio.sleep(0.01) |
| 65 | + self.step_count += 1 |
| 66 | + self.current_loss *= 0.95 |
| 67 | + |
| 68 | + class AsyncFuture: |
| 69 | + async def __await__(self): |
| 70 | + return {"loss": self.current_loss} |
| 71 | + |
| 72 | + def __await__(self): |
| 73 | + async def _wait(): |
| 74 | + return {"loss": self.current_loss} |
| 75 | + return _wait().__await__() |
| 76 | + |
| 77 | + return AsyncFuture() |
| 78 | + |
| 79 | + def optim_step(self, params): |
| 80 | + return MockFuture({"success": True}) |
| 81 | + |
| 82 | + async def optim_step_async(self, params): |
| 83 | + await asyncio.sleep(0.01) |
| 84 | + |
| 85 | + class AsyncFuture: |
| 86 | + def __await__(self): |
| 87 | + async def _wait(): |
| 88 | + return {"success": True} |
| 89 | + return _wait().__await__() |
| 90 | + |
| 91 | + return AsyncFuture() |
| 92 | + |
| 93 | + def save_weights_for_sampler(self, name: str = "checkpoint"): |
| 94 | + result = MockSaveResult(name, self.checkpoint_dir) |
| 95 | + return MockFuture(result) |
| 96 | + |
| 97 | + def save_state(self, name: str = "checkpoint"): |
| 98 | + result = MockSaveResult(f"{name}_state", self.checkpoint_dir) |
| 99 | + return MockFuture(result) |
| 100 | + |
| 101 | + def load_state(self, path: str): |
| 102 | + print(f"Loaded state from {path}") |
| 103 | + return MockFuture({"success": True}) |
| 104 | + |
| 105 | + |
| 106 | +class MockSamplingClient: |
| 107 | + """Mock sampling client for evaluations.""" |
| 108 | + |
| 109 | + def sample(self, prompt, sampling_params, num_samples=1): |
| 110 | + return MockFuture({"sequences": [{"tokens": [1, 2, 3]}]}) |
| 111 | + |
| 112 | + |
| 113 | +class MockServiceClient: |
| 114 | + """Mock Tinker service client.""" |
| 115 | + |
| 116 | + def __init__(self, checkpoint_dir: Optional[Path] = None): |
| 117 | + self.checkpoint_dir = checkpoint_dir or Path("./mock_checkpoints") |
| 118 | + self.checkpoint_dir.mkdir(exist_ok=True) |
| 119 | + |
| 120 | + def create_lora_training_client(self, base_model: str, rank: int = 16): |
| 121 | + print(f"[MOCK MODE] Creating LoRA training client for {base_model} (rank={rank})") |
| 122 | + return MockTrainingClient(base_model, rank, self.checkpoint_dir) |
| 123 | + |
| 124 | + def create_sampling_client(self, base_model: Optional[str] = None, model_path: Optional[str] = None): |
| 125 | + print(f"[MOCK MODE] Creating sampling client for {model_path or base_model}") |
| 126 | + return MockSamplingClient() |
| 127 | + |
| 128 | + def get_server_capabilities(self): |
| 129 | + class Capabilities: |
| 130 | + supported_models = [ |
| 131 | + type('Model', (), {'model_name': 'meta-llama/Llama-3.1-8B-Instruct'}), |
| 132 | + type('Model', (), {'model_name': 'meta-llama/Llama-3.1-70B'}), |
| 133 | + ] |
| 134 | + return Capabilities() |
| 135 | + |
| 136 | + |
| 137 | +class MockTypes: |
| 138 | + """Mock types module.""" |
| 139 | + |
| 140 | + class Datum: |
| 141 | + def __init__(self, model_input, loss_fn_inputs): |
| 142 | + self.model_input = model_input |
| 143 | + self.loss_fn_inputs = loss_fn_inputs |
| 144 | + |
| 145 | + class ModelInput: |
| 146 | + @staticmethod |
| 147 | + def from_ints(tokens): |
| 148 | + return tokens |
| 149 | + |
| 150 | + def to_ints(self): |
| 151 | + return self if isinstance(self, list) else [] |
| 152 | + |
| 153 | + class AdamParams: |
| 154 | + def __init__(self, learning_rate: float): |
| 155 | + self.learning_rate = learning_rate |
| 156 | + |
| 157 | + class SamplingParams: |
| 158 | + def __init__(self, max_tokens: int = 100, temperature: float = 0.7, stop: List[str] = None, top_p: float = 1.0): |
| 159 | + self.max_tokens = max_tokens |
| 160 | + self.temperature = temperature |
| 161 | + self.stop = stop or [] |
| 162 | + self.top_p = top_p |
| 163 | + |
| 164 | + |
| 165 | +def create_mock_tinker_module(): |
| 166 | + """Create a mock tinker module for offline use.""" |
| 167 | + |
| 168 | + class MockTinker: |
| 169 | + ServiceClient = MockServiceClient |
| 170 | + types = MockTypes |
| 171 | + |
| 172 | + return MockTinker |
0 commit comments