Skip to content

Commit d71fcbc

Browse files
committed
Add nfsstore bandwidth testing script
1 parent 6af3964 commit d71fcbc

File tree

1 file changed

+128
-60
lines changed

1 file changed

+128
-60
lines changed

ucm/store/test/e2e/nfsstore_embed_fetch.py

Lines changed: 128 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2323
# SOFTWARE.
2424
#
25+
import csv
2526
import os
2627
import secrets
2728
import time
28-
import csv
29-
from typing import List, Dict
29+
from typing import Dict, List
3030

3131
import torch
3232

@@ -46,7 +46,9 @@ def setup_store(storage_backends, block_size, device_id, io_size) -> UcmKVStoreB
4646
return UcmNfsStore(config)
4747

4848

49-
def make_buffers(block_number, device_id, batch_size, head_dim, block_len, block_layer, num_head):
49+
def make_buffers(
50+
block_number, device_id, batch_size, head_dim, block_len, block_layer, num_head
51+
):
5052
hashes = [secrets.token_hex(16) for _ in range(block_number)]
5153
kv_caches = {}
5254
for i in range(block_layer):
@@ -65,13 +67,20 @@ def store_all_hashes(hashes: List[str]):
6567
f.write(h + "\n")
6668

6769

68-
def embed(store: UcmKVStoreBase, hashes: List[str], kvcaches: Dict[int, torch.Tensor],
69-
num_tokens: int, block_len: int, block_layer: int, block_dim: int):
70+
def embed(
71+
store: UcmKVStoreBase,
72+
hashes: List[str],
73+
kvcaches: Dict[int, torch.Tensor],
74+
num_tokens: int,
75+
block_len: int,
76+
block_layer: int,
77+
block_dim: int,
78+
):
7079
start_time = time.perf_counter()
71-
80+
7281
total_block_ids, total_offsets, total_tensors = [], [], []
7382
total_size = 0
74-
83+
7584
for i, hash_val in enumerate(hashes):
7685
offset = 0
7786
for layer_id, kv_layer in kvcaches.items():
@@ -85,27 +94,36 @@ def embed(store: UcmKVStoreBase, hashes: List[str], kvcaches: Dict[int, torch.Te
8594

8695
task = store.dump(total_block_ids, total_offsets, total_tensors)
8796
store.wait(task)
88-
97+
8998
elapsed_time = time.perf_counter() - start_time
90-
throughput_gbps = (total_size / (1024 ** 3)) / elapsed_time if elapsed_time > 0 else 0
91-
92-
print(f"WRITE: Data Size={(total_size / (1024 ** 3)):.4f} GB, Time={elapsed_time:.4f} s, "
93-
f"Speed={throughput_gbps:.4f} GB/s")
94-
99+
throughput_gbps = (total_size / (1024**3)) / elapsed_time if elapsed_time > 0 else 0
100+
101+
print(
102+
f"WRITE: Data Size={(total_size / (1024 ** 3)):.4f} GB, Time={elapsed_time:.4f} s, "
103+
f"Speed={throughput_gbps:.4f} GB/s"
104+
)
105+
95106
return total_size, elapsed_time, throughput_gbps
96107

97108

98-
def fetch(store: UcmKVStoreBase, hashes: List[str], kvcaches: Dict[int, torch.Tensor],
99-
num_tokens: int, block_len: int, block_layer: int, block_dim: int):
109+
def fetch(
110+
store: UcmKVStoreBase,
111+
hashes: List[str],
112+
kvcaches: Dict[int, torch.Tensor],
113+
num_tokens: int,
114+
block_len: int,
115+
block_layer: int,
116+
block_dim: int,
117+
):
100118
start_time = time.perf_counter()
101-
119+
102120
founds = store.lookup(hashes)
103121
for f in founds:
104122
assert f, "Cache block miss detected"
105123

106124
block_ids, offsets, tensors = [], [], []
107125
total_size = 0
108-
126+
109127
for i, hash_val in enumerate(hashes):
110128
offset = 0
111129
for layer_id, kv_layer in kvcaches.items():
@@ -120,33 +138,35 @@ def fetch(store: UcmKVStoreBase, hashes: List[str], kvcaches: Dict[int, torch.Te
120138
task = store.load(block_ids, offsets, tensors)
121139
ret = store.wait(task)
122140
assert ret == 0, "Load operation failed"
123-
141+
124142
elapsed_time = time.perf_counter() - start_time
125-
throughput_gbps = (total_size / (1024 ** 3)) / elapsed_time if elapsed_time > 0 else 0
126-
127-
print(f"READ: Data Size={(total_size / (1024 ** 3)):.4f} GB, Time={elapsed_time:.4f} s, "
128-
f"Speed={throughput_gbps:.4f} GB/s")
129-
143+
throughput_gbps = (total_size / (1024**3)) / elapsed_time if elapsed_time > 0 else 0
144+
145+
print(
146+
f"READ: Data Size={(total_size / (1024 ** 3)):.4f} GB, Time={elapsed_time:.4f} s, "
147+
f"Speed={throughput_gbps:.4f} GB/s"
148+
)
149+
130150
return total_size, elapsed_time, throughput_gbps
131151

132152

133153
def main():
134154
storage_backends = "."
135155
device_id = 1
136156
mla = False
137-
repeat = 3
157+
repeat = 3
138158
block_elem_size = 2
139159
num_tokens_list = [2048, 4096, 8192, 16384, 32768]
140-
160+
141161
if mla:
142-
block_lens = [ 64, 128 ]
162+
block_lens = [64, 128]
143163
block_layer = 61
144164
head_size = 576
145165
kv = 1
146166
model_name = "deepseek-v3"
147167
num_head_list = [1]
148168
else:
149-
block_lens = [ 128, 256 ]
169+
block_lens = [128, 256]
150170
block_layer = 64
151171
head_size = 128
152172
kv = 2
@@ -156,18 +176,31 @@ def main():
156176
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
157177
csv_file = os.path.join(SCRIPT_DIR, "embed_fetch_result1.csv")
158178
need_header = not os.path.exists(csv_file)
159-
179+
160180
with open(csv_file, "a", newline="", encoding="utf-8") as csv_fp:
161181
writer = csv.writer(csv_fp)
162-
182+
163183
if need_header:
164-
writer.writerow([
165-
"Model", "Sequence Length", "Batch Size", "Layers", "Element Size",
166-
"KV", "Num Head", "Block Size", "IO Count", "IO Size(B)",
167-
"Total Size(GB)", "Write Avg Time(s)", "Write Avg Bandwidth(GB/s)",
168-
"Read Avg Time(s)", "Read Avg Bandwidth(GB/s)"
169-
])
170-
184+
writer.writerow(
185+
[
186+
"Model",
187+
"Sequence Length",
188+
"Batch Size",
189+
"Layers",
190+
"Element Size",
191+
"KV",
192+
"Num Head",
193+
"Block Size",
194+
"IO Count",
195+
"IO Size(B)",
196+
"Total Size(GB)",
197+
"Write Avg Time(s)",
198+
"Write Avg Bandwidth(GB/s)",
199+
"Read Avg Time(s)",
200+
"Read Avg Bandwidth(GB/s)",
201+
]
202+
)
203+
171204
for num_head in num_head_list:
172205
for block_len in block_lens:
173206
block_dim = head_size * num_head
@@ -177,35 +210,58 @@ def main():
177210

178211
for num_tokens in num_tokens_list:
179212
sep = "=" * 60
180-
print(f"\n{sep}\n= num_head={num_head} | num_tokens={num_tokens:>6} | Repeat {repeat} times =\n{sep}\n")
213+
print(
214+
f"\n{sep}\n= num_head={num_head} | num_tokens={num_tokens:>6} | Repeat {repeat} times =\n{sep}\n"
215+
)
181216

182217
batch_size = int(num_tokens / block_len)
183218
io_num = int(num_tokens / block_len * block_layer)
184-
219+
185220
w_bw_list, r_bw_list = [], []
186221
w_time_list, r_time_list = [], []
187222
w_size_sum, r_size_sum = 0.0, 0.0
188223

189224
for r in range(repeat):
190225
print(f"\n--- Round {r+1} ---")
191-
store = setup_store(storage_backends, block_size, device_id, io_size)
192-
226+
store = setup_store(
227+
storage_backends, block_size, device_id, io_size
228+
)
229+
193230
hashes, kvcaches = make_buffers(
194-
real_blocks, device_id, batch_size, head_size,
195-
block_len, block_layer, num_head
231+
real_blocks,
232+
device_id,
233+
batch_size,
234+
head_size,
235+
block_len,
236+
block_layer,
237+
num_head,
196238
)
197239

198240
results = store.create(hashes[:batch_size])
199241
assert sum(results) == 0, "Create operation failed"
200242

201-
w_size, w_time, w_bw = embed(store, hashes[:batch_size], kvcaches,
202-
num_tokens, block_len, block_layer, block_dim)
243+
w_size, w_time, w_bw = embed(
244+
store,
245+
hashes[:batch_size],
246+
kvcaches,
247+
num_tokens,
248+
block_len,
249+
block_layer,
250+
block_dim,
251+
)
203252
store.commit(hashes[:batch_size], True)
204-
253+
205254
store_all_hashes(hashes[:batch_size])
206255

207-
r_size, r_time, r_bw = fetch(store, hashes[:batch_size], kvcaches,
208-
num_tokens, block_len, block_layer, block_dim)
256+
r_size, r_time, r_bw = fetch(
257+
store,
258+
hashes[:batch_size],
259+
kvcaches,
260+
num_tokens,
261+
block_len,
262+
block_layer,
263+
block_dim,
264+
)
209265

210266
w_bw_list.append(w_bw)
211267
r_bw_list.append(r_bw)
@@ -222,22 +278,34 @@ def main():
222278
avg_r_bw = sum(r_bw_list) / repeat
223279
avg_w_time = sum(w_time_list) / repeat
224280
avg_r_time = sum(r_time_list) / repeat
225-
avg_w_size = w_size_sum / (1024 ** 3) / repeat
226-
avg_r_size = r_size_sum / (1024 ** 3) / repeat
227-
228-
writer.writerow([
229-
model_name,
230-
num_tokens, batch_size, block_layer, block_elem_size,
231-
kv, num_head, block_len, io_num, io_size,
232-
f"{avg_w_size:.4f}", f"{avg_w_time:.4f}", f"{avg_w_bw:.4f}",
233-
f"{avg_r_time:.4f}", f"{avg_r_bw:.4f}"
234-
])
235-
281+
avg_w_size = w_size_sum / (1024**3) / repeat
282+
avg_r_size = r_size_sum / (1024**3) / repeat
283+
284+
writer.writerow(
285+
[
286+
model_name,
287+
num_tokens,
288+
batch_size,
289+
block_layer,
290+
block_elem_size,
291+
kv,
292+
num_head,
293+
block_len,
294+
io_num,
295+
io_size,
296+
f"{avg_w_size:.4f}",
297+
f"{avg_w_time:.4f}",
298+
f"{avg_w_bw:.4f}",
299+
f"{avg_r_time:.4f}",
300+
f"{avg_r_bw:.4f}",
301+
]
302+
)
303+
236304
csv_fp.flush()
237-
305+
238306
print("\n" + "=" * 60 + "\n= All combinations tested =\n" + "=" * 60 + "\n")
239307

240308

241309
if __name__ == "__main__":
242310
os.environ["UC_LOGGER_LEVEL"] = "debug"
243-
main()
311+
main()

0 commit comments

Comments
 (0)