Skip to content

Commit 4e81da0

Browse files
committed
Add nfsstore bandwidth testing script
1 parent 39c09d3 commit 4e81da0

File tree

2 files changed

+36
-25
lines changed

2 files changed

+36
-25
lines changed

ucm/store/test/e2e/nfsstore_embed_fetch.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,13 @@ def setup(
4949

5050

5151
def make_buffers(
52-
block_number, device_id, batch_size, head_dim, block_len, block_layer, num_head
52+
block_number, device_id, batch_size, head_dim, block_len, block_layer, num_head, kv
5353
):
5454
hashes = [secrets.token_hex(16) for _ in range(block_number)]
5555
kv_caches = {}
5656
for i in range(block_layer):
5757
kv_caches[i] = torch.rand(
58-
[1, block_number, block_len, num_head, head_dim],
58+
[kv, block_number, block_len, num_head, head_dim],
5959
dtype=torch.bfloat16,
6060
device=f"cuda:{device_id}",
6161
)
@@ -73,10 +73,7 @@ def embed(
7373
store: UcmKVStoreBase,
7474
hashes: List[str],
7575
kvcaches: Dict[int, torch.Tensor],
76-
num_tokens: int,
77-
block_len: int,
78-
block_layer: int,
79-
block_dim: int,
76+
mla: bool,
8077
):
8178
start_time = time.perf_counter()
8279

@@ -86,14 +83,23 @@ def embed(
8683
for i, hash_val in enumerate(hashes):
8784
offset = 0
8885
for layer_id, kv_layer in kvcaches.items():
89-
tensor = kv_layer[0][i] # kv=1
90-
total_tensors.append(tensor)
86+
k_tensor = kv_layer[0][i] # kv=1
87+
total_tensors.append(k_tensor)
9188
total_block_ids.append(hash_val)
9289
total_offsets.append(offset)
93-
sz = tensor.numel() * tensor.element_size()
90+
sz = k_tensor.numel() * k_tensor.element_size()
9491
offset += sz
9592
total_size += sz
9693

94+
if not mla:
95+
v_tensor = kv_layer[1][i]
96+
total_tensors.append(v_tensor)
97+
total_block_ids.append(hash_val)
98+
total_offsets.append(offset)
99+
sz = v_tensor.numel() * v_tensor.element_size()
100+
offset += sz
101+
total_size += sz
102+
97103
task = store.dump(total_block_ids, total_offsets, total_tensors)
98104
store.wait(task)
99105

@@ -112,10 +118,7 @@ def fetch(
112118
store: UcmKVStoreBase,
113119
hashes: List[str],
114120
kvcaches: Dict[int, torch.Tensor],
115-
num_tokens: int,
116-
block_len: int,
117-
block_layer: int,
118-
block_dim: int,
121+
mla: bool,
119122
):
120123
start_time = time.perf_counter()
121124

@@ -129,14 +132,23 @@ def fetch(
129132
for i, hash_val in enumerate(hashes):
130133
offset = 0
131134
for layer_id, kv_layer in kvcaches.items():
132-
tensor = kv_layer[0][i] # kv=1
135+
k_tensor = kv_layer[0][i] # kv=1
133136
block_ids.append(hash_val)
134137
offsets.append(offset)
135-
tensors.append(tensor)
136-
sz = tensor.numel() * tensor.element_size()
138+
tensors.append(k_tensor)
139+
sz = k_tensor.numel() * k_tensor.element_size()
137140
offset += sz
138141
total_size += sz
139142

143+
if not mla:
144+
v_tensor = kv_layer[1][i]
145+
block_ids.append(hash_val)
146+
offsets.append(offset)
147+
tensors.append(v_tensor)
148+
sz = v_tensor.numel() * v_tensor.element_size()
149+
offset += sz
150+
total_size += sz
151+
140152
task = store.load(block_ids, offsets, tensors)
141153
ret = store.wait(task)
142154
assert ret == 0, "Load operation failed"
@@ -163,6 +175,8 @@ def run(
163175
block_layer: int,
164176
head_size: int,
165177
block_elem_size: int,
178+
kv: int,
179+
mla: bool,
166180
) -> Tuple[float, float, float, float, float, float]:
167181
"""
168182
Run a single test with given parameters and return performance metrics.
@@ -195,6 +209,7 @@ def run(
195209
block_len,
196210
block_layer,
197211
num_head,
212+
kv,
198213
)
199214

200215
results = store.create(hashes[:batch_size])
@@ -204,10 +219,7 @@ def run(
204219
store,
205220
hashes[:batch_size],
206221
kvcaches,
207-
num_tokens,
208-
block_len,
209-
block_layer,
210-
block_dim,
222+
mla,
211223
)
212224
store.commit(hashes[:batch_size], True)
213225

@@ -217,10 +229,7 @@ def run(
217229
store,
218230
hashes[:batch_size],
219231
kvcaches,
220-
num_tokens,
221-
block_len,
222-
block_layer,
223-
block_dim,
232+
mla,
224233
)
225234

226235
w_bw_list.append(w_bw)

ucm/store/test/e2e/nfsstore_embed_fetch_run.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def main():
6464
num_head_list = [1, 2, 4, 8]
6565

6666
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
67-
csv_file = os.path.join(SCRIPT_DIR, "embed_fetch_result.csv")
67+
csv_file = os.path.join(SCRIPT_DIR, "embed_fetch_result_all_6.csv")
6868
need_header = not os.path.exists(csv_file)
6969

7070
with open(csv_file, "a", newline="", encoding="utf-8") as csv_fp:
@@ -124,6 +124,8 @@ def main():
124124
block_layer,
125125
head_size,
126126
block_elem_size,
127+
kv,
128+
mla,
127129
),
128130
)
129131

0 commit comments

Comments
 (0)