@@ -49,13 +49,13 @@ def setup(
4949
5050
5151def make_buffers (
52- block_number , device_id , batch_size , head_dim , block_len , block_layer , num_head
52+ block_number , device_id , batch_size , head_dim , block_len , block_layer , num_head , kv
5353):
5454 hashes = [secrets .token_hex (16 ) for _ in range (block_number )]
5555 kv_caches = {}
5656 for i in range (block_layer ):
5757 kv_caches [i ] = torch .rand (
58- [1 , block_number , block_len , num_head , head_dim ],
58+ [kv , block_number , block_len , num_head , head_dim ],
5959 dtype = torch .bfloat16 ,
6060 device = f"cuda:{ device_id } " ,
6161 )
@@ -73,10 +73,7 @@ def embed(
7373 store : UcmKVStoreBase ,
7474 hashes : List [str ],
7575 kvcaches : Dict [int , torch .Tensor ],
76- num_tokens : int ,
77- block_len : int ,
78- block_layer : int ,
79- block_dim : int ,
76+ mla : bool ,
8077):
8178 start_time = time .perf_counter ()
8279
@@ -86,14 +83,23 @@ def embed(
8683 for i , hash_val in enumerate (hashes ):
8784 offset = 0
8885 for layer_id , kv_layer in kvcaches .items ():
89- tensor = kv_layer [0 ][i ] # kv=1
90- total_tensors .append (tensor )
86+ k_tensor = kv_layer [0 ][i ] # kv=1
87+ total_tensors .append (k_tensor )
9188 total_block_ids .append (hash_val )
9289 total_offsets .append (offset )
93- sz = tensor .numel () * tensor .element_size ()
90+ sz = k_tensor .numel () * k_tensor .element_size ()
9491 offset += sz
9592 total_size += sz
9693
94+ if not mla :
95+ v_tensor = kv_layer [1 ][i ]
96+ total_tensors .append (v_tensor )
97+ total_block_ids .append (hash_val )
98+ total_offsets .append (offset )
99+ sz = v_tensor .numel () * v_tensor .element_size ()
100+ offset += sz
101+ total_size += sz
102+
97103 task = store .dump (total_block_ids , total_offsets , total_tensors )
98104 store .wait (task )
99105
@@ -112,10 +118,7 @@ def fetch(
112118 store : UcmKVStoreBase ,
113119 hashes : List [str ],
114120 kvcaches : Dict [int , torch .Tensor ],
115- num_tokens : int ,
116- block_len : int ,
117- block_layer : int ,
118- block_dim : int ,
121+ mla : bool ,
119122):
120123 start_time = time .perf_counter ()
121124
@@ -129,14 +132,23 @@ def fetch(
129132 for i , hash_val in enumerate (hashes ):
130133 offset = 0
131134 for layer_id , kv_layer in kvcaches .items ():
132- tensor = kv_layer [0 ][i ] # kv=1
135+ k_tensor = kv_layer [0 ][i ] # kv=1
133136 block_ids .append (hash_val )
134137 offsets .append (offset )
135- tensors .append (tensor )
136- sz = tensor .numel () * tensor .element_size ()
138+ tensors .append (k_tensor )
139+ sz = k_tensor .numel () * k_tensor .element_size ()
137140 offset += sz
138141 total_size += sz
139142
143+ if not mla :
144+ v_tensor = kv_layer [1 ][i ]
145+ block_ids .append (hash_val )
146+ offsets .append (offset )
147+ tensors .append (v_tensor )
148+ sz = v_tensor .numel () * v_tensor .element_size ()
149+ offset += sz
150+ total_size += sz
151+
140152 task = store .load (block_ids , offsets , tensors )
141153 ret = store .wait (task )
142154 assert ret == 0 , "Load operation failed"
@@ -163,6 +175,8 @@ def run(
163175 block_layer : int ,
164176 head_size : int ,
165177 block_elem_size : int ,
178+ kv : int ,
179+ mla : bool ,
166180) -> Tuple [float , float , float , float , float , float ]:
167181 """
168182 Run a single test with given parameters and return performance metrics.
@@ -195,6 +209,7 @@ def run(
195209 block_len ,
196210 block_layer ,
197211 num_head ,
212+ kv ,
198213 )
199214
200215 results = store .create (hashes [:batch_size ])
@@ -204,10 +219,7 @@ def run(
204219 store ,
205220 hashes [:batch_size ],
206221 kvcaches ,
207- num_tokens ,
208- block_len ,
209- block_layer ,
210- block_dim ,
222+ mla ,
211223 )
212224 store .commit (hashes [:batch_size ], True )
213225
@@ -217,10 +229,7 @@ def run(
217229 store ,
218230 hashes [:batch_size ],
219231 kvcaches ,
220- num_tokens ,
221- block_len ,
222- block_layer ,
223- block_dim ,
232+ mla ,
224233 )
225234
226235 w_bw_list .append (w_bw )
0 commit comments