2222# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2323# SOFTWARE.
2424#
25+ import csv
2526import os
2627import secrets
2728import time
28- import csv
29- from typing import List , Dict
29+ from typing import Dict , List
3030
3131import torch
3232
@@ -46,7 +46,9 @@ def setup_store(storage_backends, block_size, device_id, io_size) -> UcmKVStoreB
4646 return UcmNfsStore (config )
4747
4848
49- def make_buffers (block_number , device_id , batch_size , head_dim , block_len , block_layer , num_head ):
49+ def make_buffers (
50+ block_number , device_id , batch_size , head_dim , block_len , block_layer , num_head
51+ ):
5052 hashes = [secrets .token_hex (16 ) for _ in range (block_number )]
5153 kv_caches = {}
5254 for i in range (block_layer ):
@@ -65,13 +67,20 @@ def store_all_hashes(hashes: List[str]):
6567 f .write (h + "\n " )
6668
6769
68- def embed (store : UcmKVStoreBase , hashes : List [str ], kvcaches : Dict [int , torch .Tensor ],
69- num_tokens : int , block_len : int , block_layer : int , block_dim : int ):
70+ def embed (
71+ store : UcmKVStoreBase ,
72+ hashes : List [str ],
73+ kvcaches : Dict [int , torch .Tensor ],
74+ num_tokens : int ,
75+ block_len : int ,
76+ block_layer : int ,
77+ block_dim : int ,
78+ ):
7079 start_time = time .perf_counter ()
71-
80+
7281 total_block_ids , total_offsets , total_tensors = [], [], []
7382 total_size = 0
74-
83+
7584 for i , hash_val in enumerate (hashes ):
7685 offset = 0
7786 for layer_id , kv_layer in kvcaches .items ():
@@ -85,27 +94,36 @@ def embed(store: UcmKVStoreBase, hashes: List[str], kvcaches: Dict[int, torch.Te
8594
8695 task = store .dump (total_block_ids , total_offsets , total_tensors )
8796 store .wait (task )
88-
97+
8998 elapsed_time = time .perf_counter () - start_time
90- throughput_gbps = (total_size / (1024 ** 3 )) / elapsed_time if elapsed_time > 0 else 0
91-
92- print (f"WRITE: Data Size={ (total_size / (1024 ** 3 )):.4f} GB, Time={ elapsed_time :.4f} s, "
93- f"Speed={ throughput_gbps :.4f} GB/s" )
94-
99+ throughput_gbps = (total_size / (1024 ** 3 )) / elapsed_time if elapsed_time > 0 else 0
100+
101+ print (
102+ f"WRITE: Data Size={ (total_size / (1024 ** 3 )):.4f} GB, Time={ elapsed_time :.4f} s, "
103+ f"Speed={ throughput_gbps :.4f} GB/s"
104+ )
105+
95106 return total_size , elapsed_time , throughput_gbps
96107
97108
98- def fetch (store : UcmKVStoreBase , hashes : List [str ], kvcaches : Dict [int , torch .Tensor ],
99- num_tokens : int , block_len : int , block_layer : int , block_dim : int ):
109+ def fetch (
110+ store : UcmKVStoreBase ,
111+ hashes : List [str ],
112+ kvcaches : Dict [int , torch .Tensor ],
113+ num_tokens : int ,
114+ block_len : int ,
115+ block_layer : int ,
116+ block_dim : int ,
117+ ):
100118 start_time = time .perf_counter ()
101-
119+
102120 founds = store .lookup (hashes )
103121 for f in founds :
104122 assert f , "Cache block miss detected"
105123
106124 block_ids , offsets , tensors = [], [], []
107125 total_size = 0
108-
126+
109127 for i , hash_val in enumerate (hashes ):
110128 offset = 0
111129 for layer_id , kv_layer in kvcaches .items ():
@@ -120,33 +138,35 @@ def fetch(store: UcmKVStoreBase, hashes: List[str], kvcaches: Dict[int, torch.Te
120138 task = store .load (block_ids , offsets , tensors )
121139 ret = store .wait (task )
122140 assert ret == 0 , "Load operation failed"
123-
141+
124142 elapsed_time = time .perf_counter () - start_time
125- throughput_gbps = (total_size / (1024 ** 3 )) / elapsed_time if elapsed_time > 0 else 0
126-
127- print (f"READ: Data Size={ (total_size / (1024 ** 3 )):.4f} GB, Time={ elapsed_time :.4f} s, "
128- f"Speed={ throughput_gbps :.4f} GB/s" )
129-
143+ throughput_gbps = (total_size / (1024 ** 3 )) / elapsed_time if elapsed_time > 0 else 0
144+
145+ print (
146+ f"READ: Data Size={ (total_size / (1024 ** 3 )):.4f} GB, Time={ elapsed_time :.4f} s, "
147+ f"Speed={ throughput_gbps :.4f} GB/s"
148+ )
149+
130150 return total_size , elapsed_time , throughput_gbps
131151
132152
133153def main ():
134154 storage_backends = "."
135155 device_id = 1
136156 mla = False
137- repeat = 3
157+ repeat = 3
138158 block_elem_size = 2
139159 num_tokens_list = [2048 , 4096 , 8192 , 16384 , 32768 ]
140-
160+
141161 if mla :
142- block_lens = [ 64 , 128 ]
162+ block_lens = [64 , 128 ]
143163 block_layer = 61
144164 head_size = 576
145165 kv = 1
146166 model_name = "deepseek-v3"
147167 num_head_list = [1 ]
148168 else :
149- block_lens = [ 128 , 256 ]
169+ block_lens = [128 , 256 ]
150170 block_layer = 64
151171 head_size = 128
152172 kv = 2
@@ -156,18 +176,31 @@ def main():
156176 SCRIPT_DIR = os .path .dirname (os .path .abspath (__file__ ))
157177 csv_file = os .path .join (SCRIPT_DIR , "embed_fetch_result1.csv" )
158178 need_header = not os .path .exists (csv_file )
159-
179+
160180 with open (csv_file , "a" , newline = "" , encoding = "utf-8" ) as csv_fp :
161181 writer = csv .writer (csv_fp )
162-
182+
163183 if need_header :
164- writer .writerow ([
165- "Model" , "Sequence Length" , "Batch Size" , "Layers" , "Element Size" ,
166- "KV" , "Num Head" , "Block Size" , "IO Count" , "IO Size(B)" ,
167- "Total Size(GB)" , "Write Avg Time(s)" , "Write Avg Bandwidth(GB/s)" ,
168- "Read Avg Time(s)" , "Read Avg Bandwidth(GB/s)"
169- ])
170-
184+ writer .writerow (
185+ [
186+ "Model" ,
187+ "Sequence Length" ,
188+ "Batch Size" ,
189+ "Layers" ,
190+ "Element Size" ,
191+ "KV" ,
192+ "Num Head" ,
193+ "Block Size" ,
194+ "IO Count" ,
195+ "IO Size(B)" ,
196+ "Total Size(GB)" ,
197+ "Write Avg Time(s)" ,
198+ "Write Avg Bandwidth(GB/s)" ,
199+ "Read Avg Time(s)" ,
200+ "Read Avg Bandwidth(GB/s)" ,
201+ ]
202+ )
203+
171204 for num_head in num_head_list :
172205 for block_len in block_lens :
173206 block_dim = head_size * num_head
@@ -177,35 +210,58 @@ def main():
177210
178211 for num_tokens in num_tokens_list :
179212 sep = "=" * 60
180- print (f"\n { sep } \n = num_head={ num_head } | num_tokens={ num_tokens :>6} | Repeat { repeat } times =\n { sep } \n " )
213+ print (
214+ f"\n { sep } \n = num_head={ num_head } | num_tokens={ num_tokens :>6} | Repeat { repeat } times =\n { sep } \n "
215+ )
181216
182217 batch_size = int (num_tokens / block_len )
183218 io_num = int (num_tokens / block_len * block_layer )
184-
219+
185220 w_bw_list , r_bw_list = [], []
186221 w_time_list , r_time_list = [], []
187222 w_size_sum , r_size_sum = 0.0 , 0.0
188223
189224 for r in range (repeat ):
190225 print (f"\n --- Round { r + 1 } ---" )
191- store = setup_store (storage_backends , block_size , device_id , io_size )
192-
226+ store = setup_store (
227+ storage_backends , block_size , device_id , io_size
228+ )
229+
193230 hashes , kvcaches = make_buffers (
194- real_blocks , device_id , batch_size , head_size ,
195- block_len , block_layer , num_head
231+ real_blocks ,
232+ device_id ,
233+ batch_size ,
234+ head_size ,
235+ block_len ,
236+ block_layer ,
237+ num_head ,
196238 )
197239
198240 results = store .create (hashes [:batch_size ])
199241 assert sum (results ) == 0 , "Create operation failed"
200242
201- w_size , w_time , w_bw = embed (store , hashes [:batch_size ], kvcaches ,
202- num_tokens , block_len , block_layer , block_dim )
243+ w_size , w_time , w_bw = embed (
244+ store ,
245+ hashes [:batch_size ],
246+ kvcaches ,
247+ num_tokens ,
248+ block_len ,
249+ block_layer ,
250+ block_dim ,
251+ )
203252 store .commit (hashes [:batch_size ], True )
204-
253+
205254 store_all_hashes (hashes [:batch_size ])
206255
207- r_size , r_time , r_bw = fetch (store , hashes [:batch_size ], kvcaches ,
208- num_tokens , block_len , block_layer , block_dim )
256+ r_size , r_time , r_bw = fetch (
257+ store ,
258+ hashes [:batch_size ],
259+ kvcaches ,
260+ num_tokens ,
261+ block_len ,
262+ block_layer ,
263+ block_dim ,
264+ )
209265
210266 w_bw_list .append (w_bw )
211267 r_bw_list .append (r_bw )
@@ -222,22 +278,34 @@ def main():
222278 avg_r_bw = sum (r_bw_list ) / repeat
223279 avg_w_time = sum (w_time_list ) / repeat
224280 avg_r_time = sum (r_time_list ) / repeat
225- avg_w_size = w_size_sum / (1024 ** 3 ) / repeat
226- avg_r_size = r_size_sum / (1024 ** 3 ) / repeat
227-
228- writer .writerow ([
229- model_name ,
230- num_tokens , batch_size , block_layer , block_elem_size ,
231- kv , num_head , block_len , io_num , io_size ,
232- f"{ avg_w_size :.4f} " , f"{ avg_w_time :.4f} " , f"{ avg_w_bw :.4f} " ,
233- f"{ avg_r_time :.4f} " , f"{ avg_r_bw :.4f} "
234- ])
235-
281+ avg_w_size = w_size_sum / (1024 ** 3 ) / repeat
282+ avg_r_size = r_size_sum / (1024 ** 3 ) / repeat
283+
284+ writer .writerow (
285+ [
286+ model_name ,
287+ num_tokens ,
288+ batch_size ,
289+ block_layer ,
290+ block_elem_size ,
291+ kv ,
292+ num_head ,
293+ block_len ,
294+ io_num ,
295+ io_size ,
296+ f"{ avg_w_size :.4f} " ,
297+ f"{ avg_w_time :.4f} " ,
298+ f"{ avg_w_bw :.4f} " ,
299+ f"{ avg_r_time :.4f} " ,
300+ f"{ avg_r_bw :.4f} " ,
301+ ]
302+ )
303+
236304 csv_fp .flush ()
237-
305+
238306 print ("\n " + "=" * 60 + "\n = All combinations tested =\n " + "=" * 60 + "\n " )
239307
240308
241309if __name__ == "__main__" :
242310 os .environ ["UC_LOGGER_LEVEL" ] = "debug"
243- main ()
311+ main ()
0 commit comments