Skip to content

Commit 2af394c

Browse files
committed
Add mock Tinker client for offline demos
1 parent b5f1619 commit 2af394c

File tree

1 file changed

+172
-0
lines changed

1 file changed

+172
-0
lines changed

mock_tinker.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
"""
2+
Mock Tinker client for offline demos and testing.
3+
4+
Allows running the evaluation loop without cloud API access.
5+
"""
6+
7+
import asyncio
8+
import json
9+
from pathlib import Path
10+
from typing import Any, Dict, List, Optional
11+
import numpy as np
12+
13+
14+
class MockTokenizer:
15+
"""Mock tokenizer for offline mode."""
16+
17+
def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
18+
tokens = text.split()
19+
return list(range(len(tokens)))
20+
21+
def decode(self, tokens: List[int]) -> str:
22+
return f"<decoded_{len(tokens)}_tokens>"
23+
24+
25+
class MockFuture:
26+
"""Mock future for sync API."""
27+
28+
def __init__(self, value: Any):
29+
self.value = value
30+
31+
def result(self):
32+
return self.value
33+
34+
35+
class MockSaveResult:
36+
"""Mock save result with path."""
37+
38+
def __init__(self, name: str, checkpoint_dir: Path):
39+
self.path = f"mock://checkpoint/{name}"
40+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
41+
checkpoint_file = checkpoint_dir / f"{name}.json"
42+
checkpoint_file.write_text(json.dumps({"name": name, "mock": True}))
43+
44+
45+
class MockTrainingClient:
46+
"""Mock LoRA training client for offline demos."""
47+
48+
def __init__(self, base_model: str, rank: int, checkpoint_dir: Path):
49+
self.base_model = base_model
50+
self.rank = rank
51+
self.checkpoint_dir = checkpoint_dir
52+
self.step_count = 0
53+
self.current_loss = 2.5
54+
55+
def get_tokenizer(self):
56+
return MockTokenizer()
57+
58+
def forward_backward(self, datums: List[Any], loss_fn: str = "cross_entropy"):
59+
self.step_count += 1
60+
self.current_loss *= 0.95
61+
return MockFuture({"loss": self.current_loss})
62+
63+
async def forward_backward_async(self, datums: List[Any], loss_fn: str = "cross_entropy"):
64+
await asyncio.sleep(0.01)
65+
self.step_count += 1
66+
self.current_loss *= 0.95
67+
68+
class AsyncFuture:
69+
async def __await__(self):
70+
return {"loss": self.current_loss}
71+
72+
def __await__(self):
73+
async def _wait():
74+
return {"loss": self.current_loss}
75+
return _wait().__await__()
76+
77+
return AsyncFuture()
78+
79+
def optim_step(self, params):
80+
return MockFuture({"success": True})
81+
82+
async def optim_step_async(self, params):
83+
await asyncio.sleep(0.01)
84+
85+
class AsyncFuture:
86+
def __await__(self):
87+
async def _wait():
88+
return {"success": True}
89+
return _wait().__await__()
90+
91+
return AsyncFuture()
92+
93+
def save_weights_for_sampler(self, name: str = "checkpoint"):
94+
result = MockSaveResult(name, self.checkpoint_dir)
95+
return MockFuture(result)
96+
97+
def save_state(self, name: str = "checkpoint"):
98+
result = MockSaveResult(f"{name}_state", self.checkpoint_dir)
99+
return MockFuture(result)
100+
101+
def load_state(self, path: str):
102+
print(f"Loaded state from {path}")
103+
return MockFuture({"success": True})
104+
105+
106+
class MockSamplingClient:
107+
"""Mock sampling client for evaluations."""
108+
109+
def sample(self, prompt, sampling_params, num_samples=1):
110+
return MockFuture({"sequences": [{"tokens": [1, 2, 3]}]})
111+
112+
113+
class MockServiceClient:
114+
"""Mock Tinker service client."""
115+
116+
def __init__(self, checkpoint_dir: Optional[Path] = None):
117+
self.checkpoint_dir = checkpoint_dir or Path("./mock_checkpoints")
118+
self.checkpoint_dir.mkdir(exist_ok=True)
119+
120+
def create_lora_training_client(self, base_model: str, rank: int = 16):
121+
print(f"[MOCK MODE] Creating LoRA training client for {base_model} (rank={rank})")
122+
return MockTrainingClient(base_model, rank, self.checkpoint_dir)
123+
124+
def create_sampling_client(self, base_model: Optional[str] = None, model_path: Optional[str] = None):
125+
print(f"[MOCK MODE] Creating sampling client for {model_path or base_model}")
126+
return MockSamplingClient()
127+
128+
def get_server_capabilities(self):
129+
class Capabilities:
130+
supported_models = [
131+
type('Model', (), {'model_name': 'meta-llama/Llama-3.1-8B-Instruct'}),
132+
type('Model', (), {'model_name': 'meta-llama/Llama-3.1-70B'}),
133+
]
134+
return Capabilities()
135+
136+
137+
class MockTypes:
138+
"""Mock types module."""
139+
140+
class Datum:
141+
def __init__(self, model_input, loss_fn_inputs):
142+
self.model_input = model_input
143+
self.loss_fn_inputs = loss_fn_inputs
144+
145+
class ModelInput:
146+
@staticmethod
147+
def from_ints(tokens):
148+
return tokens
149+
150+
def to_ints(self):
151+
return self if isinstance(self, list) else []
152+
153+
class AdamParams:
154+
def __init__(self, learning_rate: float):
155+
self.learning_rate = learning_rate
156+
157+
class SamplingParams:
158+
def __init__(self, max_tokens: int = 100, temperature: float = 0.7, stop: List[str] = None, top_p: float = 1.0):
159+
self.max_tokens = max_tokens
160+
self.temperature = temperature
161+
self.stop = stop or []
162+
self.top_p = top_p
163+
164+
165+
def create_mock_tinker_module():
166+
"""Create a mock tinker module for offline use."""
167+
168+
class MockTinker:
169+
ServiceClient = MockServiceClient
170+
types = MockTypes
171+
172+
return MockTinker

0 commit comments

Comments
 (0)