From 40de7e0d1c1387d1fc39f17ea4716eeb603bf087 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Thu, 6 Nov 2025 01:11:31 +0000 Subject: [PATCH 1/6] Move dyn_int8_act_int4_wei_cpu_layout to prototype --- torchao/prototype/dtypes/uintx/__init__.py | 2 + .../uintx/dyn_int8_act_int4_wei_cpu_layout.py | 319 ++++++++++++++++++ 2 files changed, 321 insertions(+) create mode 100644 torchao/prototype/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py diff --git a/torchao/prototype/dtypes/uintx/__init__.py b/torchao/prototype/dtypes/uintx/__init__.py index 53edddb8ac..89c1f3f810 100644 --- a/torchao/prototype/dtypes/uintx/__init__.py +++ b/torchao/prototype/dtypes/uintx/__init__.py @@ -6,8 +6,10 @@ from .block_sparse_layout import BlockSparseLayout from .cutlass_int4_packed_layout import CutlassInt4PackedLayout +from .dyn_int8_act_int4_wei_cpu_layout import Int8DynamicActInt4WeightCPULayout __all__ = [ "BlockSparseLayout", "CutlassInt4PackedLayout", + "Int8DynamicActInt4WeightCPULayout", ] diff --git a/torchao/prototype/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py b/torchao/prototype/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py new file mode 100644 index 0000000000..8d0cfaddeb --- /dev/null +++ b/torchao/prototype/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py @@ -0,0 +1,319 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Tuple + +import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.utils import Layout, PlainLayout, is_device +from torchao.utils import torch_version_at_least + +from .int4_cpu_layout import ( + Int4CPUAQTTensorImpl, + _is_float, +) + +aten = torch.ops.aten + + +@dataclass(frozen=True) +class Int8DynamicActInt4WeightCPULayout(Layout): + """Layout class for da8w4 CPU layout for affine quantized tensor""" + + pass + + +@register_layout(Int8DynamicActInt4WeightCPULayout) +class DA8W4CPUAQTTensorImpl(Int4CPUAQTTensorImpl): + """TensorImpl for da8w4 CPU layout for affine quantized tensor + It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of + dimension: [n][k / 2] (uint8 dtype) + It is similar to Int4CPUAQTTensorImpl but with a different memory layout of weight data + fields: + packed_weight (torch.Tensor): the 2-d packed tensor in a Int4 CPU layout + scales (torch.Tensor): the scales Tensor used to map between floating point tensor to quantized tensor + qzeros (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor + """ + + def __new__( + cls, + packed_weight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + compensation: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = packed_weight.device + kwargs["layout"] = ( + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_weight.layout + ) + kwargs["dtype"] = packed_weight.dtype + kwargs["requires_grad"] = False + shape = packed_weight.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + packed_weight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + compensation: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + self.packed_weight = packed_weight + self.scales = scales + self.qzeros = qzeros + self.compensation = compensation + self.transposed = transposed + self._layout = _layout + + def __tensor_flatten__(self): + return ["packed_weight", "scales", "qzeros", "compensation"], [ + self.transposed, + self._layout, + ] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + packed_weight, scales, qzeros, compensation = ( + tensor_data_dict["packed_weight"], + tensor_data_dict["scales"], + tensor_data_dict["qzeros"], + tensor_data_dict["compensation"], + ) + ( + transposed, + _layout, + ) = tensor_attributes + return cls(packed_weight, scales, qzeros, compensation, transposed, _layout) + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + _layout: Layout, + ): + assert isinstance(_layout, Int8DynamicActInt4WeightCPULayout) + assert int_data.dtype == torch.uint8, "DA8W4 CPU: expects uint8 weight" + assert int_data.shape[1] % 2 == 0, "DA8W4 CPU: expects even number of columns" + if scale.dim() == 1: + scale.unsqueeze_(-1) + scale = scale.to(torch.float) + if zero_point.dim() == 1: + zero_point.unsqueeze_(-1) + + # Pack weight from [N, K] to [N / block_n, K / block_k, block_k, block_n]. + # Pack the inner blocks [block_k, block_n] to VNNI layout if AMX is available. + # Pack scales/qzeros from [N, num_groups] to [N / block_n, num_groups, block_n]. + # Compensation shape = [N / block_n, K / block_k, block_n]. + weight_int4, scales, qzeros, compensation = ( + torch.ops.torchao.da8w4_linear_prepack_cpu(int_data, scale, zero_point) + ) + return cls(weight_int4, scales, qzeros, compensation, False, _layout) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.packed_weight), + fn(self.scales), + fn(self.qzeros), + fn(self.compensation), + self.transposed, + self._layout, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + if func is aten.t.default: + """we don't need to repack the weight and just rely on external + shape being changed and record the status of transpose/no-transpose + """ + transposed = DA8W4CPUAQTTensorImpl( + args[0].packed_weight, + args[0].scales, + args[0].qzeros, + args[0].compensation, + not args[0].transposed, + args[0]._layout, + ) + return return_and_correct_aliasing(func, args, kwargs, transposed) + else: + return super().__torch_dispatch__(func, types, args, kwargs) + + __torch_function__ = torch._C._disabled_torch_function_impl + + @property + def block_size(self): + assert len(self.packed_weight.shape) == 2 + weight_shape = self.packed_weight.shape + N = weight_shape[0] + K = weight_shape[1] * 2 + groups = self.scales.numel() // N + group_size = K // groups + return (1, group_size) + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Unpack weight by linear(eye(K), packed_weight).t() + packed_w_shape = self.packed_weight.shape + if len(packed_w_shape) == 4: + K = packed_w_shape[1] * packed_w_shape[2] + else: + K = packed_w_shape[1] + x = torch.eye(K).to(torch.uint8) + x_scale = torch.ones(K).float() + x_qzero = torch.zeros(K).to(torch.int32) + w_scale = torch.ones_like(self.scales).float() + w_qzero = torch.zeros_like(self.qzeros).to(torch.int8) + plain_weight = torch.ops.torchao.da8w4_linear_cpu.default( + x, + x_scale, + x_qzero, + self.packed_weight, + w_scale, + w_qzero, + self.compensation, + None, # bias + torch.float, # out_dtype + ) + plain_weight = plain_weight.t().contiguous() + plain_weight = plain_weight.to(torch.int8) + + if self.scales.dim() == 2: + assert self.qzeros.dim() == 2 + plain_scales = self.scales + plain_qzeros = self.qzeros + else: + assert self.scales.dim() == 3 and self.qzeros.dim() == 3 + packed_shape = self.scales.shape # [Nc, G, block_n] + plain_scales = ( + self.scales.permute([0, 2, 1]).contiguous().view([-1, packed_shape[1]]) + ) + plain_qzeros = ( + self.qzeros.permute([0, 2, 1]).contiguous().view([-1, packed_shape[1]]) + ) + + return plain_weight, plain_scales, plain_qzeros + + +def _aqt_is_uint8(aqt): + """Check if an AffineQuantizedTensor is uint8 quantized Tensor""" + return ( + aqt.tensor_impl.dtype == torch.uint8 + and aqt.quant_min == 0 + and aqt.quant_max == 255 + ) + + +def _aqt_is_int8(aqt): + """Check if an AffineQuantizedTensor is uint8 quantized Tensor""" + return ( + aqt.tensor_impl.dtype == torch.int8 + and aqt.quant_min == -127 + and aqt.quant_max == 127 + ) + + +def _aqt_is_uint4(aqt): + """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" + return ( + aqt.tensor_impl.dtype == torch.uint8 + and aqt.quant_min == 0 + and aqt.quant_max == 15 + ) + + +def _linear_int8_act_int4_weight_cpu_check(input_tensor, weight_tensor, bias): + return ( + torch_version_at_least("2.7.0") + and is_device(input_tensor.device.type, "cpu") + and is_device(weight_tensor.device.type, "cpu") + and (bias is None or is_device(bias.device.type, "cpu")) + and isinstance(input_tensor, AffineQuantizedTensor) + and (_aqt_is_uint8(input_tensor) or _aqt_is_int8(input_tensor)) + and _is_float(input_tensor.dtype) + and isinstance(input_tensor._layout, PlainLayout) + and isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_uint4(weight_tensor) + and _is_float(weight_tensor.dtype) + and isinstance(weight_tensor._layout, Int8DynamicActInt4WeightCPULayout) + ) + + +def _linear_int8_act_int4_weight_cpu_impl(input_tensor, weight_tensor, bias): + assert torch_version_at_least("2.7.0"), ( + f"Requires PyTorch version at least 2.7, but got: {torch.__version__}" + ) + if _aqt_is_int8(input_tensor): + assert torch_version_at_least("2.8.0"), ( + f"Requires PyTorch version at least 2.8, but got: {torch.__version__}" + ) + assert is_device(input_tensor.device.type, "cpu"), ( + f"For CPU device only but got: {input_tensor.device}" + ) + assert weight_tensor.block_size[0] == 1, ( + f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" + ) + assert input_tensor.shape[-1] == weight_tensor.shape[1], ( + f"need input_tensor shape: {input_tensor.shape} final" + f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " + ) + + act_mat = input_tensor + act = act_mat.tensor_impl.int_data + act_scales = act_mat.tensor_impl.scale + act_qzeros = act_mat.tensor_impl.zero_point + + packed_weight = weight_tensor.tensor_impl.packed_weight + wei_scales = weight_tensor.tensor_impl.scales + wei_qzeros = weight_tensor.tensor_impl.qzeros + compensation = weight_tensor.tensor_impl.compensation + + orig_act_size = act_mat.size() + orig_dtype = act_mat.dtype + + # reshape to 2D + act = act.reshape(-1, act.shape[-1]) + + y = torch.ops.torchao.da8w4_linear_cpu.default( + act.contiguous(), + act_scales, + act_qzeros, + packed_weight, + wei_scales, + wei_qzeros, + compensation, + bias.float() if bias is not None else bias, # requires bias to be float + orig_dtype, # out_dtype + ) + + # remove out_feature padding + orig_out_features = weight_tensor.shape[-2] + y = y[:, :orig_out_features] + y = y.reshape(*orig_act_size[:-1], orig_out_features) + + return y.to(orig_dtype) + + +# Register the concat linear fusion pass +# from ...prototype.inductor.fx_passes import register_da8w4_concat_linear_cpu_pass + +# register_da8w4_concat_linear_cpu_pass() From d3db93eca4d98f527a51153dde709c6805ae5f17 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Thu, 6 Nov 2025 03:48:30 +0000 Subject: [PATCH 2/6] Move dyn_int8_act_int4_wei_cpu_layout to prototype --- test/quantization/test_da8w4_cpu.py | 27 ++ torchao/dtypes/__init__.py | 2 +- torchao/dtypes/affine_quantized_tensor_ops.py | 8 +- .../uintx/dyn_int8_act_int4_wei_cpu_layout.py | 326 +----------------- torchao/prototype/dtypes/__init__.py | 7 +- .../uintx/dyn_int8_act_int4_wei_cpu_layout.py | 7 +- 6 files changed, 58 insertions(+), 319 deletions(-) diff --git a/test/quantization/test_da8w4_cpu.py b/test/quantization/test_da8w4_cpu.py index d4f68c4333..c4b0eac39f 100644 --- a/test/quantization/test_da8w4_cpu.py +++ b/test/quantization/test_da8w4_cpu.py @@ -176,5 +176,32 @@ def forward(self, x): common_utils.instantiate_parametrized_tests(TestDa8w4Cpu) +# TODO: Remove this test once the deprecated API has been removed +def test_int8_dynamic_act_int4_weight_cpu_layout_deprecated(): + import sys + import warnings + + # We need to clear the cache to force re-importing and trigger the warning again. + modules_to_clear = [ + "torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout", + "torchao.dtypes", + ] + for mod in modules_to_clear: + if mod in sys.modules: + del sys.modules[mod] + + with warnings.catch_warnings(record=True) as w: + from torchao.dtypes import Int8DynamicActInt4WeightCPULayout # noqa: F401 + + warnings.simplefilter("always") # Ensure all warnings are captured + assert any( + issubclass(warning.category, DeprecationWarning) + and "Int8DynamicActInt4WeightCPULayout" in str(warning.message) + for warning in w + ), ( + f"Expected deprecation warning for Int8DynamicActInt4WeightCPULayout, got: {[str(warning.message) for warning in w]}" + ) + + if __name__ == "__main__": run_tests() diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 252498bc97..354692e794 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -16,7 +16,6 @@ from .uintx import ( Int4CPULayout, Int4XPULayout, - Int8DynamicActInt4WeightCPULayout, MarlinQQQLayout, MarlinQQQTensor, MarlinSparseLayout, @@ -29,6 +28,7 @@ ) from .uintx.block_sparse_layout import BlockSparseLayout from .uintx.cutlass_int4_packed_layout import CutlassInt4PackedLayout +from .uintx.dyn_int8_act_int4_wei_cpu_layout import Int8DynamicActInt4WeightCPULayout from .utils import ( Layout, PlainLayout, diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index e46809059e..3816f9bf1f 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -25,10 +25,6 @@ _linear_f16_bf16_act_floatx_weight_check, _linear_f16_bf16_act_floatx_weight_impl, ) -from torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout import ( - _linear_int8_act_int4_weight_cpu_check, - _linear_int8_act_int4_weight_cpu_impl, -) from torchao.dtypes.uintx.gemlite_layout import ( _linear_fp_act_int4_weight_gemlite_check, _linear_fp_act_int4_weight_gemlite_impl, @@ -94,6 +90,10 @@ _linear_int8_act_int4_weight_cutlass_check, _linear_int8_act_int4_weight_cutlass_impl, ) +from torchao.prototype.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout import ( + _linear_int8_act_int4_weight_cpu_check, + _linear_int8_act_int4_weight_cpu_impl, +) from torchao.quantization.quant_primitives import ( ZeroPointDomain, _dequantize_affine_no_zero_point, diff --git a/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py b/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py index 8d0cfaddeb..d66f70e2ee 100644 --- a/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py +++ b/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py @@ -3,317 +3,25 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass -from typing import Tuple -import torch -from torch.utils._python_dispatch import ( - return_and_correct_aliasing, -) +# Backward compatibility stub - imports from the new location +import warnings -from torchao.dtypes.affine_quantized_tensor import ( - AffineQuantizedTensor, - register_layout, +warnings.warn( + "Importing from torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout is deprecated. " + "Please use 'from torchao.prototype.dtypes import Int8DynamicActInt4WeightCPULayout' instead. " + "This import path will be removed in a future release of torchao. " + "See https://github.com/pytorch/ao/issues/2752 for more details.", + DeprecationWarning, + stacklevel=2, ) -from torchao.dtypes.utils import Layout, PlainLayout, is_device -from torchao.utils import torch_version_at_least -from .int4_cpu_layout import ( - Int4CPUAQTTensorImpl, - _is_float, +from torchao.prototype.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout import ( # noqa: F401 + DA8W4CPUAQTTensorImpl, # noqa: F401 + Int8DynamicActInt4WeightCPULayout, # noqa: F401 + _aqt_is_int8, # noqa: F401 + _aqt_is_uint4, # noqa: F401 + _aqt_is_uint8, # noqa: F401 + _linear_int8_act_int4_weight_cpu_check, # noqa: F401 + _linear_int8_act_int4_weight_cpu_impl, # noqa: F401 ) - -aten = torch.ops.aten - - -@dataclass(frozen=True) -class Int8DynamicActInt4WeightCPULayout(Layout): - """Layout class for da8w4 CPU layout for affine quantized tensor""" - - pass - - -@register_layout(Int8DynamicActInt4WeightCPULayout) -class DA8W4CPUAQTTensorImpl(Int4CPUAQTTensorImpl): - """TensorImpl for da8w4 CPU layout for affine quantized tensor - It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of - dimension: [n][k / 2] (uint8 dtype) - It is similar to Int4CPUAQTTensorImpl but with a different memory layout of weight data - fields: - packed_weight (torch.Tensor): the 2-d packed tensor in a Int4 CPU layout - scales (torch.Tensor): the scales Tensor used to map between floating point tensor to quantized tensor - qzeros (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor - """ - - def __new__( - cls, - packed_weight: torch.Tensor, - scales: torch.Tensor, - qzeros: torch.Tensor, - compensation: torch.Tensor, - transposed: bool, - _layout: Layout, - ): - kwargs = {} - kwargs["device"] = packed_weight.device - kwargs["layout"] = ( - kwargs.get("layout") - if kwargs.get("layout", False) - else packed_weight.layout - ) - kwargs["dtype"] = packed_weight.dtype - kwargs["requires_grad"] = False - shape = packed_weight.shape - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - packed_weight: torch.Tensor, - scales: torch.Tensor, - qzeros: torch.Tensor, - compensation: torch.Tensor, - transposed: bool, - _layout: Layout, - ): - self.packed_weight = packed_weight - self.scales = scales - self.qzeros = qzeros - self.compensation = compensation - self.transposed = transposed - self._layout = _layout - - def __tensor_flatten__(self): - return ["packed_weight", "scales", "qzeros", "compensation"], [ - self.transposed, - self._layout, - ] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - packed_weight, scales, qzeros, compensation = ( - tensor_data_dict["packed_weight"], - tensor_data_dict["scales"], - tensor_data_dict["qzeros"], - tensor_data_dict["compensation"], - ) - ( - transposed, - _layout, - ) = tensor_attributes - return cls(packed_weight, scales, qzeros, compensation, transposed, _layout) - - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - _layout: Layout, - ): - assert isinstance(_layout, Int8DynamicActInt4WeightCPULayout) - assert int_data.dtype == torch.uint8, "DA8W4 CPU: expects uint8 weight" - assert int_data.shape[1] % 2 == 0, "DA8W4 CPU: expects even number of columns" - if scale.dim() == 1: - scale.unsqueeze_(-1) - scale = scale.to(torch.float) - if zero_point.dim() == 1: - zero_point.unsqueeze_(-1) - - # Pack weight from [N, K] to [N / block_n, K / block_k, block_k, block_n]. - # Pack the inner blocks [block_k, block_n] to VNNI layout if AMX is available. - # Pack scales/qzeros from [N, num_groups] to [N / block_n, num_groups, block_n]. - # Compensation shape = [N / block_n, K / block_k, block_n]. - weight_int4, scales, qzeros, compensation = ( - torch.ops.torchao.da8w4_linear_prepack_cpu(int_data, scale, zero_point) - ) - return cls(weight_int4, scales, qzeros, compensation, False, _layout) - - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.packed_weight), - fn(self.scales), - fn(self.qzeros), - fn(self.compensation), - self.transposed, - self._layout, - ) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - if func is aten.t.default: - """we don't need to repack the weight and just rely on external - shape being changed and record the status of transpose/no-transpose - """ - transposed = DA8W4CPUAQTTensorImpl( - args[0].packed_weight, - args[0].scales, - args[0].qzeros, - args[0].compensation, - not args[0].transposed, - args[0]._layout, - ) - return return_and_correct_aliasing(func, args, kwargs, transposed) - else: - return super().__torch_dispatch__(func, types, args, kwargs) - - __torch_function__ = torch._C._disabled_torch_function_impl - - @property - def block_size(self): - assert len(self.packed_weight.shape) == 2 - weight_shape = self.packed_weight.shape - N = weight_shape[0] - K = weight_shape[1] * 2 - groups = self.scales.numel() // N - group_size = K // groups - return (1, group_size) - - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # Unpack weight by linear(eye(K), packed_weight).t() - packed_w_shape = self.packed_weight.shape - if len(packed_w_shape) == 4: - K = packed_w_shape[1] * packed_w_shape[2] - else: - K = packed_w_shape[1] - x = torch.eye(K).to(torch.uint8) - x_scale = torch.ones(K).float() - x_qzero = torch.zeros(K).to(torch.int32) - w_scale = torch.ones_like(self.scales).float() - w_qzero = torch.zeros_like(self.qzeros).to(torch.int8) - plain_weight = torch.ops.torchao.da8w4_linear_cpu.default( - x, - x_scale, - x_qzero, - self.packed_weight, - w_scale, - w_qzero, - self.compensation, - None, # bias - torch.float, # out_dtype - ) - plain_weight = plain_weight.t().contiguous() - plain_weight = plain_weight.to(torch.int8) - - if self.scales.dim() == 2: - assert self.qzeros.dim() == 2 - plain_scales = self.scales - plain_qzeros = self.qzeros - else: - assert self.scales.dim() == 3 and self.qzeros.dim() == 3 - packed_shape = self.scales.shape # [Nc, G, block_n] - plain_scales = ( - self.scales.permute([0, 2, 1]).contiguous().view([-1, packed_shape[1]]) - ) - plain_qzeros = ( - self.qzeros.permute([0, 2, 1]).contiguous().view([-1, packed_shape[1]]) - ) - - return plain_weight, plain_scales, plain_qzeros - - -def _aqt_is_uint8(aqt): - """Check if an AffineQuantizedTensor is uint8 quantized Tensor""" - return ( - aqt.tensor_impl.dtype == torch.uint8 - and aqt.quant_min == 0 - and aqt.quant_max == 255 - ) - - -def _aqt_is_int8(aqt): - """Check if an AffineQuantizedTensor is uint8 quantized Tensor""" - return ( - aqt.tensor_impl.dtype == torch.int8 - and aqt.quant_min == -127 - and aqt.quant_max == 127 - ) - - -def _aqt_is_uint4(aqt): - """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" - return ( - aqt.tensor_impl.dtype == torch.uint8 - and aqt.quant_min == 0 - and aqt.quant_max == 15 - ) - - -def _linear_int8_act_int4_weight_cpu_check(input_tensor, weight_tensor, bias): - return ( - torch_version_at_least("2.7.0") - and is_device(input_tensor.device.type, "cpu") - and is_device(weight_tensor.device.type, "cpu") - and (bias is None or is_device(bias.device.type, "cpu")) - and isinstance(input_tensor, AffineQuantizedTensor) - and (_aqt_is_uint8(input_tensor) or _aqt_is_int8(input_tensor)) - and _is_float(input_tensor.dtype) - and isinstance(input_tensor._layout, PlainLayout) - and isinstance(weight_tensor, AffineQuantizedTensor) - and _aqt_is_uint4(weight_tensor) - and _is_float(weight_tensor.dtype) - and isinstance(weight_tensor._layout, Int8DynamicActInt4WeightCPULayout) - ) - - -def _linear_int8_act_int4_weight_cpu_impl(input_tensor, weight_tensor, bias): - assert torch_version_at_least("2.7.0"), ( - f"Requires PyTorch version at least 2.7, but got: {torch.__version__}" - ) - if _aqt_is_int8(input_tensor): - assert torch_version_at_least("2.8.0"), ( - f"Requires PyTorch version at least 2.8, but got: {torch.__version__}" - ) - assert is_device(input_tensor.device.type, "cpu"), ( - f"For CPU device only but got: {input_tensor.device}" - ) - assert weight_tensor.block_size[0] == 1, ( - f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" - ) - assert input_tensor.shape[-1] == weight_tensor.shape[1], ( - f"need input_tensor shape: {input_tensor.shape} final" - f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " - ) - - act_mat = input_tensor - act = act_mat.tensor_impl.int_data - act_scales = act_mat.tensor_impl.scale - act_qzeros = act_mat.tensor_impl.zero_point - - packed_weight = weight_tensor.tensor_impl.packed_weight - wei_scales = weight_tensor.tensor_impl.scales - wei_qzeros = weight_tensor.tensor_impl.qzeros - compensation = weight_tensor.tensor_impl.compensation - - orig_act_size = act_mat.size() - orig_dtype = act_mat.dtype - - # reshape to 2D - act = act.reshape(-1, act.shape[-1]) - - y = torch.ops.torchao.da8w4_linear_cpu.default( - act.contiguous(), - act_scales, - act_qzeros, - packed_weight, - wei_scales, - wei_qzeros, - compensation, - bias.float() if bias is not None else bias, # requires bias to be float - orig_dtype, # out_dtype - ) - - # remove out_feature padding - orig_out_features = weight_tensor.shape[-2] - y = y[:, :orig_out_features] - y = y.reshape(*orig_act_size[:-1], orig_out_features) - - return y.to(orig_dtype) - - -# Register the concat linear fusion pass -# from ...prototype.inductor.fx_passes import register_da8w4_concat_linear_cpu_pass - -# register_da8w4_concat_linear_cpu_pass() diff --git a/torchao/prototype/dtypes/__init__.py b/torchao/prototype/dtypes/__init__.py index 25f139d583..52a5aec425 100644 --- a/torchao/prototype/dtypes/__init__.py +++ b/torchao/prototype/dtypes/__init__.py @@ -4,9 +4,14 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from .uintx import BlockSparseLayout, CutlassInt4PackedLayout +from .uintx import ( + BlockSparseLayout, + CutlassInt4PackedLayout, + Int8DynamicActInt4WeightCPULayout, +) __all__ = [ "BlockSparseLayout", "CutlassInt4PackedLayout", + "Int8DynamicActInt4WeightCPULayout", ] diff --git a/torchao/prototype/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py b/torchao/prototype/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py index 8d0cfaddeb..24cc02e358 100644 --- a/torchao/prototype/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py +++ b/torchao/prototype/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py @@ -15,13 +15,12 @@ AffineQuantizedTensor, register_layout, ) -from torchao.dtypes.utils import Layout, PlainLayout, is_device -from torchao.utils import torch_version_at_least - -from .int4_cpu_layout import ( +from torchao.dtypes.uintx.int4_cpu_layout import ( Int4CPUAQTTensorImpl, _is_float, ) +from torchao.dtypes.utils import Layout, PlainLayout, is_device +from torchao.utils import torch_version_at_least aten = torch.ops.aten From 3fb6a2c9b96053b47ade4fd19225437791d23380 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Fri, 7 Nov 2025 05:54:45 +0000 Subject: [PATCH 3/6] Move all deprecated api tests into a single file --- test/dtypes/test_api_deprecation_warning.py | 85 +++++++++++++++++++++ test/integration/test_integration.py | 27 ------- test/quantization/test_da8w4_cpu.py | 27 ------- test/sparsity/test_sparse_api.py | 27 ------- 4 files changed, 85 insertions(+), 81 deletions(-) create mode 100644 test/dtypes/test_api_deprecation_warning.py diff --git a/test/dtypes/test_api_deprecation_warning.py b/test/dtypes/test_api_deprecation_warning.py new file mode 100644 index 0000000000..3c50e8aae3 --- /dev/null +++ b/test/dtypes/test_api_deprecation_warning.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for deprecated API imports that have been moved to prototype. +TODO: Remove these tests once the deprecated APIs have been removed. +""" + +import sys +import warnings + + +def test_int8_dynamic_act_int4_weight_cpu_layout_deprecated(): + """Test deprecation warning for Int8DynamicActInt4WeightCPULayout.""" + # We need to clear the cache to force re-importing and trigger the warning again. + modules_to_clear = [ + "torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout", + "torchao.dtypes", + ] + for mod in modules_to_clear: + if mod in sys.modules: + del sys.modules[mod] + + with warnings.catch_warnings(record=True) as w: + from torchao.dtypes import Int8DynamicActInt4WeightCPULayout # noqa: F401 + + warnings.simplefilter("always") # Ensure all warnings are captured + assert any( + issubclass(warning.category, DeprecationWarning) + and "Int8DynamicActInt4WeightCPULayout" in str(warning.message) + for warning in w + ), ( + f"Expected deprecation warning for Int8DynamicActInt4WeightCPULayout, got: {[str(warning.message) for warning in w]}" + ) + + +def test_cutlass_int4_packed_layout_deprecated(): + """Test deprecation warning for CutlassInt4PackedLayout.""" + # We need to clear the cache to force re-importing and trigger the warning again. + modules_to_clear = [ + "torchao.dtypes.uintx.cutlass_int4_packed_layout", + "torchao.dtypes", + ] + for mod in modules_to_clear: + if mod in sys.modules: + del sys.modules[mod] + + with warnings.catch_warnings(record=True) as w: + from torchao.dtypes import CutlassInt4PackedLayout # noqa: F401 + + warnings.simplefilter("always") # Ensure all warnings are captured + assert any( + issubclass(warning.category, DeprecationWarning) + and "CutlassInt4PackedLayout" in str(warning.message) + for warning in w + ), ( + f"Expected deprecation warning for CutlassInt4PackedLayout, got: {[str(warning.message) for warning in w]}" + ) + + +def test_block_sparse_layout_deprecated(): + """Test deprecation warning for BlockSparseLayout.""" + # We need to clear the cache to force re-importing and trigger the warning again. + modules_to_clear = [ + "torchao.dtypes.uintx.block_sparse_layout", + "torchao.dtypes", + ] + for mod in modules_to_clear: + if mod in sys.modules: + del sys.modules[mod] + + with warnings.catch_warnings(record=True) as w: + from torchao.dtypes import BlockSparseLayout # noqa: F401 + + warnings.simplefilter("always") # Ensure all warnings are captured + assert any( + issubclass(warning.category, DeprecationWarning) + and "BlockSparseLayout" in str(warning.message) + for warning in w + ), ( + f"Expected deprecation warning for BlockSparseLayout, got: {[str(warning.message) for warning in w]}" + ) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 2d05426d73..dc58470526 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1946,32 +1946,5 @@ def test_benchmark_model_cpu(self): assert self.run_benchmark_model("cpu") is not None -# TODO: Remove this test once the deprecated API has been removed -def test_cutlass_int4_packed_layout_deprecated(): - import sys - import warnings - - # We need to clear the cache to force re-importing and trigger the warning again. - modules_to_clear = [ - "torchao.dtypes.uintx.cutlass_int4_packed_layout", - "torchao.dtypes", - ] - for mod in modules_to_clear: - if mod in sys.modules: - del sys.modules[mod] - - with warnings.catch_warnings(record=True) as w: - from torchao.dtypes import CutlassInt4PackedLayout # noqa: F401 - - warnings.simplefilter("always") # Ensure all warnings are captured - assert any( - issubclass(warning.category, DeprecationWarning) - and "CutlassInt4PackedLayout" in str(warning.message) - for warning in w - ), ( - f"Expected deprecation warning for CutlassInt4PackedLayout, got: {[str(warning.message) for warning in w]}" - ) - - if __name__ == "__main__": unittest.main() diff --git a/test/quantization/test_da8w4_cpu.py b/test/quantization/test_da8w4_cpu.py index c4b0eac39f..d4f68c4333 100644 --- a/test/quantization/test_da8w4_cpu.py +++ b/test/quantization/test_da8w4_cpu.py @@ -176,32 +176,5 @@ def forward(self, x): common_utils.instantiate_parametrized_tests(TestDa8w4Cpu) -# TODO: Remove this test once the deprecated API has been removed -def test_int8_dynamic_act_int4_weight_cpu_layout_deprecated(): - import sys - import warnings - - # We need to clear the cache to force re-importing and trigger the warning again. - modules_to_clear = [ - "torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout", - "torchao.dtypes", - ] - for mod in modules_to_clear: - if mod in sys.modules: - del sys.modules[mod] - - with warnings.catch_warnings(record=True) as w: - from torchao.dtypes import Int8DynamicActInt4WeightCPULayout # noqa: F401 - - warnings.simplefilter("always") # Ensure all warnings are captured - assert any( - issubclass(warning.category, DeprecationWarning) - and "Int8DynamicActInt4WeightCPULayout" in str(warning.message) - for warning in w - ), ( - f"Expected deprecation warning for Int8DynamicActInt4WeightCPULayout, got: {[str(warning.message) for warning in w]}" - ) - - if __name__ == "__main__": run_tests() diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index c9d41a98a9..66cd032a9a 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -267,33 +267,6 @@ def test_sparse(self, compile): torch.testing.assert_close(reference, sparse_result, rtol=1e-1, atol=1e-1) - # TODO: Remove this test once the deprecated API has been removed - def test_sparse_deprecated(self): - import sys - import warnings - - # We need to clear the cache to force re-importing and trigger the warning again. - modules_to_clear = [ - "torchao.dtypes.uintx.block_sparse_layout", - "torchao.dtypes", - ] - for mod in modules_to_clear: - if mod in sys.modules: - del sys.modules[mod] - - with warnings.catch_warnings(record=True) as w: - from torchao.dtypes import BlockSparseLayout # noqa: F401 - - warnings.simplefilter("always") # Ensure all warnings are captured - self.assertTrue( - any( - issubclass(warning.category, DeprecationWarning) - and "BlockSparseLayout" in str(warning.message) - for warning in w - ), - f"Expected deprecation warning for BlockSparseLayout, got: {[str(w.message) for w in w]}", - ) - common_utils.instantiate_parametrized_tests(TestSemiStructuredSparse) common_utils.instantiate_parametrized_tests(TestQuantSemiSparse) From 67bc40a18167619d2441c6e94198653610e7b7f1 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Thu, 6 Nov 2025 22:45:14 -0800 Subject: [PATCH 4/6] Add to docs --- docs/source/api_ref_dtypes.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/api_ref_dtypes.rst b/docs/source/api_ref_dtypes.rst index e347dfd2e3..5c73d275eb 100644 --- a/docs/source/api_ref_dtypes.rst +++ b/docs/source/api_ref_dtypes.rst @@ -52,6 +52,7 @@ Prototype BlockSparseLayout CutlassInt4PackedLayout + Int8DynamicActInt4WeightCPULayout .. _NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring From ed58e1e2fa02e5abf63bcfbe92b8506a125d0652 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Fri, 7 Nov 2025 18:30:31 +0000 Subject: [PATCH 5/6] Update tests --- test/dtypes/test_api_deprecation_warning.py | 85 --------------------- test/dtypes/test_uintx.py | 37 +++++++++ 2 files changed, 37 insertions(+), 85 deletions(-) delete mode 100644 test/dtypes/test_api_deprecation_warning.py diff --git a/test/dtypes/test_api_deprecation_warning.py b/test/dtypes/test_api_deprecation_warning.py deleted file mode 100644 index 3c50e8aae3..0000000000 --- a/test/dtypes/test_api_deprecation_warning.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - -""" -Tests for deprecated API imports that have been moved to prototype. -TODO: Remove these tests once the deprecated APIs have been removed. -""" - -import sys -import warnings - - -def test_int8_dynamic_act_int4_weight_cpu_layout_deprecated(): - """Test deprecation warning for Int8DynamicActInt4WeightCPULayout.""" - # We need to clear the cache to force re-importing and trigger the warning again. - modules_to_clear = [ - "torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout", - "torchao.dtypes", - ] - for mod in modules_to_clear: - if mod in sys.modules: - del sys.modules[mod] - - with warnings.catch_warnings(record=True) as w: - from torchao.dtypes import Int8DynamicActInt4WeightCPULayout # noqa: F401 - - warnings.simplefilter("always") # Ensure all warnings are captured - assert any( - issubclass(warning.category, DeprecationWarning) - and "Int8DynamicActInt4WeightCPULayout" in str(warning.message) - for warning in w - ), ( - f"Expected deprecation warning for Int8DynamicActInt4WeightCPULayout, got: {[str(warning.message) for warning in w]}" - ) - - -def test_cutlass_int4_packed_layout_deprecated(): - """Test deprecation warning for CutlassInt4PackedLayout.""" - # We need to clear the cache to force re-importing and trigger the warning again. - modules_to_clear = [ - "torchao.dtypes.uintx.cutlass_int4_packed_layout", - "torchao.dtypes", - ] - for mod in modules_to_clear: - if mod in sys.modules: - del sys.modules[mod] - - with warnings.catch_warnings(record=True) as w: - from torchao.dtypes import CutlassInt4PackedLayout # noqa: F401 - - warnings.simplefilter("always") # Ensure all warnings are captured - assert any( - issubclass(warning.category, DeprecationWarning) - and "CutlassInt4PackedLayout" in str(warning.message) - for warning in w - ), ( - f"Expected deprecation warning for CutlassInt4PackedLayout, got: {[str(warning.message) for warning in w]}" - ) - - -def test_block_sparse_layout_deprecated(): - """Test deprecation warning for BlockSparseLayout.""" - # We need to clear the cache to force re-importing and trigger the warning again. - modules_to_clear = [ - "torchao.dtypes.uintx.block_sparse_layout", - "torchao.dtypes", - ] - for mod in modules_to_clear: - if mod in sys.modules: - del sys.modules[mod] - - with warnings.catch_warnings(record=True) as w: - from torchao.dtypes import BlockSparseLayout # noqa: F401 - - warnings.simplefilter("always") # Ensure all warnings are captured - assert any( - issubclass(warning.category, DeprecationWarning) - and "BlockSparseLayout" in str(warning.message) - for warning in w - ), ( - f"Expected deprecation warning for BlockSparseLayout, got: {[str(warning.message) for warning in w]}" - ) diff --git a/test/dtypes/test_uintx.py b/test/dtypes/test_uintx.py index cb0c88b21c..6be8b29400 100644 --- a/test/dtypes/test_uintx.py +++ b/test/dtypes/test_uintx.py @@ -3,6 +3,9 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import sys +import warnings + import pytest import torch @@ -165,3 +168,37 @@ def test_uintx_model_size(dtype): quantize_(linear[0], UIntXWeightOnlyConfig(dtype)) quantized_size = get_model_size_in_bytes(linear) assert bf16_size * _dtype_to_ratio[dtype] == quantized_size + + +def test_uintx_api_deprecation(): + """ + Test that deprecated uintx APIs trigger deprecation warnings on import. + TODO: Remove this test once the deprecated APIs have been removed. + """ + deprecated_apis = [ + ( + "Int8DynamicActInt4WeightCPULayout", + "torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout", + ), + ("CutlassInt4PackedLayout", "torchao.dtypes.uintx.cutlass_int4_packed_layout"), + ("BlockSparseLayout", "torchao.dtypes.uintx.block_sparse_layout"), + ] + + for api_name, module_path in deprecated_apis: + # Clear the cache to force re-importing and trigger the warning again + modules_to_clear = [module_path, "torchao.dtypes"] + for mod in modules_to_clear: + if mod in sys.modules: + del sys.modules[mod] + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") # Ensure all warnings are captured + + # Dynamically import the deprecated API + exec(f"from torchao.dtypes import {api_name}") + + assert any( + issubclass(warning.category, DeprecationWarning) + and api_name in str(warning.message) + for warning in w + ), f"Expected deprecation warning for {api_name}, got: {[str(warning.message) for warning in w]}" From 4f5ddd82d621ee9c600ef9bf0c235269b8762d1c Mon Sep 17 00:00:00 2001 From: jainapurva Date: Fri, 7 Nov 2025 18:33:26 +0000 Subject: [PATCH 6/6] lint fixes --- test/dtypes/test_uintx.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/dtypes/test_uintx.py b/test/dtypes/test_uintx.py index 6be8b29400..5d54a80753 100644 --- a/test/dtypes/test_uintx.py +++ b/test/dtypes/test_uintx.py @@ -201,4 +201,6 @@ def test_uintx_api_deprecation(): issubclass(warning.category, DeprecationWarning) and api_name in str(warning.message) for warning in w - ), f"Expected deprecation warning for {api_name}, got: {[str(warning.message) for warning in w]}" + ), ( + f"Expected deprecation warning for {api_name}, got: {[str(warning.message) for warning in w]}" + )