|
| 1 | +# Copyright 2024-2025 ModelCloud.ai |
| 2 | +# Copyright 2024-2025 qubitium@modelcloud.ai |
| 3 | +# Contact: qubitium@modelcloud.ai, x.com/qubitium |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | + |
| 17 | + |
| 18 | +import torch |
| 19 | +import torch.nn as nn |
| 20 | +from packaging import version |
| 21 | +from torch import __version__ as torch_version |
| 22 | +from transformers import PreTrainedModel |
| 23 | + |
| 24 | +from ...adapter.adapter import Adapter, Lora |
| 25 | +from ...models._const import DEVICE, PLATFORM |
| 26 | +from ...nn_modules.qlinear import BaseQuantLinear, PackableQuantLinear |
| 27 | +from ...utils.backend import BACKEND |
| 28 | +from ...utils.logger import setup_logger |
| 29 | + |
| 30 | +log = setup_logger() |
| 31 | + |
| 32 | +class TorchFusedQuantLinear(PackableQuantLinear): |
| 33 | + SUPPORTS_BITS = [4] |
| 34 | + SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128] |
| 35 | + SUPPORTS_DESC_ACT = [True, False] |
| 36 | + SUPPORTS_SYM = [True, False] |
| 37 | + SUPPORTS_SHARDS = True |
| 38 | + SUPPORTS_TRAINING = True |
| 39 | + SUPPORTS_AUTO_PADDING = True |
| 40 | + SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [1] |
| 41 | + SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1] |
| 42 | + |
| 43 | + SUPPORTS_DEVICES = [DEVICE.XPU] |
| 44 | + SUPPORTS_PLATFORM = [PLATFORM.ALL] |
| 45 | + SUPPORTS_PACK_DTYPES = [torch.int32] |
| 46 | + SUPPORTS_ADAPTERS = [Lora] |
| 47 | + |
| 48 | + SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] |
| 49 | + |
| 50 | + # for transformers/optimum tests compat |
| 51 | + QUANT_TYPE = "torch" |
| 52 | + |
| 53 | + def __init__( |
| 54 | + self, |
| 55 | + bits: int, |
| 56 | + group_size: int, |
| 57 | + sym: bool, |
| 58 | + desc_act: bool, |
| 59 | + in_features: int, |
| 60 | + out_features: int, |
| 61 | + bias: bool = False, |
| 62 | + pack_dtype: torch.dtype = torch.int32, |
| 63 | + adapter: Adapter = None, |
| 64 | + register_buffers: bool = True, |
| 65 | + **kwargs, |
| 66 | + ): |
| 67 | + super().__init__( |
| 68 | + bits=bits, |
| 69 | + group_size=group_size, |
| 70 | + sym=sym, |
| 71 | + desc_act=desc_act, |
| 72 | + in_features=in_features, |
| 73 | + out_features=out_features, |
| 74 | + bias=bias, |
| 75 | + pack_dtype=pack_dtype, |
| 76 | + backend=kwargs.pop("backend", BACKEND.TORCH), |
| 77 | + adapter=adapter, |
| 78 | + register_buffers=register_buffers, |
| 79 | + **kwargs) |
| 80 | + |
| 81 | + self.transformed = False |
| 82 | + self.dequant_dtype = torch.int16 if self.bits == 8 else torch.int8 |
| 83 | + |
| 84 | + def post_init(self): |
| 85 | + super().post_init() |
| 86 | + self.optimize() |
| 87 | + |
| 88 | + def optimize(self): |
| 89 | + if self.optimized: |
| 90 | + return |
| 91 | + |
| 92 | + super().optimize() |
| 93 | + |
| 94 | + def train(self, mode: bool = True): |
| 95 | + old_train = self.training |
| 96 | + if mode == old_train: |
| 97 | + return self |
| 98 | + |
| 99 | + from ...utils.model import convert_gptq_v1_to_v2_format_module |
| 100 | + |
| 101 | + if self.SUPPORTS_TRAINING_USE_TORCH_KERNEL: |
| 102 | + # training starts |
| 103 | + if mode: |
| 104 | + # one time clone v1 qzeros and save both v1 and v2 qzeros in memory |
| 105 | + if self.qzero_format() == 1: |
| 106 | + if not hasattr(self, "qzeros_data_v1"): |
| 107 | + self.qzeros_data_v1 = self.qzeros.data.clone() |
| 108 | + convert_gptq_v1_to_v2_format_module(self, bits=self.bits, pack_dtype=self.pack_dtype) |
| 109 | + self.qzeros_data_v2 = self.qzeros.data |
| 110 | + else: |
| 111 | + self.qzeros.data = self.qzeros_data_v2 |
| 112 | + self.qzero_format(format=2) |
| 113 | + |
| 114 | + # training switching to inference/eval |
| 115 | + else: |
| 116 | + if hasattr(self, "qzeros_data_v1"): |
| 117 | + # switch qzero back to v1 for inference/eval |
| 118 | + self.qzeros.data = self.qzeros_data_v1 |
| 119 | + self.qzero_format(format=1) |
| 120 | + |
| 121 | + return super().train(mode=mode) |
| 122 | + |
| 123 | + def transform(self, dtype): |
| 124 | + self.scales = self.scales.clone().to(dtype).contiguous() |
| 125 | + # Unpack qzeros |
| 126 | + zeros = torch.bitwise_right_shift( |
| 127 | + torch.unsqueeze(self.qzeros, 2).expand(-1, -1, self.pack_factor), |
| 128 | + self.wf_unsqueeze_zero # self.wf.unsqueeze(0), |
| 129 | + ).to(self.dequant_dtype) |
| 130 | + zeros = torch.bitwise_and(zeros, self.maxq).reshape(zeros.shape[0], -1) |
| 131 | + # Unpack and reorder qweight |
| 132 | + weight = torch.bitwise_and( |
| 133 | + torch.bitwise_right_shift( |
| 134 | + torch.unsqueeze(self.qweight, 1).expand(-1, self.pack_factor, -1), |
| 135 | + self.wf_unsqueeze_neg_one # self.wf.unsqueeze(-1) |
| 136 | + ).to(self.dequant_dtype), |
| 137 | + self.maxq |
| 138 | + ) |
| 139 | + self.ret_idx = torch.zeros(self.g_idx.shape[0], dtype=torch.int32).to(self.g_idx.device) |
| 140 | + groups = self.g_idx.shape[0] // self.group_size |
| 141 | + remainder = self.g_idx.shape[0] % self.group_size |
| 142 | + g_idx_2 = self.g_idx * self.group_size |
| 143 | + if remainder > 0: |
| 144 | + g_idx_2[self.g_idx == groups] += torch.arange(remainder).to(self.g_idx_2.device).to(self.g_idx_2.dtype) |
| 145 | + arange_tensor = torch.arange(self.group_size).to(self.g_idx.device).to(self.g_idx.dtype) |
| 146 | + for i in range(groups): |
| 147 | + g_idx_2[self.g_idx == i] += arange_tensor |
| 148 | + self.ret_idx[g_idx_2] = torch.arange(self.g_idx.shape[0]).to(self.ret_idx.device).to(self.ret_idx.dtype) |
| 149 | + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]).index_select(0, self.ret_idx).t() |
| 150 | + # Pack qweight |
| 151 | + packed = torch.zeros(weight.shape[0], weight.shape[1] // self.pack_factor, dtype=torch.int32, device=weight.device) |
| 152 | + for col in range(weight.shape[1] // self.pack_factor): |
| 153 | + for i in range(self.pack_factor): |
| 154 | + packed_col = weight[:, col * self.pack_factor + i].to(torch.int32) |
| 155 | + packed[:, col] |= packed_col << (i * self.bits) |
| 156 | + |
| 157 | + self.qweight = packed.contiguous() |
| 158 | + self.qzeros = zeros.contiguous() |
| 159 | + |
| 160 | + def forward(self, x: torch.Tensor): |
| 161 | + out_shape = x.shape[:-1] + (self.out_features,) |
| 162 | + x = x.reshape(-1, x.shape[-1]) |
| 163 | + out = self._forward(x, out_shape) |
| 164 | + return out |
| 165 | + |
| 166 | + def _forward(self, x, out_shape): |
| 167 | + num_itr = self.g_idx.shape[0] // x.shape[-1] |
| 168 | + |
| 169 | + if not self.training and not self.transformed and version.parse(torch_version).release >= version.parse("2.8").release: |
| 170 | + self.transform(x.dtype) |
| 171 | + self.transformed = True |
| 172 | + |
| 173 | + if not self.transformed: |
| 174 | + # make sure dequant dtype matches input x |
| 175 | + weights = self.dequantize_weight(num_itr=num_itr).to(x.dtype) |
| 176 | + out = torch.matmul(x, weights).reshape(out_shape) |
| 177 | + else: |
| 178 | + x = x[:, self.ret_idx].contiguous() |
| 179 | + out = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros( |
| 180 | + x, self.qweight, self.group_size, self.scales, self.qzeros |
| 181 | + ).reshape(out_shape) |
| 182 | + |
| 183 | + if self.bias is not None: |
| 184 | + out.add_(self.bias) |
| 185 | + |
| 186 | + if self.adapter: |
| 187 | + out = self.adapter.apply(x=x, out=out) |
| 188 | + |
| 189 | + return out |
| 190 | + |
| 191 | + # clear gptq only weights: useful in de-quantization |
| 192 | + def _empty_gptq_only_weights(self): |
| 193 | + self.qzeros = None |
| 194 | + self.qweight = None |
| 195 | + self.g_idx = None |
| 196 | + self.scales = None |
| 197 | + |
| 198 | + |
| 199 | +def dequantize_model(model: PreTrainedModel): |
| 200 | + for name, module in model.named_modules(): |
| 201 | + if isinstance(module, BaseQuantLinear) and not isinstance(module, TorchFusedQuantLinear): |
| 202 | + raise ValueError( |
| 203 | + "Only models loaded using TorchFusedQuantLinear are supported for dequantization. " |
| 204 | + "Please load model using backend=BACKEND.TORCH." |
| 205 | + ) |
| 206 | + |
| 207 | + if isinstance(module, TorchFusedQuantLinear): |
| 208 | + # Create a new Linear layer with dequantized weights |
| 209 | + new_module = nn.Linear(module.in_features, module.out_features) |
| 210 | + new_module.weight = nn.Parameter(module.dequantize_weight().T.detach().to("cpu", torch.float16)) |
| 211 | + new_module.bias = torch.nn.Parameter(module.bias) |
| 212 | + |
| 213 | + # Replace the module in the model |
| 214 | + parent = model |
| 215 | + if '.' in name: |
| 216 | + parent_name, module_name = name.rsplit('.', 1) |
| 217 | + parent = dict(model.named_modules())[parent_name] |
| 218 | + else: |
| 219 | + module_name = name |
| 220 | + |
| 221 | + setattr(parent, module_name, new_module) |
| 222 | + |
| 223 | + del model.config.quantization_config |
| 224 | + return model |
| 225 | + |
| 226 | + |
| 227 | +__all__ = ["TorchFusedQuantLinear", "dequantize_model"] |
0 commit comments