Skip to content

Commit d3db93e

Browse files
committed
Move dyn_int8_act_int4_wei_cpu_layout to prototype
1 parent 40de7e0 commit d3db93e

File tree

6 files changed

+58
-319
lines changed

6 files changed

+58
-319
lines changed

test/quantization/test_da8w4_cpu.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,5 +176,32 @@ def forward(self, x):
176176
common_utils.instantiate_parametrized_tests(TestDa8w4Cpu)
177177

178178

179+
# TODO: Remove this test once the deprecated API has been removed
180+
def test_int8_dynamic_act_int4_weight_cpu_layout_deprecated():
181+
import sys
182+
import warnings
183+
184+
# We need to clear the cache to force re-importing and trigger the warning again.
185+
modules_to_clear = [
186+
"torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout",
187+
"torchao.dtypes",
188+
]
189+
for mod in modules_to_clear:
190+
if mod in sys.modules:
191+
del sys.modules[mod]
192+
193+
with warnings.catch_warnings(record=True) as w:
194+
from torchao.dtypes import Int8DynamicActInt4WeightCPULayout # noqa: F401
195+
196+
warnings.simplefilter("always") # Ensure all warnings are captured
197+
assert any(
198+
issubclass(warning.category, DeprecationWarning)
199+
and "Int8DynamicActInt4WeightCPULayout" in str(warning.message)
200+
for warning in w
201+
), (
202+
f"Expected deprecation warning for Int8DynamicActInt4WeightCPULayout, got: {[str(warning.message) for warning in w]}"
203+
)
204+
205+
179206
if __name__ == "__main__":
180207
run_tests()

torchao/dtypes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from .uintx import (
1717
Int4CPULayout,
1818
Int4XPULayout,
19-
Int8DynamicActInt4WeightCPULayout,
2019
MarlinQQQLayout,
2120
MarlinQQQTensor,
2221
MarlinSparseLayout,
@@ -29,6 +28,7 @@
2928
)
3029
from .uintx.block_sparse_layout import BlockSparseLayout
3130
from .uintx.cutlass_int4_packed_layout import CutlassInt4PackedLayout
31+
from .uintx.dyn_int8_act_int4_wei_cpu_layout import Int8DynamicActInt4WeightCPULayout
3232
from .utils import (
3333
Layout,
3434
PlainLayout,

torchao/dtypes/affine_quantized_tensor_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,6 @@
2525
_linear_f16_bf16_act_floatx_weight_check,
2626
_linear_f16_bf16_act_floatx_weight_impl,
2727
)
28-
from torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout import (
29-
_linear_int8_act_int4_weight_cpu_check,
30-
_linear_int8_act_int4_weight_cpu_impl,
31-
)
3228
from torchao.dtypes.uintx.gemlite_layout import (
3329
_linear_fp_act_int4_weight_gemlite_check,
3430
_linear_fp_act_int4_weight_gemlite_impl,
@@ -94,6 +90,10 @@
9490
_linear_int8_act_int4_weight_cutlass_check,
9591
_linear_int8_act_int4_weight_cutlass_impl,
9692
)
93+
from torchao.prototype.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout import (
94+
_linear_int8_act_int4_weight_cpu_check,
95+
_linear_int8_act_int4_weight_cpu_impl,
96+
)
9797
from torchao.quantization.quant_primitives import (
9898
ZeroPointDomain,
9999
_dequantize_affine_no_zero_point,

0 commit comments

Comments
 (0)