Skip to content

Commit 265136e

Browse files
authored
Modify start_load_kv (#103)
1 parent 89d9c26 commit 265136e

File tree

2 files changed

+26
-26
lines changed

2 files changed

+26
-26
lines changed

test/test_uc_connector.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ def init_uc(
106106
ucconnector.load_tasks: dict[str, tuple[Task, Task]] = {}
107107
ucconnector.total_tp_size = 2
108108
ucconnector._connector_metadata = metadata
109+
ucconnector.layerwise_load_tasks: dict[
110+
str, dict[str, tuple[Task, Task]]
111+
] = {}
109112
return ucconnector
110113

111114
def test_get_num_new_matched_tokens_hit(self):
@@ -293,7 +296,7 @@ def mock_load(
293296
ucconnector = self.init_uc(mock_connector, metadata=metadata)
294297
forward_context = Mock()
295298
ucconnector.start_load_kv(forward_context)
296-
assert mock_connector.load.call_count == 2
299+
assert mock_connector.load.call_count == 2 * self.num_layers
297300

298301
def test_generate_layerwise_load_tasks_success(self):
299302
# init implement

unifiedcache/integration/vllm/uc_connector.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
127127
# dump tasks record request -> block -> list[task]
128128
self.dump_tasks: dict[str, dict[str, List[Task]]] = {}
129129
self.load_tasks: dict[str, tuple[Task, Task]] = {}
130+
self.layerwise_load_tasks: dict[str, dict[str, tuple[Task, Task]]] = {}
130131
self.is_mla = self._vllm_config.model_config.is_deepseek_mla
131132
self.num_layers = vllm_config.model_config.get_num_layers(
132133
vllm_config.parallel_config
@@ -294,9 +295,7 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
294295
if len(self.kv_caches) == 0:
295296
self._init_kv_caches_from_forward_context(forward_context)
296297

297-
self.layerwise_load_tasks: dict[
298-
str, Generator[tuple[Task, Task], None, None]
299-
] = {}
298+
self.layerwise_load_tasks.clear()
300299
self.current_layer = 0
301300
for request in metadata.requests:
302301
if request.load_paras is None or not request.load_paras.can_load:
@@ -339,18 +338,22 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
339338
)
340339
assert self.connector.wait(task) == 0
341340
else:
342-
layer_to_tensor[layer_name] = (tensors, offsets)
343-
344-
if layer_to_tensor:
345-
layerwise_load_task = self.generate_layerwise_load_tasks(
346-
fetch_block_hashes, layer_to_tensor
347-
)
348-
load_task = next(layerwise_load_task)
349-
assert (
350-
load_task is not None
351-
), "The first layerwise task should not be None!"
352-
self.load_tasks[request.request_id] = load_task
353-
self.layerwise_load_tasks[request.request_id] = layerwise_load_task
341+
k_task_id = self.connector.load(
342+
fetch_block_hashes, offsets[:blocks_len], tensors[:blocks_len]
343+
)
344+
v_task_id = None
345+
if not self.is_mla:
346+
v_task_id = self.connector.load(
347+
fetch_block_hashes,
348+
offsets[blocks_len:],
349+
tensors[blocks_len:],
350+
)
351+
if request.request_id not in self.layerwise_load_tasks:
352+
self.layerwise_load_tasks[request.request_id] = {}
353+
self.layerwise_load_tasks[request.request_id][layer_name] = (
354+
k_task_id,
355+
v_task_id,
356+
)
354357

355358
def wait_for_layer_load(self, layer_name: str) -> None:
356359
"""
@@ -371,18 +374,12 @@ def wait_for_layer_load(self, layer_name: str) -> None:
371374
assert (
372375
self.current_layer < self.num_layers
373376
), "The current layer should be less than total layers!"
374-
for request_id, gene_load_task in self.layerwise_load_tasks.items():
375-
k_task, v_task = self.load_tasks[request_id]
377+
for request_id, layer_to_task in self.layerwise_load_tasks.items():
378+
k_task, v_task = layer_to_task[layer_name]
376379
assert self.connector.wait(k_task) == 0
377-
if v_task:
380+
if not self.is_mla:
378381
assert self.connector.wait(v_task) == 0
379-
if self.current_layer < self.num_layers - 1:
380-
self.load_tasks[request_id] = next(gene_load_task)
381-
assert (
382-
self.load_tasks[request_id] is not None
383-
), "The task for next layer should not be None!"
384-
else:
385-
logger.debug(f"Load tasks for {request_id} finished.")
382+
logger.debug(f"Load tasks for {request_id} on layer {layer_name} finished.")
386383

387384
def save_kv_layer(
388385
self,

0 commit comments

Comments
 (0)