diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index 8382f6e58..79b4911f7 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -571,7 +571,7 @@ def skip(*args, **kwargs): # If we use marlin or bitblas to load the quantized model, the model is already a converted model, # and we no longer need to call load_checkpoint_in_model() - if load_checkpoint_in_model and backend not in [BACKEND.MARLIN, BACKEND.MARLIN_FP16, BACKEND.BITBLAS]: + if load_checkpoint_in_model and backend not in [BACKEND.MARLIN, BACKEND.MARLIN_FP16, BACKEND.BITBLAS, BACKEND.MACHETE]: load_checkpoint_in_model_then_tie_weights( model, dtype=torch_dtype, diff --git a/gptqmodel/nn_modules/qlinear/machete.py b/gptqmodel/nn_modules/qlinear/machete.py new file mode 100644 index 000000000..c2b295554 --- /dev/null +++ b/gptqmodel/nn_modules/qlinear/machete.py @@ -0,0 +1,227 @@ +# Copyright 2024-2025 ModelCloud.ai +# Copyright 2024-2025 qubitium@modelcloud.ai +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from typing import Optional, Tuple +from functools import partial +from ...models._const import DEVICE, PLATFORM +from ...adapter.adapter import Adapter, Lora +from .marlin import MarlinQuantLinear, replace_tensor +from ...utils.backend import BACKEND +from ...utils.scalar_type import scalar_types, ScalarType + +machete_import_exception = None +try: + import gptqmodel_machete_kernels +except ImportError as e: + machete_import_exception = e + +TYPE_MAP = { + (4, True): scalar_types.uint4b8, + (8, True): scalar_types.uint8b128, + } + +MACHETE_PREPACKED_BLOCK_SHAPE = [64, 128] + +def check_machete_supports_shape(in_features: int, out_featrues: int) \ + -> Tuple[bool, Optional[str]]: + if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0: + return False, "Input features size must be divisible by "\ + f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}" + if out_featrues % MACHETE_PREPACKED_BLOCK_SHAPE[1] != 0: + return False, "Output features size must be divisible by "\ + f"{MACHETE_PREPACKED_BLOCK_SHAPE[1]}" + return True, None + +def pack_quantized_values_into_int32(w_q: torch.Tensor, + wtype: ScalarType, + packed_dim: int = 0): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + assert w_q_perm.shape[-1] % pack_factor == 0 + new_shape_perm[-1] //= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i + + return res.permute(inv_perm) + + +def unpack_quantized_values_into_int32(w_q: torch.Tensor, + wtype: ScalarType, + packed_dim: int = 0): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + new_shape_perm[-1] *= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask + + return res.permute(inv_perm) + +def machete_mm( + a: torch.Tensor, + # b_q Should be the tensor returned by machete_prepack_B + b_q: torch.Tensor, + b_type: ScalarType, + out_type: Optional[torch.dtype] = None, + b_group_scales: Optional[torch.Tensor] = None, + b_group_zeros: Optional[torch.Tensor] = None, + b_group_size: Optional[int] = None, + b_channel_scales: Optional[torch.Tensor] = None, + a_token_scales: Optional[torch.Tensor] = None, + schedule: Optional[str] = None) -> torch.Tensor: + return gptqmodel_machete_kernels.machete_mm(a, b_q, b_type.id, out_type, b_group_scales, + b_group_zeros, b_group_size, + b_channel_scales, a_token_scales, schedule) + + +def machete_prepack_B( + b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType, + group_scales_type: Optional[torch.dtype]) -> torch.Tensor: + return gptqmodel_machete_kernels.machete_prepack_B(b_q_weight, a_type, b_type.id, + group_scales_type) + + +def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: + return gptqmodel_machete_kernels.permute_cols(a, perm) + +class MacheteQuantLinear(MarlinQuantLinear): + SUPPORTS_BITS = [4, 8] + SUPPORTS_GROUP_SIZE = [-1, 128] + SUPPORTS_DESC_ACT = [True, False] + SUPPORTS_SYM = [True] + SUPPORTS_SHARDS = True + SUPPORTS_TRAINING = False + SUPPORTS_AUTO_PADDING = False + SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [1] + SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [64] + + SUPPORTS_DEVICES = [DEVICE.CUDA] + SUPPORTS_PLATFORM = [PLATFORM.LINUX] + SUPPORTS_PACK_DTYPES = [torch.int32] + SUPPORTS_ADAPTERS = [Lora] + + SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] + + # for transformers/optimum tests compat + QUANT_TYPE = "machete" + + def __init__( + self, bits: int, + group_size: int, + desc_act: bool, + sym: bool, + in_features: int, + out_features: int, + bias: bool = False, + pack_dtype: torch.dtype = torch.int32, + adapter: Adapter = None, + **kwargs): + + if machete_import_exception is not None: + raise ValueError( + f"Trying to use the machete backend, but could not import the C++/CUDA dependencies with the following error: {machete_import_exception}" + ) + + super().__init__( + bits=bits, + group_size=group_size, + sym=sym, + desc_act=desc_act, + in_features=in_features, + out_features=out_features, + bias=bias, + pack_dtype=pack_dtype, + backend=kwargs.pop("backend", BACKEND.MACHETE), + adapter=adapter, + **kwargs) + + _, err = check_machete_supports_shape(self.in_features, self.out_features) + if err is not None: + raise ValueError( + f"check_machete_supports_shape failed, {err}" + ) + + self.quant_type = TYPE_MAP.get((bits, sym), None) + + def post_init(self): + perm = torch.argsort(self.g_idx) \ + .to(torch.int) + + self.act_perm = lambda x: x[:, perm] + if self.in_features % 8 == 0: + self.act_perm = partial(permute_cols, perm=perm) + + x_unpacked = unpack_quantized_values_into_int32(self.qweight.data, + self.quant_type, + packed_dim=0) + + x_perm = x_unpacked[perm, :] + self.qweight.data = pack_quantized_values_into_int32(x_perm, + self.quant_type, + packed_dim=0) + + machete_qweight = machete_prepack_B(self.qweight.data.t().contiguous().t(), + a_type=self.scales.dtype, + b_type=self.quant_type, + group_scales_type=self.scales.dtype) + + replace_tensor(self, "qweight", machete_qweight) + + marlin_scales = self.scales.data.contiguous() + + replace_tensor(self, "scales", marlin_scales) + + def forward(self, x: torch.Tensor): + if x.shape[0] == 0: + return torch.empty((0, self.out_features), dtype=x.dtype, device=x.device) + + # make sure scales is synced with x/input + if x.dtype != self.scales.dtype: + self.scales = self.scales.to(dtype=x.dtype) + + x_2d = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (self.out_features,) + + x_2d = self.act_perm(x_2d) + output = machete_mm(a=x_2d, + b_q=self.qweight, + b_type=self.quant_type, + b_group_zeros=None, + b_group_scales=self.scales, + b_group_size=self.group_size) + + if self.bias is not None: + output.add_(self.bias) # In-place add + + return output.reshape(out_shape) diff --git a/gptqmodel/nn_modules/qlinear/marlin.py b/gptqmodel/nn_modules/qlinear/marlin.py index 6f27853f6..d355d8d39 100644 --- a/gptqmodel/nn_modules/qlinear/marlin.py +++ b/gptqmodel/nn_modules/qlinear/marlin.py @@ -495,4 +495,5 @@ def dequantize_qzeros(layer): return unpacked_qzeros -__all__ = ["MarlinQuantLinear"] + +__all__ = ["MarlinQuantLinear", "replace_tensor"] diff --git a/gptqmodel/utils/backend.py b/gptqmodel/utils/backend.py index f814d1356..eeee857a9 100644 --- a/gptqmodel/utils/backend.py +++ b/gptqmodel/utils/backend.py @@ -29,6 +29,7 @@ class BACKEND(str, Enum): EXLLAMA_EORA = "exllama_eora" MARLIN = "marlin" # FASTEST: marlin reduce ops in fp32 (higher precision -> more accurate, slightly slower) MARLIN_FP16 = "marlin_fp16" # FASTEST and then some: marlin reduce ops in fp16 (lower precision -> less accurate, slightly faster) + MACHETE = "machete" BITBLAS = "bitblas" # EXTREMELY FAST: speed at the cost of 10+ minutes of AOT (ahead of time compilation with disk cache) IPEX = "ipex" # Best kernel for Intel XPU and Intel/AMD CPU with AVX512, AMX, # XMX diff --git a/gptqmodel/utils/importer.py b/gptqmodel/utils/importer.py index 4bee62999..b8abf6e1c 100644 --- a/gptqmodel/utils/importer.py +++ b/gptqmodel/utils/importer.py @@ -29,6 +29,7 @@ from ..nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear from ..nn_modules.qlinear.ipex import IPEXQuantLinear from ..nn_modules.qlinear.marlin import MarlinQuantLinear +from ..nn_modules.qlinear.machete import MacheteQuantLinear from ..nn_modules.qlinear.qqq import QQQQuantLinear from ..nn_modules.qlinear.torch import TorchQuantLinear from ..nn_modules.qlinear.tritonv2 import TRITON_AVAILABLE, TRITON_INSTALL_HINT, TritonV2QuantLinear @@ -51,7 +52,7 @@ BACKEND.IPEX: IPEXQuantLinear, # best kernel Intel XPU and CPU with amx/avx512/xmx BACKEND.BITBLAS: BitBLASQuantLinear, # super slow AOT pre-compiler but fastest for bs=1 BACKEND.TORCH: TorchQuantLinear, # slightly slower than Triton but getting close in Torch 2.6.0+ - + BACKEND.MACHETE: MacheteQuantLinear, # machete kernel BACKEND.QQQ: QQQQuantLinear, # qqq kernel based on marlin }) @@ -250,6 +251,8 @@ def select_quant_linear( qlinear = QQQQuantLinear elif backend == BACKEND.TORCH: qlinear = TorchQuantLinear + elif backend == BACKEND.MACHETE: + qlinear = MacheteQuantLinear else: qlinear = TorchQuantLinear diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index eb891ebc9..981872593 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -37,6 +37,7 @@ import transformers from gptqmodel.nn_modules.qlinear.exllama_eora import ExllamaEoraQuantLinear from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear +from gptqmodel.nn_modules.qlinear.machete import MacheteQuantLinear from gptqmodel.nn_modules.qlinear.qqq import QQQQuantLinear from huggingface_hub import HfApi, hf_hub_download from packaging import version @@ -371,7 +372,7 @@ def hf_convert_gptq_v1_to_v2_format( ) -> Tuple[nn.Module, bool]: if checkpoint_format == "gptq": # skip v1 to v2 conversion for kernels that can only operate on sym=True (gptq_v1) - if qlinear_kernel in [IPEXQuantLinear, MarlinQuantLinear, ExllamaEoraQuantLinear]: + if qlinear_kernel in [IPEXQuantLinear, MarlinQuantLinear, MacheteQuantLinear, ExllamaEoraQuantLinear]: return model, False cfg = QuantizeConfig(bits=bits) @@ -470,7 +471,7 @@ def convert_gptq_v1_to_v2_format( qlinear_kernel: Type[BaseQuantLinear], ): # skip v2 to v1 conversion for gptq_v1 kernels - if qlinear_kernel in [IPEXQuantLinear, MarlinQuantLinear, ExllamaEoraQuantLinear, QQQQuantLinear]: + if qlinear_kernel in [IPEXQuantLinear, MarlinQuantLinear, MacheteQuantLinear, ExllamaEoraQuantLinear, QQQQuantLinear]: log.info( f"Format: Skipped v1 to v2 conversion due to Kernel `{qlinear_kernel}`.") return model @@ -547,7 +548,7 @@ def convert_gptq_v2_to_v1_format( ): # skip v2 to v1 conversion for gptq_v1 kernels - if qlinear_kernel in [IPEXQuantLinear, MarlinQuantLinear, ExllamaEoraQuantLinear, QQQQuantLinear]: + if qlinear_kernel in [IPEXQuantLinear, MarlinQuantLinear, MacheteQuantLinear, ExllamaEoraQuantLinear, QQQQuantLinear]: return model # Limit thread usage to avoid auto-parallizataion regression diff --git a/gptqmodel/utils/scalar_type.py b/gptqmodel/utils/scalar_type.py new file mode 100644 index 000000000..5d893a3a5 --- /dev/null +++ b/gptqmodel/utils/scalar_type.py @@ -0,0 +1,346 @@ +# SPDX-License-Identifier: Apache-2.0 + +import functools +import struct +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Union + +_SCALAR_TYPES_ID_MAP = {} + + +# Mirrors enum in `core/scalar_type.hpp` +class NanRepr(Enum): + NONE = 0 # nans are not supported + IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s + EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s + + +# This ScalarType class is a parallel implementation of the C++ ScalarType +# class found in csrc/core/scalar_type.hpp. These two classes should be kept +# in sync until the inductor fully supports custom C++ classes. +@dataclass(frozen=True) +class ScalarType: + """ + ScalarType can represent a wide range of floating point and integer + types, in particular it can be used to represent sub-byte data types + (something that torch.dtype currently does not support). It is also + capable of representing types with a bias, i.e.: + `stored_value = value + bias`, + this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias + of 8). The implementation for this class can be found in + csrc/core/scalar_type.hpp, these type signatures should be kept in sync + with that file. + """ + + exponent: int + """ + Number of bits in the exponent if this is a floating point type + (zero if this an integer type) + """ + + mantissa: int + """ + Number of bits in the mantissa if this is a floating point type, + or the number bits representing an integer excluding the sign bit if + this an integer type. + """ + + signed: bool + "If the type is signed (i.e. has a sign bit)" + + bias: int + """ + bias used to encode the values in this scalar type + (value = stored_value - bias, default 0) for example if we store the + type as an unsigned integer with a bias of 128 then the value 0 will be + stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. + """ + + _finite_values_only: bool = False + """ + Private: if infs are supported, used `has_infs()` instead. + """ + + nan_repr: NanRepr = NanRepr.IEEE_754 + """ + How NaNs are represent in this scalar type, returns NanRepr value. + (not applicable for integer types) + """ + + def _floating_point_max_int(self) -> int: + assert ( + self.mantissa <= 52 and self.exponent <= 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + + max_mantissa = (1 << self.mantissa) - 1 + if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: + max_mantissa = max_mantissa - 1 + + max_exponent = (1 << self.exponent) - 2 + if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN + or self.nan_repr == NanRepr.NONE): + assert ( + self.exponent < 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + max_exponent = max_exponent + 1 + + # adjust the exponent to match that of a double + # for now we assume the exponent bias is the standard 2^(e-1) -1, (where + # e is the exponent bits), there is some precedent for non-standard + # biases, example `float8_e4m3b11fnuz` here: + # https://github.com/jax-ml/ml_dtypes but to avoid premature over + # complication we are just assuming the standard exponent bias until + # there is a need to support non-standard biases + exponent_bias = (1 << (self.exponent - 1)) - 1 + exponent_bias_double = (1 << 10) - 1 # double e = 11 + + max_exponent_double = (max_exponent - exponent_bias + + exponent_bias_double) + + # shift the mantissa and exponent into the proper positions for an + # IEEE double and bitwise-or them together. + return (max_mantissa << + (52 - self.mantissa)) | (max_exponent_double << 52) + + def _floating_point_max(self) -> float: + double_raw = self._floating_point_max_int() + return struct.unpack('!d', struct.pack('!Q', double_raw))[0] + + def _raw_max(self) -> Union[int, float]: + if self.is_floating_point(): + return self._floating_point_max() + else: + assert (self.size_bits < 64 or self.size_bits == 64 + and self.is_signed()), "Cannot represent max as an int" + return (1 << self.mantissa) - 1 + + def _raw_min(self) -> Union[int, float]: + if self.is_floating_point(): + assert self.is_signed( + ), "We currently assume all floating point types are signed" + sign_bit_double = 1 << 63 + + max_raw = self._floating_point_max_int() + min_raw = max_raw | sign_bit_double + return struct.unpack('!d', struct.pack('!Q', min_raw))[0] + else: + assert (not self.is_signed() or self.size_bits + <= 64), "Cannot represent min as a int64_t" + + if self.is_signed(): + return -(1 << (self.size_bits - 1)) + else: + return 0 + + @functools.cached_property + def id(self) -> int: + """ + Convert the ScalarType to an int which can be passed to pytorch custom + ops. This layout of the int must be kept in sync with the C++ + ScalarType's from_id method. + """ + val = 0 + offset = 0 + + def or_and_advance(member, bit_width): + nonlocal val + nonlocal offset + bit_mask = (1 << bit_width) - 1 + val = val | (int(member) & bit_mask) << offset + offset = offset + bit_width + + or_and_advance(self.exponent, 8) + or_and_advance(self.mantissa, 8) + or_and_advance(self.signed, 1) + or_and_advance(self.bias, 32) + or_and_advance(self._finite_values_only, 1) + or_and_advance(self.nan_repr.value, 8) + + assert offset <= 64, \ + f"ScalarType fields too big {offset} to fit into an int64" + + _SCALAR_TYPES_ID_MAP[val] = self + + return val + + @property + def size_bits(self) -> int: + return self.exponent + self.mantissa + int(self.signed) + + def min(self) -> Union[int, float]: + """ + Min representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_min() - self.bias + + def max(self) -> Union[int, float]: + """ + Max representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_max() - self.bias + + def is_signed(self) -> bool: + """ + If the type is signed (i.e. has a sign bit), same as `signed` + added for consistency with: + https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html + """ + return self.signed + + def is_floating_point(self) -> bool: + "If the type is a floating point type" + return self.exponent != 0 + + def is_integer(self) -> bool: + "If the type is an integer type" + return self.exponent == 0 + + def has_bias(self) -> bool: + "If the type has a non-zero bias" + return self.bias != 0 + + def has_infs(self) -> bool: + "If the type is floating point and supports infinity" + return not self._finite_values_only + + def has_nans(self) -> bool: + return self.nan_repr != NanRepr.NONE.value + + def is_ieee_754(self) -> bool: + """ + If the type is a floating point type that follows IEEE 754 + conventions + """ + return self.nan_repr == NanRepr.IEEE_754.value and \ + not self._finite_values_only + + def __str__(self) -> str: + """ + naming generally follows: https://github.com/jax-ml/ml_dtypes + for floating point types (leading f) the scheme is: + `float_em[flags]` + flags: + - no-flags: means it follows IEEE 754 conventions + - f: means finite values only (no infinities) + - n: means nans are supported (non-standard encoding) + for integer types the scheme is: + `[u]int[b]` + - if bias is not present it means its zero + """ + if self.is_floating_point(): + ret = "float" + str(self.size_bits) + "_e" + str( + self.exponent) + "m" + str(self.mantissa) + + if not self.is_ieee_754(): + if self._finite_values_only: + ret = ret + "f" + if self.nan_repr != NanRepr.NONE: + ret = ret + "n" + + return ret + else: + ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) + if self.has_bias(): + ret = ret + "b" + str(self.bias) + return ret + + def __repr__(self) -> str: + return "ScalarType." + self.__str__() + + # __len__ needs to be defined (and has to throw TypeError) for pytorch's + # opcheck to work. + def __len__(self) -> int: + raise TypeError + + # + # Convenience Constructors + # + + @classmethod + def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + "Create a signed integer scalar type (size_bits includes sign-bit)." + ret = cls(0, size_bits - 1, True, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + """Create a unsigned integer scalar type.""" + ret = cls(0, size_bits, False, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': + """ + Create a standard floating point type + (i.e. follows IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + ret = cls(exponent, mantissa, True, 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, + nan_repr: NanRepr) -> 'ScalarType': + """ + Create a non-standard floating point type + (i.e. does not follow IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + assert (nan_repr != NanRepr.IEEE_754), ( + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions") + ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def from_id(cls, scalar_type_id: int): + if scalar_type_id not in _SCALAR_TYPES_ID_MAP: + raise ValueError( + f"scalar_type_id {scalar_type_id} doesn't exists.") + return _SCALAR_TYPES_ID_MAP[scalar_type_id] + + +# naming generally follows: https://github.com/jax-ml/ml_dtypes +# for floating point types (leading f) the scheme is: +# `float_em[flags]` +# flags: +# - no-flags: means it follows IEEE 754 conventions +# - f: means finite values only (no infinities) +# - n: means nans are supported (non-standard encoding) +# for integer types the scheme is: +# `[u]int[b]` +# - if bias is not present it means its zero + + +class scalar_types: + int4 = ScalarType.int_(4, None) + uint4 = ScalarType.uint(4, None) + int8 = ScalarType.int_(8, None) + uint8 = ScalarType.uint(8, None) + float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) + float8_e5m2 = ScalarType.float_IEEE754(5, 2) + float16_e8m7 = ScalarType.float_IEEE754(8, 7) + float16_e5m10 = ScalarType.float_IEEE754(5, 10) + + # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main + float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + + # fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + float4_e2m1fn = ScalarType.float_(2, 1, True, NanRepr.NONE) + + # "gptq" types + uint2b2 = ScalarType.uint(2, 2) + uint3b4 = ScalarType.uint(3, 4) + uint4b8 = ScalarType.uint(4, 8) + uint8b128 = ScalarType.uint(8, 128) + + # colloquial names + bfloat16 = float16_e8m7 + float16 = float16_e5m10 diff --git a/gptqmodel_ext/machete/machete.cpp b/gptqmodel_ext/machete/machete.cpp new file mode 100644 index 000000000..001898175 --- /dev/null +++ b/gptqmodel_ext/machete/machete.cpp @@ -0,0 +1,15 @@ +#include +#include +#include +#include +#include +#include + +#include "machete_pytorch.cuh" +#include "permute_cols.cuh" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("machete_prepack_B", &machete::prepack_B, "machete_prepack_B"); + m.def("machete_mm", &machete::mm, "machete_mm"); + m.def("permute_cols", &permute_cols, "permute_cols"); +} \ No newline at end of file diff --git a/gptqmodel_ext/machete/machete_collective_builder.cuh b/gptqmodel_ext/machete/machete_collective_builder.cuh new file mode 100644 index 000000000..ee825583d --- /dev/null +++ b/gptqmodel_ext/machete/machete_collective_builder.cuh @@ -0,0 +1,31 @@ +#pragma once + +#include "cutlass_extensions/vllm_collective_builder.cuh" +#include "machete_mainloop.cuh" + +namespace cutlass::gemm::collective { +using namespace cute; + +struct MacheteKernelTag {}; + +template +struct VLLMCollectiveBuilder< + MacheteKernelTag, arch::Sm90, arch::OpClassTensorOp, ElementPairA_, + GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_, AlignmentB, + ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, + KernelScheduleType, + cute::enable_if_t<( + cute::is_same_v || + cute::is_same_v || + cute::is_same_v)>> { + using CollectiveOp = machete::MacheteCollectiveMma< + ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_, + AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, + StageCountType, KernelScheduleType>; +}; + +}; // namespace cutlass::gemm::collective diff --git a/gptqmodel_ext/machete/machete_interleaving_utils.cuh b/gptqmodel_ext/machete/machete_interleaving_utils.cuh new file mode 100644 index 000000000..d397f87f1 --- /dev/null +++ b/gptqmodel_ext/machete/machete_interleaving_utils.cuh @@ -0,0 +1,35 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" + +namespace machete { + +using namespace cute; + +// get an interleaved block layout where each element consecutive element has a +// stride of bit_stride and the block width is blk_bit_width, +// examples: +// size_bits = 8, bit_stride = 8, blk_bit_width = 32 -> 4:1 +// size_bits = 8, bit_stride = 16, blk_bit_width = 32 -> (2, 2):(2, 1) +// size_bits = 4, bit_stride = 8, blk_bit_width = 32 -> (4, 2):(2, 1) +// size_bits = 4, bit_stride = 16, blk_bit_width = 32 -> (2, 4):(4, 1) +template +CUTE_HOST_DEVICE static constexpr auto get_interleaved_blk_layout() { + static_assert(blk_bit_width % bit_stride == 0); + static_assert(bit_stride % cute::sizeof_bits_v == 0); + + constexpr auto elems_per_blk = blk_bit_width / cute::sizeof_bits_v; + + if constexpr (cute::sizeof_bits_v == bit_stride) { + // identity layout + return Layout>>{}; + } else { + constexpr auto elems_per_stride = bit_stride / cute::sizeof_bits_v; + constexpr auto num_strides = elems_per_blk / elems_per_stride; + return Layout, Int>, + Stride, Int<1>>>{}; + } +} + +}; // namespace machete diff --git a/gptqmodel_ext/machete/machete_mainloop.cuh b/gptqmodel_ext/machete/machete_mainloop.cuh new file mode 100644 index 000000000..572894064 --- /dev/null +++ b/gptqmodel_ext/machete/machete_mainloop.cuh @@ -0,0 +1,1474 @@ +// +// Based off of: +// cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp +// Specifically: +// https://github.com/NVIDIA/cutlass/tree/06b21349bcf6ddf6a1686a47a137ad1446579db9/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp +// Referred to as upstream from in the comments +// +// The main optimization machete implements compared to upstream is to prepack +// the weight matrix to more closely match the shape of the wgmma instructions +// allowing for wider (ideally 128bit) shared memory loads. For subbyte types +// this is done by packing values from multiple wgmma loads (for a single +// thread) into a single 128bit load. This is very similar to layout used in +// Marlin, although specific to the wgmma instructions. +// +// Since the wgmma instructions only support sourcing from registers for the A +// operand, and we want to upconvert/decompress the weight values/elements +// before feeding them into the tensor cores in registers, we need the weight +// matrix to be A. To achieve this we compute the transpose of Y = XW^t as +// Y^t = W^tX^t. This is mostly done outside of this file in +// csrc/quantization/machete/machete_mm_kernel.cuh, but this why A is the +// quantized/narrow type and has the prepacked layout despite the API being: +// B_prepacked = machete_prepack_B(B) +// Y = machete_mm(A, B_prepacked) +// +#pragma once + +// clang-format off +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/detail/layout.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_traits_sm90_tma.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/transform/collective/sm90_wgmma_transpose.hpp" +#include "cutlass/trace.h" + +#include "cutlass/detail/collective.hpp" +// clang-format on + +#include "cutlass_extensions/cute_utils.cuh" + +namespace machete { + +using namespace cute; +using namespace cutlass; +using namespace cutlass::gemm; +using namespace cutlass::gemm::collective; +using namespace cutlass::gemm::collective::detail; + +template +struct MacheteCollectiveMma { + using Schedule = KernelScheduleType; + static_assert( + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v, + "KernelSchedule must be one of the warp specialized policies"); + + public: + static constexpr bool ALayoutIsPrepacked = true; + + // Prepacked block shape (N is M in the transposed problem) + using PPBlockShape_MK = typename GmemLayoutA::PPBlockShape_NK; + // Prepacked blocks per dim for a single MMA tile + using PPBlocksPerTile_MK = decltype(make_shape( + size<0>(TileShape_MNK{}) / size<0>(PPBlockShape_MK{}), + size<2>(TileShape_MNK{}) / size<1>(PPBlockShape_MK{}))); + + using IlvdBlkLayout = typename GmemLayoutA::IlvdBlkLayout; + + static_assert(size<0>(TileShape_MNK{}) % size<0>(PPBlockShape_MK{}) == 0, + "M in PPBlockShape_MK must evenly divide M TileShape_MNK"); + static_assert(size<2>(TileShape_MNK{}) % size<1>(PPBlockShape_MK{}) == 0, + "K in PPBlockShape_MK must evenly divide K TileShape_MNK"); + + using ArchTag = arch::Sm90; + using TileShape = TileShape_MNK; + using ClusterShape = ClusterShape_MNK; + using ElementA = deduce_mixed_width_dtype_t<0, ElementATuple_>; + using StrideA = TagToStrideA_t; + using ElementB = ElementB_; + using StrideB = TagToStrideB_t; + using ElementAccumulator = ElementAccumulator_; + using ElementMma = ElementB; + using ElementATuple = + cute::conditional_t::value, + cute::tuple, ElementATuple_>; + + static constexpr cute::GMMA::Major GmmaMajorA = + gmma_rs_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = + gmma_rs_tag_to_major_B(); + + // For coop schedules we have two warp groups cooperatively issuing wgmma + // instructions so we use 2 atoms along the M dim (one for each warpgroup) + using AtomLayoutMNK = cute::conditional_t< + cute::is_same_v, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma( + cute::GMMA::rs_op_selector(), + AtomLayoutMNK{})); + + private: + // + // the setup section (until "section setup end") contains a combination of + // modified code from (used as a starting point): + // `cutlass/gemm/collective/builders/sm90_gmma_builder.inl` + // `cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp` + // (upstream) + // + // however in-order to simplify the code we combine a lot of the logic from + // `CollectiveMma` and `CollectiveBuilder` into this class, this also makes + // sense given that we have flexibility on layouts here. We also simplify the + // code by only supporting scales and zeros for A (in the transposed problem, + // B from an API perspective), also since we force A to be the narrow type + // (i.e. the type to be upconverted) we can remove all the `SwapAB` logic in + // the upstream also simplifying the code. This section includes new logic + // (compared ustream) for handling the prepacked-A layouts (in the transposed + // problem, B from an API perspective) + // + using ElementScale = deduce_mixed_width_dtype_t<1, ElementATuple_>; + using ElementZero = deduce_mixed_width_dtype_t<2, ElementATuple_>; + + static constexpr bool IsANarrow = cutlass::sizeof_bits::value < + cutlass::sizeof_bits::value; + static_assert(IsANarrow, + "A must be the narrow one since its the one that flows through " + "registers."); + + public: + static constexpr int PipelineStages = + compute_stage_count_or_override_single_affine_transformed_input< + sm90_smem_capacity_bytes, ElementA, ElementB, ElementScale, + ElementZero, TileShape_MNK>(StageCountType{}); + + struct DispatchPolicy { + constexpr static int Stages = PipelineStages; + using ClusterShape = ClusterShape_MNK; + using Schedule = KernelScheduleType; + }; + + using GmemTiledCopyA = + decltype(sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = + decltype(sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + // ((T, V), (BlocksM, BlocksK), pipe) -> offset + using SmemLayoutA = decltype(GmemLayoutA::TVbNbKL_to_offset( + make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}), + Int{}))); + + using SmemLayoutACopy = decltype(GmemLayoutA::TVbNbKL_to_offset_copy( + make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}), + Int{}))); + + using SmemLayoutAtomARowMajor = + decltype(rs_smem_selector(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{}))>()); + + using SmemLayoutAtomScale = Layout< + Shape(SmemLayoutAtomARowMajor{})), cute::Int<1>>>; + + using SmemLayoutAtomB = + decltype(rs_smem_selector(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{}))>()); + + using SmemCopyAtomA = Copy_Atom; + using SmemCopyAtomB = void; + + // + // Validity checks + // + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(is_aligned(), + "Should meet TMA alignment requirement\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, + "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + private: + enum class ConversionMode { + DirectConvert, + ConvertAndScale, + ConvertAndScaleWithZero + }; + + public: + // + // Type Aliases + // + using KernelSchedule = KernelScheduleType; + + // For cases where we can't have a void type, we can use this to allow the + // code to compile when the scale / zero is void. + using NonVoidElementScale = + cute::conditional_t, float, ElementScale>; + using NonVoidElementZero = + cute::conditional_t, float, ElementZero>; + + // These are always MN major + using StrideScale = cute::Stride, int64_t, int64_t>; + // For cases where we can't have a void scale, we can use this to allow the + // code to compile when the scale is void. + using NonVoidStrideScale = + cute::conditional_t, + cute::Stride<_1, int64_t, int64_t>, StrideScale>; + + static_assert((cutlass::gemm::detail::is_k_major()), + "The transformed matrix (A) must be K-major."); + + static_assert((sizeof(ElementB) == 2) || + (cutlass::gemm::detail::is_k_major() && + cutlass::gemm::detail::is_k_major()), + "The unscaled element (matrix B) must be 2 bytes OR both " + "inputs must be K-major"); + + static_assert(cutlass::gemm::detail::is_mn_major(), + "Scale must be MN major [Col Major if A is scaled, Row Major " + "if B is scaled]."); + + static_assert(std::is_same_v, + "TiledMma::ValTypeC must be the same as ElementAccumulator."); + + using GmemTiledCopyScale = cute::SM90_TMA_LOAD; + + using SmemCopyAtomScale = Copy_Atom; + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any + // rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using InternalElementA = + cute::conditional_t>>; + using InternalElementB = + cute::conditional_t>>; + + using TransformA = cute::identity; + using TransformB = cute::identity; + + static constexpr int IsSubbyteA = cute::sizeof_bits_v < 8; + using TmaElementA = + cute::conditional_t; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + // One threads per CTA are producers (1 for operand tile) + static constexpr int NumProducerThreadEvents = 1; + + using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}), + shape<1>(SmemLayoutAtomScale{}))); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, + "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomScale{}) == 2, + "SmemLayoutAtomScale must be rank 2"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomScale{})) == 0, + "SmemLayoutAtomScale must equal the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0, + "SmemLayoutAtomScale must evenly divide tile k shape."); + + // Tile along modes in a way that maximizes the TMA box size + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), + Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), + Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + + // It is assumed that the scales and zero-points share the same smem layout + using SmemLayoutScale = decltype(tile_to_shape( + SmemLayoutAtomScale{}, + make_shape(shape<0>(ScaleTileShape{}), shape<1>(ScaleTileShape{}), + Int{}))); + + // If A mn-layout and B mn-layout, transposing B matrix since WGMMA is k-major + // only (e.g. tf32, fp32, fp8, int8). + static constexpr bool IsLayoutAmnBmn = + cute::is_same_v, + layout::ColumnMajor> && + cute::is_same_v, + layout::RowMajor>; + + static_assert(DispatchPolicy::Stages >= 2, + "Specialization requires Stages set to value 2 or more."); + static_assert(not cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source A from rmem and B operand from smem_desc " + "for this mainloop."); + static_assert(cute::is_same_v || + cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || + cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + using GmmaSmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), + Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), + Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + + // These two restrictions are related, so we place the assertions together. + // To relax them, we need to handle loading more than 1 row of scales for + // every main loop iteration. We must also handle updating the pipeline + // transaction bytes on the fly. NOTE: Deleting this assertion without + // required changes will cause the code to hang. + static_assert(size<1>(SmemLayoutAtomScale{}) == 1, + "size<1>(SmemLayoutAtomScale) must be 1."); + + private: + static constexpr ConversionMode get_conversion_mode() { + if constexpr (cute::is_void_v) { + return ConversionMode::DirectConvert; + } else if constexpr (cute::is_void_v) { + return ConversionMode::ConvertAndScale; + } else { + return ConversionMode::ConvertAndScaleWithZero; + } + } + + static constexpr ConversionMode KernelConversionMode = get_conversion_mode(); + static constexpr bool ModeHasScales = + KernelConversionMode == ConversionMode::ConvertAndScale || + KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; + + // Same as upstream, should be kept the same when possible + static constexpr auto elements_per_smem_scale() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return 0; + } else if constexpr (ModeHasScales) { + return cute::cosize_v; + } else { + static_assert(cutlass::detail::dependent_false, + "Type not handled in scale smem allocation."); + } + } + + // Same as upstream, should be kept the same when possible + static constexpr auto elements_per_smem_zero() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert || + KernelConversionMode == ConversionMode::ConvertAndScale) { + return 0; + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + return cute::cosize_v; + } else { + static_assert(cutlass::detail::dependent_false, + "Type not handled in scale smem allocation."); + } + } + + // Same as upstream, should be kept the same when possible, not formatte for + // easier comparison + // clang-format off + // These methods use some the public members of the class. For that reason, we define them after the public section. + static constexpr uint32_t + compute_tma_transaction_bytes_mk() { + constexpr uint32_t baseline_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(cute::sizeof_bits_v)); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return baseline_bytes; + } + else if constexpr (ModeHasScales) { + constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); + static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return baseline_bytes + scale_tx_bytes; + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + // Scale and zero share smem layout + constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); + static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA + return baseline_bytes + scale_tx_bytes + zero_tx_bytes; + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); + } + } + + static constexpr uint32_t + compute_tma_transaction_bytes_nk() { + return cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(cute::sizeof_bits_v)); + } + // clang-format on + + // ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx) + using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset_copy( + make_shape(int32_t(0), int32_t(0), int32_t(0))))); + + using ATensor = decltype(make_tensor( + get_logical_ptr(static_cast(nullptr)), + shape(GmemLayoutA::TVbNbKL_to_offset_copy( + make_shape(int32_t(0), int32_t(0), int32_t(0)))), + PrepackedStrideA{})); + + using BTensor = decltype(make_tensor( + get_logical_ptr(static_cast(nullptr)), + repeat_like(StrideB{}, int32_t(0)), StrideB{})); + using ScaleTensor = decltype(make_tensor( + get_logical_ptr(static_cast(nullptr)), + repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{})); + + using ZeroTensor = decltype(make_tensor( + get_logical_ptr(static_cast(nullptr)), + repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{})); + + static constexpr auto make_tma_copy_A(ATensor tensor_a = ATensor{}) { + return make_tma_copy( + GmemTiledCopyA{}, tensor_a, SmemLayoutACopy{}(_, _, cute::Int<0>{}), + shape(SmemLayoutACopy{}(_, _, cute::Int<0>{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + } + + static constexpr auto make_tma_copy_scale( + ScaleTensor tensor_scale = ScaleTensor{}) { + return make_tma_copy(GmemTiledCopyScale{}, tensor_scale, + SmemLayoutScale{}(_, _, cute::Int<0>{}), + ScaleTileShape{}, + _1{}); // mcast along N mode for this M load, if any + } + + static constexpr auto make_tma_copy_zero( + ZeroTensor tensor_zero = ZeroTensor{}) { + return make_tma_copy(GmemTiledCopyScale{}, tensor_zero, + SmemLayoutScale{}(_, _, cute::Int<0>{}), + ScaleTileShape{}, + _1{}); // mcast along N mode for this M load, if any + } + + static constexpr auto make_tma_copy_B(BTensor tensor_b = BTensor{}) { + return make_tma_copy( + GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_, _, cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + } + + public: + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // with `RealInternalElementA` -> `ElementA` since we support `SwapAB` logic + // clang-format off + static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); + + static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); + + // Just pick the max alignment of A and B since it is required to be at least 128B + static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB); + + static_assert(SmemAlignmentA >= 128 and SmemAlignmentB >= 128, "Require at least 128B alignment"); + + struct SharedStorage + { + static constexpr int scale_elements = elements_per_smem_scale(); + static constexpr int zero_elements = elements_per_smem_zero(); + struct TensorStorage : cute::aligned_struct { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine smem_scale; + cute::ArrayEngine smem_zero; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A = nullptr; + StrideA dA{}; + ElementB const* ptr_B = nullptr; + StrideB dB{}; + ElementScale const* ptr_S = nullptr; + NonVoidStrideScale dS{}; + int group_size = 0; + ElementZero const* ptr_Z = nullptr; + uint32_t mma_promotion_interval = 4; + }; + // clang-format on + + // + // section setup end + // + + // Similar (but not idendtical) to upstream, should be kept the same when + // possible + // compared to upstream we use `make_tma_copy_A`, `make_tma_copy_B` etc. to + // define the TMA types + // Device side kernel params + struct Params { + public: + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy_A()); + using TMA_Scale = decltype(make_tma_copy_scale()); + using TMA_Zero = decltype(make_tma_copy_zero()); + using TMA_B = decltype(make_tma_copy_B()); + + // required by outer loop: i.e. + // cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_Scale tma_load_scale; + TMA_Zero tma_load_zero; + int64_t scale_k; + int group_size; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + }; + + // + // Methods + // + + // Similar (but not idendtical) to upstream, should be kept the same when + // possible + // compared to upstream we use `make_tma_copy_A` and `TVbNbKL_to_offset` here + // to handle the prepacked layout + template + static constexpr Params to_underlying_arguments( + ProblemShape const& problem_shape, Arguments const& args, + void* workspace) { + (void)workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is + // only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + auto make_logical_tensor = [&](auto ptr, auto shape, auto stride) { + return make_tensor(get_logical_ptr(ptr), make_layout(shape, stride)); + }; + + typename Params::TMA_A tma_load_a; + typename Params::TMA_B tma_load_b; + typename Params::TMA_Scale tma_load_scale; + typename Params::TMA_Zero tma_load_zero; + + auto layout = GmemLayoutA::TVbNbKL_to_offset_copy(make_shape(M, K, L)); + tma_load_a = make_tma_copy_A( + make_logical_tensor(ptr_A, shape(layout), stride(layout))); + + tma_load_b = make_tma_copy_B( + make_logical_tensor(ptr_B, make_shape(N, K, L), args.dB)); + + int32_t scale_k = + (ModeHasScales) ? (K + args.group_size - 1) / args.group_size : 0; + int32_t group_size = (ModeHasScales) ? args.group_size : 0; + + if constexpr (ModeHasScales) { + tma_load_scale = make_tma_copy_scale( + make_logical_tensor(args.ptr_S, make_shape(M, scale_k, L), args.dS)); + } + + if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + tma_load_zero = make_tma_copy_zero( + make_logical_tensor(args.ptr_Z, make_shape(M, scale_k, L), args.dS)); + } + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert || + KernelConversionMode == ConversionMode::ConvertAndScale || + KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + return {tma_load_a, tma_load_b, tma_load_scale, + tma_load_zero, scale_k, group_size}; + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in to_underlying_arguments."); + } + } + + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // with `SwapAB ? N : M -> M` since we dont support SwapAB + // clang-format off + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + implementable = implementable && (args.ptr_S == nullptr); + implementable = implementable && (args.ptr_Z == nullptr); + } + else if constexpr (ModeHasScales) { + const int scale_mn = M; + const int scale_k = (K + args.group_size - 1) / args.group_size; + constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), StrideScale{}); + implementable = implementable && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0)); + implementable = implementable && args.group_size != 0; + implementable = implementable && (args.ptr_S != nullptr); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + implementable = implementable && (args.ptr_Z == nullptr); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + constexpr int min_tma_aligned_elements_zero = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), StrideScale{}); + implementable = implementable && (args.ptr_Z != nullptr); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr uint32_t TmaTransactionBytesMK = compute_tma_transaction_bytes_mk(); + static constexpr uint32_t TmaTransactionBytesNK = compute_tma_transaction_bytes_nk(); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Nothing extra to do + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor()); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_zero.get_tma_descriptor()); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA prefetch."); + } + + } + // clang-format off + + // Modified from upstream, should be kept close to that when possible + // the main difference is special handling for the prepacked A layout + // + // Set up the data needed by this collective for load and mma. + // Returns a tuple of tensors. The collective and the kernel layer have the + // contract Returned tuple must contain at least two elements, with the first + // two elements being: gA_mkl - The tma tensor, A after a local tile so it + // has shape (TILE_V,TILE_B,m,k,l) gB_nkl - The tma tensor, B after a local + // tile so it has shape (TILE_N,TILE_K,n,k,l) The rest of the tensors can be + // specified as needed by this collective. + // NOTE: TILE_B is the prepacked block index within a tile. TILE_V is the + // values within a prepacked block. + template + CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, + Params const& mainloop_params) const { + using X = Underscore; + auto M = get<0>(problem_shape_MNKL), N = get<1>(problem_shape_MNKL), + K = get<2>(problem_shape_MNKL), L = get<3>(problem_shape_MNKL); + + // (TILE_V,TILE_B,m,k,l) + auto make_gA_mkl = [&]() { + // ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx) + auto layout = GmemLayoutA::TVbNbKL_to_offset_copy(make_shape(M, K, L)); + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(layout)); + return local_tile(mA_mkl, + make_shape(size<0>(layout), PPBlocksPerTile_MK{}), + make_coord(0, make_coord(_, _))); + }; + + // (TILE_N,TILE_K,n,k,l) + auto make_gB_nkl = [&]() { + Tensor mB_nkl = + mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); + return local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), + Step{}); + }; + + // (TILE_M,TILE_Scale_K,m,scale_k,l) + auto make_gS_mkl = [&]() { + auto scale_k = mainloop_params.scale_k; + Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor( + make_shape(M, scale_k, L)); + return local_tile(mS_mkl, ScaleTileShape{}, make_coord(_, _)); + }; + + // (TILE_M,TILE_Scale_K,m,scale_k,l) + auto make_gZ_mkl = [&]() { + auto scale_k = mainloop_params.scale_k; + Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor( + make_shape(M, scale_k, L)); + return local_tile(mZ_mkl, ScaleTileShape{}, make_coord(_, _)); + }; + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(make_gA_mkl(), make_gB_nkl()); + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScale) { + return cute::make_tuple(make_gA_mkl(), make_gB_nkl(), make_gS_mkl()); + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + return cute::make_tuple(make_gA_mkl(), make_gB_nkl(), make_gS_mkl(), + make_gZ_mkl()); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in load_init."); + } + } + + // Similar to upstream, should be kept close to that when possible + // the main difference is in the layout comments + // clang-format off + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + /// This overload gets triggered when we have scales. + template < + class... Ts, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + static_assert(sizeof... (Ts) == 2, "Direct convert needs two inputs"); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + static_assert(sizeof... (Ts) == 3, "Scaled convert needs three inputs"); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + static_assert(sizeof... (Ts) == 4, "Scaled and zero convert needs four inputs"); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA load."); + } + + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A, B and Scales + // + + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (TILE_V,TILE_B,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (TILE_N,TILE_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + uint16_t mcast_mask_s = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + auto extra_input_partitions = partition_extra_tma_inputs(mainloop_params, load_inputs, shared_tensors, cluster_local_block_id, m_coord, l_coord); + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Nothing extra to do. + } + else if constexpr (ModeHasScales) { + auto tSgS = get<0>(extra_input_partitions); + auto tSsS = get<1>(extra_input_partitions); + + // Temporary factor which will determine which k tile to reload from gmem. Needed so we don't modify tma transaction bytes + // on the fly. + // We must do a ceiling divide here to correctly handle with group_size == K. In that case, we don't require that K + // is a multiple of the threadblock tile K + const int ReloadFactor = (mainloop_params.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}); + const int scale_load_k = *k_tile_iter / ReloadFactor; // This will always be 0 when group_size == K. + copy(mainloop_params.tma_load_scale.with(*tma_barrier, mcast_mask_s), tSgS(_,_,_,scale_load_k), tSsS(_,_,_,write_stage)); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tZgZ = get<2>(extra_input_partitions); + auto tZsZ = get<3>(extra_input_partitions); + copy(mainloop_params.tma_load_zero.with(*tma_barrier, mcast_mask_s), tZgZ(_,_,_,scale_load_k), tZsZ(_,_,_,write_stage)); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } + + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + // clang-format off + + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // clang-format off + // Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + // clang-format on + + // Modified from upstream, should be kept close to that when possible + // the main differences are handling the prepacked A layout, and separating + // the loading of A from upcoverting A + // + // Perform a collective-scoped matrix multiply-accumulate + // Consumer Perspective + template + CUTLASS_DEVICE void mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, FrgTensorC& accum, + int k_tile_count, int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + static_assert(is_rmem::value, + "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutB{}) == 3, + "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, + "SmemLayoutAtomB must be rank 2."); + static_assert(!cute::is_void_v, + "SM90 GMMA mainloops must specify a non-void copy atom for " + "RF sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for " + "smem sourced instructions."); + + // Obtain warp index + int warp_idx = canonical_warp_idx_sync(); + [[maybe_unused]] int warp_group_thread_idx = thread_idx % 128; + + // ((T, (FrgV,(RestM, RestK)), (BlocksM, BlocksK), pipe) -> offset + auto constexpr smem_A = SmemLayoutA{}; + + // convert: + // ((T, (MMA,(MMA_M, MMA_K)), (BlocksM, BlocksK), pipe) -> offset + // to: + // (T, MMA, ((MMA_M, BlocksM), (MMA_K, BlocksK)), pipe) -> offset + // which can be thought of as: + // (T, MMA, (MMA_M, MMA_K), pipe) -> offset + auto constexpr smem_A_mma_ = + make_layout(get<0, 0>(smem_A), get<0, 1, 0>(smem_A), + zip(get<0, 1, 1>(smem_A), get<1>(smem_A)), get<2>(smem_A)); + // flatten to: + // (T, MMA, MMA_M, MMA_K, pipe) -> offset + auto constexpr smem_A_mma = smem_A_mma_(_, _, make_coord(_, _), _); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), + smem_A_mma); // (T, MMA, MMA_M, MMA_K, pipe) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), + SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = sA(thread_idx, _, _, _, _); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate fragments and descriptors + Tensor tCrA_load = make_tensor( + tCsA(_, _, _, Int<0>{}).shape()); // (MMA,MMA_N,MMA_K) + Tensor tCrA_mma = make_fragment_like(tCrA_load); + + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + static constexpr int A_CPY_VEC = + decltype(max_common_vector(tCsA, tCrA_load)){}; + + static constexpr int COVERSION_WIDTH = + std::min(A_CPY_VEC, int(size<0>(tCrA_mma))); + + auto load_A_to_registers = [&](int read_stage) { + copy(create_auto_vectorizing_copy(), + tCsA(_, _, _, read_stage), tCrA_load(_, _, _)); + }; + + // Partition of thread -> shared and thread -> RF + auto partitioned_extra_info = + partition_extra_mma_info(thread_mma, shared_tensors); + auto copy_partitions_extra_info = retile_extra_mma_info( + tiled_mma, partitioned_extra_info, warp_group_thread_idx); + CUTE_STATIC_ASSERT_V(size<1>(tCrA_mma) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + + auto convert_A = [&, a_vec = Int{}](int k_block, + int read_stage) { + load_extra_info_to_registers(partitioned_extra_info, + copy_partitions_extra_info, k_block, + read_stage); + transform_A_kblock(tCrA_load, a_vec, tCrA_mma, partitioned_extra_info, + k_block); + }; + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum); + + constexpr int K_BLOCK_MAX = size<2>(tCrA_load); + + ConsumerToken barrier_token = {BarrierStatus::WaitAgain}; + // first k tile + { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + ++smem_pipe_read; + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + + // copy smem->rmem for A operand + load_A_to_registers(read_stage); + convert_A(0, read_stage); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + if (k_block < K_BLOCK_MAX - 1) { + convert_A(k_block + 1, smem_pipe_read.index()); + } + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_, _, k_block), + tCrB(_, _, k_block, read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + } + + --k_tile_count; + if (k_tile_count > 0) { + // Wait for K_BLOCK_MAX - 1 to be in flight to ensure that it is safe to + // overwrite the A registers for the first mma. + warpgroup_wait(); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + load_A_to_registers(smem_pipe_read.index()); + convert_A(0, smem_pipe_read.index()); + } + } + + if (k_tile_count == 0) { + return; + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 1; --k_tile_count) { + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + ++smem_pipe_read; + + warpgroup_fence_operand(accum); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_, _, k_block), + tCrB(_, _, k_block, read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + + warpgroup_wait(); + if (k_block == K_BLOCK_MAX - 1) { + // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage, + // so we can release prior barrier + pipeline.consumer_release( + smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ + // on it + ++smem_pipe_release; + } + + if (k_block == 0) { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + } + + if (k_block == K_BLOCK_MAX - 1) { + pipeline.consumer_wait(smem_pipe_read, barrier_token); + load_A_to_registers(smem_pipe_read.index()); + convert_A(0, smem_pipe_read.index()); + } else { + convert_A(k_block + 1, read_stage); + } + } + warpgroup_fence_operand(accum); + } + + warpgroup_fence_operand(accum); + + { + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + warpgroup_fence_operand(accum); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_, _, k_block), + tCrB(_, _, k_block, read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + warpgroup_wait(); + if (k_block == K_BLOCK_MAX - 1) { + // release prior barrier + pipeline.consumer_release( + smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ + // on it + ++smem_pipe_release; + } + + if (k_block < K_BLOCK_MAX - 1) { + convert_A(k_block + 1, read_stage); + } + } + } + + warpgroup_fence_operand(accum); + } + + // Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, + PipelineState smem_pipe_release, + int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = 1; + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release( + smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on + // it + ++smem_pipe_release; + } + } + + private: + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // clang-format off + /// Utilities for any additional inputs inside of the TMA load + template + CUTLASS_DEVICE + auto partition_extra_tma_inputs( + Params const& mainloop_params, + cute::tuple const& load_inputs, + TensorStorage& shared_tensors, + uint2 const& cluster_local_block_id, + int const m_coord, + int const l_coord) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(); + } + else if constexpr (ModeHasScales) { + Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) + Tensor gS_mkl = get<2>(load_inputs); + auto block_tma_s = mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y); + Tensor gS = gS_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + + Tensor tSgS = block_tma_s.partition_S(gS); // (TMA,TMA_M,TMA_K,k) + Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE) + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tSgS, tSsS); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) + Tensor gZ_mkl = get<3>(load_inputs); + auto block_tma_z = mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y); + Tensor gZ = gZ_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + + Tensor tZgZ = block_tma_z.partition_S(gZ); // (TMA,TMA_M,TMA_K,k) + Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE) + return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); + } + } + // clang-format off + + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // clang-format off + /// Utilities for partitioning extra inputs for loading from smem in the mainloop. + template + CUTLASS_DEVICE + auto partition_extra_mma_info( + ThreadMma const& mma_thread_slice, + TensorStorage& shared_tensors) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + return cute::make_tuple(); + } + else if constexpr (ModeHasScales) { + Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsS = mma_thread_slice.partition_A(sS); + Tensor tCrS = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).shape()); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tCsS, tCrS); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsZ = mma_thread_slice.partition_A(sZ); + Tensor tCrZ = make_tensor(mma_thread_slice.partition_fragment_A(sZ(_,_,Int<0>{})).shape()); + return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + // clang-format on + + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // clang-format off + /// Returns the tiled copy and copy views for the extra inputs. + template + CUTLASS_DEVICE + auto retile_extra_mma_info( + TiledMma const& tiled_mma, + cute::tuple& partitioned_extra_info, + int const warp_group_thread_idx) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + return cute::make_tuple(); + } + else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma); + auto smem_thr_copy_S = smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx); + Tensor tCrS_copy_view = smem_thr_copy_S.retile_D(cute::get<1>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor tCrZ_copy_view = smem_thr_copy_S.retile_D(cute::get<3>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, tCrZ_copy_view); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + // clang-format on + + // Similar to `copy_A_and_extra_info` upstream, should be kept the same when + // possible + // the main differences this only loads the extra info into registers and + // not A (since we now preload more of A in the main pipeline) + // Load scales and zeros into registers if required + template + CUTLASS_DEVICE void load_extra_info_to_registers( + cute::tuple const& partitioned_mma_extra_info, + cute::tuple const& tiled_copy_and_views, int k_block, + int read_stage) { + if (k_block == 0) { + // We are starting a new k-tile so copy the scale + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + } else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = cute::get<0>(tiled_copy_and_views); + auto tCrS_copy_view = cute::get<1>(tiled_copy_and_views); + auto tCsS = cute::get<0>(partitioned_mma_extra_info); + copy(smem_tiled_copy_S, tCsS(_, _, k_block, read_stage), + tCrS_copy_view(_, _, k_block)); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + auto tCsZ = cute::get<2>(partitioned_mma_extra_info); + auto tCrZ_copy_view = cute::get<2>(tiled_copy_and_views); + copy(smem_tiled_copy_S, tCsZ(_, _, k_block, read_stage), + tCrZ_copy_view(_, _, k_block)); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in A -> RF path."); + } + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in A -> RF path."); + } + } + } + + // Similar to upstream, should be kept the same when possible. + // the main differences are that `convert_tensor` supports interleaved + // layouts and bfloat16 has been optimized. `transform_internal_A` has also + // been inlined for code simplicity. + // Utilities to transform A. + template + CUTLASS_DEVICE void transform_A_kblock( + TCrA_load const& tCrA_load, cute::Int vec_A, + TCrA_mma& tCrA_mma, cute::tuple const& partitioned_extra_info, + int const k_block) { + auto in = tCrA_load(_, _, k_block); + auto out = tCrA_mma(_, _, k_block); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + convert_tensor(in, out, vec_A); + } else if constexpr (ModeHasScales) { + auto tCrS = cute::get<1>(partitioned_extra_info); + auto converted_inputs = + make_fragment_like(tCrA_mma)(_, _, k_block); + auto scales = tCrS(_, _, 0); + + // First, we upcast the inputs to the scale type + convert_tensor(in, converted_inputs, vec_A); + // Apply scales and broadcast across inputs, store in converted_inputs + + // We need to cast to nv_bfloat16 for the multiply since + // `cutlass::bfloat16_t` has an overloaded operator* that upconverts to + // float, which nvcc will not optimize to using vectorized fma + // instructions (i.e. hfma.bf16_v2) + if constexpr (std::is_same_v) { + cute::transform( + recast(converted_inputs), recast(scales), + recast(converted_inputs), cute::multiplies{}); + } else { + cute::transform(converted_inputs, scales, converted_inputs, + cute::multiplies{}); + } + + // Apply zeros if required + if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + auto tCrZ = cute::get<3>(partitioned_extra_info); + auto converted_zeros = make_fragment_like(tCrZ)(_, _, 0); + + convert_tensor(tCrZ(_, _, 0), converted_zeros); + if constexpr (std::is_same_v) { + cute::transform(recast(converted_inputs), + recast(converted_zeros), + recast(converted_inputs), cute::plus{}); + } else { + cute::transform(converted_inputs, converted_zeros, converted_inputs, + cute::plus{}); + } + } + + // Finally, we convert the scaled inputs to the mma type. + convert_tensor(converted_inputs, out); + } else { + static_assert(cutlass::detail::dependent_false, + "No A data is loaded."); + } + } + + // Modified from upstream, should be kept the same when possible + // the main differences is that this version supports interleaved converts + // Utilities for transforming the A operand prior to issuing tensorcore math. + template > + CUTLASS_DEVICE void convert_tensor( + Tensor const& in, + Tensor& out, + cute::Int width = {}) { + // This is an element-wise conversion where we expect both tensors to have + // the same layout. As a result, we can cast as a cutlass array to use the + // fast numeric converters without worrying about indexing into the layout. + constexpr int N = cosize_v; + + // The inputs must be backed by registers & be statically sized. + static_assert(is_rmem::value, + "Input tensor for A conversion must come from registers"); + static_assert(is_rmem::value, + "Output tensor for A conversion must come from registers"); + static_assert(is_static_v, + "Tensor layout for the conversion must be static"); + static_assert(cosize_v == size(TensorLayout{}), + "Cosize and size of the layout must be equal."); + static_assert( + N % ConversionVectorWidth == 0, + "Conversion vector width must divide cosize of the tensor layout."); + + using SrcType = typename EngineIn::value_type; + using DstType = typename EngineOut::value_type; + + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + + constexpr cutlass::FloatRoundStyle RoundStyle = + cutlass::FloatRoundStyle::round_to_nearest; + + using Converter = cutlass::InterleavedNumericArrayConverter< + IlvdBlkLayout, DstType, SrcType, ConversionVectorWidth, RoundStyle>; + + constexpr int NumIterations = N / ConversionVectorWidth; + + for (int ii = 0; ii < NumIterations; ++ii) { + SrcArray const* src_array_ptr = + reinterpret_cast(raw_pointer_cast(in.data())) + ii; + DstArray* dst_array_ptr = + reinterpret_cast(raw_pointer_cast(out.data())) + ii; + *dst_array_ptr = Converter::convert(*src_array_ptr); + } + } +}; + +} // namespace machete diff --git a/gptqmodel_ext/machete/machete_mm_dispatch.cu b/gptqmodel_ext/machete/machete_mm_dispatch.cu new file mode 100644 index 000000000..94d2bf76b --- /dev/null +++ b/gptqmodel_ext/machete/machete_mm_dispatch.cu @@ -0,0 +1,454 @@ + +#include "machete_mm_launcher.cuh" + +namespace machete { + + + +extern torch::Tensor impl_f16u4b8f16voidvoidvoidf16f32_sch_128x128_2x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_f16u4b8f16voidvoidvoidf16f32_sch_128x256_2x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_f16u4b8f16voidvoidvoidf16f32_sch_128x64_2x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_f16u4b8f16voidvoidvoidf16f32_sch_128x32_2x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_f16u4b8f16voidvoidvoidf16f32_sch_256x128_2x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_f16u4b8f16voidvoidvoidf16f32_sch_128x16_1x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_f16u4b8f16voidvoidvoidf16f32_sch_256x64_2x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_f16u4b8f16voidvoidvoidf16f32_sch_256x32_2x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_f16u4b8f16voidvoidvoidf16f32_sch_256x16_1x1x1_TmaMI__TmaCoop_streamK(MMArgs); + +torch::Tensor mm_dispatch_f16u4b8f16voidvoidvoidf16f32(MMArgs args) { + [[maybe_unused]] auto M = args.A.size(0); + [[maybe_unused]] auto N = args.B.size(1); + [[maybe_unused]] auto K = args.A.size(1); + + if (!args.maybe_schedule) { + if (M > 256 && K <= 16384 && N <= 4096) + return impl_f16u4b8f16voidvoidvoidf16f32_sch_128x128_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 256) + return impl_f16u4b8f16voidvoidvoidf16f32_sch_128x256_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 128 && K <= 4096 && N <= 4096) + return impl_f16u4b8f16voidvoidvoidf16f32_sch_128x64_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 128 && K <= 8192 && N <= 8192) + return impl_f16u4b8f16voidvoidvoidf16f32_sch_128x128_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 128) + return impl_f16u4b8f16voidvoidvoidf16f32_sch_128x256_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 64 && K <= 4069 && N <= 4069) + return impl_f16u4b8f16voidvoidvoidf16f32_sch_128x32_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 64 && K <= 4069 && N <= 8192) + return impl_f16u4b8f16voidvoidvoidf16f32_sch_128x64_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 64 && K >= 8192 && N >= 12288) + return impl_f16u4b8f16voidvoidvoidf16f32_sch_256x128_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 64) + return impl_f16u4b8f16voidvoidvoidf16f32_sch_128x128_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 32 && K <= 6144 && N <= 6144) + return impl_f16u4b8f16voidvoidvoidf16f32_sch_128x16_1x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 32 && K >= 16384 && N >= 12288) + return impl_f16u4b8f16voidvoidvoidf16f32_sch_256x64_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 32) + return impl_f16u4b8f16voidvoidvoidf16f32_sch_128x64_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 16 && K <= 12288 && N <= 8192) + return impl_f16u4b8f16voidvoidvoidf16f32_sch_128x32_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 16) + return impl_f16u4b8f16voidvoidvoidf16f32_sch_256x32_2x1x1_TmaMI__TmaCoop_streamK(args); + if (N >= 26624) + return impl_f16u4b8f16voidvoidvoidf16f32_sch_256x16_1x1x1_TmaMI__TmaCoop_streamK(args); + else + return impl_f16u4b8f16voidvoidvoidf16f32_sch_128x16_1x1x1_TmaMI__TmaCoop_streamK(args); + } + if (*args.maybe_schedule == "128x128_2x1x1_TmaMI__TmaCoop_streamK") + return impl_f16u4b8f16voidvoidvoidf16f32_sch_128x128_2x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "128x256_2x1x1_TmaMI__TmaCoop_streamK") + return impl_f16u4b8f16voidvoidvoidf16f32_sch_128x256_2x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "128x64_2x1x1_TmaMI__TmaCoop_streamK") + return impl_f16u4b8f16voidvoidvoidf16f32_sch_128x64_2x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "128x32_2x1x1_TmaMI__TmaCoop_streamK") + return impl_f16u4b8f16voidvoidvoidf16f32_sch_128x32_2x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "256x128_2x1x1_TmaMI__TmaCoop_streamK") + return impl_f16u4b8f16voidvoidvoidf16f32_sch_256x128_2x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "128x16_1x1x1_TmaMI__TmaCoop_streamK") + return impl_f16u4b8f16voidvoidvoidf16f32_sch_128x16_1x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "256x64_2x1x1_TmaMI__TmaCoop_streamK") + return impl_f16u4b8f16voidvoidvoidf16f32_sch_256x64_2x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "256x32_2x1x1_TmaMI__TmaCoop_streamK") + return impl_f16u4b8f16voidvoidvoidf16f32_sch_256x32_2x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "256x16_1x1x1_TmaMI__TmaCoop_streamK") + return impl_f16u4b8f16voidvoidvoidf16f32_sch_256x16_1x1x1_TmaMI__TmaCoop_streamK(args); + TORCH_CHECK_NOT_IMPLEMENTED(false, "machete_gemm(..) is not implemented for " + "schedule = ", *args.maybe_schedule); +} + +extern torch::Tensor impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_128x128_2x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_128x256_2x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_128x64_2x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_128x32_2x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_256x128_2x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_128x16_1x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_256x64_2x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_256x32_2x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_256x16_1x1x1_TmaMI__TmaCoop_streamK(MMArgs); + +torch::Tensor mm_dispatch_bf16u4b8bf16voidvoidvoidbf16f32(MMArgs args) { + [[maybe_unused]] auto M = args.A.size(0); + [[maybe_unused]] auto N = args.B.size(1); + [[maybe_unused]] auto K = args.A.size(1); + + if (!args.maybe_schedule) { + if (M > 256 && K <= 16384 && N <= 4096) + return impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_128x128_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 256) + return impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_128x256_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 128 && K <= 4096 && N <= 4096) + return impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_128x64_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 128 && K <= 8192 && N <= 8192) + return impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_128x128_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 128) + return impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_128x256_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 64 && K <= 4069 && N <= 4069) + return impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_128x32_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 64 && K <= 4069 && N <= 8192) + return impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_128x64_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 64 && K >= 8192 && N >= 12288) + return impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_256x128_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 64) + return impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_128x128_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 32 && K <= 6144 && N <= 6144) + return impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_128x16_1x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 32 && K >= 16384 && N >= 12288) + return impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_256x64_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 32) + return impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_128x64_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 16 && K <= 12288 && N <= 8192) + return impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_128x32_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 16) + return impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_256x32_2x1x1_TmaMI__TmaCoop_streamK(args); + if (N >= 26624) + return impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_256x16_1x1x1_TmaMI__TmaCoop_streamK(args); + else + return impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_128x16_1x1x1_TmaMI__TmaCoop_streamK(args); + } + if (*args.maybe_schedule == "128x128_2x1x1_TmaMI__TmaCoop_streamK") + return impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_128x128_2x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "128x256_2x1x1_TmaMI__TmaCoop_streamK") + return impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_128x256_2x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "128x64_2x1x1_TmaMI__TmaCoop_streamK") + return impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_128x64_2x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "128x32_2x1x1_TmaMI__TmaCoop_streamK") + return impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_128x32_2x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "256x128_2x1x1_TmaMI__TmaCoop_streamK") + return impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_256x128_2x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "128x16_1x1x1_TmaMI__TmaCoop_streamK") + return impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_128x16_1x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "256x64_2x1x1_TmaMI__TmaCoop_streamK") + return impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_256x64_2x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "256x32_2x1x1_TmaMI__TmaCoop_streamK") + return impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_256x32_2x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "256x16_1x1x1_TmaMI__TmaCoop_streamK") + return impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_256x16_1x1x1_TmaMI__TmaCoop_streamK(args); + TORCH_CHECK_NOT_IMPLEMENTED(false, "machete_gemm(..) is not implemented for " + "schedule = ", *args.maybe_schedule); +} + +extern torch::Tensor impl_f16u8b128f16voidvoidvoidf16f32_sch_128x128_2x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_f16u8b128f16voidvoidvoidf16f32_sch_128x256_2x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_f16u8b128f16voidvoidvoidf16f32_sch_128x64_2x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_f16u8b128f16voidvoidvoidf16f32_sch_128x32_2x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_f16u8b128f16voidvoidvoidf16f32_sch_256x128_2x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_f16u8b128f16voidvoidvoidf16f32_sch_128x16_1x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_f16u8b128f16voidvoidvoidf16f32_sch_256x64_2x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_f16u8b128f16voidvoidvoidf16f32_sch_256x32_2x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_f16u8b128f16voidvoidvoidf16f32_sch_256x16_1x1x1_TmaMI__TmaCoop_streamK(MMArgs); + +torch::Tensor mm_dispatch_f16u8b128f16voidvoidvoidf16f32(MMArgs args) { + [[maybe_unused]] auto M = args.A.size(0); + [[maybe_unused]] auto N = args.B.size(1); + [[maybe_unused]] auto K = args.A.size(1); + + if (!args.maybe_schedule) { + if (M > 256 && K <= 16384 && N <= 4096) + return impl_f16u8b128f16voidvoidvoidf16f32_sch_128x128_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 256) + return impl_f16u8b128f16voidvoidvoidf16f32_sch_128x256_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 128 && K <= 4096 && N <= 4096) + return impl_f16u8b128f16voidvoidvoidf16f32_sch_128x64_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 128 && K <= 8192 && N <= 8192) + return impl_f16u8b128f16voidvoidvoidf16f32_sch_128x128_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 128) + return impl_f16u8b128f16voidvoidvoidf16f32_sch_128x256_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 64 && K <= 4069 && N <= 4069) + return impl_f16u8b128f16voidvoidvoidf16f32_sch_128x32_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 64 && K <= 4069 && N <= 8192) + return impl_f16u8b128f16voidvoidvoidf16f32_sch_128x64_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 64 && K >= 8192 && N >= 12288) + return impl_f16u8b128f16voidvoidvoidf16f32_sch_256x128_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 64) + return impl_f16u8b128f16voidvoidvoidf16f32_sch_128x128_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 32 && K <= 6144 && N <= 6144) + return impl_f16u8b128f16voidvoidvoidf16f32_sch_128x16_1x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 32 && K >= 16384 && N >= 12288) + return impl_f16u8b128f16voidvoidvoidf16f32_sch_256x64_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 32) + return impl_f16u8b128f16voidvoidvoidf16f32_sch_128x64_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 16 && K <= 12288 && N <= 8192) + return impl_f16u8b128f16voidvoidvoidf16f32_sch_128x32_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 16) + return impl_f16u8b128f16voidvoidvoidf16f32_sch_256x32_2x1x1_TmaMI__TmaCoop_streamK(args); + if (N >= 26624) + return impl_f16u8b128f16voidvoidvoidf16f32_sch_256x16_1x1x1_TmaMI__TmaCoop_streamK(args); + else + return impl_f16u8b128f16voidvoidvoidf16f32_sch_128x16_1x1x1_TmaMI__TmaCoop_streamK(args); + } + if (*args.maybe_schedule == "128x128_2x1x1_TmaMI__TmaCoop_streamK") + return impl_f16u8b128f16voidvoidvoidf16f32_sch_128x128_2x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "128x256_2x1x1_TmaMI__TmaCoop_streamK") + return impl_f16u8b128f16voidvoidvoidf16f32_sch_128x256_2x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "128x64_2x1x1_TmaMI__TmaCoop_streamK") + return impl_f16u8b128f16voidvoidvoidf16f32_sch_128x64_2x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "128x32_2x1x1_TmaMI__TmaCoop_streamK") + return impl_f16u8b128f16voidvoidvoidf16f32_sch_128x32_2x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "256x128_2x1x1_TmaMI__TmaCoop_streamK") + return impl_f16u8b128f16voidvoidvoidf16f32_sch_256x128_2x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "128x16_1x1x1_TmaMI__TmaCoop_streamK") + return impl_f16u8b128f16voidvoidvoidf16f32_sch_128x16_1x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "256x64_2x1x1_TmaMI__TmaCoop_streamK") + return impl_f16u8b128f16voidvoidvoidf16f32_sch_256x64_2x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "256x32_2x1x1_TmaMI__TmaCoop_streamK") + return impl_f16u8b128f16voidvoidvoidf16f32_sch_256x32_2x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "256x16_1x1x1_TmaMI__TmaCoop_streamK") + return impl_f16u8b128f16voidvoidvoidf16f32_sch_256x16_1x1x1_TmaMI__TmaCoop_streamK(args); + TORCH_CHECK_NOT_IMPLEMENTED(false, "machete_gemm(..) is not implemented for " + "schedule = ", *args.maybe_schedule); +} + +extern torch::Tensor impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_128x128_2x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_128x256_2x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_128x64_2x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_128x32_2x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_256x128_2x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_128x16_1x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_256x64_2x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_256x32_2x1x1_TmaMI__TmaCoop_streamK(MMArgs); +extern torch::Tensor impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_256x16_1x1x1_TmaMI__TmaCoop_streamK(MMArgs); + +torch::Tensor mm_dispatch_bf16u8b128bf16voidvoidvoidbf16f32(MMArgs args) { + [[maybe_unused]] auto M = args.A.size(0); + [[maybe_unused]] auto N = args.B.size(1); + [[maybe_unused]] auto K = args.A.size(1); + + if (!args.maybe_schedule) { + if (M > 256 && K <= 16384 && N <= 4096) + return impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_128x128_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 256) + return impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_128x256_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 128 && K <= 4096 && N <= 4096) + return impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_128x64_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 128 && K <= 8192 && N <= 8192) + return impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_128x128_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 128) + return impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_128x256_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 64 && K <= 4069 && N <= 4069) + return impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_128x32_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 64 && K <= 4069 && N <= 8192) + return impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_128x64_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 64 && K >= 8192 && N >= 12288) + return impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_256x128_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 64) + return impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_128x128_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 32 && K <= 6144 && N <= 6144) + return impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_128x16_1x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 32 && K >= 16384 && N >= 12288) + return impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_256x64_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 32) + return impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_128x64_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 16 && K <= 12288 && N <= 8192) + return impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_128x32_2x1x1_TmaMI__TmaCoop_streamK(args); + if (M > 16) + return impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_256x32_2x1x1_TmaMI__TmaCoop_streamK(args); + if (N >= 26624) + return impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_256x16_1x1x1_TmaMI__TmaCoop_streamK(args); + else + return impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_128x16_1x1x1_TmaMI__TmaCoop_streamK(args); + } + if (*args.maybe_schedule == "128x128_2x1x1_TmaMI__TmaCoop_streamK") + return impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_128x128_2x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "128x256_2x1x1_TmaMI__TmaCoop_streamK") + return impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_128x256_2x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "128x64_2x1x1_TmaMI__TmaCoop_streamK") + return impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_128x64_2x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "128x32_2x1x1_TmaMI__TmaCoop_streamK") + return impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_128x32_2x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "256x128_2x1x1_TmaMI__TmaCoop_streamK") + return impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_256x128_2x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "128x16_1x1x1_TmaMI__TmaCoop_streamK") + return impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_128x16_1x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "256x64_2x1x1_TmaMI__TmaCoop_streamK") + return impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_256x64_2x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "256x32_2x1x1_TmaMI__TmaCoop_streamK") + return impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_256x32_2x1x1_TmaMI__TmaCoop_streamK(args); + if (*args.maybe_schedule == "256x16_1x1x1_TmaMI__TmaCoop_streamK") + return impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_256x16_1x1x1_TmaMI__TmaCoop_streamK(args); + TORCH_CHECK_NOT_IMPLEMENTED(false, "machete_gemm(..) is not implemented for " + "schedule = ", *args.maybe_schedule); +} + + +static inline std::optional maybe_scalartype( + std::optional const& t) { + if (!t) { + return std::nullopt; + } else { + return t->scalar_type(); + }; +} + +torch::Tensor mm_dispatch(MMArgs args) { + auto out_type = args.maybe_out_type.value_or(args.A.scalar_type()); + auto a_type = args.A.scalar_type(); + auto maybe_g_scales_type = maybe_scalartype(args.maybe_group_scales); + auto maybe_g_zeros_type = maybe_scalartype(args.maybe_group_zeros); + auto maybe_ch_scales_type = maybe_scalartype(args.maybe_channel_scales); + auto maybe_tok_scales_type = maybe_scalartype(args.maybe_token_scales); + + + if (args.b_type == vllm::kU4B8 + && a_type == at::ScalarType::Half + && out_type == at::ScalarType::Half + && maybe_g_scales_type == at::ScalarType::Half + && !maybe_g_zeros_type + && !maybe_ch_scales_type + && !maybe_tok_scales_type + ) { + return mm_dispatch_f16u4b8f16voidvoidvoidf16f32(args); + } + if (args.b_type == vllm::kU4B8 + && a_type == at::ScalarType::BFloat16 + && out_type == at::ScalarType::BFloat16 + && maybe_g_scales_type == at::ScalarType::BFloat16 + && !maybe_g_zeros_type + && !maybe_ch_scales_type + && !maybe_tok_scales_type + ) { + return mm_dispatch_bf16u4b8bf16voidvoidvoidbf16f32(args); + } + if (args.b_type == vllm::kU8B128 + && a_type == at::ScalarType::Half + && out_type == at::ScalarType::Half + && maybe_g_scales_type == at::ScalarType::Half + && !maybe_g_zeros_type + && !maybe_ch_scales_type + && !maybe_tok_scales_type + ) { + return mm_dispatch_f16u8b128f16voidvoidvoidf16f32(args); + } + if (args.b_type == vllm::kU8B128 + && a_type == at::ScalarType::BFloat16 + && out_type == at::ScalarType::BFloat16 + && maybe_g_scales_type == at::ScalarType::BFloat16 + && !maybe_g_zeros_type + && !maybe_ch_scales_type + && !maybe_tok_scales_type + ) { + return mm_dispatch_bf16u8b128bf16voidvoidvoidbf16f32(args); + } + + TORCH_CHECK_NOT_IMPLEMENTED( + false, "machete_mm(..) is not implemented for " + "a_type=", args.A.scalar_type(), + ", b_type=", args.b_type.str(), + ", out_type=", out_type, + ", with_group_scale_type=", maybe_g_scales_type + ? toString(*maybe_g_scales_type) : "None", + ", with_group_zeropoint_type=", maybe_g_zeros_type + ? toString(*maybe_g_zeros_type) : "None", + ", with_channel_scale_type=", maybe_ch_scales_type + ? toString(*maybe_ch_scales_type) : "None", + ", with_token_scale_type=", maybe_tok_scales_type + ? toString(*maybe_tok_scales_type) : "None", + "; implemented types are: \n", + "\ta_type=f16, b_type=u4b8, with_group_scale_type=f16, with_group_zeropoint_type=void, with_channel_scale_type=void, a_token_scale_type=void, out_type=f16, accumulator_type=f32\n", + "\ta_type=bf16, b_type=u4b8, with_group_scale_type=bf16, with_group_zeropoint_type=void, with_channel_scale_type=void, a_token_scale_type=void, out_type=bf16, accumulator_type=f32\n", + "\ta_type=f16, b_type=u8b128, with_group_scale_type=f16, with_group_zeropoint_type=void, with_channel_scale_type=void, a_token_scale_type=void, out_type=f16, accumulator_type=f32\n", + "\ta_type=bf16, b_type=u8b128, with_group_scale_type=bf16, with_group_zeropoint_type=void, with_channel_scale_type=void, a_token_scale_type=void, out_type=bf16, accumulator_type=f32\n", + ""); +} + +std::vector supported_schedules_dispatch( + SupportedSchedulesArgs args) { + auto out_type = args.maybe_out_type.value_or(args.a_type); + + + if (args.b_type == vllm::kU4B8 + && args.a_type == at::ScalarType::Half + && out_type == at::ScalarType::Half + && args.maybe_group_scales_type == at::ScalarType::Half + && !args.maybe_group_zeros_type + ) { + return { + "128x128_2x1x1_TmaMI__TmaCoop_streamK", + "128x256_2x1x1_TmaMI__TmaCoop_streamK", + "128x64_2x1x1_TmaMI__TmaCoop_streamK", + "128x32_2x1x1_TmaMI__TmaCoop_streamK", + "256x128_2x1x1_TmaMI__TmaCoop_streamK", + "128x16_1x1x1_TmaMI__TmaCoop_streamK", + "256x64_2x1x1_TmaMI__TmaCoop_streamK", + "256x32_2x1x1_TmaMI__TmaCoop_streamK", + "256x16_1x1x1_TmaMI__TmaCoop_streamK" + }; + } + if (args.b_type == vllm::kU4B8 + && args.a_type == at::ScalarType::BFloat16 + && out_type == at::ScalarType::BFloat16 + && args.maybe_group_scales_type == at::ScalarType::BFloat16 + && !args.maybe_group_zeros_type + ) { + return { + "128x128_2x1x1_TmaMI__TmaCoop_streamK", + "128x256_2x1x1_TmaMI__TmaCoop_streamK", + "128x64_2x1x1_TmaMI__TmaCoop_streamK", + "128x32_2x1x1_TmaMI__TmaCoop_streamK", + "256x128_2x1x1_TmaMI__TmaCoop_streamK", + "128x16_1x1x1_TmaMI__TmaCoop_streamK", + "256x64_2x1x1_TmaMI__TmaCoop_streamK", + "256x32_2x1x1_TmaMI__TmaCoop_streamK", + "256x16_1x1x1_TmaMI__TmaCoop_streamK" + }; + } + if (args.b_type == vllm::kU8B128 + && args.a_type == at::ScalarType::Half + && out_type == at::ScalarType::Half + && args.maybe_group_scales_type == at::ScalarType::Half + && !args.maybe_group_zeros_type + ) { + return { + "128x128_2x1x1_TmaMI__TmaCoop_streamK", + "128x256_2x1x1_TmaMI__TmaCoop_streamK", + "128x64_2x1x1_TmaMI__TmaCoop_streamK", + "128x32_2x1x1_TmaMI__TmaCoop_streamK", + "256x128_2x1x1_TmaMI__TmaCoop_streamK", + "128x16_1x1x1_TmaMI__TmaCoop_streamK", + "256x64_2x1x1_TmaMI__TmaCoop_streamK", + "256x32_2x1x1_TmaMI__TmaCoop_streamK", + "256x16_1x1x1_TmaMI__TmaCoop_streamK" + }; + } + if (args.b_type == vllm::kU8B128 + && args.a_type == at::ScalarType::BFloat16 + && out_type == at::ScalarType::BFloat16 + && args.maybe_group_scales_type == at::ScalarType::BFloat16 + && !args.maybe_group_zeros_type + ) { + return { + "128x128_2x1x1_TmaMI__TmaCoop_streamK", + "128x256_2x1x1_TmaMI__TmaCoop_streamK", + "128x64_2x1x1_TmaMI__TmaCoop_streamK", + "128x32_2x1x1_TmaMI__TmaCoop_streamK", + "256x128_2x1x1_TmaMI__TmaCoop_streamK", + "128x16_1x1x1_TmaMI__TmaCoop_streamK", + "256x64_2x1x1_TmaMI__TmaCoop_streamK", + "256x32_2x1x1_TmaMI__TmaCoop_streamK", + "256x16_1x1x1_TmaMI__TmaCoop_streamK" + }; + } + + return {}; +}; + +}; // namespace machete \ No newline at end of file diff --git a/gptqmodel_ext/machete/machete_mm_impl.cu b/gptqmodel_ext/machete/machete_mm_impl.cu new file mode 100644 index 000000000..dd48289ce --- /dev/null +++ b/gptqmodel_ext/machete/machete_mm_impl.cu @@ -0,0 +1,300 @@ + +#include "machete_mm_launcher.cuh" + +namespace machete { + + +struct sch_256x16_1x1x1_TmaMI__TmaCoop_streamK { + using TileShapeNM = Shape<_256, _16>; + using ClusterShape = Shape<_1, _1, _1>; + // TODO: Reimplement + // using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using TileScheduler = cutlass::gemm::StreamKScheduler; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; +}; + +struct sch_128x128_2x1x1_TmaMI__TmaCoop_streamK { + using TileShapeNM = Shape<_128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + // TODO: Reimplement + // using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using TileScheduler = cutlass::gemm::StreamKScheduler; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; +}; + +struct sch_256x128_2x1x1_TmaMI__TmaCoop_streamK { + using TileShapeNM = Shape<_256, _128>; + using ClusterShape = Shape<_2, _1, _1>; + // TODO: Reimplement + // using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using TileScheduler = cutlass::gemm::StreamKScheduler; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; +}; + +struct sch_128x32_2x1x1_TmaMI__TmaCoop_streamK { + using TileShapeNM = Shape<_128, _32>; + using ClusterShape = Shape<_2, _1, _1>; + // TODO: Reimplement + // using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using TileScheduler = cutlass::gemm::StreamKScheduler; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; +}; + +struct sch_256x32_2x1x1_TmaMI__TmaCoop_streamK { + using TileShapeNM = Shape<_256, _32>; + using ClusterShape = Shape<_2, _1, _1>; + // TODO: Reimplement + // using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using TileScheduler = cutlass::gemm::StreamKScheduler; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; +}; + +struct sch_128x64_2x1x1_TmaMI__TmaCoop_streamK { + using TileShapeNM = Shape<_128, _64>; + using ClusterShape = Shape<_2, _1, _1>; + // TODO: Reimplement + // using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using TileScheduler = cutlass::gemm::StreamKScheduler; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; +}; + +struct sch_128x256_2x1x1_TmaMI__TmaCoop_streamK { + using TileShapeNM = Shape<_128, _256>; + using ClusterShape = Shape<_2, _1, _1>; + // TODO: Reimplement + // using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using TileScheduler = cutlass::gemm::StreamKScheduler; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; +}; + +struct sch_256x64_2x1x1_TmaMI__TmaCoop_streamK { + using TileShapeNM = Shape<_256, _64>; + using ClusterShape = Shape<_2, _1, _1>; + // TODO: Reimplement + // using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using TileScheduler = cutlass::gemm::StreamKScheduler; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; +}; + +struct sch_128x16_1x1x1_TmaMI__TmaCoop_streamK { + using TileShapeNM = Shape<_128, _16>; + using ClusterShape = Shape<_1, _1, _1>; + // TODO: Reimplement + // using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using TileScheduler = cutlass::gemm::StreamKScheduler; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; +}; + + + +template +using Kernel_f16u4b8f16voidvoidvoidf16f32 = MacheteKernelTemplate< + cutlass::half_t, // ElementA + cutlass::vllm_uint4b8_t, // ElementB + cutlass::half_t, // ElementD + float, // Accumulator + cutlass::half_t, // GroupScaleT + void, // GroupZeroT + void, // ChannelScaleT + void, // TokenScaleT + cutlass::gemm::KernelTmaWarpSpecializedCooperative, + Sch>; + + +torch::Tensor +impl_f16u4b8f16voidvoidvoidf16f32_sch_128x128_2x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_f16u4b8f16voidvoidvoidf16f32_sch_128x256_2x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_f16u4b8f16voidvoidvoidf16f32_sch_128x64_2x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_f16u4b8f16voidvoidvoidf16f32_sch_128x32_2x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_f16u4b8f16voidvoidvoidf16f32_sch_256x128_2x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_f16u4b8f16voidvoidvoidf16f32_sch_128x16_1x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_f16u4b8f16voidvoidvoidf16f32_sch_256x64_2x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_f16u4b8f16voidvoidvoidf16f32_sch_256x32_2x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_f16u4b8f16voidvoidvoidf16f32_sch_256x16_1x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +template +using Kernel_bf16u4b8bf16voidvoidvoidbf16f32 = MacheteKernelTemplate< + cutlass::bfloat16_t, // ElementA + cutlass::vllm_uint4b8_t, // ElementB + cutlass::bfloat16_t, // ElementD + float, // Accumulator + cutlass::bfloat16_t, // GroupScaleT + void, // GroupZeroT + void, // ChannelScaleT + void, // TokenScaleT + cutlass::gemm::KernelTmaWarpSpecializedCooperative, + Sch>; + + +torch::Tensor +impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_128x128_2x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_128x256_2x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_128x64_2x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_128x32_2x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_256x128_2x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_128x16_1x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_256x64_2x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_256x32_2x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_bf16u4b8bf16voidvoidvoidbf16f32_sch_256x16_1x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +template +using Kernel_f16u8b128f16voidvoidvoidf16f32 = MacheteKernelTemplate< + cutlass::half_t, // ElementA + cutlass::vllm_uint8b128_t, // ElementB + cutlass::half_t, // ElementD + float, // Accumulator + cutlass::half_t, // GroupScaleT + void, // GroupZeroT + void, // ChannelScaleT + void, // TokenScaleT + cutlass::gemm::KernelTmaWarpSpecializedCooperative, + Sch>; + + +torch::Tensor +impl_f16u8b128f16voidvoidvoidf16f32_sch_128x128_2x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_f16u8b128f16voidvoidvoidf16f32_sch_128x256_2x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_f16u8b128f16voidvoidvoidf16f32_sch_128x64_2x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_f16u8b128f16voidvoidvoidf16f32_sch_128x32_2x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_f16u8b128f16voidvoidvoidf16f32_sch_256x128_2x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_f16u8b128f16voidvoidvoidf16f32_sch_128x16_1x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_f16u8b128f16voidvoidvoidf16f32_sch_256x64_2x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_f16u8b128f16voidvoidvoidf16f32_sch_256x32_2x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_f16u8b128f16voidvoidvoidf16f32_sch_256x16_1x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +template +using Kernel_bf16u8b128bf16voidvoidvoidbf16f32 = MacheteKernelTemplate< + cutlass::bfloat16_t, // ElementA + cutlass::vllm_uint8b128_t, // ElementB + cutlass::bfloat16_t, // ElementD + float, // Accumulator + cutlass::bfloat16_t, // GroupScaleT + void, // GroupZeroT + void, // ChannelScaleT + void, // TokenScaleT + cutlass::gemm::KernelTmaWarpSpecializedCooperative, + Sch>; + + +torch::Tensor +impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_128x128_2x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_128x256_2x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_128x64_2x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_128x32_2x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_256x128_2x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_128x16_1x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_256x64_2x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_256x32_2x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} +torch::Tensor +impl_bf16u8b128bf16voidvoidvoidbf16f32_sch_256x16_1x1x1_TmaMI__TmaCoop_streamK(MMArgs args) { + return run_impl>(args); +} + +}; // namespace machete \ No newline at end of file diff --git a/gptqmodel_ext/machete/machete_mm_kernel.cuh b/gptqmodel_ext/machete/machete_mm_kernel.cuh new file mode 100644 index 000000000..69ec46e81 --- /dev/null +++ b/gptqmodel_ext/machete/machete_mm_kernel.cuh @@ -0,0 +1,325 @@ +#pragma once + +#include +#include +#include + +// clang-format off +// The cutlass include order matters (annoyingly) +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +// clang-format on + +#include "cutlass_extensions/cute_utils.cuh" +#include "cutlass_extensions/vllm_numeric_conversion.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" +#include "cutlass_extensions/torch_utils.hpp" +#include "machete_collective_builder.cuh" +#include "machete_prepacked_layout.cuh" +#include "machete_interleaving_utils.cuh" + +namespace machete { + +using namespace cute; + +// NOTE This kernel computes D = alpha * A * B + beta * C by computing +// D^t = alpha * B^t * A^t + beta * C^t, this is because the wgmma +// instructions only support sourcing from registers for the left-hand +// operand, we want to upconvert/decompress the quantized operand in +// register. Since the primary use case we want to support is Y = XW^t where +// W is quantized, in this situation or right-hand operand is quantized so +// we compute the transpose to move it to the left-hand side. +template +struct MacheteKernelTemplate { + static constexpr bool with_C = false; // not ever used + static constexpr bool with_group_scales = !std::is_same_v; + static constexpr bool with_group_zeropoints = + !std::is_same_v; + static constexpr bool with_channel_scales = + !std::is_same_v; + static constexpr bool with_token_scales = !std::is_same_v; + + using MmaType = ElementA_; + using ElementA = ElementA_; + using ElementB = ElementB_; + using ElementD = ElementD_; + using ElementC = cute::conditional_t; + using ElementAccumulator = AccumulatorT; + using ElementCompute = AccumulatorT; // For Epilogue + // Use dummy values when we don't have scales or zeropoints + using ElementZGroup = + cute::conditional_t; + using ElementSGroup = + cute::conditional_t; + using ElementConvertGroup = + cute::conditional_t; + using ElementSChannel = + cute::conditional_t; + using ElementSToken = + cute::conditional_t; + + using BTypeTuple = cute::conditional_t< + with_group_scales, + cute::conditional_t, + cute::tuple>, + ElementB>; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = LayoutC; + using LayoutScale = cutlass::layout::RowMajor; + // not actually used since B has the prepacked layout, but required by cutlass + using _LayoutB = cutlass::layout::ColumnMajor; + + // Interface strides expected by create_arguments (will get transposed) + using StrideA = cutlass::detail::TagToStrideA_t; + using StrideC = cutlass::detail::TagToStrideA_t; + using StrideD = cutlass::detail::TagToStrideA_t; + using StrideSGroup = cutlass::detail::TagToStrideA_t; + using StrideZGroup = StrideSGroup; + + using LayoutA_Transpose = + typename cutlass::layout::LayoutTranspose::type; + using LayoutC_Transpose = + typename cutlass::layout::LayoutTranspose::type; + using LayoutD_Transpose = + typename cutlass::layout::LayoutTranspose::type; + + using ArchTag = cutlass::arch::Sm90; + using OperatorClass = cutlass::arch::OpClassTensorOp; + + using PrepackedLayoutB = + PrepackedLayoutBTemplate; + + static int constexpr TileShapeK = + 128 * 8 / cutlass::sizeof_bits::value; + static int constexpr AlignmentA = 128 / cutlass::sizeof_bits_v; + static int constexpr AlignmentB = 128 / cutlass::sizeof_bits_v; + static int constexpr AlignmentC = + (with_C) ? 128 / cutlass::sizeof_bits_v : 0; + static int constexpr AlignmentD = 128 / cutlass::sizeof_bits_v; + + using TileShape = decltype(append(typename ScheduleConfig::TileShapeNM{}, + cute::Int{})); + using ClusterShape = typename ScheduleConfig::ClusterShape; + using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; + using EpilogueTileType = typename ScheduleConfig::EpilogueTileType; + using TileScheduler = typename ScheduleConfig::TileScheduler; + + static_assert( + (!with_channel_scales && !with_token_scales) || + ((with_channel_scales && with_token_scales) && + std::is_same_v), + "Currently token and channel scales (if present) must be the same type"); + + // Currently only supports float scales + using ChTokScalesEpilogue = + typename vllm::c3x::ScaledEpilogue; + static_assert((with_channel_scales || with_token_scales) || + (std::is_same_v && + std::is_same_v), + "Currently token and channel scales (if present) must be float " + "(and if one is present the other must be too)"); + + using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion::Sm90AccFetch>; + + using EVTCompute = + std::conditional_t; + + // EVTCompute + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, + ElementAccumulator, ElementSChannel, ElementC, LayoutC_Transpose, + AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, EpilogueSchedule, + EVTCompute>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::VLLMCollectiveBuilder< + cutlass::gemm::collective::MacheteKernelTag, ArchTag, OperatorClass, + BTypeTuple, PrepackedLayoutB, AlignmentB, ElementA, LayoutA_Transpose, + AlignmentA, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, CollectiveEpilogue, TileScheduler>; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // stride_B is unused (since B is prepacked), but still required by cutlass + using _StrideB = cutlass::detail::TagToStrideB_t<_LayoutB>; + + using Arguments = typename Gemm::Arguments; + using MainloopArguments = typename GemmKernel::MainloopArguments; + using EpilogueArguments = typename GemmKernel::EpilogueArguments; + + static Arguments create_arguments( + cudaStream_t stream, + torch::Tensor const& A, // MxK matrix + torch::Tensor const& B, // KxN prepacked matrix + torch::Tensor& D, // MxN matrix + std::optional const& maybe_g_scales, // scale_KxN matrix + std::optional const& maybe_g_zeros, // scale_KxN matrix + std::optional maybe_group_size, + std::optional const& maybe_ch_scales, // len N vector + std::optional const& maybe_tok_scales) // len M vector + { + static_assert(!with_group_zeropoints || with_group_scales); + + int M = A.size(0), N = B.size(1), K = A.size(1); + TORCH_CHECK(D.size(0) == M && D.size(1) == N); + + auto layout_A = make_cute_layout(A, "A"); + auto layout_D = make_cute_layout(D, "D"); + auto layout_S_group = + maybe_make_cute_layout(maybe_g_scales, "group_scales"); + auto layout_Z_group = + maybe_make_cute_layout(maybe_g_zeros, "group_zeros"); + int64_t numel_S_channel = maybe_ch_scales ? maybe_ch_scales->numel() : 0; + int64_t numel_S_token = maybe_tok_scales ? maybe_tok_scales->numel() : 0; + + auto unwrap = [](auto const& t) { + return t ? t->const_data_ptr() : nullptr; + }; + auto A_ptr = static_cast(A.const_data_ptr()); + auto B_ptr = static_cast(B.const_data_ptr()); + auto D_ptr = static_cast(D.mutable_data_ptr()); + auto S_group_ptr = + static_cast(unwrap(maybe_g_scales)); + auto Z_group_ptr = static_cast(unwrap(maybe_g_zeros)); + auto S_channel_ptr = + static_cast(unwrap(maybe_ch_scales)); + auto S_token_ptr = + static_cast(unwrap(maybe_tok_scales)); + + int const group_size = + maybe_group_size == -1 ? K : maybe_group_size.value_or(K); + int const scale_k = (K + group_size - 1) / group_size; + + TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K); + TORCH_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N); + + if constexpr (with_group_scales) { + TORCH_CHECK(S_group_ptr && layout_S_group); + TORCH_CHECK((size<0>(*layout_S_group) == scale_k && + size<1>(*layout_S_group) == N)); + } else { + TORCH_CHECK(!S_group_ptr, "Scales not supported"); + } + + if constexpr (with_group_zeropoints) { + TORCH_CHECK(Z_group_ptr && layout_Z_group); + TORCH_CHECK((size<0>(*layout_Z_group) == scale_k && + size<1>(*layout_Z_group) == N)); + TORCH_CHECK(layout_S_group && *layout_Z_group == *layout_S_group, + "Scales and zeros must have the same layout"); + } else { + TORCH_CHECK(!Z_group_ptr, "Zeropoints not supported"); + } + + if constexpr (with_channel_scales || with_token_scales) { + TORCH_CHECK( + (maybe_ch_scales->numel() == N || maybe_ch_scales->numel() == 1) && + (maybe_tok_scales->numel() == M || maybe_tok_scales->numel() == 1)); + } + + // Transpose A and D + // A doesn't need to be transposed since cutlass expects a NxK matrix + // for B (which is At) + auto stride_At = layout_A.stride(); + auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride(); + + MainloopArguments mainloop_arguments{}; + // {Accum, C, C_layout, D, D} + EpilogueArguments epilogue_arguments{}; + + if constexpr (with_channel_scales || with_token_scales) { + epilogue_arguments = + EpilogueArguments{ChTokScalesEpilogue::prepare_args( + *maybe_ch_scales, *maybe_tok_scales), + nullptr, + {}, + D_ptr, + stride_Dt}; + } else { + epilogue_arguments = EpilogueArguments{{}, nullptr, {}, D_ptr, stride_Dt}; + } + + if constexpr (with_group_scales && with_group_zeropoints) { + auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride(); + mainloop_arguments = MainloopArguments{ + B_ptr, _StrideB{}, A_ptr, stride_At, + S_group_ptr, stride_S_group, group_size, Z_group_ptr}; + } else if constexpr (with_group_scales) { + auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride(); + mainloop_arguments = + MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At, + S_group_ptr, stride_S_group, group_size}; + } else { + mainloop_arguments = + MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At}; + } + + return Arguments{cutlass::gemm::GemmUniversalMode::kGemm, + {N, M, K, 1}, + mainloop_arguments, + epilogue_arguments}; + }; + + static size_t get_workspace_size(Arguments const& args) { + return Gemm::get_workspace_size(args); + } + + static bool can_implement(Arguments const& args) { + return Gemm::can_implement(args) == cutlass::Status::kSuccess; + } + + static const char* cutlassStatusToString(cutlass::Status status) { + switch (status) { + case cutlass::Status::kSuccess: return "Success"; + case cutlass::Status::kErrorMisalignedOperand: return "ErrorMisalignedOperand"; + case cutlass::Status::kErrorInvalidLayout: return "ErrorInvalidLayout"; + case cutlass::Status::kErrorInvalidProblem: return "ErrorInvalidProblem"; + case cutlass::Status::kErrorNotSupported: return "ErrorNotSupported"; + case cutlass::Status::kErrorWorkspaceNull: return "ErrorWorkspaceNull"; + case cutlass::Status::kErrorInternal: return "ErrorInternal"; + case cutlass::Status::kInvalid: return "Invalid"; + default: return "Unknown CUTLASS Status"; + } + } + + static void run(Arguments const& args, void* workspace, cudaStream_t stream) { + Gemm gemm_op; + + cutlass::Status status = gemm_op.initialize(args, workspace, stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, + "Machete kernel failed to initialize workspace. Status: ", + cutlassStatusToString(status)); + + status = gemm_op.run(stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Machete kernel failed Status: ", + cutlassStatusToString(status)); + } +}; + +}; // namespace machete diff --git a/gptqmodel_ext/machete/machete_mm_launcher.cuh b/gptqmodel_ext/machete/machete_mm_launcher.cuh new file mode 100644 index 000000000..d920067a3 --- /dev/null +++ b/gptqmodel_ext/machete/machete_mm_launcher.cuh @@ -0,0 +1,78 @@ +#pragma once + +#include +#include + +#include "machete_mm_kernel.cuh" +#include "cutlass_extensions/torch_utils.hpp" +#include "core/scalar_type.hpp" + +namespace machete { + +struct MMArgs { + torch::Tensor const& A; + torch::Tensor const& B; + vllm::ScalarType const& b_type; + std::optional const& maybe_out_type; + std::optional const& maybe_group_scales; + std::optional const& maybe_group_zeros; + std::optional maybe_group_size; + std::optional const& maybe_channel_scales; + std::optional const& maybe_token_scales; + std::optional maybe_schedule; +}; + +struct SupportedSchedulesArgs { + at::ScalarType a_type; + vllm::ScalarType b_type; + std::optional maybe_group_scales_type; + std::optional maybe_group_zeros_type; + std::optional maybe_channel_scales_type; + std::optional maybe_token_scales_type; + std::optional maybe_out_type; +}; + +torch::Tensor mm_dispatch(MMArgs args); + +std::vector supported_schedules_dispatch( + SupportedSchedulesArgs args); + +template +torch::Tensor run_impl(MMArgs args) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A)); + + auto device = args.A.device(); + auto stream = at::cuda::getCurrentCUDAStream(device.index()); + + int M = args.A.size(0); + int N = args.B.size(1); + int K = args.A.size(1); + + std::cout << "M = " << M << ", N = " << N << ", K = " << K << std::endl; + + // Allocate output + torch::Tensor D = torch::empty( + {M, N}, + torch::TensorOptions() + .dtype(equivalent_scalar_type_v) + .device(device)); + + auto arguments = MacheteKernel::create_arguments( + stream, // + args.A, args.B, D, args.maybe_group_scales, args.maybe_group_zeros, + args.maybe_group_size, args.maybe_channel_scales, + args.maybe_token_scales); + TORCH_CHECK(MacheteKernel::can_implement(arguments), + "Machete kernel cannot be run with these arguments"); + + size_t workspace_size = MacheteKernel::get_workspace_size(arguments); + std::cout << "Workspace size: " << workspace_size << " bytes" << std::endl; + torch::Tensor workspace = torch::empty( + workspace_size, torch::TensorOptions().dtype(torch::kU8).device(device)); + + MacheteKernel::run(arguments, workspace.mutable_data_ptr(), stream); + + return D; +}; + +}; // namespace machete \ No newline at end of file diff --git a/gptqmodel_ext/machete/machete_prepack.cu b/gptqmodel_ext/machete/machete_prepack.cu new file mode 100644 index 000000000..f1fa26b58 --- /dev/null +++ b/gptqmodel_ext/machete/machete_prepack.cu @@ -0,0 +1,73 @@ + +#include "machete_prepack_launcher.cuh" + +namespace machete { + +torch::Tensor prepack_B_dispatch(PrepackBArgs args) { + auto convert_type = args.maybe_group_scales_type.value_or(args.a_type); + + if (args.a_type == at::ScalarType::Half + && args.b_type.size_bits() == 4 + && convert_type == at::ScalarType::Half) { + return prepack_impl< + PrepackedLayoutBTemplate< + cutlass::half_t, // ElementA + cutlass::uint4b_t, // ElementB + cutlass::half_t, // ElementConvert + float, // Accumulator + cutlass::layout::ColumnMajor, + cutlass::gemm::KernelTmaWarpSpecializedCooperative> + >(args.B); + } + + if (args.a_type == at::ScalarType::BFloat16 + && args.b_type.size_bits() == 4 + && convert_type == at::ScalarType::BFloat16) { + return prepack_impl< + PrepackedLayoutBTemplate< + cutlass::bfloat16_t, // ElementA + cutlass::uint4b_t, // ElementB + cutlass::bfloat16_t, // ElementConvert + float, // Accumulator + cutlass::layout::ColumnMajor, + cutlass::gemm::KernelTmaWarpSpecializedCooperative> + >(args.B); + } + + if (args.a_type == at::ScalarType::Half + && args.b_type.size_bits() == 8 + && convert_type == at::ScalarType::Half) { + return prepack_impl< + PrepackedLayoutBTemplate< + cutlass::half_t, // ElementA + uint8_t, // ElementB + cutlass::half_t, // ElementConvert + float, // Accumulator + cutlass::layout::ColumnMajor, + cutlass::gemm::KernelTmaWarpSpecializedCooperative> + >(args.B); + } + + if (args.a_type == at::ScalarType::BFloat16 + && args.b_type.size_bits() == 8 + && convert_type == at::ScalarType::BFloat16) { + return prepack_impl< + PrepackedLayoutBTemplate< + cutlass::bfloat16_t, // ElementA + uint8_t, // ElementB + cutlass::bfloat16_t, // ElementConvert + float, // Accumulator + cutlass::layout::ColumnMajor, + cutlass::gemm::KernelTmaWarpSpecializedCooperative> + >(args.B); + } + + TORCH_CHECK_NOT_IMPLEMENTED(false, + "prepack_B_dispatch(..) is not implemented for " + "atype = ", args.a_type, + ", b_type = ", args.b_type.str(), + ", with_group_scales_type= ", args.maybe_group_scales_type ? + toString(*args.maybe_group_scales_type) : "None"); +} + +}; // namespace machete \ No newline at end of file diff --git a/gptqmodel_ext/machete/machete_prepack_kernel.cuh b/gptqmodel_ext/machete/machete_prepack_kernel.cuh new file mode 100644 index 000000000..d002355ca --- /dev/null +++ b/gptqmodel_ext/machete/machete_prepack_kernel.cuh @@ -0,0 +1,76 @@ +#pragma once + +#include "machete_mm_kernel.cuh" +#include "cutlass_extensions/cute_utils.cuh" +#include "cutlass_extensions/torch_utils.hpp" + +namespace machete { + +template +static __global__ void prepack_B_kernel(BInTensor B_in, ElementB* B_out_ptr) { + auto constexpr block_size = + Int{}; + auto constexpr eles_per_thread = Int{}; + static_assert(block_size % threads == 0, + "block_size must be divisible by the number of threads"); + + // Which pre-packed are we responsible for + auto blk_coord = make_coord(blockIdx.x, blockIdx.y, blockIdx.z); + auto tB_in = local_tile( + B_in, append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}), + blk_coord); + + // Find the start offset in the output for this pre-packed block + auto bNbKL_to_offset = PrepackedLayoutB::bNbKL_to_offset(shape(B_in)); + + // Tensor representing a 1:1 mapping to the output space in 1D + auto tB_out_linear = + make_tensor(get_logical_ptr(B_out_ptr) + bNbKL_to_offset(blk_coord), + make_layout(make_shape(block_size))); + // Mapping from output space (1D) to input space + auto tB_in_linear = make_tensor( + tB_in.data(), + tB_in.layout() + .compose(right_inverse(PrepackedLayoutB::ppblock_ilvd_NK_to_offset())) + .with_shape(make_shape(block_size))); + + // Tile for this specific thread (could have used a TiledCopy but these work + // best with 2d layouts, this is a simple 1d layout so local_tile is enough, + // we are also not that concerned with performance for this kernel) + auto thr_tB_in_linear = + local_tile(tB_in_linear, make_shape(eles_per_thread), threadIdx.x); + auto thr_tB_out_linear = + local_tile(tB_out_linear, make_shape(eles_per_thread), threadIdx.x); + + // Construct a register-backed Tensor with the same shape as each thread's + // partition + auto fragment = make_tensor(shape(thr_tB_in_linear)); + + copy(thr_tB_in_linear, fragment); + copy(Copy_Atom{}, fragment, thr_tB_out_linear); +} + +template +static void prepack_B_template( + cudaStream_t stream, typename PrepackedLayoutB::ElementB const* B_in_ptr, + InLayout B_layout, typename PrepackedLayoutB::ElementB* B_out_ptr) { + using TileShapeNKL = + decltype(append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{})); + auto ilvd_NKbNbKL_to_offset = + PrepackedLayoutB::ilvd_NKbNbKL_to_offset(shape(B_layout)); + + TORCH_CHECK(size<0>(B_layout) % size<0>(TileShapeNKL{}) == 0); + TORCH_CHECK(size<1>(B_layout) % size<1>(TileShapeNKL{}) == 0); + + auto N_tiles = size<0>(B_layout) / size<0>(TileShapeNKL{}); + auto K_tiles = size<1>(B_layout) / size<1>(TileShapeNKL{}); + auto L_tiles = size<2>(B_layout); + + auto B_in = make_tensor(get_logical_ptr(B_in_ptr), B_layout); + + prepack_B_kernel<128, PrepackedLayoutB> + <<>>(B_in, B_out_ptr); +} + +}; // namespace machete \ No newline at end of file diff --git a/gptqmodel_ext/machete/machete_prepack_launcher.cuh b/gptqmodel_ext/machete/machete_prepack_launcher.cuh new file mode 100644 index 000000000..634b651a4 --- /dev/null +++ b/gptqmodel_ext/machete/machete_prepack_launcher.cuh @@ -0,0 +1,74 @@ +#pragma once + +#include "machete_prepack_kernel.cuh" +#include "cutlass_extensions/torch_utils.hpp" +#include "core/scalar_type.hpp" + +namespace machete { + +struct PrepackBArgs { + torch::Tensor const& B; + at::ScalarType a_type; + vllm::ScalarType b_type; + std::optional maybe_group_scales_type; +}; + +template +torch::Tensor prepack_impl(torch::Tensor const B) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(B)); + using ElementB = typename PrepackedLayoutB::ElementB; + using PPBlockShape_NK = typename PrepackedLayoutB::PPBlockShape_NK; + + auto device = B.device(); + auto stream = at::cuda::getCurrentCUDAStream(device.index()); + auto B_ptr = static_cast(B.const_data_ptr()); + // elements per storage item for B + auto eles_per_storage = + (B.dtype().itemsize() * 8) / cute::sizeof_bits_v; + + // torch B passed in is/should be (packed_K,N), the kernel expects (N,K,L) (to + // match cutlass using (N,K,L) for B), so we transpose B to (N,packed_K,L) + auto Bt_packed = B.t(); + + TORCH_CHECK( + (B.size(0) * eles_per_storage) % size<1>(PPBlockShape_NK{}) == 0, + "B.shape[0] (in terms of unpacked elements) must be a multiple of ", + size<1>(PPBlockShape_NK{})); + TORCH_CHECK(B.size(1) % size<0>(PPBlockShape_NK{}) == 0, + "B.shape[1] must be a multiple of ", size<0>(PPBlockShape_NK{})); + + using StrideB = cutlass::detail::TagToStrideB_t; + auto const l_Bt_packed = make_cute_layout(Bt_packed, "B"); + + // convert (N,packed_K,L) layout to (N,K,L) layout + // in effect we want to do: blocked_product(layout_Bt_packed, + // make_ordered_layout(make_shape(_1{}, eles_per_storage, _1{}), + // Step<_1, _0, _2>{})); + // but blocked_product does not support dynamic strides so we implement the + // equivalent manually, + // new_shape = (N, packed_K, L) * (1, eles_per_storage, 1) -> (N, K, L) + // new_stride = (s0, s1, s2) * (eles_per_storage, 1, eles_per_storage) + // when s1 == 1 + TORCH_CHECK(stride<1>(l_Bt_packed) == 1); + // clang-format off + auto const layout_Bt = make_layout( + transform_with_idx(l_Bt_packed.shape(), [&](auto ele, auto idx) { + return idx == 1 ? ele * eles_per_storage : ele; + }), + transform_with_idx(l_Bt_packed.stride(), [&](auto ele, auto idx) { + return idx != 1 ? ele * eles_per_storage : ele; + })); + // clang-format on + + // Allocate output + torch::Tensor D = torch::empty_like(B, {}, at::MemoryFormat::Contiguous); + + prepack_B_template( + stream, B_ptr, layout_Bt, static_cast(D.mutable_data_ptr())); + + return D; +}; + +torch::Tensor prepack_B_dispatch(PrepackBArgs args); + +}; // namespace machete \ No newline at end of file diff --git a/gptqmodel_ext/machete/machete_prepacked_layout.cuh b/gptqmodel_ext/machete/machete_prepacked_layout.cuh new file mode 100644 index 000000000..81aaa6c4f --- /dev/null +++ b/gptqmodel_ext/machete/machete_prepacked_layout.cuh @@ -0,0 +1,249 @@ +#pragma once + +#include +#include +#include + +// clang-format off +// The cutlass include order matters (annoyingly) + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +// clang-format on + +#include "cutlass_extensions/cute_utils.cuh" +#include "machete_collective_builder.cuh" +#include "machete_interleaving_utils.cuh" + +namespace machete { + +using namespace cute; + +struct IlvBlkLayoutAuto {}; + +// This defines a prepacked layout for the B matrix, where the matrix is broken +// up into PPBlockShape_NK blocks. The data within each block is then compactly +// stored in memory such that when performing a TiledMMA operation with the same +// shape as prepacked block, all the data for a given thread is contiguous in +// memory. This allows us to use wider shared memory loads when loading B from +// shared memory. The values within a thread are also potentially interlaeved +// inorder to allow for more efficient upconverting. +// +// The contract here is that the `TiledMma` determined below matches the one +// ultimately used in the kernel. (this is also why the other element types are +// required along with the kernel schedule) +template +// clang-format on +struct PrepackedLayoutBTemplate { + using MmaType = ElementA_; + using ElementA = ElementA_; + using ElementB = ElementB_; + using ElementAccumulator = AccumulatorT; + using ElementMma = MmaType; + + // Interleave for 4bit bit types when we are not upconverting to fp8 or int8, + // in those cases case we use a LUT using prmt instructions to upconvert and + // is more efficient if the data is not interleaved For 8bit+ prmt + // instructions makes non-interleaved layouts efficient enough we don't need + // iterleaved layouts (and can reuse more of the existing cutlass converts) + static constexpr bool should_interleave = + sizeof_bits_v <= 4 && + !std::is_same_v && + !std::is_same_v; + + // Only use interleaved layouts for subbyte weights, + using IlvdBlkLayout = std::conditional_t< + std::is_same_v, + std::conditional_t< + should_interleave, + decltype(get_interleaved_blk_layout< + ElementB, sizeof_bits_v, 32>()), + void>, + IlvBlkLayout_>; + + // TODO (LucasWilkinson): compare the performance for other sizes + // Prepacked block shape, smallest layout atom for loading into registers + // (can contain multiple wgmma instructions worth of data in one block) + // We ideally want this to be configured such that a thread can perform 128bit + // loads, i.e. we amount of data associated with each thread within a + // prepacked block is a multiple of 128bits, when using a cooperative sechdule + // we have 256 threads working a single block at a time, this means each + // thread works on `sizeof_bits_v * (128*64) / 256` bits of data, + // for a 4bit type this would be 128bits + using PPBlockShape_NK = Shape<_128, _64>; + + // Create the shape of the tile anticipated to be used by the GEMM kernel, + // when the kernel executes we will compute `Ct = Bt * At` since the + // quantized weights (B), must be the lhs operand so the flow through + // registers. + // The _128 here doesn't actually impact the shape of the stored tile directly + // but may impact the op selected by rs_op_selector + using GemmTileShape = decltype(make_shape(size<0>(PPBlockShape_NK{}), _128{}, + size<1>(PPBlockShape_NK{}))); + + static constexpr cute::GMMA::Major GmmaMajorB = + gmma_rs_tag_to_major_B(); + + // For coop schedules we have two warp groups cooperatively issuing wgmma + // instructions so we use 2 atoms along the M dim (one for each warpgroup) + using AtomLayoutMNK = cute::conditional_t< + cute::is_same_v, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma( + cute::GMMA::rs_op_selector(), + AtomLayoutMNK{})); + + // Prepacked block, (athrid, val) -> (N,K) + // i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (N,K) + CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_NK() { + return TiledMma{}.thrfrg_A(make_layout(PPBlockShape_NK{})); + } + + // Prepacked block, (N,K) -> (athrid, val) + // i.e. (N,K) -> ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) + CUTE_HOST_DEVICE static constexpr auto ppblock_NK_to_TV() { + return right_inverse(ppblock_TV_to_NK()).with_shape(PPBlockShape_NK{}); + } + + // Prepacked block, (athrid, val) -> (storage_offset) + // i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (storage_idx) + CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_offset() { + // Return iterleaved layout + return make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{}); + } + + // Prepacked block, (athrid, val) -> (storage_offset) + // i.e. ((ThrV,(ThrM,ThrK)),(IlvdFrgV,(RestM,RestK,...))) -> (storage_idx) + CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_TV_to_offset() { + auto layout_no_interleave = + make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{}); + + if constexpr (std::is_same_v) { + return layout_no_interleave; + } else { + // interleave by transforming FrgV into interleaved blocks where each + // block has the layout IlvdBlkLayout, for example if IlvdBlkLayout is + // (2, 2) : (2, 1) then we get: ((2, 2), size(FrgV) / 4) : ((2, 1), 4) + // if FrgV is {A, B, C, D, E, F, G, H} + // then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H} + auto frgV = get<1, 0>(layout_no_interleave); + auto ilvdBlk = IlvdBlkLayout{}; + static_assert(size(frgV) % size(ilvdBlk) == 0, + "FrgV must be divisible by size(ilvdBlk)"); + auto ilvd_FrgV = make_layout( + make_shape(shape(ilvdBlk), Int{}), + make_stride(stride(ilvdBlk), size(ilvdBlk))); + + // Return iterleaved layout + return make_layout( + get<0>(layout_no_interleave), + make_layout(ilvd_FrgV, get<1, 1>(layout_no_interleave))); + } + } + + // Prepacked block, (M,K) -> (storage_offset) + CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_NK_to_offset() { + // do (M,K) -> (athrid, val) -> (storage_idx) + return ppblock_ilvd_TV_to_offset().compose(ppblock_NK_to_TV()); + } + + // ((athrid, val), (BlocksN, BlocksK), L) -> (storage_idx) + template + CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset( + Shape_NKL shape_mkl) { + constexpr auto block_layout = ppblock_TV_to_offset(); + + // (BlocksN, BlocksK, L) + auto blocks_shape = + cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}), + [](auto x, auto y) { return x / y; }); + + // ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx) + auto result = make_layout( + block_layout, + make_layout(blocks_shape, + compact_col_major(blocks_shape, size(block_layout)))); + + // ((athrid, val), (BlocksN, BlocksK, L)) + // => ((athrid, val), (BlocksN, BlocksK), L) + return group<1, 3>(result(_, repeat(result)>(_))); + } + + // ((athrid_val), (BlocksN, BlocksK, L)) -> (N, K, L) + template + CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset_copy( + Shape_NKL shape_mkl) { + auto layout = TVbNbKL_to_offset(shape_mkl); + return make_layout(coalesce(get<0>(layout)), get<1>(layout), + get<2>(layout)); + } + + // ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx) + template + CUTE_HOST_DEVICE static constexpr auto ilvd_NKbNbKL_to_offset( + Shape_NKL shape_mkl) { + constexpr auto block_layout = ppblock_ilvd_NK_to_offset(); + + // (BlocksN, BlocksK, L) + auto blocks_shape = + cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}), + [](auto x, auto y) { return x / y; }); + + // ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx) + auto result = make_layout( + block_layout, + make_layout(blocks_shape, + compact_col_major(blocks_shape, size(block_layout)))); + + // ((athrid, val), (BlocksN, BlocksK, L)) => ((athrid, val), (BlocksN, + // BlocksK), L) + return group<1, 3>(result(_, repeat(result)>(_))); + } + + // (BlocksN, BlocksK, L) -> (storage_idx) + template + CUTE_HOST_DEVICE static constexpr auto bNbKL_to_offset(Shape_NKL shape_mkl) { + // (BlocksN, BlocksK, L) + auto blocks_shape = + cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}), + [](auto x, auto y) { return x / y; }); + auto stride = size(PPBlockShape_NK{}); + + // (BlocksN, BlocksK, L) -> (storage_idx) + return make_layout(blocks_shape, compact_col_major(blocks_shape, stride)); + } + + // ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L) + template + CUTE_HOST_DEVICE static auto TVbNbK_to_NKL(Shape_NKL shape_mkl) { + auto tile = make_tile(make_layout(size<0>(PPBlockShape_NK{})), + make_layout(size<1>(PPBlockShape_NK{}))); + + // ((BlockN, BlockK), (BlocksN, BlocksK, L)) -> (N, K, L) + auto tiled_A = zipped_divide(make_layout(shape_mkl), tile); + return tiled_A.compose(ppblock_TV_to_NK(), _); + } + + // (N, K, L) -> ((athrid, val), (BlocksN, BlocksK), L) + template + CUTE_HOST_DEVICE static auto NKL_to_TVbNbK(Shape_NKL shape_mkl) { + auto TVbNbK_to_NKL_layout = TVbNbK_to_NKL(shape_mkl); + return blocked_product(ppblock_NK_to_TV(), + make_layout(shape<1>(TVbNbK_to_NKL_layout))); + } +}; + +}; // namespace machete diff --git a/gptqmodel_ext/machete/machete_pytorch.cu b/gptqmodel_ext/machete/machete_pytorch.cu new file mode 100644 index 000000000..ead6dcc65 --- /dev/null +++ b/gptqmodel_ext/machete/machete_pytorch.cu @@ -0,0 +1,63 @@ +#include "machete_mm_launcher.cuh" +#include "machete_prepack_launcher.cuh" +#include "core/scalar_type.hpp" + +#include "core/registration.h" + +namespace machete { + +using namespace vllm; + +std::vector supported_schedules( + at::ScalarType a_type, int64_t b_type_id, + std::optional maybe_group_scales_type, + std::optional maybe_group_zeros_type, + std::optional maybe_channel_scales_type, + std::optional maybe_token_scales_type, + std::optional maybe_out_type) { + ScalarType const b_type = ScalarType::from_id(b_type_id); + return supported_schedules_dispatch({ + .a_type = a_type, + .b_type = b_type, + .maybe_group_scales_type = maybe_group_scales_type, + .maybe_group_zeros_type = maybe_group_zeros_type, + .maybe_channel_scales_type = maybe_channel_scales_type, + .maybe_token_scales_type = maybe_token_scales_type, + .maybe_out_type = maybe_out_type, + }); +} + +torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B, + int64_t b_type_id, + std::optional const& maybe_out_type, + std::optional const& maybe_group_scales, + std::optional const& maybe_group_zeros, + std::optional maybe_group_size, + std::optional const& maybe_channel_scales, + std::optional const& maybe_token_scales, + std::optional maybe_schedule) { + ScalarType const b_type = ScalarType::from_id(b_type_id); + return mm_dispatch({.A = A, + .B = B, + .b_type = b_type, + .maybe_out_type = maybe_out_type, + .maybe_group_scales = maybe_group_scales, + .maybe_group_zeros = maybe_group_zeros, + .maybe_group_size = maybe_group_size, + .maybe_channel_scales = maybe_channel_scales, + .maybe_token_scales = maybe_token_scales, + .maybe_schedule = maybe_schedule}); +} + +torch::Tensor prepack_B( + torch::Tensor const& B, at::ScalarType const& a_type, int64_t b_type_id, + std::optional const& maybe_group_scales_type) { + ScalarType const b_type = ScalarType::from_id(b_type_id); + return prepack_B_dispatch( + {.B = B, + .a_type = a_type, + .b_type = b_type, + .maybe_group_scales_type = maybe_group_scales_type}); +} + +}; // namespace machete diff --git a/gptqmodel_ext/machete/machete_pytorch.cuh b/gptqmodel_ext/machete/machete_pytorch.cuh new file mode 100644 index 000000000..921782260 --- /dev/null +++ b/gptqmodel_ext/machete/machete_pytorch.cuh @@ -0,0 +1,22 @@ +#include +#include "core/scalar_type.hpp" + +namespace machete { + +using namespace vllm; + +torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B, + int64_t b_type_id, + std::optional const& maybe_out_type, + std::optional const& maybe_group_scales, + std::optional const& maybe_group_zeros, + std::optional maybe_group_size, + std::optional const& maybe_channel_scales, + std::optional const& maybe_token_scales, + std::optional maybe_schedule); + + +torch::Tensor prepack_B( + torch::Tensor const& B, at::ScalarType const& a_type, int64_t b_type_id, + std::optional const& maybe_group_scales_type); +}; \ No newline at end of file diff --git a/gptqmodel_ext/machete/permute_cols.cu b/gptqmodel_ext/machete/permute_cols.cu new file mode 100644 index 000000000..f51fa7329 --- /dev/null +++ b/gptqmodel_ext/machete/permute_cols.cu @@ -0,0 +1,88 @@ +#include + +#include +#include + +#include + +static constexpr int default_threads = 256; +static constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +// Currently only supports 16bit types (since we permute half types) +__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) { + int start_row = block_rows * blockIdx.x; + int finish_row = start_row + block_rows; + if (finish_row > size_m) { + finish_row = size_m; + } + int cur_block_rows = std::max(finish_row - start_row, 0); + + int row_stride = size_k * sizeof(half) / 16; + + auto permute_row = [&](int row) { + int iters = size_k / default_threads; + int rest = size_k % default_threads; + + int offset = row * row_stride; + + half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); + half* out_half = reinterpret_cast(out_int4_ptr + offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += default_threads; + } + + if (rest) { + if (threadIdx.x < rest) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int i = 0; i < cur_block_rows; i++) { + int cur_row = start_row + i; + if (cur_row < size_m) { + permute_row(cur_row); + } + } +} + +// More efficient version of A[..., perm] +// taken from gptq_marlin.cu +torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); + auto dev = A.get_device(); + auto stream = at::cuda::getCurrentCUDAStream(dev); + + TORCH_CHECK(A.scalar_type() == at::kHalf || A.scalar_type() == at::kBFloat16, + "Currently only 16bit types are supported"); + TORCH_CHECK(A.is_contiguous(), "A must be contiguous"); + TORCH_CHECK(A.size(-1) % 8 == 0, + "A columns must be a multiple of 8 (128bits)"); + auto A_2d = A.view({-1, A.size(-1)}); + + torch::Tensor D = torch::empty_like(A); + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + int block_rows = div_ceil(A_2d.size(0), sms); + permute_cols_kernel<<>>( + reinterpret_cast(A_2d.const_data_ptr()), + perm.const_data_ptr(), reinterpret_cast(D.mutable_data_ptr()), + A_2d.size(0), A_2d.size(1), block_rows); + return D; +} \ No newline at end of file diff --git a/gptqmodel_ext/machete/permute_cols.cuh b/gptqmodel_ext/machete/permute_cols.cuh new file mode 100644 index 000000000..e1fec9af5 --- /dev/null +++ b/gptqmodel_ext/machete/permute_cols.cuh @@ -0,0 +1,3 @@ +#include + +torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm); \ No newline at end of file diff --git a/gptqmodel_ext/vllm_csrc/core/exception.hpp b/gptqmodel_ext/vllm_csrc/core/exception.hpp new file mode 100644 index 000000000..f3b2ffaef --- /dev/null +++ b/gptqmodel_ext/vllm_csrc/core/exception.hpp @@ -0,0 +1,3 @@ +#pragma once + +#define VLLM_IMPLIES(p, q) (!(p) || (q)) diff --git a/gptqmodel_ext/vllm_csrc/core/math.hpp b/gptqmodel_ext/vllm_csrc/core/math.hpp new file mode 100644 index 000000000..6764e1fd6 --- /dev/null +++ b/gptqmodel_ext/vllm_csrc/core/math.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include +#include + +inline constexpr uint32_t next_pow_2(uint32_t const num) { + if (num <= 1) return num; + return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); +} + +template +static inline constexpr auto div_ceil(A a, B b) { + return (a + b - 1) / b; +} + +// Round a down to the next multiple of b. The caller is responsible for making +// sure that b is non-zero +template +inline constexpr T round_to_previous_multiple_of(T a, T b) { + return a % b == 0 ? a : (a / b) * b; +} + +// Round a up to the next multiple of b. The caller is responsible for making +// sure that b is non-zero +template +inline constexpr T round_to_next_multiple_of(T a, T b) { + return a % b == 0 ? a : ((a / b) + 1) * b; +} diff --git a/gptqmodel_ext/vllm_csrc/core/registration.h b/gptqmodel_ext/vllm_csrc/core/registration.h new file mode 100644 index 000000000..4d0ce1c57 --- /dev/null +++ b/gptqmodel_ext/vllm_csrc/core/registration.h @@ -0,0 +1,27 @@ +#pragma once + +#include + +#define _CONCAT(A, B) A##B +#define CONCAT(A, B) _CONCAT(A, B) + +#define _STRINGIFY(A) #A +#define STRINGIFY(A) _STRINGIFY(A) + +// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME +// could be a macro instead of a literal token. +#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) + +// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME +// could be a macro instead of a literal token. +#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \ + TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE) + +// REGISTER_EXTENSION allows the shared library to be loaded and initialized +// via python's import statement. +#define REGISTER_EXTENSION(NAME) \ + PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \ + static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \ + STRINGIFY(NAME), nullptr, 0, nullptr}; \ + return PyModule_Create(&module); \ + } diff --git a/gptqmodel_ext/vllm_csrc/core/scalar_type.hpp b/gptqmodel_ext/vllm_csrc/core/scalar_type.hpp new file mode 100644 index 000000000..c2ae554c9 --- /dev/null +++ b/gptqmodel_ext/vllm_csrc/core/scalar_type.hpp @@ -0,0 +1,347 @@ +#pragma once + +// For TORCH_CHECK +#include + +namespace vllm { + +// +// ScalarType can represent a wide range of floating point and integer types, +// in particular it can be used to represent sub-byte data types (something +// that torch.dtype currently does not support). +// +// The type definitions on the Python side can be found in: vllm/scalar_type.py +// these type definitions should be kept up to date with any Python API changes +// here. +// +class ScalarType { + public: + enum NanRepr : uint8_t { + NAN_NONE = 0, // nans are not supported + NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s + NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s + + NAN_REPR_ID_MAX + }; + + constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_, + int32_t bias, bool finite_values_only = false, + NanRepr nan_repr = NAN_IEEE_754) + : exponent(exponent), + mantissa(mantissa), + signed_(signed_), + bias(bias), + finite_values_only(finite_values_only), + nan_repr(nan_repr) {}; + + static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits - 1, true, bias); + } + + static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits, false, bias); + } + + // IEEE 754 compliant floating point type + static constexpr ScalarType float_IEEE754(uint8_t exponent, + uint8_t mantissa) { + TORCH_CHECK(mantissa > 0 && exponent > 0); + return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754); + } + + // IEEE 754 non-compliant floating point type + static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, + bool finite_values_only, + NanRepr nan_repr) { + TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr"); + TORCH_CHECK(mantissa > 0 && exponent > 0); + TORCH_CHECK(nan_repr != NAN_IEEE_754, + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions"); + return ScalarType(exponent, mantissa, true, 0, finite_values_only, + nan_repr); + } + + uint8_t const exponent; // size of the exponent field (0 for integer types) + uint8_t const mantissa; // size of the mantissa field (size of the integer + // excluding the sign bit for integer types) + bool const signed_; // flag if the type supports negative numbers (i.e. has a + // sign bit) + int32_t const bias; // stored values equal value + bias, + // used for quantized type + + // Extra Floating point info + bool const finite_values_only; // i.e. no +/-inf if true + NanRepr const nan_repr; // how NaNs are represented + // (not applicable for integer types) + + using Id = int64_t; + + private: + // Field size in id + template + static constexpr size_t member_id_field_width() { + using T = std::decay_t; + return std::is_same_v ? 1 : sizeof(T) * 8; + } + + template + static constexpr auto reduce_members_helper(Fn f, Init val, Member member, + Rest... rest) { + auto new_val = f(val, member); + if constexpr (sizeof...(rest) > 0) { + return reduce_members_helper(f, new_val, rest...); + } else { + return new_val; + }; + } + + template + constexpr auto reduce_members(Fn f, Init init) const { + // Should be in constructor order for `from_id` + return reduce_members_helper(f, init, exponent, mantissa, signed_, bias, + finite_values_only, nan_repr); + }; + + template + static constexpr auto reduce_member_types(Fn f, Init init) { + constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE); + return dummy_type.reduce_members(f, init); + }; + + static constexpr auto id_size_bits() { + return reduce_member_types( + [](int acc, auto member) -> int { + return acc + member_id_field_width(); + }, + 0); + } + + public: + // unique id for this scalar type that can be computed at compile time for + // c++17 template specialization this is not needed once we migrate to + // c++20 and can pass literal classes as template parameters + constexpr Id id() const { + static_assert(id_size_bits() <= sizeof(Id) * 8, + "ScalarType id is too large to be stored"); + + auto or_and_advance = [](std::pair result, + auto member) -> std::pair { + auto [id, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) + << bit_offset, + bit_offset + bits}; + }; + return reduce_members(or_and_advance, std::pair{}).first; + } + + // create a ScalarType from an id, for c++17 template specialization, + // this is not needed once we migrate to c++20 and can pass literal + // classes as template parameters + static constexpr ScalarType from_id(Id id) { + auto extract_and_advance = [id](auto result, auto member) { + using T = decltype(member); + auto [tuple, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + auto extracted_val = static_cast((int64_t(id) >> bit_offset) & + ((uint64_t(1) << bits) - 1)); + auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val)); + return std::pair{new_tuple, bit_offset + bits}; + }; + + auto [tuple_args, _] = reduce_member_types(extract_and_advance, + std::pair, int>{}); + return std::apply([](auto... args) { return ScalarType(args...); }, + tuple_args); + } + + constexpr int64_t size_bits() const { + return mantissa + exponent + is_signed(); + } + constexpr bool is_signed() const { return signed_; } + constexpr bool is_integer() const { return exponent == 0; } + constexpr bool is_floating_point() const { return exponent > 0; } + constexpr bool is_ieee_754() const { + return is_floating_point() && finite_values_only == false && + nan_repr == NAN_IEEE_754; + } + constexpr bool has_nans() const { + return is_floating_point() && nan_repr != NAN_NONE; + } + constexpr bool has_infs() const { + return is_floating_point() && finite_values_only == false; + } + constexpr bool has_bias() const { return bias != 0; } + + private: + double _floating_point_max() const { + TORCH_CHECK(mantissa <= 52 && exponent <= 11, + "Cannot represent max/min as a double for type ", str()); + + uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) { + max_mantissa -= 1; + } + + uint64_t max_exponent = (uint64_t(1) << exponent) - 2; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) { + TORCH_CHECK(exponent < 11, + "Cannot represent max/min as a double for type ", str()); + max_exponent += 1; + } + + // adjust the exponent to match that of a double + // for now we assume the exponent bias is the standard 2^(e-1) -1, (where e + // is the exponent bits), there is some precedent for non-standard biases, + // example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes + // but to avoid premature over complication we are just assuming the + // standard exponent bias until there is a need to support non-standard + // biases + uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1; + uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11 + + uint64_t max_exponent_double = + max_exponent - exponent_bias + exponent_bias_double; + + // shift the mantissa into the position for a double and + // the exponent + uint64_t double_raw = + (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52); + + return *reinterpret_cast(&double_raw); + } + + constexpr std::variant _raw_max() const { + if (is_floating_point()) { + return {_floating_point_max()}; + } else { + TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(), + "Cannot represent max as a int64_t"); + return {(int64_t(1) << mantissa) - 1}; + } + } + + constexpr std::variant _raw_min() const { + if (is_floating_point()) { + TORCH_CHECK(is_signed(), + "We currently assume all floating point types are signed"); + constexpr uint64_t sign_bit_double = (uint64_t(1) << 63); + + double max = _floating_point_max(); + uint64_t max_raw = *reinterpret_cast(&max); + uint64_t min_raw = max_raw | sign_bit_double; + return {*reinterpret_cast(&min_raw)}; + } else { + TORCH_CHECK(!is_signed() || size_bits() <= 64, + "Cannot represent min as a int64_t"); + if (is_signed()) { + // set the top bit to 1 (i.e. INT64_MIN) and the rest to 0 + // then perform an arithmetic shift right to set all the bits above + // (size_bits() - 1) to 1 + return {INT64_MIN >> (64 - size_bits())}; + } else { + return {int64_t(0)}; + } + } + } + + public: + // Max representable value for this scalar type. + // (accounting for bias if there is one) + constexpr std::variant max() const { + return std::visit( + [this](auto x) -> std::variant { return {x - bias}; }, + _raw_max()); + } + + // Min representable value for this scalar type. + // (accounting for bias if there is one) + constexpr std::variant min() const { + return std::visit( + [this](auto x) -> std::variant { return {x - bias}; }, + _raw_min()); + } + + std::string str() const { + /* naming generally follows: https://github.com/jax-ml/ml_dtypes + * for floating point types (leading f) the scheme is: + * `float_em[flags]` + * flags: + * - no-flags: means it follows IEEE 754 conventions + * - f: means finite values only (no infinities) + * - n: means nans are supported (non-standard encoding) + * for integer types the scheme is: + * `[u]int[b]` + * - if bias is not present it means its zero + */ + if (is_floating_point()) { + auto ret = "float" + std::to_string(size_bits()) + "_e" + + std::to_string(exponent) + "m" + std::to_string(mantissa); + if (!is_ieee_754()) { + if (finite_values_only) { + ret += "f"; + } + if (nan_repr != NAN_NONE) { + ret += "n"; + } + } + return ret; + } else { + auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits()); + if (has_bias()) { + ret += "b" + std::to_string(bias); + } + return ret; + } + } + + constexpr bool operator==(ScalarType const& other) const { + return mantissa == other.mantissa && exponent == other.exponent && + bias == other.bias && signed_ == other.signed_ && + finite_values_only == other.finite_values_only && + nan_repr == other.nan_repr; + } +}; + +using ScalarTypeId = ScalarType::Id; + +// "rust style" names generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70 +static inline constexpr auto kS4 = ScalarType::int_(4); +static inline constexpr auto kU4 = ScalarType::uint(4); +static inline constexpr auto kU4B8 = ScalarType::uint(4, 8); +static inline constexpr auto kS8 = ScalarType::int_(8); +static inline constexpr auto kU8 = ScalarType::uint(8); +static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); + +static inline constexpr auto kFE3M2f = + ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); +static inline constexpr auto kFE4M3fn = + ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2); +static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7); +static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10); + +// Fixed width style names, generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57 +static inline constexpr auto kInt4 = kS4; +static inline constexpr auto kUint4 = kU4; +static inline constexpr auto kUint4b8 = kU4B8; +static inline constexpr auto kInt8 = kS8; +static inline constexpr auto kUint8 = kU8; +static inline constexpr auto kUint8b128 = kU8B128; + +static inline constexpr auto kFloat6_e3m2f = kFE3M2f; +static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; +static inline constexpr auto kFloat8_e5m2 = kFE5M2; +static inline constexpr auto kFloat16_e8m7 = kFE8M7; +static inline constexpr auto kFloat16_e5m10 = kFE5M10; + +// colloquial names +static inline constexpr auto kHalf = kFE5M10; +static inline constexpr auto kFloat16 = kHalf; +static inline constexpr auto kBFloat16 = kFE8M7; + +static inline constexpr auto kFloat16Id = kFloat16.id(); +}; // namespace vllm diff --git a/gptqmodel_ext/vllm_csrc/cutlass_extensions/.DS_Store b/gptqmodel_ext/vllm_csrc/cutlass_extensions/.DS_Store new file mode 100644 index 000000000..51a8cedc6 Binary files /dev/null and b/gptqmodel_ext/vllm_csrc/cutlass_extensions/.DS_Store differ diff --git a/gptqmodel_ext/vllm_csrc/cutlass_extensions/common.cpp b/gptqmodel_ext/vllm_csrc/cutlass_extensions/common.cpp new file mode 100644 index 000000000..3d2093ab9 --- /dev/null +++ b/gptqmodel_ext/vllm_csrc/cutlass_extensions/common.cpp @@ -0,0 +1,11 @@ +#include "cutlass_extensions/common.hpp" + +int32_t get_sm_version_num() { + int32_t major_capability, minor_capability; + cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, + 0); + cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, + 0); + int32_t version_num = major_capability * 10 + minor_capability; + return version_num; +} \ No newline at end of file diff --git a/gptqmodel_ext/vllm_csrc/cutlass_extensions/common.hpp b/gptqmodel_ext/vllm_csrc/cutlass_extensions/common.hpp new file mode 100644 index 000000000..dbe0e30f5 --- /dev/null +++ b/gptqmodel_ext/vllm_csrc/cutlass_extensions/common.hpp @@ -0,0 +1,61 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include +#include "cuda_runtime.h" +#include + +/** + * Helper function for checking CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + TORCH_CHECK(error == cutlass::Status::kSuccess, \ + cutlassGetStatusString(error)); \ + } + +/** + * Panic wrapper for unwinding CUDA runtime errors + */ +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \ + } + +inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { + int max_shared_mem_per_block_opt_in = 0; + cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, + cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + return max_shared_mem_per_block_opt_in; +} + +int32_t get_sm_version_num(); + +/** + * A wrapper for a kernel that is used to guard against compilation on + * architectures that will never use the kernel. The purpose of this is to + * reduce the size of the compiled binary. + * __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef + * into code that will be executed on the device where it is defined. + */ +template +struct enable_sm90_or_later : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 + Kernel::operator()(std::forward(args)...); +#endif + } +}; + +template +struct enable_sm90_only : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 900 + Kernel::operator()(std::forward(args)...); +#endif + } +}; diff --git a/gptqmodel_ext/vllm_csrc/cutlass_extensions/cute_utils.cuh b/gptqmodel_ext/vllm_csrc/cutlass_extensions/cute_utils.cuh new file mode 100644 index 000000000..f61fe3ceb --- /dev/null +++ b/gptqmodel_ext/vllm_csrc/cutlass_extensions/cute_utils.cuh @@ -0,0 +1,68 @@ +#pragma once + +#include +#include +namespace cute { + +//////////////////////////////////////////////////////////////////// +// layout utils +//////////////////////////////////////////////////////////////////// + +// Permute layout based on indices, example: +// permute_layout<1, 0>(layout) will swap the two dimensions +// permute_layout<0, 2, 1>(layout) will swap the last two dimensions +template +CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) { + static_assert(rank(l) == sizeof...(I), "Invalid permutation, rank mismatch"); + return cute::make_layout(cute::get(l)...); +} + +// is the layout f(x) = x +template +CUTE_HOST_DEVICE static constexpr bool is_identity_layout() { + if constexpr (std::is_same_v) { + return true; + } else { + constexpr auto coalesced_layout = coalesce(Layout{}); + if constexpr (rank(coalesced_layout) == 1 && + stride<0>(coalesced_layout) == 1) { + return true; + } + return false; + } +} + +//////////////////////////////////////////////////////////////////// +// Pointer utils +//////////////////////////////////////////////////////////////////// + +template +static constexpr auto get_logical_ptr(PointerType* ptr) { + if constexpr (cute::sizeof_bits_v < 8) { + return cute::subbyte_iterator(ptr); + } else { + return ptr; + } +} + +//////////////////////////////////////////////////////////////////// +// Misc utils +//////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE static constexpr auto create_auto_vectorizing_copy() { + constexpr auto bits = sizeof_bits_v * Elements{}; + if constexpr (bits % 128 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<128>{}; + } else if constexpr (bits % 64 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<64>{}; + } else if constexpr (bits % 32 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<32>{}; + } else if constexpr (bits % 16 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<16>{}; + } else { + return AutoVectorizingCopyWithAssumedAlignment<8>{}; + } +} + +}; // namespace cute diff --git a/gptqmodel_ext/vllm_csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp b/gptqmodel_ext/vllm_csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp new file mode 100644 index 000000000..5c1d6e3f4 --- /dev/null +++ b/gptqmodel_ext/vllm_csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp @@ -0,0 +1,457 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// +// This file is a modified excerpt of +// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +// from https://github.com/NVIDIA/cutlass v3.5.0 +// It has been modified to support either row/column or scalar broadcasting +// where the tensor being loaded from is always passed in via a device pointer. +// This lets one compiled kernel handle all cases of per-tensor or +// per-channel/per-token quantization. +// +// This interface also allows the scales to be passed in as tensors that +// consistently reside on the device, which avoids an issue with a previous +// implementation where scalars needed to be on the CPU since they +// were passed in via float values. This created a potential performance hazard +// if scales were initially on the device, and caused torch.compile graphs +// breaks when moving scales to the CPU. +// +#pragma once + +// Turn off clang-format for the entire file to keep it close to upstream +// clang-format off + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +// Row vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_0,_1,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90RowOrScalarBroadcastArray { + static_assert(Stages == 0, "Row broadcast doesn't support smem usage"); + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); + + struct SharedStorage { + array_aligned(CtaTileShapeMNK{})> smem; + }; + + // This struct has been modified to have a bool indicating that ptr_row is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_row is null. + struct Arguments { + const Element* const* ptr_row_array = nullptr; + bool row_broadcast = true; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcastArray() { } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage) + : params(params) + , smem(const_cast(shared_storage.smem.data())) { } + + Params params; + Element *smem = nullptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.row_broadcast && *(params.ptr_row_array[group]) == Element(0)); + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, + GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, + SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, + CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, + int group, Params const& params_) + : tGS_gRow(tGS_gRow_) + , tGS_sRow(tGS_sRow_) + , tGS_cRow(tGS_cRow_) + , tiled_G2S(tiled_g2s_) + , tSR_sRow(tSR_sRow_) + , tSR_rRow(tSR_rRow_) + , tCcRow(tCcRow_) + , residue_tCcRow(residue_tCcRow_) + , group(group) + , params(params_) {} + + GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N) + GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N) + GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N) + Tiled_G2S tiled_G2S; + + SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ThrResidue residue_tCcRow; // (m, n) + ThrNum thr_num; + int group; + Params const& params; + + CUTLASS_DEVICE void + begin() { + if (!params.row_broadcast) { + fill(tSR_rRow, *(params.ptr_row_array[group])); + return; + } + + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); + Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); + Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); + + for (int i = 0; i < size(tGS_gRow_flt); ++i) { + if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { + continue; // OOB of SMEM, + } + if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) { + tGS_sRow_flt(i) = tGS_gRow_flt(i); + } + else { + tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds. + } + } + synchronize(); + } + + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { + if (epi_m == 0) { // Assumes M-major subtile loop + if (!params.row_broadcast) return; // Do not issue LDS when row is scalar + Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); + Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); + copy(tSR_sRow_flt, tSR_rRow_flt); + } + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_row; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_row[i] = tSR_rRow(epi_v * FragmentSize + i); + } + + return frg_row; + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + using ThreadCount = decltype(size(args.tiled_copy)); + + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow); + Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) + Tensor sRow = make_tensor(make_smem_ptr(smem), + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) + //// G2S: Gmem to Smem + auto tiled_g2s = make_tiled_copy(Copy_Atom{}, + Layout< Shape<_1, ThreadCount>, + Stride<_0, _1>>{}, + Layout<_1>{}); + auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); + Tensor tGS_gRow = thr_g2s.partition_S(gRow); + Tensor tGS_sRow = thr_g2s.partition_D(sRow); + + //// G2S: Coord + auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); + Tensor tGS_cRow = thr_g2s.partition_S(cRow); + + //// S2R: Smem to Reg + Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) + + return ConsumerStoreCallbacks( + tGS_gRow, + tGS_sRow, + tGS_cRow, tiled_g2s, + tSR_sRow, + tSR_rRow, + args.tCcD, + args.residue_cD, + ThreadCount{}, + l, + params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_1,_0,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90ColOrScalarBroadcastArray { + static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet"); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert( + (cute::is_same_v>) || // col vector broadcast, e.g. per-row alpha/bias + (cute::is_same_v>)); // batched col vector broadcast, e.g. batched per-row bias + + // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem + struct SharedStorage { }; + + // This struct has been modified to have a bool indicating that ptr_col is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_col is null. + struct Arguments { + const Element* const* ptr_col_array = nullptr; + bool col_broadcast = true; + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.col_broadcast && *(params.ptr_col_array[group]) == Element(0)); + } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcastArray() { } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GTensor&& tCgCol, + RTensor&& tCrCol, + CTensor&& tCcCol, + ProblemShape problem_shape, + int group, + Params const& params + ): + tCgCol(cute::forward(tCgCol)), + tCrCol(cute::forward(tCrCol)), + tCcCol(cute::forward(tCcCol)), + m(get<0>(problem_shape)), + group(group), + params(params) {} + + GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + RTensor tCrCol; + CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Params const& params; + int m; + int group; + + CUTLASS_DEVICE void + begin() { + Tensor pred = make_tensor(shape(tCgCol)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tCcCol(i)) < m; + } + + if (!params.col_broadcast) { + fill(tCrCol, *(params.ptr_col_array[group])); + return; + } + + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + copy_if(pred, filter(tCgCol), filter(tCrCol)); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_col; + Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); + } + + return frg_col; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol); + Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + // Generate an identity tensor matching the shape of the global tensor and + // partition the same way, this will be used to generate the predicate + // tensor for loading + Tensor cCol = make_identity_tensor(mCol.shape()); + Tensor tCcCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + return ConsumerStoreCallbacks( + cute::move(tCgCol), + cute::move(tCrCol), + cute::move(tCcCol), + args.problem_shape_mnkl, + l, + params + ); + } +}; + +} diff --git a/gptqmodel_ext/vllm_csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp b/gptqmodel_ext/vllm_csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp new file mode 100644 index 000000000..7aa87feb4 --- /dev/null +++ b/gptqmodel_ext/vllm_csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp @@ -0,0 +1,497 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// +// This file is a modified excerpt of +// include/cutlass/epilogue/fusion/visitor_load.hpp from +// https://github.com/NVIDIA/cutlass v3.5.0 +// It has been modified to support either +// row/column or scalar broadcasting where the tensor being loaded from is +// always passed in via a device pointer. This lets one compiled kernel handle +// all cases of per-tensor or per-channel/per-token quantization. +// +// This interface also allows the scales to be passed in as tensors that +// consistently reside on the device, which avoids an issue with a previous +// implementation where scalars needed to be on the CPU since they +// were passed in via float values. This created a potential performance hazard +// if scales were initially on the device, and caused torch.compile graph +// breaks when moving scales to the CPU. +// +#pragma once + +// Turn off clang-format for the entire file to keep it close to upstream +// clang-format off + +#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" +#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" +#include "cute/tensor.hpp" + +namespace cutlass::epilogue::threadblock { + +using namespace cute; +using namespace detail; + +template< + class ThreadMap, + class Element, + class StrideMNL +> +struct VisitorRowOrScalarBroadcast { + + // This struct has been modified to have a bool indicating that ptr_row is a + // scalar that must be broadcast. + struct Arguments { + Element const* ptr_row = nullptr; + bool row_broadcast = true; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage {}; + + // Global load type + static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; + using VecType = uint_bit_t; + static int constexpr VecLength = sizeof(VecType) / sizeof(Element); + + CUTLASS_HOST_DEVICE + VisitorRowOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gRow, + RTensor&& tC_rRow, + CTensor&& tC_cRow, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gRow(cute::forward(tC_gRow)), + tC_rRow(cute::forward(tC_rRow)), + tC_cRow(cute::forward(tC_cRow)), + n(get<1>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gRow; + RTensor tC_rRow; + CTensor tC_cRow; + Params const* params_ptr; + int n; + + // This function is modified from VisitorRowBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rRow); + auto src_v = filter(tC_gRow); + auto coord_v = filter(tC_cRow); + auto dst_v = filter(tC_rRow); + + if (params_ptr->row_broadcast) { + // In this case we are loading from a row vector and broadcasting + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + bool guard = get<1>(coord_v(i)) < n; + cutlass::arch::global_load( + dst_v(i), (void const*)&src_v(i), guard); + } + } else { + // In this case we are loading from a scalar and broadcasting + VecType filled_vec; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < VecLength; i++) { + reinterpret_cast(&filled_vec)[i] = *(params_ptr->ptr_row); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + if (get<1>(coord_v(i)) < n) { + dst_v(i) = filled_vec; + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Tensor rRow_frg = recast>(coalesce(tC_rRow)); + return rRow_frg(column_idx); + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mRow = make_tensor( + make_gmem_ptr(params_ptr->ptr_row), + problem_shape, + params_ptr->dRow); + + // VECTOR, FRAGMENT_COLUMN + Tensor tC_gRow = recast( + ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) + )(_,_,_0{},_0{},_0{},_0{}); + Tensor tC_rRow = make_tensor_like(tC_gRow); + + // Generate the pred tensor + Tensor cRow = make_identity_tensor(mRow.shape()); + Tensor tC_cRow = outer_partition( + ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), + Shape>{}, + (_0{}) + ); + + return Callbacks< + decltype(tC_gRow), decltype(tC_rRow), + decltype(tC_cRow), ProblemShape>( + cute::move(tC_gRow), + cute::move(tC_rRow), + cute::move(tC_cRow), + problem_shape, + params_ptr + ); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// This is a modified RowBroadcast that will broadcast 0 if ptr_row is null +template< + class ThreadMap, + class Element, + class StrideMNL +> +struct VisitorRowOrZeroBroadcast { + + // This struct has been modified to remove null_default (because it's always 0) + struct Arguments { + Element const* ptr_row = nullptr; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage {}; + + // Global load type + static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; + using VecType = uint_bit_t; + static int constexpr VecLength = sizeof(VecType) / sizeof(Element); + + CUTLASS_HOST_DEVICE + VisitorRowOrZeroBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorRowOrZeroBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gRow, + RTensor&& tC_rRow, + CTensor&& tC_cRow, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gRow(cute::forward(tC_gRow)), + tC_rRow(cute::forward(tC_rRow)), + tC_cRow(cute::forward(tC_cRow)), + n(get<1>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gRow; + RTensor tC_rRow; + CTensor tC_cRow; + Params const* params_ptr; + int n; + + // This function is modified from VisitorRowBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rRow); + auto src_v = filter(tC_gRow); + auto coord_v = filter(tC_cRow); + auto dst_v = filter(tC_rRow); + + if (params_ptr->ptr_row != nullptr) { + // In this case we are loading from a row vector and broadcasting + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + bool guard = get<1>(coord_v(i)) < n; + cutlass::arch::global_load( + dst_v(i), (void const*)&src_v(i), guard); + } + } else { + // In this case we are broadcasting 0 + VecType filled_vec; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < VecLength; i++) { + reinterpret_cast(&filled_vec)[i] = Element{0}; + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + if (get<1>(coord_v(i)) < n) { + dst_v(i) = filled_vec; + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Tensor rRow_frg = recast>(coalesce(tC_rRow)); + return rRow_frg(column_idx); + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mRow = make_tensor( + make_gmem_ptr(params_ptr->ptr_row), + problem_shape, + params_ptr->dRow); + + // VECTOR, FRAGMENT_COLUMN + Tensor tC_gRow = recast( + ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) + )(_,_,_0{},_0{},_0{},_0{}); + Tensor tC_rRow = make_tensor_like(tC_gRow); + + // Generate the pred tensor + Tensor cRow = make_identity_tensor(mRow.shape()); + Tensor tC_cRow = outer_partition( + ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), + Shape>{}, + (_0{}) + ); + + return Callbacks< + decltype(tC_gRow), decltype(tC_rRow), + decltype(tC_cRow), ProblemShape>( + cute::move(tC_gRow), + cute::move(tC_rRow), + cute::move(tC_cRow), + problem_shape, + params_ptr + ); + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + class ThreadMap, + class Element, + class StrideMNL = Stride<_1,_0,_0> +> +struct VisitorColOrScalarBroadcast { + + // This struct has been modified to have a bool indicating that ptr_col is a + // scalar that must be broadcast. + struct Arguments { + Element const* ptr_col = nullptr; + bool col_broadcast = true; + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage { }; + + CUTLASS_HOST_DEVICE + VisitorColOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gCol, + RTensor&& tC_rCol, + CTensor&& tC_cCol, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gCol(cute::forward(tC_gCol)), + tC_rCol(cute::forward(tC_rCol)), + tC_cCol(cute::forward(tC_cCol)), + m(get<0>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gCol; + RTensor tC_rCol; + CTensor tC_cCol; + Params const* params_ptr; + int m; + + // This function is modified from VisitorColBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rCol); + + Tensor pred = make_tensor(shape(tC_gCol)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tC_cCol(i)) < m; + } + + if (params_ptr->col_broadcast) { + // In this case we are loading from a column vector and broadcasting + copy_if(pred, tC_gCol, tC_rCol); + } else { + // In this case we are loading from a scalar and broadcasting + auto dst_v = filter(tC_rCol); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(dst_v); ++i) { + if (pred(i)) { + dst_v(i) = *(params_ptr->ptr_col); + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Array frg_col; + frg_col.fill(tC_rCol(row_idx,iter_idx)); + return frg_col; + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mCol = make_tensor( + make_gmem_ptr(params_ptr->ptr_col), + problem_shape, + params_ptr->dCol); + + // VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER + Tensor tC_gCol = group_modes<1,4>( + ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); + Tensor tC_rCol = make_tensor_like(tC_gCol); + + // Generate the pred tensor + Tensor cCol = make_identity_tensor(mCol.shape()); + Tensor tC_cCol = group_modes<1,4>( + ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); + + return Callbacks< + decltype(tC_gCol), decltype(tC_rCol), + decltype(tC_cCol), ProblemShape>( + cute::move(tC_gCol), + cute::move(tC_rCol), + cute::move(tC_cCol), + problem_shape, + params_ptr + ); + } +}; + +} diff --git a/gptqmodel_ext/vllm_csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp b/gptqmodel_ext/vllm_csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp new file mode 100644 index 000000000..58b1e8ff1 --- /dev/null +++ b/gptqmodel_ext/vllm_csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp @@ -0,0 +1,447 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// +// This file is a modified excerpt of +// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +// from https://github.com/NVIDIA/cutlass v3.5.0 +// It has been modified to support either row/column or scalar broadcasting +// where the tensor being loaded from is always passed in via a device pointer. +// This lets one compiled kernel handle all cases of per-tensor or +// per-channel/per-token quantization. +// +// This interface also allows the scales to be passed in as tensors that +// consistently reside on the device, which avoids an issue with a previous +// implementation where scalars needed to be on the CPU since they +// were passed in via float values. This created a potential performance hazard +// if scales were initially on the device, and caused torch.compile graphs +// breaks when moving scales to the CPU. +// +#pragma once + +// Turn off clang-format for the entire file to keep it close to upstream +// clang-format off + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +// Row vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_0,_1,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90RowOrScalarBroadcast { + static_assert(Stages == 0, "Row broadcast doesn't support smem usage"); + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); + + struct SharedStorage { + array_aligned(CtaTileShapeMNK{})> smem; + }; + + // This struct has been modified to have a bool indicating that ptr_row is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_row is null. + struct Arguments { + Element const* ptr_row = nullptr; + bool row_broadcast = true; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params(params) + , smem(const_cast(shared_storage.smem.data())) { } + + Params params; + Element *smem = nullptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.row_broadcast && *(params.ptr_row) == Element(0)); + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, + GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, + SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, + CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_) + : tGS_gRow(tGS_gRow_) + , tGS_sRow(tGS_sRow_) + , tGS_cRow(tGS_cRow_) + , tiled_G2S(tiled_g2s_) + , tSR_sRow(tSR_sRow_) + , tSR_rRow(tSR_rRow_) + , tCcRow(tCcRow_) + , residue_tCcRow(residue_tCcRow_) + , params(params_) {} + + GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N) + GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N) + GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N) + Tiled_G2S tiled_G2S; + + SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ThrResidue residue_tCcRow; // (m, n) + ThrNum thr_num; + Params const& params; + + CUTLASS_DEVICE void + begin() { + if (!params.row_broadcast) { + fill(tSR_rRow, *(params.ptr_row)); + return; + } + + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); + Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); + Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); + + for (int i = 0; i < size(tGS_gRow_flt); ++i) { + if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { + continue; // OOB of SMEM, + } + if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) { + tGS_sRow_flt(i) = tGS_gRow_flt(i); + } + else { + tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds. + } + } + synchronize(); + } + + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { + if (epi_m == 0) { // Assumes M-major subtile loop + if (!params.row_broadcast) return; // Do not issue LDS when row is scalar + Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); + Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); + copy(tSR_sRow_flt, tSR_rRow_flt); + } + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_row; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_row[i] = tSR_rRow(epi_v * FragmentSize + i); + } + + return frg_row; + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + using ThreadCount = decltype(size(args.tiled_copy)); + + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); + Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) + Tensor sRow = make_tensor(make_smem_ptr(smem), + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) + //// G2S: Gmem to Smem + auto tiled_g2s = make_tiled_copy(Copy_Atom{}, + Layout< Shape<_1, ThreadCount>, + Stride<_0, _1>>{}, + Layout<_1>{}); + auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); + Tensor tGS_gRow = thr_g2s.partition_S(gRow); + Tensor tGS_sRow = thr_g2s.partition_D(sRow); + + //// G2S: Coord + auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); + Tensor tGS_cRow = thr_g2s.partition_S(cRow); + + //// S2R: Smem to Reg + Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) + + return ConsumerStoreCallbacks( + tGS_gRow, + tGS_sRow, + tGS_cRow, tiled_g2s, + tSR_sRow, + tSR_rRow, + args.tCcD, + args.residue_cD, + ThreadCount{}, + params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_1,_0,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90ColOrScalarBroadcast { + static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet"); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert( + (cute::is_same_v>) || // col vector broadcast, e.g. per-row alpha/bias + (cute::is_same_v>)); // batched col vector broadcast, e.g. batched per-row bias + + // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem + struct SharedStorage { }; + + // This struct has been modified to have a bool indicating that ptr_col is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_col is null. + struct Arguments { + Element const* ptr_col = nullptr; + bool col_broadcast = true; + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.col_broadcast && *(params.ptr_col) == Element(0)); + } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GTensor&& tCgCol, + RTensor&& tCrCol, + CTensor&& tCcCol, + ProblemShape problem_shape, + Params const& params + ): + tCgCol(cute::forward(tCgCol)), + tCrCol(cute::forward(tCrCol)), + tCcCol(cute::forward(tCcCol)), + m(get<0>(problem_shape)), + params(params) {} + + GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + RTensor tCrCol; + CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Params const& params; + int m; + + CUTLASS_DEVICE void + begin() { + Tensor pred = make_tensor(shape(tCgCol)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tCcCol(i)) < m; + } + + if (!params.col_broadcast) { + fill(tCrCol, *(params.ptr_col)); + return; + } + + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + copy_if(pred, filter(tCgCol), filter(tCrCol)); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_col; + Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); + } + + return frg_col; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); + Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + // Generate an identity tensor matching the shape of the global tensor and + // partition the same way, this will be used to generate the predicate + // tensor for loading + Tensor cCol = make_identity_tensor(mCol.shape()); + Tensor tCcCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + return ConsumerStoreCallbacks( + cute::move(tCgCol), + cute::move(tCrCol), + cute::move(tCcCol), + args.problem_shape_mnkl, + params + ); + } +}; + +} diff --git a/gptqmodel_ext/vllm_csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp b/gptqmodel_ext/vllm_csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp new file mode 100644 index 000000000..64b7ddae3 --- /dev/null +++ b/gptqmodel_ext/vllm_csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp @@ -0,0 +1,321 @@ +#pragma once + +#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp" + +/* + This file defines custom epilogues for fusing channel scales, token scales, + bias, and activation zero-points onto a GEMM operation using the + CUTLASS 2.x API, for sm80 (Ampere) NVIDIA GPUs. + + Epilogues must contain a public type named EVTCompute of type Sm80EVT, + as well as a static prepare_args function that constructs an + EVTCompute::Arguments struct. +*/ + +namespace vllm::c2x { + +using namespace cute; + +/* + * This class provides the common load descriptors for the + * ScaledEpilogue[...] classes + */ +template +struct ScaledEpilogueBase { + protected: + using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; + + template + using ColOrScalarLoad = + cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< + OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoad = + cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< + OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + + template + using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast< + OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; + + template + using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast< + OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + + template + using RowOrZeroLoad = + cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast< + OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + + // This utility function constructs the arguments for the load descriptors + // from a tensor. It can handle both row and column, as well as row/column or + // scalar cases. + template + static auto args_from_tensor(torch::Tensor const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = static_cast(tensor.data_ptr()); + if constexpr (std::is_same_v> || + std::is_same_v>) { + return Arguments{data_ptr, tensor.numel() != 1}; + } else { + // it would technically work but no use case as data_ptr is never nullptr + static_assert(!std::is_same_v>); + return Arguments{data_ptr}; + } + } + + // This overload handles the case where there might not be a tensor, in which + // case a nullptr is passed and a constant (0) is used. + template + static auto args_from_tensor(std::optional const& tensor) { + static_assert(std::is_same_v>); + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; + return Arguments{data_ptr}; + } +}; + +/* + This epilogue function defines a quantized GEMM operation similar to + torch._scaled_mm. + + A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or + per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ +template +struct ScaledEpilogue + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + + using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::threadblock::Sm80EVT; + + using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::threadblock::Sm80EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, {}}; + } +}; + +/* + * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. + * This bias can also be used in the per-tensor azp case, where the activation + * zero point (azp) is used to compute an azp correction term, + * which is folded into the bias. + * + * The bias tensor must be per-output channel. + * ScaleA and ScaleB can be per-tensor or per-token/per-channel. + */ +template +struct ScaledEpilogueBias + : protected ScaledEpilogueBase { + protected: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::threadblock::Sm80EVT; + + using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, bias_args, {}}; + } +}; + +/* + * This epilogue directly supports per-tensor azp in int32 form. + * As opposed to the per-token epilogue below, this epilogue only has an azp_adj + * term, which should already be multiplied with the scalar azp. + * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzp + : protected ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowOrZeroLoad; + + // This is the full AZP term, azp * J @ B, shape (1,n) + using AzpWithAdj = typename SUPER::template RowLoad; + + // Compute float(accum - azp_adj), both operands are int32_t + using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::threadblock::Sm80EVT; + + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + std::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{ + b_args, evt_azp_args, {}}; + return ArgumentType{a_args, evt_scale_b_args, bias_args, {}}; + } +}; + +/* + * This epilogue supports per-token azp by computing and applying + * the correction term using a rank-1 update. If the term were materialized, + * it would require O(m*n) space, and this way it only requires O(m+n) space. + * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero + * point for each row of A. + * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzpToken + : protected ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowOrZeroLoad; + + // Per-token azp term, shape (m,1) + using Azp = typename SUPER::template ColLoad; + + // This is the AZP adjustment term, J @ B, shape (1,n) + using AzpAdj = typename SUPER::template RowLoad; + + // Compute azp * azp_adj + using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, int32_t, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::threadblock::Sm80EVT; + + // Compute float(accum - azp*azp_adj), all operands are int32_t + using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAcc = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::threadblock::Sm80EVT; + + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + torch::Tensor const& azp, + std::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_args = SUPER::template args_from_tensor(azp); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}}; + typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{ + b_args, evt_acc_args, {}}; + return ArgumentType{a_args, evt_scale_b_args, bias_args, {}}; + } +}; + +}; // namespace vllm::c2x diff --git a/gptqmodel_ext/vllm_csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/gptqmodel_ext/vllm_csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp new file mode 100644 index 000000000..62b848a0a --- /dev/null +++ b/gptqmodel_ext/vllm_csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -0,0 +1,450 @@ +#pragma once + +#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" +#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp" + +/* + This file defines custom epilogues for fusing channel scales, token scales, + bias, and activation zero-points onto a GEMM operation using the + CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later. + + Epilogues must contain a public type named EVTCompute of type Sm90EVT, + as well as a static prepare_args function that constructs an + EVTCompute::Arguments struct. +*/ + +namespace vllm::c3x { + +using namespace cute; + +template +struct identity { + CUTLASS_HOST_DEVICE + T operator()(T lhs) const { return lhs; } +}; + +template +struct TrivialEpilogue { + private: + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + using Compute = cutlass::epilogue::fusion::Sm90Compute< + cutlass::epilogue::thread::Identity, ElementD, ElementAcc, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + template + static ArgumentType prepare_args(Args... args) { + return {}; + } +}; + +/* + * This class provides the common load descriptors for the + * ScaledEpilogue[...] classes + */ +template +struct ScaledEpilogueBase { + protected: + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + template + using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< + 0 /*Stages*/, TileShape, T, Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< + 0 /*Stages*/, TileShape, T, Stride, Int<1>, Int<0>>>; + + // Don't want to support nullptr by default + template + using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0 /*Stages*/, TileShape, T, T, Stride, Int<0>, Int<0>>, + 128 / sizeof_bits_v, EnableNullPtr>; + + // Don't want to support nullptr by default + template + using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< + 0 /*Stages*/, TileShape, T, T, Stride, Int<1>, Int<0>>, + 128 / sizeof_bits_v, EnableNullPtr>; + + template + using ColOrScalarLoadArray = + cutlass::epilogue::fusion::Sm90ColOrScalarBroadcastArray< + 0 /*Stages*/, TileShape, T, Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoadArray = + cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray< + 0 /*Stages*/, TileShape, T, Stride, Int<1>, Int<0>>>; + + // This utility function constructs the arguments for the load descriptors + // from a tensor. It can handle both row and column, as well as row/column or + // scalar cases. + template + static auto args_from_tensor(torch::Tensor const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = static_cast(tensor.data_ptr()); + if constexpr (std::is_same_v> || + std::is_same_v>) { + return Arguments{data_ptr, tensor.numel() != 1}; + } else { + static_assert(!std::is_same_v> && + !std::is_same_v>); + return Arguments{data_ptr}; + } + } + + // This overload handles the case where there might not be a tensor, in which + // case a nullptr is passed and a constant (0) is used. + template + static auto args_from_tensor(std::optional const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; + static_assert(std::is_same_v> || + std::is_same_v>); + return Arguments{data_ptr}; + } + + template + static auto args_from_tensor(const T* const* data_ptr, bool do_broadcast) { + using Arguments = typename Descriptor::Arguments; + static_assert(std::is_same_v> || + std::is_same_v>); + return Arguments{data_ptr, do_broadcast}; + } +}; + +/* + This epilogue function defines a quantized GEMM operation similar to + torch.scaled_mm_. + + A and B may be both either int8 or fp8_e4m3. A can be + quantized per-tensor or per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ +template +struct ScaledEpilogue + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, {}}; + } +}; + +/* + * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. + * This bias can also be used in the per-tensor azp case, where the activation + * zero point (azp) is used to compute an azp correction term, + * which is folded into the bias. + * + * The bias tensor must be per-output channel. + * ScaleA and ScaleB can be per-tensor or per-token/per-channel. + */ +template +struct ScaledEpilogueBias + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, bias_args, {}}; + } +}; + +/* + * This epilogue performs the same operation as ScaledEpilogueBias, but the + * bias is a column vector instead of a row vector. Useful e.g. if we are + * computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels. + */ +template +struct ScaledEpilogueColumnBias + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template ColLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, bias_args, {}}; + } +}; + +/* + * This epilogue directly supports per-tensor azp in int32 form. + * As opposed to the per-token epilogue below, this epilogue only has an azp_adj + * term, which should already be multiplied with the scalar azp. + * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzp + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + // This is the full AZP term, azp * J @ B, shape (1,n) + using AzpWithAdj = typename SUPER::template RowLoad; + + // Compute float(accum - azp_adj), both operands are int32_t + using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + std::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{ + b_args, evt_azp_args, {}}; + return ArgumentType{a_args, evt_scale_b_args, bias_args, {}}; + } +}; + +/* + * This epilogue supports per-token azp by computing and applying + * the correction term using a rank-1 update. If the term were materialized, + * it would require O(m*n) space, and this way it only requires O(m+n) space. + * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero + * point for each row of A. + * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzpToken + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + // Per-token azp term, shape (m,1) + using Azp = typename SUPER::template ColLoad; + + // This is the AZP adjustment term, J @ B, shape (1,n) + using AzpAdj = typename SUPER::template RowLoad; + + // Compute azp * azp_adj + using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, int32_t, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::fusion::Sm90EVT; + + // Compute float(accum - azp*azp_adj), all operands are int32_t + using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAcc = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + torch::Tensor const& azp, + std::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_args = SUPER::template args_from_tensor(azp); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}}; + typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{ + b_args, evt_acc_args, {}}; + return ArgumentType{a_args, evt_scale_b_args, bias_args, {}}; + } +}; + +/* + This epilogue works like ScaledEpilogue, but ScaleA and ScaleB are pointers + to arrays containing different scales used in group gemm. The number of + pointers in ScaleA and the number of pointers in ScaleB are equal to the + group size. +*/ +template +struct ScaledEpilogueArray + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoadArray; + using ScaleB = typename SUPER::template RowOrScalarLoadArray; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + using ScaleAArray = typename SUPER::template ColOrScalarLoadArray; + using ScaleBArray = typename SUPER::template RowOrScalarLoadArray; + + static ArgumentType prepare_args(float const* const* a_scales_ptr, + float const* const* b_scales_ptr, + bool a_col_broadcast, bool b_row_broadcast) { + auto a_args = SUPER::template args_from_tensor( + a_scales_ptr, a_col_broadcast); + auto b_args = SUPER::template args_from_tensor( + b_scales_ptr, b_row_broadcast); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, {}}; + } +}; + +}; // namespace vllm::c3x diff --git a/gptqmodel_ext/vllm_csrc/cutlass_extensions/gemm/.DS_Store b/gptqmodel_ext/vllm_csrc/cutlass_extensions/gemm/.DS_Store new file mode 100644 index 000000000..0e47daf33 Binary files /dev/null and b/gptqmodel_ext/vllm_csrc/cutlass_extensions/gemm/.DS_Store differ diff --git a/gptqmodel_ext/vllm_csrc/cutlass_extensions/gemm/collective/collective_builder.hpp b/gptqmodel_ext/vllm_csrc/cutlass_extensions/gemm/collective/collective_builder.hpp new file mode 100644 index 000000000..ec75c29e5 --- /dev/null +++ b/gptqmodel_ext/vllm_csrc/cutlass_extensions/gemm/collective/collective_builder.hpp @@ -0,0 +1,123 @@ +// Modified from: cutlass/gemm/collective/builders/sm90_gmma_builder.inl +// clang-format off +#pragma once + +#include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl" + +#include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp" + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_WS_SS (BlockScaled Builders) +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + int ScaleGranularityM +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum, + cute::enable_if_t< + not detail::is_use_rmem_A()> +> { + using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum; + + static_assert(is_static::value); + static_assert(is_static::value); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + static_assert(detail::is_aligned(), + "Should meet TMA alignment requirement\n"); + + static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v); + static constexpr bool IsFP8Input = detail::is_input_fp8(); + static_assert((!IsFP8Input || !IsArrayOfPointersGemm), + "KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now."); + + // For fp32 types, map to tf32 MMA value type + using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; + using ElementBMma = cute::conditional_t, tfloat32_t, ElementB>; + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + static constexpr bool IsCooperative = cute::is_any_of_v>; + using AtomLayoutMNK = cute::conditional_t>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector< + GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector< + GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0; + static constexpr int KernelSmemCarveout = static_cast(TensorMapStorage); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/gptqmodel_ext/vllm_csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp b/gptqmodel_ext/vllm_csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp new file mode 100644 index 000000000..13b90e998 --- /dev/null +++ b/gptqmodel_ext/vllm_csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp @@ -0,0 +1,183 @@ +// clang-format off +// adapted from: https://github.com/soundOfDestiny/cutlass/blob/a4208aa6958864923505cade9c63eb2a6daf16e5/include/cutlass/gemm/collective/fp8_accumulation.hpp + +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cute/algorithm/clear.hpp" +#include "cute/tensor.hpp" + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////FP8 Accumulation/////////////////////////// +////////////////////////////////////////////////////////////////////////////// +/// This class provides API to promote (add) or scale (multiply_add) the results +/// from the tensor core accumulators to the main accumulators when the number +/// of MMAs reaches the max number of MMA interval specified by user, after that +/// the tensor core accumulators are zeroed. +////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +template < + class EngineAccum, + class LayoutAccum> +struct GmmaFP8AccumulationWithScale { + using TensorAccum = cute::Tensor; + using ElementAccumulator = typename EngineAccum::value_type; + + static_assert(is_static::value, "Accumulator Layout should be static"); + static_assert(is_rmem::value , "Accumulator tensor must be rmem resident."); + +private: + TensorAccum& accum_; + TensorAccum accum_temp_; + + uint32_t accum_promotion_interval_; // defines the max num of executed MMAs after which accum should be promoted. + uint32_t mma_count_per_mainloop_iteration_; // num of MMAs per k_tile of mainloop + uint32_t mma_count_; // current executed MMAs + uint32_t reset_accum_flag_; // accum needs to be zeroed or not. + + // promote or `add` the partial accumulators to main accumulator (FADD). + CUTLASS_DEVICE + void promote_core() { + warpgroup_wait<0>(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum_); ++i) { + accum_(i) += accum_temp_(i); + } + } + + // `multiply` scale the partial accumulators and `add` to main accumulator (FFMA). + template < + class EngineScale, + class LayoutScale> + CUTLASS_DEVICE + void scale_core(const cute::Tensor &scale) { + using TensorScale = cute::Tensor; + + static_assert(is_static::value, "Scale Layout should be static"); + static_assert(is_rmem::value , "Scale tensor must be rmem resident."); + + static_assert(LayoutAccum{}.shape() == LayoutScale{}.shape(), "Accumulator and scale must have same shape."); + + warpgroup_wait<0>(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum_); ++i) { + accum_(i) += accum_temp_(i) * scale(i); + } + } + +public: + CUTLASS_DEVICE + GmmaFP8AccumulationWithScale( + TensorAccum &accum, + uint32_t accum_promotion_interval, + uint32_t mma_count_per_mainloop_iteration) + : accum_(accum), + accum_promotion_interval_(accum_promotion_interval), + mma_count_per_mainloop_iteration_(mma_count_per_mainloop_iteration), + mma_count_(0), + reset_accum_flag_(0) + { + accum_temp_ = cute::make_fragment_like(accum); + } + + // + // Methods (Common) + // + + CUTLASS_DEVICE + TensorAccum& operator()() { + return accum_temp_; + } + + /// prepare the MMA accumulators when initialization or zeroing is required. + CUTLASS_DEVICE + bool prepare_if_needed() { + return reset_accum_flag_; + } + + // + // Methods (for FADD version) + // + + /// promote (add) the results from the MMA accumulators to main accumulator if needed. + CUTLASS_DEVICE + void promote_if_needed() { + mma_count_ += mma_count_per_mainloop_iteration_; + reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); + if (reset_accum_flag_) { + promote_core(); + mma_count_ = 0; + } + } + + /// promote (add) the residue results from the MMA accumulators to main accumulator if needed. + CUTLASS_DEVICE + void promote_residue_if_needed() { + if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { + promote_core(); + } + } + + // + // Methods (for FFMA version) + // + + /// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed. + template < + class EngineScale, + class LayoutScale> + CUTLASS_DEVICE + void scale_if_needed(const cute::Tensor &scale) { + mma_count_ += mma_count_per_mainloop_iteration_; + reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); + if (reset_accum_flag_) { + scale_core(scale); + mma_count_ = 0; + } + } + + /// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed. + template < + class EngineScale, + class LayoutScale> + CUTLASS_DEVICE + void scale_residue_if_needed(const cute::Tensor &scale) { + if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { + scale_core(scale); + } + } +}; + +} // namespace cutlass::gemm::collective diff --git a/gptqmodel_ext/vllm_csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/gptqmodel_ext/vllm_csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp new file mode 100644 index 000000000..d922a3349 --- /dev/null +++ b/gptqmodel_ext/vllm_csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp @@ -0,0 +1,730 @@ +// clang-format off +// Adapted (Heavily) from: https://github.com/soundOfDestiny/cutlass/blob/9d997ce0dea4c5fa1a617db6b7ff29aa9235822c/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp + +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/trace.h" +#include "cutlass/numeric_types.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm80.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +#include "cutlass_extensions/gemm/dispatch_policy.hpp" +#include "cutlass_extensions/gemm/collective/fp8_accumulation.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule, + int ScaleGranularityM_, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using ElementBlockScale = ElementAccumulator; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + using PipelineParams = typename MainloopPipeline::Params; + + // Two threads per CTA are producers (1 for operand tile and 32 for scales) + static constexpr int NumProducerThreadEvents = 33; + + static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_; + static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + // Block scaling gmem-to-smem copy atom + using SmemBlockScalingCopyAtomA = Copy_Atom, ElementBlockScale>; + using SmemBlockScalingCopyAtomB = Copy_Atom, ElementBlockScale>; + + // Block scaling smem layout + using SmemLayoutScaleA = Layout, Int>>; + using SmemLayoutScaleB = Layout>, Stride<_1>>; // `ScaleNsPerTile` is always 1. + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v, + "ElementAccumulator and ElementBlockScale should be same datatype"); + + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128> { + cute::array_aligned> smem_A; // mxk + cute::array_aligned> smem_B; // nxk + cute::array_aligned> smem_scale_A; // ScaleMsPerTile x k + cute::array_aligned> smem_scale_B; // 1xk + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + ElementBlockScale const* ptr_scale_A; + ElementBlockScale const* ptr_scale_B; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy_A_sm90( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,0), + TileShape{}, + ClusterShape{})); + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy_B_sm90( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,0), + TileShape{}, + ClusterShape{})); + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + // Block scaling factors for A and B + ElementBlockScale const* ptr_scale_A; + ElementBlockScale const* ptr_scale_B; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + uint32_t transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t transaction_bytes_nk = TmaTransactionBytesNK; + uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk; + + return { + tma_load_a, + tma_load_b, + transaction_bytes, + transaction_bytes_mk, + transaction_bytes_nk, + args.ptr_scale_A, + args.ptr_scale_B + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytesMK = + cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytesNK = + cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) + { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + constexpr auto scales_m = Int{}; + auto tM = get<2>(gA_mkl.shape()); + auto tN = get<2>(gB_nkl.shape()); + auto tK = get<3>(gA_mkl.shape()); + + // Make the tiled views of scale tensors + auto scaleA_shape = make_shape(M / ScaleGranularityM, tK, L); // (scale_m,k,l) + auto scaleA_layout = make_ordered_layout(scaleA_shape, Step<_0, _1, _2>{}); + auto scaleB_shape = make_shape(tN, tK, L); // (n,k,l) + auto scaleB_layout = make_ordered_layout(scaleB_shape, Step<_1, _0, _2>{}); + + // Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and + // gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl. + Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (scale_m,k,l) + Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TensorB, + class TensorScaleA, class TensorScaleB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + // Blockscaling: Tma loads for load_input and CpAsync for load_scale + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k) + Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k) + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + + // Block scaling: load_scale has scaling tensors in global memory which are not tiled + Tensor mScaleA_mkl = get<2>(load_inputs); + Tensor mScaleB_nkl = get<3>(load_inputs); + auto scales_m = get<0>(mScaleA_mkl.shape()); + + Tensor cScaleA_mkl = make_identity_tensor(mScaleA_mkl.shape()); + + Tensor gScaleA = local_tile( + mScaleA_mkl, make_tile(Int{}), + make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1) + Tensor cScaleA = local_tile( + cScaleA_mkl, make_tile(Int{}), + make_coord(m_coord,_,l_coord)); + Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1) + + // TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128 + TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{}, + Layout>{}, Layout>{}); // (1,1,1) + TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{}, + Layout>{}, Layout>{}); // (1,1,1) + ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x); + ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x); + + Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA); + Tensor tAcA_ScaleA = thr_scale_copy_a.partition_S(cScaleA); + Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA); + + Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB); + Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB); + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Allocate predicate tensors for a_scales (since we can't guarantee that + // all scales are valid, since we could have a partial tiles along M) + Tensor tApA_ScaleA = make_tensor(shape(tAsA_ScaleA(_,_,0))); + #pragma unroll + for (int i = 0; i < size(tApA_ScaleA); ++i) { + tApA_ScaleA(i) = get<0>(tAcA_ScaleA(i)) < scales_m; + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + int write_stage = smem_pipe_write.index(); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + // Copy operands A and B from global memory to shared memory + if (lane_predicate) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + if (lane_predicate) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + + // Copy scale tensors from global memory to shared memory + copy_if(scale_copy_a, tApA_ScaleA, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage)); + copy(scale_copy_b, tBgB_ScaleB(_,*k_tile_iter), tBsB_ScaleB(_,write_stage)); + pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc); + + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail( + MainloopPipeline pipeline, + PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Block scaling + Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), + Layout< + Shape, Int>, cute::tuple_element_t<1, TileShape>, Int>, + Stride, _0, Int> + >{}); // ((ScaleGranularityM,ScaleMsPerTile),n,k) + Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k) + + // + // Define C accumulators and A/B partitioning + // + + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + + Tensor tCsScaleAViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C. + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Per block scale values for operand A and B + + using RegLayoutScaleAViewAsC = decltype(make_layout_like(tCsScaleAViewAsC(_, _, _, 0).layout())); // `make_layout_like` makes a compact layout. + using RegLayoutScaleAEssential = decltype(filter_zeros(RegLayoutScaleAViewAsC{}.stride(), RegLayoutScaleAViewAsC{}.shape())); // an interface to traverse the underlying storage for the compact layout mentioned above + + Tensor tCrScaleAViewAsC = make_tensor(RegLayoutScaleAViewAsC{}); // (MMA,MMA_M,MMA_N) + ElementBlockScale scale_b; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + GmmaFP8AccumulationWithScale accumulation(accum, size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{}), size<2>(tCrA)); + warpgroup_fence_operand(accumulation()); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + int read_stage = smem_pipe_read.index(); + + // Load per block scale values from shared memory to registers. + scale_b = sScaleB[read_stage]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { + tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{})); + } + if constexpr (ScaleMsPerTile == 1) { + static_assert(size(RegLayoutScaleAEssential{}) == 1); + tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`. + } else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { + tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b; + } + } + + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` + accumulation.scale_if_needed(tCrScaleAViewAsC); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accumulation()); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + // Load per block scale values from shared memory to registers (at most twice per block along M and exactly once per block along N) + scale_b = sScaleB[read_stage]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { + tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{})); + } + if constexpr (ScaleMsPerTile == 1) { + static_assert(size(RegLayoutScaleAEssential{}) == 1); + tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`. + } else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { + tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b; + } + } + + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + warpgroup_fence_operand(accumulation()); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accumulation()); + + // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` + accumulation.scale_if_needed(tCrScaleAViewAsC); + + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + accumulation.scale_residue_if_needed(tCrScaleAViewAsC); + + warpgroup_fence_operand(accumulation()); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/gptqmodel_ext/vllm_csrc/cutlass_extensions/gemm/dispatch_policy.hpp b/gptqmodel_ext/vllm_csrc/cutlass_extensions/gemm/dispatch_policy.hpp new file mode 100644 index 000000000..df809e27a --- /dev/null +++ b/gptqmodel_ext/vllm_csrc/cutlass_extensions/gemm/dispatch_policy.hpp @@ -0,0 +1,39 @@ +#pragma once + +#include "cutlass/gemm/dispatch_policy.hpp" + +namespace cutlass::gemm { + +////////////////////////////////////////////////////////////////////////////// + +// FP8 related policies (including Blocked Scaled Accumulation) +// `ScaleGranularityM` specifies scaling granularity along M, while zero-value +// `ScaleGranularityM` indicates that scaling granularity is +// `size<0>(TileShape_MNK{})` along M. +template +struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum + : KernelTmaWarpSpecializedCooperative {}; + +// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp +// specialized dynamic schedule For FP8 kernels with Block Scaling +template , + class KernelSchedule = KernelTmaWarpSpecialized, + int ScaleGranularityM = + 0 // `ScaleGranularityM` specifies scaling granularity along M, + // while zero-value `ScaleGranularityM` indicates that scaling + // granularity is `size<0>(TileShape_MNK{})` along M. + > +struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8 + : MainloopSm90TmaGmmaWarpSpecialized { + static_assert( + cute::is_same_v< + KernelSchedule, + KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum< + ScaleGranularityM>>, + "KernelSchedule must be one of the warp specialized policies"); +}; + +////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm \ No newline at end of file diff --git a/gptqmodel_ext/vllm_csrc/cutlass_extensions/torch_utils.hpp b/gptqmodel_ext/vllm_csrc/cutlass_extensions/torch_utils.hpp new file mode 100644 index 000000000..a1ff933cc --- /dev/null +++ b/gptqmodel_ext/vllm_csrc/cutlass_extensions/torch_utils.hpp @@ -0,0 +1,160 @@ +#pragma once + +#include + +#include "cute/layout.hpp" +#include "cutlass/layout/matrix.h" +#include "cutlass/bfloat16.h" +#include "cutlass/half.h" + +using ColumnMajor = typename cutlass::layout::ColumnMajor; +using RowMajor = typename cutlass::layout::RowMajor; + +namespace cute { + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g, + seq) { + return g(f(cute::get(static_cast(t)), I)...); +} + +template +CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f, seq) { + return make_shape(f(I)...); +} + +}; // namespace detail + +template +CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) { + if constexpr (cute::is_tuple::value) { + return detail::tapply_with_idx( + t, f, [](auto const&... a) { return cute::make_tuple(a...); }, + tuple_seq{}); + } else { + return f(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// calls: make_shape(f(0), f(1), ..., f(N-1)) +template +CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) { + return detail::make_shape_from_idx(f, make_seq{}); +} + +}; // namespace cute + +// Make a layout from a tensor with `rank(Stride{})`, where the shape is the +// shape of the passed in tensor and the strides are of type `Stride` and +// contain the strides of the passed in tensor, checking that any static strides +// in `Stride{}` match the strides of the passed in tensor. +// If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra +// strides are set to be 0 or 1. +template +static inline auto make_cute_layout(torch::Tensor const& tensor, + std::string_view name = "tensor") { + TORCH_CHECK(tensor.dim() <= rank(Stride{})); + auto stride = cute::transform_with_idx( + Stride{}, [&](auto const& stride_ele, auto const& idx) { + using StrideEle = std::decay_t; + + if (idx < tensor.dim()) { + if constexpr (cute::is_static_v) { + TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ", + name, ".stride(", idx, ") to be ", StrideEle::value); + return StrideEle{}; + } else { + if (tensor.size(idx) == 1) { + // use 0 stride for dim with size 1, this is easier for + // cute/cutlass to optimize (helps the TMA code flatten dims) + return StrideEle{0}; + } else { + return tensor.stride(idx); + } + } + } else { + // Extra strides are assumed to be 0 or 1 + if constexpr (cute::is_static_v) { + static_assert(StrideEle::value == 0 || StrideEle::value == 1); + } + return StrideEle{}; + } + }); + + auto shape = cute::make_shape_from_idx([&](auto const& idx) { + if (idx < tensor.dim()) + return tensor.size(idx); + else + return int64_t(1); + }); + + return make_layout(shape, stride); +} + +template +static inline auto maybe_make_cute_layout( + std::optional const& tensor, + std::string_view name = "tensor") { + using Layout = decltype(make_cute_layout(*tensor)); + + if (tensor) { + return std::optional{make_cute_layout(*tensor, name)}; + } else { + return std::optional{}; + } +} + +// +// Torch Type to Cutlass Type (equivalent_cutlass_type) +// + +template +struct equivalent_cutlass_type { + using type = T; +}; + +template +using equivalent_cutlass_type_t = typename equivalent_cutlass_type::type; + +template <> +struct equivalent_cutlass_type { + using type = cutlass::half_t; +}; + +template <> +struct equivalent_cutlass_type { + using type = cutlass::bfloat16_t; +}; + +// +// equivalent_scalar_t (basically inverse of equivalent_cutlass_type) +// + +// Return a `c10::CppTypeToScalarType` compatible type, i.e. get the C++ from +// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half` +template +struct equivalent_scalar_type { + using type = T; +}; + +template +using equivalent_scalar_type_t = typename equivalent_scalar_type::type; + +template <> +struct equivalent_scalar_type { + using type = c10::Half; +}; + +template <> +struct equivalent_scalar_type { + using type = c10::BFloat16; +}; + +// get equivalent c10::ScalarType tag from compile time type +template +static inline constexpr c10::ScalarType equivalent_scalar_type_v = + c10::CppTypeToScalarType>::value; \ No newline at end of file diff --git a/gptqmodel_ext/vllm_csrc/cutlass_extensions/vllm_collective_builder.cuh b/gptqmodel_ext/vllm_csrc/cutlass_extensions/vllm_collective_builder.cuh new file mode 100644 index 000000000..e7fbba4cd --- /dev/null +++ b/gptqmodel_ext/vllm_csrc/cutlass_extensions/vllm_collective_builder.cuh @@ -0,0 +1,43 @@ +#pragma once + +#include "cutlass_extensions/gemm/collective/collective_builder.hpp" + +namespace cutlass::gemm::collective { +using namespace cute; + +// +// VLLMCollectiveBuilder is a wrapper around CollectiveBuilder that allows for +// for custom kernel tags, allowing you to build custom collectives. Without +// touching the cutlass library headers, using `CutlassKernelTag` will mean it +// will resort to using the standard cutlass collective builder. +// + +// Use the default Cutlass collective builder, i.e. use an unmodified cutless +// collective +struct CutlassKernelTag {}; + +template +struct VLLMCollectiveBuilder { + static_assert(sizeof(ElementA) == 0, + "Could not build a collective for given parameters."); +}; + +template +struct VLLMCollectiveBuilder< + CutlassKernelTag, ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA, + ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, + ClusterShape_MNK, StageCountType, KernelScheduleType> { + using CollectiveOp = typename CollectiveBuilder< + ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA, ElementB, + GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, + ClusterShape_MNK, StageCountType, KernelScheduleType>::CollectiveOp; +}; + +}; // namespace cutlass::gemm::collective \ No newline at end of file diff --git a/gptqmodel_ext/vllm_csrc/cutlass_extensions/vllm_custom_types.cuh b/gptqmodel_ext/vllm_csrc/cutlass_extensions/vllm_custom_types.cuh new file mode 100644 index 000000000..6146bdc1f --- /dev/null +++ b/gptqmodel_ext/vllm_csrc/cutlass_extensions/vllm_custom_types.cuh @@ -0,0 +1,50 @@ +#pragma once + +#include "cutlass/integer_subbyte.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct vllm_biased_integer_subbyte : public integer_subbyte { + using Base = integer_subbyte; + + using Storage = typename Base::Storage; + using xint_t = typename Base::xint_t; + + using Base::bits_mask_; + using Base::sign_mask_; + using Base::storage; + + // + // Methods + // + + /// No operation + vllm_biased_integer_subbyte() = default; + + /// Conversion from integer type + CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(int value) + : Base(value) {} + CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(unsigned value) + : Base(value) {} + CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(double value) + : Base(value) {} +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// "GPTQ" types, i.e. symmetric quantization +using vllm_uint4b8_t = vllm_biased_integer_subbyte<4, 8>; // u4b8 +using vllm_uint8b128_t = vllm_biased_integer_subbyte<8, 128>; // u8b128 + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct sizeof_bits> { + static constexpr int value = Bits; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/gptqmodel_ext/vllm_csrc/cutlass_extensions/vllm_cutlass_library_extension.py b/gptqmodel_ext/vllm_csrc/cutlass_extensions/vllm_cutlass_library_extension.py new file mode 100644 index 000000000..d64f0d0a5 --- /dev/null +++ b/gptqmodel_ext/vllm_csrc/cutlass_extensions/vllm_cutlass_library_extension.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 + +import enum +from typing import Union + +from cutlass_library import * + +# +# Extend cutlass library with custom types, and missing values +# + + +class VLLMDataType(enum.Enum): + u4b8 = enum_auto() + u8b128 = enum_auto() + + +class MixedInputKernelScheduleType(enum.Enum): + TmaWarpSpecialized = enum_auto() + TmaWarpSpecializedPingpong = enum_auto() + TmaWarpSpecializedCooperative = enum_auto() + + +VLLMDataTypeNames: dict[Union[VLLMDataType, DataType], str] = { + **DataTypeNames, # type: ignore + **{ + VLLMDataType.u4b8: "u4b8", + VLLMDataType.u8b128: "u8b128", + } +} + +VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = { + **DataTypeTag, # type: ignore + **{ + VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t", + VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t", + } +} + +VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = { + **DataTypeSize, # type: ignore + **{ + VLLMDataType.u4b8: 4, + VLLMDataType.u8b128: 8, + } +} + +VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = { + VLLMDataType.u4b8: "vllm::kU4B8", + VLLMDataType.u8b128: "vllm::kU8B128", + DataType.u4: "vllm::kU4", + DataType.u8: "vllm::kU8", + DataType.s4: "vllm::kS4", + DataType.s8: "vllm::kS8", + DataType.f16: "vllm::kFloat16", + DataType.bf16: "vllm::kBfloat16", +} + +VLLMDataTypeTorchDataTypeTag: dict[Union[VLLMDataType, DataType], str] = { + DataType.u8: "at::ScalarType::Byte", + DataType.s8: "at::ScalarType::Char", + DataType.e4m3: "at::ScalarType::Float8_e4m3fn", + DataType.s32: "at::ScalarType::Int", + DataType.f16: "at::ScalarType::Half", + DataType.bf16: "at::ScalarType::BFloat16", + DataType.f32: "at::ScalarType::Float", +} + +VLLMKernelScheduleTag: dict[Union[ + MixedInputKernelScheduleType, KernelScheduleType], str] = { + **KernelScheduleTag, # type: ignore + **{ + MixedInputKernelScheduleType.TmaWarpSpecialized: + "cutlass::gemm::KernelTmaWarpSpecialized", + MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: + "cutlass::gemm::KernelTmaWarpSpecializedPingpong", + MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: + "cutlass::gemm::KernelTmaWarpSpecializedCooperative", + } + } diff --git a/gptqmodel_ext/vllm_csrc/cutlass_extensions/vllm_numeric_conversion.cuh b/gptqmodel_ext/vllm_csrc/cutlass_extensions/vllm_numeric_conversion.cuh new file mode 100644 index 000000000..90f226cf6 --- /dev/null +++ b/gptqmodel_ext/vllm_csrc/cutlass_extensions/vllm_numeric_conversion.cuh @@ -0,0 +1,992 @@ +#pragma once + +#include "cutlass/numeric_conversion.h" +#include "cutlass_extensions/vllm_custom_types.cuh" +#include "cutlass_extensions/cute_utils.cuh" +#include "cutlass_extensions/vllm_type_utils.cuh" + +// this file extends: +// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h +// with vllm specific type conversions, namely: vllm_uint4b8_t, vllm_uint8b128_t +// as well as adds interleaved numeric array converters for specific types. +// (interleaved numeric array converters can be more efficient for subbyte +// types) + +namespace cutlass { + +// InterleavedNumericArrayConverter is like NumericArrayConverter but also +// deinterleaves converted elements based on IlvBlkLayout, interleaving can +// make subbyte converts more efficient by allowing for efficient extraction +// of subbyte elements from a 32bit register. +template +struct InterleavedNumericArrayConverter { + using Converter = NumericArrayConverter; + + using result_type = typename Converter::result_type; + using source_type = typename Converter::source_type; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + if (cute::elect_one_sync()) { + if constexpr (std::is_same_v) { + printf( + "Convert %s <= %s (N = %d, IlvBlkLayout = void), not implemented\n", + nameof_v, nameof_v, N); + } else { + printf( + "Convert %s <= %s (N = %d, size(IlvBlkLayout{}) = %d), not " + "implemented\n", + nameof_v, nameof_v, N, size(IlvBlkLayout{})); + } + __brkpt(); + } + return {}; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +template +struct InterleavedNumericArrayConverter< + IlvBlkLayout, T, S, N, Round, + std::enable_if_t()>> { + using Converter = NumericArrayConverter; + + using result_type = typename Converter::result_type; + using source_type = typename Converter::source_type; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return Converter::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +template +struct ArrayConverterPacked32Bit { + using result_type = Array; + using source_type = Array; + + using result_packed_8_t = Array; + using result_packed_4_t = Array; + using result_packed_2_t = Array; + using src_packed_8_t = Array; + using src_packed_4_t = Array; + using src_packed_2_t = Array; + + static_assert(N % 2 == 0, "N must be a multiple of 2"); + static_assert(cutlass::sizeof_bits_v >= 4); // TODO: add 16 packed sources + static_assert(32 % cutlass::sizeof_bits_v == 0); + static constexpr auto src_elems_per_32bit_reg = + 32 / cutlass::sizeof_bits_v; + + // Maybe not Valid. ScalarConverter will not actually work unless + // NumericConverter is implemented. However it won't be used + // anyways since we assert N % 2 == 0, just here for compliance with + // VectorizedConverter. + using ScalarConverter = NumericConverter; + + template + CUTLASS_DEVICE static auto to_regs(PackedSrc const& src) { + if constexpr (sizeof(PackedSrc) == 1) { + return Array{reinterpret_cast(src)}; + } else if constexpr (sizeof(PackedSrc) == 2) { + return Array{reinterpret_cast(src)}; + } else if constexpr (sizeof(PackedSrc) == 4) { + return Array{reinterpret_cast(src)}; + } else { + static_assert(sizeof(PackedSrc) == 8); + return reinterpret_cast const&>(src); + } + } + + // The core converter uses bit tricks to construct a known FP16 number, then + // does a subtraction in FP16 for the final result. + template + CUTLASS_DEVICE static PackedResultType packed_convert( + PackedSrcType const& source) { + static_assert(PackedSrcType::kElements == PackedResultType::kElements); + static_assert(PackedResultType::kElements == 2 || + PackedResultType::kElements == 4 || + PackedResultType::kElements == 8, + "Invalid PackedResultType must be 2, 4 or 8."); + static_assert(std::is_same_v); + static_assert(std::is_same_v); + + return RegConvert32bit::template convert(to_regs(source)); + } + + friend class detail::VectorizedConverter; + + public: + CUTLASS_DEVICE static result_type convert(source_type const& source) { + result_type result; + using ConverterType = + ArrayConverterPacked32Bit; + + if constexpr (src_elems_per_32bit_reg >= 8) { + detail::VectorizedConverter::convert< + ConverterType, result_packed_8_t, src_packed_8_t, result_packed_4_t, + src_packed_4_t, result_packed_2_t, src_packed_2_t>(result, source); + } else if constexpr (src_elems_per_32bit_reg >= 4) { + detail::VectorizedConverter::convert(result, source); + } else { + detail::VectorizedConverter::convert(result, source); + } + + return result; + } +}; + +// Convert 8 4bit values packed into a 32bit register to 8 8bit values packed +// into 2 32bit register. +template +CUTLASS_DEVICE cutlass::AlignedArray lut_4bit_to_8bit_convert( + uint32_t src) { + cutlass::AlignedArray r; + // Determines if the value is in the top half of the LUT if set or + // (i.e. LUT[8:15]) in the bottom half (i.e. LUT[0:7]) if not set. Then move + // into bit position 0x4 of each nibble so when or'd with final_prmt_base it + // selects the correct candidate. When elements in final_prmt_base + // are >= 0x4, the high candidate is selected (i.e. LUT[8:15]), when elements + // are < 0x4, the low candidate is selected (i.e. LUT[0:7]) + uint32_t high_bit = (src & 0x88888888) >> 1; + + // `high_bit` is OR'd with 0x31203120 to find the correct value in the LUT + // (selects correct high or low candidate) + const uint32_t final_prmt_base = 0x32103210; + + // Ignore the high bit when indexing into LUT, for each 4bit value + // we index into both the high and low candidates then use + // high_bit | final_prmt_base to select the correct candidate + uint32_t lut_idx = (src & 0x77777777); + + auto pack = [](uint8_t a, uint8_t b, uint8_t c, uint8_t d) { + return uint32_t(a) | (uint32_t(b) << 8) | (uint32_t(c) << 16) | + (uint32_t(d) << 24); + }; + + static constexpr uint32_t LOW_0 = pack(LUT0, LUT1, LUT2, LUT3); + static constexpr uint32_t LOW_1 = pack(LUT4, LUT5, LUT6, LUT7); + static constexpr uint32_t HIGH_0 = pack(LUT8, LUT9, LUT10, LUT11); + static constexpr uint32_t HIGH_1 = pack(LUT12, LUT13, LUT14, LUT15); + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 2; ++ii, lut_idx >>= 16, high_bit >>= 16) { + uint32_t final_prmt_idx = final_prmt_base | high_bit; + + // This uses a look up table to convert packed int4s to packed int8s, + // using the int4 value as the index to prmt. It first select both the + // high and low candidates, then uses the high bit (i.e. `high_bit`) to + // select the correct candidate. + asm volatile( + "{\n" + " .reg .b32 low, high;\n" + " prmt.b32 low, %1, %2, %5;\n" + " prmt.b32 high, %3, %4, %5;\n" + " prmt.b32 %0, low, high, %6;\n" + "}\n" + : "=r"(r[ii]) + : "n"(LOW_0), "n"(LOW_1), "n"(HIGH_0), "n"(HIGH_1), "r"(lut_idx), + "r"(final_prmt_idx)); + } + + return r; +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + // [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as int8s + auto r = lut_4bit_to_8bit_convert<0xF8, 0xF9, 0xFA, 0xFB, // + 0xFC, 0xFD, 0xFE, 0xFF, // + 0x00, 0x01, 0x02, 0x03, // + 0x04, 0x05, 0x06, 0x07>(src_[0]); + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + // [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as fp8s + auto r = lut_4bit_to_8bit_convert<0xD0, 0xCE, 0xCC, 0xCA, // + 0xC8, 0xC4, 0xC0, 0xB8, // + 0x00, 0x38, 0x40, 0x44, // + 0x48, 0x4A, 0x4C, 0x4E>(src_[0]); + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; + using RegArray = + cutlass::AlignedArray; + RegArray r; + + // Below constructs the following temporary: + // fp16s_01 = {0x00, i4_01, 0x00, i4_01} + // fp16s_23 = {0x00, i4_23, 0x00, i4_23} + // fp16s_45 = {0x00, i4_45, 0x00, i4_45} + // fp16s_67 = {0x00, i4_67, 0x00, i4_67} + // We use inline asm instead of __byte_perm intrinsic since we don't want + // the documented (& 0x7) on the index. NVCC might be able to optimize it + // out since the index is a constexpr, but we choose to be safe about it + // here. + uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343}; + static_assert(RegArray::kElements <= 4, + "Too many inputs for F16 -> I4 vector converter"); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " prmt.b32 %0, %1, %2, %3;\n" + "}\n" + : "=r"(r[ii]) + : "r"(src), "n"(0), "r"(prmt_indices[ii])); + } + + // Since the stored 4bit values are biased by 8 we get stored_val = (x+8) + // we are trying to construct x and a fp16 value + // The below XOR does the following: + // 1) Sets the exponent bits of the FP16 to the correct value for the + // FP16 magic_num. We will be constructing {1024+16*(x1+8), 1024+(x0+8)}, + // where x1 in the high nibble and x0 is the low nibble then using hfma + // to subtract 1032 from that + // The AND does the following: + // 1) Clear the set bits for the int4 we will ignore. + // We use lop3 so that we can use 1 instruction for AND and XOR. + static constexpr uint32_t xor_mask = 0x64006400; + static constexpr uint32_t and_mask = 0xFFF0FF0F; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + + // For each operand, computes: + // r[i] = (r[i] & and_mask) ^ xor_mask + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + } + + // We will issue 2 hfmas that do the following: + // {x1, x0} = {1024+16*(x1+8), 1024+(x0+8)} * {1/16, 1} - {72, 1032} + // = {x1 + 1152, x0 + 1032} * {1/16, 1} - {72, 1032} + static constexpr uint32_t hfma_bias_rep = 0xD480E408; // {72, 1032} + static constexpr uint32_t hfma_scale_rep = 0x2C003C00; // {1 / 16, 1} + + const half2& hfma_bias = reinterpret_cast(hfma_bias_rep); + const half2& hfma_scale = reinterpret_cast(hfma_scale_rep); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); + fp16x2_val = __hfma2(hfma_scale, fp16x2_val, hfma_bias); + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::half_t, vllm_uint4b8_t, N, + Round, void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t xor_mask = 0x64006400; + + for (int ii = 0; ii < RegArray::kElements; ii += 2) { + auto src_ = src >> (4 * (ii)); + r[ii + 0] = src_; + r[ii + 1] = src_; + + static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa; + + static constexpr uint32_t low_nib_mask = 0x000F000F; + static constexpr uint32_t high_nib_mask = 0x00F000F0; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 0]) + : "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 1]) + : "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + // For low nibble: + // {x1, x0} = {1024+(x1+8), 1024+(x0+8)} * {1, 1} - {1032, 1032} + // For high nibble: + // {x1, x0} = {1024+16*(x1+8), 1024+16*(x0+8)} * {1/16, 1/16} + // - {72, 72} + static constexpr uint32_t low_nib_bias = 0x64086408; // {1032, 1032} + static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16} + static constexpr uint32_t high_nib_bias = 0xD480D480; // {-72, -72} + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]); + fp16x2_val = + __hsub2(fp16x2_val, reinterpret_cast(low_nib_bias)); + } + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]); + fp16x2_val = __hfma2(fp16x2_val, + reinterpret_cast(high_nib_scale), + reinterpret_cast(high_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::half_t, uint4_t, N, Round, + void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t xor_mask = 0x64006400; + + for (int ii = 0; ii < RegArray::kElements; ii += 2) { + auto src_ = src >> (4 * (ii)); + r[ii + 0] = src_; + r[ii + 1] = src_; + + static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa; + + static constexpr uint32_t low_nib_mask = 0x000F000F; + static constexpr uint32_t high_nib_mask = 0x00F000F0; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 0]) + : "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 1]) + : "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + // For low nibble: + // {x1, x0} = {1024+x1, 1024+x0} - {1024, 1024} + // For high nibble: + // {x1, x0} = {1024+16*x1, 1024+16*x0} * {1/16, 1/16} - {64, 64} + static constexpr uint32_t low_nib_bias = 0x64006400; // {1024, 1024} + static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16} + static constexpr uint32_t high_nib_bias = 0xD400D400; // {-64, -64} + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]); + fp16x2_val = + __hsub2(fp16x2_val, reinterpret_cast(low_nib_bias)); + } + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]); + fp16x2_val = __hfma2(fp16x2_val, + reinterpret_cast(high_nib_scale), + reinterpret_cast(high_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; + // Hold output FP16s in reg. We need 1 reg for every 2 elements + using RegArray = + cutlass::AlignedArray; + RegArray r; + + uint32_t const prmt_indices[2] = {0x5150, 0x5352}; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(r[ii]) + : "r"(src), "n"(start_byte_for_fp16), + "r"(prmt_indices[ii])); + } + + // -128 is folded into bias subtraction, i.e. the 0x80 in the low bytes + static constexpr uint32_t bias_rep = 0x64806480; + const half2& bias = reinterpret_cast(bias_rep); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); + fp16x2_val = __hsub2(fp16x2_val, bias); + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; + PackedResultType r; + + // __byte_perm simulates the add.u32 0x4B000000 to every u8 element of + // u8x4 source and stores the result in r (without introducing extra + // cvt.u32.u8 instruction) + uint32_t const prmt_indices[4] = {0x7650, 0x7651, 0x7652, 0x7653}; + uint32_t* result_as_int = reinterpret_cast(&r); + for (int ii = 0; ii < PackedResultType::kElements; ++ii) { + result_as_int[ii] = __byte_perm(src, 0x4B000000, prmt_indices[ii]); + // Subtract the magic number 0x4B000000 from tmp in floating-point + // arithmetic to obtain final result + r[ii] -= (8388608.f + 128.f); // fold in -128 bias + } + + return r; + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src_reg = src_[0]; + // Hold output BF16s in reg. We need 1 reg for every 2 elements + using RegArray = + cutlass::AlignedArray; + RegArray r; + uint32_t src_reg_shifted = src_reg >> 4; + + // Below constructs the following temporary: + uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3}; + static_assert(RegArray::kElements <= 4, + "Too many inputs for uint4b8_t -> BF16 vector converter"); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " prmt.b32 %0, %1, %2, %3;\n" + "}\n" + : "=r"(r[ii]) + : "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii])); + } + + // Since the stored 4bit values are biased by 8 we get stored_val = (x+8) + // we are trying to construct x and a BF16 value + // The below XOR does the following: + // 1) Sets the exponent bits of the BF16 to the correct value for the + // BF16 magic_num. We will be constructing {128 + (x1+8), 128 + (x0+8)} + // and subtracting 136 to get {x1, x0} + static constexpr uint32_t xor_mask = 0x43004300; + static constexpr uint32_t and_mask = 0x000F000F; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + + // For each operand, computes: + // r[i] = (r[i] & and_mask) ^ xor_mask + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + } + + // We will issue 2 bfmas that do the following: + // high BF16: + // hi_bf16 - 136, lo_bf16 - 136 + + // This is the BF16 {136, 136} represented as an integer. + static constexpr uint32_t bias_rep = 0x43084308; + const __nv_bfloat162& bias = + reinterpret_cast(bias_rep); + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + bf16x2_val = __hsub2(bf16x2_val, bias); + } + + return reinterpret_cast(r); + } + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::bfloat16_t, vllm_uint4b8_t, N, + Round, void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t or_mask = 0x43004300; + + // Unlike float16 where the mantissa is large enough to contain 2 + // nibbles, bfloat16 can only fit one, so we can only convert one + // nibble at a time + for (int ii = 0; ii < RegArray::kElements; ++ii) { + r[ii] = src >> (4 * ii); + + static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t low_nib_mask = 0x000F000F; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 0]) + : "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut)); + + // For low nibble: + // {x1, x0} = {128+(x1+8), 128+(x0+8)} * {1, 1} - {136, 136} + static constexpr uint32_t low_nib_bias = 0x43084308; // {136, 136} + + { + __nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + fp16x2_val = + __hsub2(fp16x2_val, + reinterpret_cast(low_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::bfloat16_t, uint4_t, N, Round, + void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t or_mask = 0x43004300; + + // Unlike float16 where the mantissa is large enough to contain 2 + // nibbles, bfloat16 can only fit one, so we can only convert one + // nibble at a time + for (int ii = 0; ii < RegArray::kElements; ++ii) { + r[ii] = src >> (4 * ii); + + static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t low_nib_mask = 0x000F000F; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut)); + + // For low nibble: + // {x1, x0} = {128 + x1, 128 + x0} * {1, 1} - {128, 128} + static constexpr uint32_t low_nib_bias = 0x43004300; // {128, 128} + + { + __nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + fp16x2_val = + __hsub2(fp16x2_val, + reinterpret_cast(low_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + private: + using result_packed_4_t = Array; + using result_packed_2_t = Array; + using src_packed_4_t = Array; + using src_packed_2_t = Array; + + // Not Valid, not supported, only here to satisfy the interface and to avoid + // a compile error. ScalarConverter will not actually work until + // NumericConverter is + // implemented + using ScalarConverter = + NumericConverter; + + template + CUTLASS_DEVICE static PackedResultType packed_convert( + PackedSrcType const& source) { + static_assert( + (platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private " + "convert dispatch."); + + NumericArrayConverter + convert_uint8_to_f32; + Array tmp = + convert_uint8_to_f32(source); + NumericArrayConverter + convert_f32_to_bf16_; + return convert_f32_to_bf16_(tmp); + } + + friend class detail::VectorizedConverter; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; + using ConverterType = + NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +#endif + +// for Array <= Array +// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904 +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + struct RegConvert { + // FastFP16toINT8 from https://arxiv.org/pdf/2406.09904 + template + CUTLASS_DEVICE static PackedResultType convert( + Array src) { + // Hold output int8s in reg. We need 1 reg for every 4 elements + using RegArray = cutlass::AlignedArray< + uint32_t, std::max(PackedResultType::kElements / 4, size_t(1))>; + RegArray r; + + static constexpr uint32_t MAGIC_BIAS_ = 0x64806480; + auto MAGIC_BIAS = *reinterpret_cast(&MAGIC_BIAS_); + + *reinterpret_cast(&src[0]) = + __hadd2(*reinterpret_cast(&src[0]), MAGIC_BIAS); + + if constexpr (src_regs > 1) { + *reinterpret_cast(&src[1]) = + __hadd2(*reinterpret_cast(&src[1]), MAGIC_BIAS); + } + + static_assert(PackedResultType::kElements <= 4); + uint32_t uint8s; + static constexpr uint32_t MASK_0246 = 0x6420; + static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(uint8s) + : "r"(src[0]), "r"((src_regs > 1) ? src[1] : src[0]), + "n"(MASK_0246)); + + uint32_t int8s = (uint8s ^ UINT8s_TO_INT8s_MASK); + + return reinterpret_cast(int8s); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/gptqmodel_ext/vllm_csrc/cutlass_extensions/vllm_type_utils.cuh b/gptqmodel_ext/vllm_csrc/cutlass_extensions/vllm_type_utils.cuh new file mode 100644 index 000000000..500ed508c --- /dev/null +++ b/gptqmodel_ext/vllm_csrc/cutlass_extensions/vllm_type_utils.cuh @@ -0,0 +1,42 @@ +#include "cutlass/bfloat16.h" +#include "cutlass/half.h" +#include "cuda_bf16.h" + +#include "cutlass_extensions/vllm_custom_types.cuh" + +namespace cutlass { + +template +struct nameof { + static constexpr char const* value = "unknown"; +}; + +template +inline constexpr auto nameof_v = nameof::value; + +#define NAMEOF_TYPE(T) \ + template <> \ + struct nameof { \ + static constexpr char const* value = #T; \ + }; + +NAMEOF_TYPE(float_e4m3_t) +NAMEOF_TYPE(float_e5m2_t) +NAMEOF_TYPE(half_t) +NAMEOF_TYPE(nv_bfloat16) +NAMEOF_TYPE(bfloat16_t) +NAMEOF_TYPE(float) + +NAMEOF_TYPE(int4b_t) +NAMEOF_TYPE(int8_t) +NAMEOF_TYPE(int32_t) +NAMEOF_TYPE(int64_t) + +NAMEOF_TYPE(vllm_uint4b8_t) +NAMEOF_TYPE(uint4b_t) +NAMEOF_TYPE(uint8_t) +NAMEOF_TYPE(vllm_uint8b128_t) +NAMEOF_TYPE(uint32_t) +NAMEOF_TYPE(uint64_t) + +}; // namespace cutlass \ No newline at end of file diff --git a/setup.py b/setup.py index 1f91f90a9..d8e01f3ff 100644 --- a/setup.py +++ b/setup.py @@ -164,14 +164,25 @@ def get_version_tag() -> str: from distutils.sysconfig import get_python_lib from torch.utils import cpp_extension as cpp_ext + project_root = os.path.dirname(os.path.abspath(__file__)) conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include") + cutlass_include_dir = os.path.join(get_python_lib(), "cutlass_library/source/include") + vllm_csrc_dir = os.path.join(project_root, "gptqmodel_ext/vllm_csrc") print("conda_cuda_include_dir", conda_cuda_include_dir) if os.path.isdir(conda_cuda_include_dir): include_dirs.append(conda_cuda_include_dir) print(f"appending conda cuda include dir {conda_cuda_include_dir}") + if os.path.isdir(cutlass_include_dir): + include_dirs.append(cutlass_include_dir) + print(f"appending conda cuda include dir {cutlass_include_dir}") + + if os.path.isdir(vllm_csrc_dir): + include_dirs.append(vllm_csrc_dir) + print(f"appending conda cuda include dir {vllm_csrc_dir}") + extra_link_args = [] extra_compile_args = { "cxx": [ @@ -191,6 +202,7 @@ def get_version_tag() -> str: "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "-U__CUDA_NO_BFLOAT162_OPERATORS__", "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "-D CUTLASS_DEBUG_TRACE_LEVEL=2", ], } @@ -262,6 +274,19 @@ def _hipify_compile_flags(flags): extra_link_args=extra_link_args, extra_compile_args=extra_compile_args, ), + cpp_ext.CUDAExtension( + "gptqmodel_machete_kernels", + [ + "gptqmodel_ext/machete/machete.cpp", + "gptqmodel_ext/machete/machete_pytorch.cu", + "gptqmodel_ext/machete/machete_mm_dispatch.cu", + "gptqmodel_ext/machete/machete_mm_impl.cu", + "gptqmodel_ext/machete/machete_prepack.cu", + "gptqmodel_ext/machete/permute_cols.cu", + ], + extra_link_args=extra_link_args, + extra_compile_args=extra_compile_args, + ), cpp_ext.CUDAExtension( "gptqmodel_qqq_kernels", [