Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api_ref_dtypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Prototype

BlockSparseLayout
CutlassInt4PackedLayout
Int8DynamicActInt4WeightCPULayout

..
_NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring
Expand Down
39 changes: 39 additions & 0 deletions test/dtypes/test_uintx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import sys
import warnings

import pytest
import torch

Expand Down Expand Up @@ -165,3 +168,39 @@ def test_uintx_model_size(dtype):
quantize_(linear[0], UIntXWeightOnlyConfig(dtype))
quantized_size = get_model_size_in_bytes(linear)
assert bf16_size * _dtype_to_ratio[dtype] == quantized_size


def test_uintx_api_deprecation():
"""
Test that deprecated uintx APIs trigger deprecation warnings on import.
TODO: Remove this test once the deprecated APIs have been removed.
"""
deprecated_apis = [
(
"Int8DynamicActInt4WeightCPULayout",
"torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout",
),
("CutlassInt4PackedLayout", "torchao.dtypes.uintx.cutlass_int4_packed_layout"),
("BlockSparseLayout", "torchao.dtypes.uintx.block_sparse_layout"),
]

for api_name, module_path in deprecated_apis:
# Clear the cache to force re-importing and trigger the warning again
modules_to_clear = [module_path, "torchao.dtypes"]
for mod in modules_to_clear:
if mod in sys.modules:
del sys.modules[mod]

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") # Ensure all warnings are captured

# Dynamically import the deprecated API
exec(f"from torchao.dtypes import {api_name}")

assert any(
issubclass(warning.category, DeprecationWarning)
and api_name in str(warning.message)
for warning in w
), (
f"Expected deprecation warning for {api_name}, got: {[str(warning.message) for warning in w]}"
)
27 changes: 0 additions & 27 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1948,32 +1948,5 @@ def test_benchmark_model_cpu(self):
assert self.run_benchmark_model("cpu") is not None


# TODO: Remove this test once the deprecated API has been removed
def test_cutlass_int4_packed_layout_deprecated():
import sys
import warnings

# We need to clear the cache to force re-importing and trigger the warning again.
modules_to_clear = [
"torchao.dtypes.uintx.cutlass_int4_packed_layout",
"torchao.dtypes",
]
for mod in modules_to_clear:
if mod in sys.modules:
del sys.modules[mod]

with warnings.catch_warnings(record=True) as w:
from torchao.dtypes import CutlassInt4PackedLayout # noqa: F401

warnings.simplefilter("always") # Ensure all warnings are captured
assert any(
issubclass(warning.category, DeprecationWarning)
and "CutlassInt4PackedLayout" in str(warning.message)
for warning in w
), (
f"Expected deprecation warning for CutlassInt4PackedLayout, got: {[str(warning.message) for warning in w]}"
)


if __name__ == "__main__":
unittest.main()
27 changes: 0 additions & 27 deletions test/sparsity/test_sparse_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,33 +267,6 @@ def test_sparse(self, compile):

torch.testing.assert_close(reference, sparse_result, rtol=1e-1, atol=1e-1)

# TODO: Remove this test once the deprecated API has been removed
def test_sparse_deprecated(self):
import sys
import warnings

# We need to clear the cache to force re-importing and trigger the warning again.
modules_to_clear = [
"torchao.dtypes.uintx.block_sparse_layout",
"torchao.dtypes",
]
for mod in modules_to_clear:
if mod in sys.modules:
del sys.modules[mod]

with warnings.catch_warnings(record=True) as w:
from torchao.dtypes import BlockSparseLayout # noqa: F401

warnings.simplefilter("always") # Ensure all warnings are captured
self.assertTrue(
any(
issubclass(warning.category, DeprecationWarning)
and "BlockSparseLayout" in str(warning.message)
for warning in w
),
f"Expected deprecation warning for BlockSparseLayout, got: {[str(w.message) for w in w]}",
)


common_utils.instantiate_parametrized_tests(TestSemiStructuredSparse)
common_utils.instantiate_parametrized_tests(TestQuantSemiSparse)
Expand Down
2 changes: 1 addition & 1 deletion torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from .uintx import (
Int4CPULayout,
Int4XPULayout,
Int8DynamicActInt4WeightCPULayout,
MarlinQQQLayout,
MarlinQQQTensor,
MarlinSparseLayout,
Expand All @@ -29,6 +28,7 @@
)
from .uintx.block_sparse_layout import BlockSparseLayout
from .uintx.cutlass_int4_packed_layout import CutlassInt4PackedLayout
from .uintx.dyn_int8_act_int4_wei_cpu_layout import Int8DynamicActInt4WeightCPULayout
from .utils import (
Layout,
PlainLayout,
Expand Down
8 changes: 4 additions & 4 deletions torchao/dtypes/affine_quantized_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@
_linear_f16_bf16_act_floatx_weight_check,
_linear_f16_bf16_act_floatx_weight_impl,
)
from torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout import (
_linear_int8_act_int4_weight_cpu_check,
_linear_int8_act_int4_weight_cpu_impl,
)
from torchao.dtypes.uintx.gemlite_layout import (
_linear_fp_act_int4_weight_gemlite_check,
_linear_fp_act_int4_weight_gemlite_impl,
Expand Down Expand Up @@ -94,6 +90,10 @@
_linear_int8_act_int4_weight_cutlass_check,
_linear_int8_act_int4_weight_cutlass_impl,
)
from torchao.prototype.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout import (
_linear_int8_act_int4_weight_cpu_check,
_linear_int8_act_int4_weight_cpu_impl,
)
from torchao.quantization.quant_primitives import (
ZeroPointDomain,
_dequantize_affine_no_zero_point,
Expand Down
Loading
Loading