Skip to content

Commit 0b25498

Browse files
[Misc] add ignore mapper for quark quantization (#28275)
Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com>
1 parent 0aecd91 commit 0b25498

File tree

1 file changed

+9
-3
lines changed
  • vllm/model_executor/layers/quantization/quark

1 file changed

+9
-3
lines changed

vllm/model_executor/layers/quantization/quark/quark.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import fnmatch
5-
from typing import Any, Optional, cast
5+
from typing import TYPE_CHECKING, Any, Optional, cast
66

77
import torch
88

@@ -34,6 +34,9 @@
3434
)
3535
from vllm.platforms import current_platform
3636

37+
if TYPE_CHECKING:
38+
from vllm.model_executor.models.utils import WeightsMapper
39+
3740
__all__ = ["QuarkLinearMethod"]
3841

3942
logger = init_logger(__name__)
@@ -54,6 +57,7 @@ def __init__(
5457
self.kv_cache_group = kv_cache_group
5558
self.kv_cache_config = kv_cache_config
5659
self.pack_method = pack_method
60+
self.ignore: list[str] = cast(list[str], self.quant_config.get("exclude", []))
5761

5862
def get_linear_method(self) -> "QuarkLinearMethod":
5963
return QuarkLinearMethod(self)
@@ -74,9 +78,8 @@ def get_quant_method(
7478
from vllm.attention.layer import Attention # Avoid circular import
7579

7680
# Check if the layer is skipped for quantization.
77-
exclude_layers = cast(list[str], self.quant_config.get("exclude"))
7881
if should_ignore_layer(
79-
prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping
82+
prefix, ignore=self.ignore, fused_mapping=self.packed_modules_mapping
8083
):
8184
return UnquantizedLinearMethod()
8285
if isinstance(layer, LinearBase):
@@ -90,6 +93,9 @@ def get_quant_method(
9093
return QuarkMoEMethod.get_moe_method(self, module=layer, layer_name=prefix)
9194
return None
9295

96+
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
97+
self.ignore = hf_to_vllm_mapper.apply_list(self.ignore)
98+
9399
@classmethod
94100
def from_config(cls, config: dict[str, Any]) -> "QuarkConfig":
95101
export_config = config.get("export")

0 commit comments

Comments
 (0)