Skip to content

Commit 33b6fd9

Browse files
authored
Merge pull request #112 from hek14/dev_ucm_sparse
[Feature] [Doc] UCMSparse framework
2 parents 8ea70c9 + e400ebc commit 33b6fd9

File tree

12 files changed

+1453
-1
lines changed

12 files changed

+1453
-1
lines changed

docs/source/feature/sparse_attn.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,41 @@
11
# Sparse Attention
2+
## Motivations
3+
Attention mechanisms, especially in LLMs, are often the bottleneck in terms of latency during inference due to their computational complexity. Despite their importance in capturing contextual relationships, traditional attention requires processing all token interactions, leading to significant delays.
4+
5+
<p align="center">
6+
<img alt="UCM" src="../images/attention_overhead.png" width="80%">
7+
</p>
8+
9+
Researchers have found that attention in LLM is highly dispersed:
10+
<p align="center">
11+
<img alt="UCM" src="../images/attention_sparsity.png" width="80%">
12+
</p>
13+
14+
This movitates them actively developing sparse attention algorithms to address the latency issue. These algorithms aim to reduce the number of token interactions by focusing only on the most relevant parts of the input, thereby lowering the computation and memory requirements.
15+
While promising, the gap between theoretical prototypes and practical implementations in inference frameworks remains a significant challenge.
16+
17+
Many existing frameworks, like vLLM, are optimized for traditional attention mechanisms. Adapting them for sparse attention can be complex and may require substantial modifications to the underlying architecture.
18+
Issues such as maintaining compatibility with existing model architectures, ensuring efficient memory usage, and leveraging hardware acceleration must be addressed to facilitate the adoption of sparse attention in real-world applications.
19+
20+
We present an **unified sparse attention framework** under UCM. Proposing a unified framework can streamline the integration of various sparse attention algorithms into inference engines like vLLM. This framework could provide **standardized interfaces and utilities** to simplify the implementation process.
21+
By utilizing UCM, researchers can efficiently implement rapid prototyping and testing of different sparse attention algorithms without the need for extensive re-engineering of the inference engine. By leveraging shared optimizations within the framework, it can help ensure that the performance gains from sparse attention are realized in real-world scenarios.
22+
23+
## Architecture
24+
### Overview
25+
The core concept of our UCMSparse attention framework is to offload the complete Key-Value (KV) cache to a dedicated KV cache storage. We then identify the crucial KV pairs relevant to the current context, as determined by our sparse attention algorithms, and selectively load only the necessary portions of the KV cache from storage into High Bandwidth Memory (HBM). This design significantly reduces the HBM footprint while accelerating generation speed.
26+
<p align="center">
27+
<img alt="UCM" src="../images/sparse_attn_arch.png" width="80%">
28+
</p>
29+
30+
31+
### Key Concepts
32+
- UCMSparse in scheduler: this instance locates in the same process as the `EnginerCore` and acts like a sparse attention budget controller. It estimates the number of slots required by a specific sparse attention algorithm. Then `KVCacheManager` allocates necessary blocks based on `num_slots_sparse`. For example, `ESA` only needs 20%~30% blocks of the normal attention.
33+
- UCMSparse in model_runner: this instance locates in the same process as the `Worker`.
34+
A typical sparse attention algorithm works like this:
35+
1. In prefill, it dumps full KV Cache from HBM to storage.
36+
2. In decode, it retrieves the most relevant blocks based on the context and loads the blocks from store to HBM.
37+
3. In decoode, it also dumps new generated blocks to keep the latest context accessible.
38+
- By fine-grained task scheduling, retrieval and loading can be executed asynchronously and overlap with the model execution. Therefore no overhead is introduced by UCMSparse and generation speed is boosted benefitted by less computational load and fewer memory accesses.
39+
40+
41+
See `ESA` for more details.
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Sparse Attention
2+
3+
This document provides a usage example and configuration guide for the **sparse attention**, which is increasingly recognized for their ability to mitigate the challenges associated with high memory bandwidth (HBM) usage and to enhance the efficiency of large language models (LLMs).
4+
5+
6+
## Configuration
7+
8+
To use the sparse_attn connector, you need to configure the `ucm_sparse_method` field in your model's launch configuration.
9+
10+
### Example:
11+
```python
12+
kv_connector_extra_config={
13+
"ucm_connector_name": "UcmDram",
14+
"ucm_connector_config": {
15+
"max_cache_size": 5368709120,
16+
"kv_block_size": 262144,
17+
},
18+
"ucm_sparse_method": "ESA" # specify the sparse attention algorithm here
19+
}
20+
```
21+
22+
## Launching Inference
23+
24+
### Offline Inference
25+
26+
To start **offline inference** with the NFS connector,modify the script `examples/offline_inference.py` to include the `ucm_sparse_method` and put a long prompt to see the acceleration effects:
27+
28+
```python
29+
# In examples/offline_inference.py
30+
ktc = KVTransferConfig(
31+
...
32+
kv_connector_extra_config={
33+
"ucm_connector_name": "UcmDram",
34+
"ucm_connector_config": {
35+
"max_cache_size": 5368709120,
36+
"kv_block_size": 262144,
37+
},
38+
"ucm_sparse_method": "ESA" # specify the sparse attention algorithm here
39+
}
40+
)
41+
42+
prompts = [
43+
"PUT A LONG PROMPT HERE TO SEE ACCELERATION EFFECTS."
44+
]
45+
```
46+
47+
Then run the script as follows:
48+
49+
```bash
50+
cd examples/
51+
export PYTHONHASHSEED=123456
52+
python offline_inference.py
53+
```
54+
55+
### Online Inference
56+
57+
For **online inference** , vLLM with our connector can also be deployed as a server that implements the OpenAI API protocol. Run the following command to start the vLLM server with the Qwen/Qwen2.5-14B-Instruct model:
58+
59+
```bash
60+
export PYTHONHASHSEED=123456
61+
vllm serve /home/models/Qwen2.5-14B-Instruct \
62+
--max-model-len 20000 \
63+
--tensor-parallel-size 2 \
64+
--gpu_memory_utilization 0.87 \
65+
--trust-remote-code \
66+
--port 7800 \
67+
--kv-transfer-config \
68+
'{
69+
"kv_connector": "UnifiedCacheConnectorV1",
70+
"kv_connector_module_path": "unifiedcache.integration.vllm.uc_connector",
71+
"kv_role": "kv_both",
72+
"kv_connector_extra_config": {
73+
"ucm_connector_name": "UcmNfsStore",
74+
"ucm_connector_config": {
75+
"storage_backends": "/mnt/test",
76+
"kv_block_size": 33554432
77+
},
78+
"ucm_sparse_method": "ESA"
79+
}
80+
}'
81+
```
82+
83+
If you see log as below:
84+
85+
```bash
86+
INFO: Started server process [1049932]
87+
INFO: Waiting for application startup.
88+
INFO: Application startup complete.
89+
```
90+
91+
Congratulations, you have successfully started the vLLM server with NFS Connector!
92+
93+
After successfully started the vLLM server,You can interact with the API as following:
94+
95+
```bash
96+
curl http://localhost:7800/v1/completions \
97+
-H "Content-Type: application/json" \
98+
-d '{
99+
"model": "/home/models/Qwen2.5-14B-Instruct",
100+
"prompt": "PUT A LONG PROMPT HERE TO SEE ACCELERATION EFFECTS.",
101+
"max_tokens": 100,
102+
"temperature": 0
103+
}'
104+
```
87.1 KB
Loading
177 KB
Loading
103 KB
Loading

examples/offline_inference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def build_llm_with_uc(module_path: str, name: str, model: str):
3030
"max_cache_size": 5368709120,
3131
"kv_block_size": 262144,
3232
},
33+
"ucm_sparse_method": "ESA",
3334
},
3435
)
3536

@@ -66,7 +67,7 @@ def print_output(
6667
def main():
6768
module_path = "unifiedcache.integration.vllm.uc_connector"
6869
name = "UnifiedCacheConnectorV1"
69-
model = "/home/models/Qwen2.5-14B-Instruct"
70+
model = os.getenv("MODEL_PATH", "/home/models/Qwen2.5-14B-Instruct")
7071

7172
setup_environment_variables()
7273

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
"""
2+
UcmSparseBase Class provides interfaces for general sparse attention algorithm implementation in vLLM.
3+
4+
The class provides the following primitives:
5+
Scheduler-side: runs in the scheduler, binds metadata, which
6+
is used by the worker-side to retrieval/load KV cache.
7+
estimate_num_slots_sparsed() - get the number of required slots.
8+
update_state_after_alloc() - update UcmSparse state after
9+
temporary buffer alloc by the CacheManager.
10+
request_finished_in_scheduler() - called when a request is finished, with
11+
the computed kv cache blocks for the request.
12+
Returns metadata for the next step.
13+
14+
Worker-side: runs in each worker, retrieval/load KV cache.
15+
execute_begin() - hook at the beginning of "ModelRunner->execute_model".
16+
execute_finished() - hook at the end of "ModelRunner->execute_model".
17+
attention_begin() - hook at the beginning of "unified_attention".
18+
attention_finished() - hook at the end of "unified_attention".
19+
request_finished_in_worker() - release the resources, like block features.
20+
"""
21+
22+
from __future__ import annotations
23+
24+
import enum
25+
from abc import ABC, abstractmethod
26+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
27+
28+
if TYPE_CHECKING:
29+
from vllm.v1.core.sched.output import SchedulerOutput
30+
from vllm.v1.request import Request
31+
from vllm.attention.backends.abstract import AttentionMetadata
32+
from unifiedcache.ucm_connector.base import UcmKVStoreBase
33+
from vllm.config import VllmConfig
34+
35+
import torch
36+
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
37+
from vllm.forward_context import ForwardContext
38+
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
39+
40+
INVALID_SLOT = -1
41+
42+
43+
class UcmSparseRole(enum.Enum):
44+
# sparser running in the scheduler process
45+
SCHEDULER = 0
46+
47+
# sparser running in the worker process
48+
WORKER = 1
49+
50+
51+
class UcmSparseMetadata(ABC): # noqa: B024
52+
"""
53+
Abstract Metadata used to communicate between the
54+
Scheduler UcmSparse instance and Worker UcmSparse instance.
55+
"""
56+
57+
pass
58+
59+
60+
class UcmSparseBase(ABC):
61+
"""
62+
An general interface for impl sparse attention algorithm in vLLM
63+
"""
64+
65+
def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole):
66+
self._sparse_metadata: Optional[UcmSparseMetadata] = None
67+
self._vllm_config = vllm_config
68+
self._role = role
69+
70+
@property
71+
def role(self) -> UcmSparseRole:
72+
return self._role
73+
74+
# ==============================
75+
# Worker-side methods
76+
# ==============================
77+
78+
def bind_sparse_metadata(self, sparse_metadata: UcmSparseMetadata) -> None:
79+
"""Set the connector metadata from the scheduler.
80+
81+
This function should be called by the model runner every time
82+
before the model execution. The metadata will be used for runtime
83+
KV cache loading and saving.
84+
85+
Args:
86+
connector_metadata (dict): the connector metadata.
87+
"""
88+
self._sparse_metadata = sparse_metadata
89+
90+
def clear_sparse_metadata(self) -> None:
91+
"""Clear the sparse metadata.
92+
93+
This function should be called by the model runner every time
94+
after the model execution.
95+
"""
96+
self._sparse_metadata = None
97+
98+
def _get_sparse_metadata(self) -> UcmSparseMetadata:
99+
"""Get the sparse metadata.
100+
101+
This function should only be called inside the UCMSparse.
102+
103+
Returns:
104+
SparseMetadata: the UCM sparse metadata.
105+
"""
106+
107+
# Should only be called while set to valid metadata.
108+
assert self._sparse_metadata is not None
109+
return self._sparse_metadata
110+
111+
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
112+
"""
113+
Args: kv_caches:
114+
dictionary of layer names, kv cache
115+
"""
116+
pass
117+
118+
def execute_begin(self, scheduler_output: SchedulerOutput):
119+
"""
120+
This is called at the beginning of "ModelRunner->execute_model" function.
121+
"""
122+
pass
123+
124+
def execute_finished(self):
125+
"""
126+
This is called at the end of "ModelRunner->execute_model" function.
127+
"""
128+
pass
129+
130+
def attention_begin(
131+
self,
132+
query: torch.Tensor,
133+
key: torch.Tensor,
134+
value: torch.Tensor,
135+
layer_name: str,
136+
forward_context: ForwardContext,
137+
) -> None:
138+
"""
139+
This is called at the beginning of "unified_attention".
140+
Sparse attention algorithm can modify forward_context.attn_metadata if necessary.
141+
(UC_TODO: modify dataclass is not allowed in python?)
142+
"""
143+
pass
144+
145+
def attention_finished(
146+
self,
147+
query: torch.Tensor,
148+
key: torch.Tensor,
149+
value: torch.Tensor,
150+
attn_output: torch.Tensor,
151+
layer_name: str,
152+
forward_context: ForwardContext,
153+
) -> None:
154+
"""
155+
This is called at the end of "unified_attention".
156+
"""
157+
pass
158+
159+
def request_finished_in_worker(self, request_id: Union[int, str]):
160+
"""
161+
This function releases the resources of finished requests at worker-side.
162+
"""
163+
pass
164+
165+
# ==============================
166+
# Scheduler-side methods
167+
# ==============================
168+
169+
@abstractmethod
170+
def request_begin(self, request_id: Union[int, str], prompt_token_ids: List[int]):
171+
"""
172+
This is called at the beginning of "Scheduler->add_request" function.
173+
"""
174+
pass
175+
176+
def request_finished_in_scheduler(self, request_id: Union[int, str]):
177+
"""
178+
This is called inside "Scheduler->finish_requests" function.
179+
Generate the metadata required by UcmSparse instance at worker-side.
180+
"""
181+
pass
182+
183+
def estimate_num_slots_sparsed(self, request: Request) -> int:
184+
"""
185+
This is called by "Scheduler->schedule" function to estimate the number of required blocks.
186+
"""
187+
pass
188+
189+
def update_state_after_alloc(self, request: Request, num_blocks: int):
190+
"""
191+
Update UcmSparse state after block allocation.
192+
"""
193+
pass
194+
195+
def build_sparse_meta(
196+
self,
197+
scheduler_output: SchedulerOutput,
198+
requests: dict[str, CachedRequestState],
199+
input_batch: InputBatch,
200+
) -> UcmSparseMetadata:
201+
"""
202+
Build the sparse metadata for this step.
203+
"""
204+
pass

0 commit comments

Comments
 (0)