Skip to content

Commit ec7eeff

Browse files
lk-chenpaulpak58
authored andcommitted
[V0 deprecation][P/D] Deprecate v0 KVConnectorBase code (1/2) (vllm-project#21785)
Signed-off-by: Linkun Chen <github@lkchen.net> Signed-off-by: Paul Pak <paulpak58@gmail.com>
1 parent 2603e16 commit ec7eeff

File tree

13 files changed

+31
-1040
lines changed

13 files changed

+31
-1040
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,6 @@ steps:
749749
# this test fails consistently.
750750
# TODO: investigate and fix
751751
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
752-
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py
753752
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
754753
- pytest -v -s models/multimodal/generation/test_maverick.py
755754

tests/kv_transfer/test_disagg.py

Lines changed: 0 additions & 120 deletions
This file was deleted.
Lines changed: 4 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -1,142 +1,10 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
"""
4-
KVConnectorBase Class for Distributed KV Cache & Hidden State communication
5-
6-
The class provides two primary abstract methods:
7-
1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
8-
2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
9-
"""
10-
11-
from abc import ABC, abstractmethod
12-
from typing import TYPE_CHECKING, Optional, Union
13-
14-
import torch
3+
"""Defines the base type for KV cache connectors."""
154

165
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
17-
from vllm.sequence import IntermediateTensors
18-
19-
if TYPE_CHECKING:
20-
from vllm.config import VllmConfig
21-
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
22-
23-
24-
class KVConnectorBase(ABC):
25-
"""
26-
Abstract base class for a KV connector.
27-
28-
The class provides two primary abstract methods:
29-
1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
30-
2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
31-
"""
32-
33-
@abstractmethod
34-
def __init__(
35-
self,
36-
rank: int,
37-
local_rank: int,
38-
config: "VllmConfig",
39-
):
40-
raise NotImplementedError
41-
42-
@abstractmethod
43-
def close(self) -> None:
44-
"""Close the buffer and release resources.
45-
46-
This method is responsible for cleaning up resources related to the
47-
connector when it is no longer needed.
48-
49-
Raises:
50-
NotImplementedError: This method must be implemented in subclasses.
51-
"""
52-
raise NotImplementedError
53-
54-
@abstractmethod
55-
def send_kv_caches_and_hidden_states(
56-
self,
57-
model_executable: torch.nn.Module,
58-
model_input: "ModelInputForGPUWithSamplingMetadata",
59-
kv_caches: list[torch.Tensor],
60-
hidden_or_intermediate_states: Union[torch.Tensor,
61-
IntermediateTensors],
62-
) -> None:
63-
"""
64-
Send KV caches and hidden states to the connector.
65-
66-
This method processes the input tokens, KV caches, and
67-
hidden/intermediate states for a given model and sends the data to the
68-
decode instance.
69-
70-
Args:
71-
model_executable (torch.nn.Module): The model executable containing
72-
start and end layer information.
73-
model_input (ModelInputForGPUWithSamplingMetadata): The input
74-
metadata from vLLM.
75-
kv_caches (list[torch.Tensor]): List of KV caches (keys and values)
76-
for each layer.
77-
hidden_or_intermediate_states (Union[torch.Tensor,
78-
IntermediateTensors]):
79-
The hidden or intermediate states associated with the tokens.
80-
81-
Returns:
82-
None
83-
84-
"""
85-
86-
raise NotImplementedError
87-
88-
@abstractmethod
89-
def recv_kv_caches_and_hidden_states(
90-
self, model_executable: torch.nn.Module,
91-
model_input: "ModelInputForGPUWithSamplingMetadata",
92-
kv_caches: list[torch.Tensor]
93-
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
94-
"ModelInputForGPUWithSamplingMetadata"]:
95-
"""
96-
Receive KV caches and hidden states from the connector.
97-
98-
This method attempts to retrieve KV caches and hidden states for input
99-
tokens. If all required KV caches and hidden states are received, it
100-
will bypass model input, else it will fall back to normal vLLM model
101-
forwarding.
102-
103-
Args:
104-
model_executable (torch.nn.Module):
105-
The model executable from vLLM modelrunner.
106-
model_input (ModelInputForGPUWithSamplingMetadata):
107-
The model input from vLLM modelrunner.
108-
kv_caches (list[torch.Tensor]):
109-
List of KV caches for each layer.
110-
111-
Returns:
112-
- hidden_or_intermediate_states (torch.Tensor or
113-
IntermediateTensors):
114-
Concatenated hidden states if all required data is retrieved,
115-
otherwise `None`.
116-
- bypass_model_exec (bool):
117-
Indicates whether the model execution can be skipped (True) or
118-
needs to be redone (False).
119-
- model_input (ModelInputForGPUWithSamplingMetadata):
120-
Optionally adjusted input metadata for re-execution when
121-
`bypass_model_exec=False`.
122-
123-
"""
124-
125-
raise NotImplementedError
126-
127-
@classmethod
128-
def get_required_kvcache_layout(
129-
cls, vllm_config: "VllmConfig") -> Optional[str]:
130-
"""
131-
Get the required KV cache layout for this connector.
132-
Args:
133-
vllm_config (VllmConfig): the vllm config.
134-
135-
Returns:
136-
str: the required KV cache layout. e.g. HND, or NHD.
137-
None if the connector does not require a specific layout.
138-
"""
139-
return None
1406

7+
KVConnectorBase = KVConnectorBase_V1
8+
KVConnectorBaseType = KVConnectorBase_V1
1419

142-
KVConnectorBaseType = Union[KVConnectorBase, KVConnectorBase_V1]
10+
__all__ = ["KVConnectorBase", "KVConnectorBaseType"]

vllm/distributed/kv_transfer/kv_connector/factory.py

Lines changed: 13 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,18 @@
55
from typing import TYPE_CHECKING, Callable
66

77
import vllm.envs as envs
8-
from vllm.config import KVTransferConfig
9-
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
10-
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
11-
KVConnectorRole)
8+
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
9+
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
1210
from vllm.logger import init_logger
1311

14-
from .base import KVConnectorBase
15-
1612
if TYPE_CHECKING:
1713
from vllm.config import VllmConfig
1814

1915
logger = init_logger(__name__)
2016

2117

2218
class KVConnectorFactory:
23-
_registry: dict[str, Callable[[], type[KVConnectorBaseType]]] = {}
19+
_registry: dict[str, Callable[[], type[KVConnectorBase]]] = {}
2420

2521
@classmethod
2622
def register_connector(cls, name: str, module_path: str,
@@ -29,28 +25,23 @@ def register_connector(cls, name: str, module_path: str,
2925
if name in cls._registry:
3026
raise ValueError(f"Connector '{name}' is already registered.")
3127

32-
def loader() -> type[KVConnectorBaseType]:
28+
def loader() -> type[KVConnectorBase]:
3329
module = importlib.import_module(module_path)
3430
return getattr(module, class_name)
3531

3632
cls._registry[name] = loader
3733

3834
@classmethod
39-
def create_connector_v0(cls, rank: int, local_rank: int,
40-
config: "VllmConfig") -> KVConnectorBase:
41-
if envs.VLLM_USE_V1:
42-
raise ValueError("Attempting to initialize a V0 Connector, "
35+
def create_connector(
36+
cls,
37+
config: "VllmConfig",
38+
role: KVConnectorRole,
39+
) -> KVConnectorBase:
40+
if not envs.VLLM_USE_V1:
41+
raise ValueError("Attempting to initialize a V1 Connector, "
4342
f"but found {envs.VLLM_USE_V1=}")
4443

45-
connector_cls = cls.get_connector_class(config.kv_transfer_config)
46-
assert issubclass(connector_cls, KVConnectorBase)
47-
return connector_cls(rank, local_rank, config)
48-
49-
@classmethod
50-
def get_connector_class(
51-
cls, kv_transfer_config: "KVTransferConfig"
52-
) -> type[KVConnectorBaseType]:
53-
"""Get the connector class by name."""
44+
kv_transfer_config = config.kv_transfer_config
5445
connector_name = kv_transfer_config.kv_connector
5546
if connector_name in cls._registry:
5647
connector_cls = cls._registry[connector_name]()
@@ -61,21 +52,7 @@ def get_connector_class(
6152
f"Unsupported connector type: {connector_name}")
6253
connector_module = importlib.import_module(connector_module_path)
6354
connector_cls = getattr(connector_module, connector_name)
64-
return connector_cls
65-
66-
@classmethod
67-
def create_connector_v1(
68-
cls,
69-
config: "VllmConfig",
70-
role: KVConnectorRole,
71-
) -> KVConnectorBase_V1:
72-
if not envs.VLLM_USE_V1:
73-
raise ValueError("Attempting to initialize a V1 Connector, "
74-
f"but found {envs.VLLM_USE_V1=}")
75-
76-
kv_transfer_config = config.kv_transfer_config
77-
connector_cls = cls.get_connector_class(kv_transfer_config)
78-
assert issubclass(connector_cls, KVConnectorBase_V1)
55+
assert issubclass(connector_cls, KVConnectorBase)
7956
logger.info("Creating v1 connector with name: %s and engine_id: %s",
8057
connector_cls.__name__, kv_transfer_config.engine_id)
8158
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
@@ -92,25 +69,6 @@ def create_connector_v1(
9269
# Register various connectors here.
9370
# The registration should not be done in each individual file, as we want to
9471
# only load the files corresponding to the current connector.
95-
KVConnectorFactory.register_connector(
96-
"PyNcclConnector",
97-
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
98-
"SimpleConnector")
99-
100-
KVConnectorFactory.register_connector(
101-
"MooncakeConnector",
102-
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
103-
"SimpleConnector")
104-
105-
KVConnectorFactory.register_connector(
106-
"LMCacheConnector",
107-
"vllm.distributed.kv_transfer.kv_connector.lmcache_connector",
108-
"LMCacheConnector")
109-
110-
KVConnectorFactory.register_connector(
111-
"MooncakeStoreConnector",
112-
"vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector",
113-
"MooncakeStoreConnector")
11472

11573
KVConnectorFactory.register_connector(
11674
"SharedStorageConnector",

0 commit comments

Comments
 (0)