|
| 1 | +From e45ed500c23f3b8905c68ada894657fd0794906b Mon Sep 17 00:00:00 2001 |
| 2 | +From: y00945504 <yuhui87@huawei.com> |
| 3 | +Date: Fri, 22 Aug 2025 11:46:48 +0800 |
| 4 | +Subject: [PATCH] manually apply patch |
| 5 | + |
| 6 | +--- |
| 7 | + vllm_ascend/attention/attention_v1.py | 33 +++++++++++++++++++++++++++ |
| 8 | + vllm_ascend/worker/model_runner_v1.py | 14 +++++++----- |
| 9 | + 2 files changed, 41 insertions(+), 6 deletions(-) |
| 10 | + |
| 11 | +diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py |
| 12 | +index 694adab..487b12b 100644 |
| 13 | +--- a/vllm_ascend/attention/attention_v1.py |
| 14 | ++++ b/vllm_ascend/attention/attention_v1.py |
| 15 | +@@ -24,6 +24,9 @@ import torch_npu |
| 16 | + from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, |
| 17 | + AttentionLayer, AttentionType) |
| 18 | + from vllm.attention.backends.utils import CommonAttentionState |
| 19 | ++from vllm.distributed.kv_transfer import (get_kv_transfer_group, |
| 20 | ++ has_kv_transfer_group, |
| 21 | ++ is_v1_kv_transfer_group) |
| 22 | + from vllm.config import get_current_vllm_config |
| 23 | + from vllm.forward_context import ForwardContext, get_forward_context |
| 24 | + from vllm.utils import direct_register_custom_op |
| 25 | +@@ -458,6 +461,8 @@ def unified_ascend_attention_with_output( |
| 26 | + output: torch.Tensor, |
| 27 | + layer_name: str, |
| 28 | + ) -> None: |
| 29 | ++ wait_for_kv_layer_from_connector(layer_name) |
| 30 | ++ |
| 31 | + forward_context: ForwardContext = get_forward_context() |
| 32 | + attn_metadata = forward_context.attn_metadata |
| 33 | + self = forward_context.no_compile_layers[layer_name] |
| 34 | +@@ -470,8 +475,36 @@ def unified_ascend_attention_with_output( |
| 35 | + attn_metadata, |
| 36 | + output, |
| 37 | + trace_flag=False) |
| 38 | ++ maybe_save_kv_layer_to_connector(layer_name, kv_cache) |
| 39 | + return |
| 40 | + |
| 41 | ++def wait_for_kv_layer_from_connector(layer_name: str): |
| 42 | ++ if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): |
| 43 | ++ return |
| 44 | ++ |
| 45 | ++ connector = get_kv_transfer_group() |
| 46 | ++ |
| 47 | ++ forward_context: ForwardContext = get_forward_context() |
| 48 | ++ attn_metadata = forward_context.attn_metadata |
| 49 | ++ if attn_metadata is None: |
| 50 | ++ return |
| 51 | ++ connector.wait_for_layer_load(layer_name) |
| 52 | ++ |
| 53 | ++def maybe_save_kv_layer_to_connector( |
| 54 | ++ layer_name: str, |
| 55 | ++ kv_cache_layer: List[torch.Tensor], |
| 56 | ++): |
| 57 | ++ if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): |
| 58 | ++ return |
| 59 | ++ |
| 60 | ++ connector = get_kv_transfer_group() |
| 61 | ++ |
| 62 | ++ forward_context: ForwardContext = get_forward_context() |
| 63 | ++ attn_metadata = forward_context.attn_metadata |
| 64 | ++ if attn_metadata is None: |
| 65 | ++ return |
| 66 | ++ connector.save_kv_layer(layer_name, kv_cache_layer, |
| 67 | ++ attn_metadata) |
| 68 | + |
| 69 | + def unified_attention_with_output_fake( |
| 70 | + query: torch.Tensor, |
| 71 | +diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py |
| 72 | +index dc28bfa..ddc996b 100644 |
| 73 | +--- a/vllm_ascend/worker/model_runner_v1.py |
| 74 | ++++ b/vllm_ascend/worker/model_runner_v1.py |
| 75 | +@@ -889,7 +889,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): |
| 76 | + intermediate_tensors: Optional[IntermediateTensors] = None, |
| 77 | + ) -> tuple[SpecDecodeMetadata, torch.Tensor, SpecDecodeMetadata, |
| 78 | + torch.Tensor, int, torch.Tensor, Optional[set[str]], |
| 79 | +- Optional[set[str]]]: |
| 80 | ++ Optional[set[str]], Optional[dict[str, list[str]]]]: |
| 81 | + # Check input valid |
| 82 | + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens |
| 83 | + assert total_num_scheduled_tokens > 0 |
| 84 | +@@ -1140,6 +1140,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): |
| 85 | + positions = self.positions[:padded_num_tokens_across_dp] |
| 86 | + |
| 87 | + # Run forward pass |
| 88 | ++ finished_dumping = None |
| 89 | + # TODO(zzzzwwjj): check param `num_tokens_across_dp` later. |
| 90 | + with set_ascend_forward_context( |
| 91 | + attn_metadata, |
| 92 | +@@ -1174,7 +1175,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): |
| 93 | + inputs_embeds=inputs_embeds, |
| 94 | + **model_kwargs) |
| 95 | + |
| 96 | +- self.maybe_wait_for_kv_save() |
| 97 | ++ finished_dumping = self.maybe_wait_for_kv_save() |
| 98 | + finished_sending, finished_recving = self.get_finished_kv_transfer( |
| 99 | + scheduler_output) |
| 100 | + use_spec_decode = len( |
| 101 | +@@ -1202,7 +1203,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): |
| 102 | + |
| 103 | + return (attn_metadata, hidden_states, spec_decode_metadata, positions, |
| 104 | + total_num_scheduled_tokens, sample_indices, finished_sending, |
| 105 | +- finished_recving) |
| 106 | ++ finished_recving, finished_dumping) |
| 107 | + |
| 108 | + def _calc_spec_decode_metadata( |
| 109 | + self, |
| 110 | +@@ -1386,7 +1387,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): |
| 111 | + |
| 112 | + (attn_metadata, hidden_states, spec_decode_metadata, positions, |
| 113 | + num_scheduled_tokens, sample_indices, finished_sending, |
| 114 | +- finished_recving) = (self._process_reqs(scheduler_output, |
| 115 | ++ finished_recving, finished_dumping) = (self._process_reqs(scheduler_output, |
| 116 | + intermediate_tensors)) |
| 117 | + |
| 118 | + if self.dynamic_eplb: |
| 119 | +@@ -1493,6 +1494,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): |
| 120 | + prompt_logprobs_dict={}, |
| 121 | + finished_sending=finished_sending, |
| 122 | + finished_recving=finished_recving, |
| 123 | ++ finished_dumping=finished_dumping |
| 124 | + ) |
| 125 | + |
| 126 | + durations = ProfileExecuteDuration().pop_captured_sync() |
| 127 | +@@ -1543,8 +1545,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): |
| 128 | + @staticmethod |
| 129 | + def maybe_wait_for_kv_save() -> None: |
| 130 | + if has_kv_transfer_group(): |
| 131 | +- get_kv_transfer_group().wait_for_save() |
| 132 | +- |
| 133 | ++ return get_kv_transfer_group().wait_for_save() |
| 134 | ++ |
| 135 | + @staticmethod |
| 136 | + def get_finished_kv_transfer( |
| 137 | + scheduler_output: "SchedulerOutput", |
| 138 | +-- |
| 139 | +2.50.1.windows.1 |
| 140 | + |
0 commit comments