@@ -127,6 +127,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
127127 # dump tasks record request -> block -> list[task]
128128 self .dump_tasks : dict [str , dict [str , List [Task ]]] = {}
129129 self .load_tasks : dict [str , tuple [Task , Task ]] = {}
130+ self .layerwise_load_tasks : dict [str , dict [str , tuple [Task , Task ]]] = {}
130131 self .is_mla = self ._vllm_config .model_config .is_deepseek_mla
131132 self .num_layers = vllm_config .model_config .get_num_layers (
132133 vllm_config .parallel_config
@@ -294,9 +295,7 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
294295 if len (self .kv_caches ) == 0 :
295296 self ._init_kv_caches_from_forward_context (forward_context )
296297
297- self .layerwise_load_tasks : dict [
298- str , Generator [tuple [Task , Task ], None , None ]
299- ] = {}
298+ self .layerwise_load_tasks .clear ()
300299 self .current_layer = 0
301300 for request in metadata .requests :
302301 if request .load_paras is None or not request .load_paras .can_load :
@@ -339,18 +338,22 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
339338 )
340339 assert self .connector .wait (task ) == 0
341340 else :
342- layer_to_tensor [layer_name ] = (tensors , offsets )
343-
344- if layer_to_tensor :
345- layerwise_load_task = self .generate_layerwise_load_tasks (
346- fetch_block_hashes , layer_to_tensor
347- )
348- load_task = next (layerwise_load_task )
349- assert (
350- load_task is not None
351- ), "The first layerwise task should not be None!"
352- self .load_tasks [request .request_id ] = load_task
353- self .layerwise_load_tasks [request .request_id ] = layerwise_load_task
341+ k_task_id = self .connector .load (
342+ fetch_block_hashes , offsets [:blocks_len ], tensors [:blocks_len ]
343+ )
344+ v_task_id = None
345+ if not self .is_mla :
346+ v_task_id = self .connector .load (
347+ fetch_block_hashes ,
348+ offsets [blocks_len :],
349+ tensors [blocks_len :],
350+ )
351+ if request .request_id not in self .layerwise_load_tasks :
352+ self .layerwise_load_tasks [request .request_id ] = {}
353+ self .layerwise_load_tasks [request .request_id ][layer_name ] = (
354+ k_task_id ,
355+ v_task_id ,
356+ )
354357
355358 def wait_for_layer_load (self , layer_name : str ) -> None :
356359 """
@@ -371,18 +374,12 @@ def wait_for_layer_load(self, layer_name: str) -> None:
371374 assert (
372375 self .current_layer < self .num_layers
373376 ), "The current layer should be less than total layers!"
374- for request_id , gene_load_task in self .layerwise_load_tasks .items ():
375- k_task , v_task = self . load_tasks [ request_id ]
377+ for request_id , layer_to_task in self .layerwise_load_tasks .items ():
378+ k_task , v_task = layer_to_task [ layer_name ]
376379 assert self .connector .wait (k_task ) == 0
377- if v_task :
380+ if not self . is_mla :
378381 assert self .connector .wait (v_task ) == 0
379- if self .current_layer < self .num_layers - 1 :
380- self .load_tasks [request_id ] = next (gene_load_task )
381- assert (
382- self .load_tasks [request_id ] is not None
383- ), "The task for next layer should not be None!"
384- else :
385- logger .debug (f"Load tasks for { request_id } finished." )
382+ logger .debug (f"Load tasks for { request_id } on layer { layer_name } finished." )
386383
387384 def save_kv_layer (
388385 self ,
0 commit comments