Skip to content

Commit a4df015

Browse files
committed
Move marlin_qqq_tensor to prototype
1 parent d3db93e commit a4df015

File tree

9 files changed

+415
-354
lines changed

9 files changed

+415
-354
lines changed

benchmarks/microbenchmarks/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def string_to_config(
218218
)
219219
if "marlin" in quantization:
220220
if "qqq" in quantization:
221-
from torchao.dtypes import MarlinQQQLayout
221+
from torchao.prototype.dtypes import MarlinQQQLayout
222222

223223
return Int8DynamicActivationInt4WeightConfig(
224224
group_size=128,

test/quantization/test_marlin_qqq.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch import nn
1111
from torch.testing._internal.common_utils import TestCase, run_tests
1212

13-
from torchao.dtypes import MarlinQQQLayout
13+
from torchao.prototype.dtypes import MarlinQQQLayout
1414
from torchao.quantization.marlin_qqq import (
1515
pack_to_marlin_qqq,
1616
unpack_from_marlin_qqq,
@@ -132,5 +132,25 @@ def test_pack_unpack_equivalence(self):
132132
)
133133

134134

135+
def test_marlin_qqq_tensor_deprecation_warning():
136+
"""Test that importing from the old location raises a deprecation warning"""
137+
import warnings
138+
139+
with warnings.catch_warnings(record=True) as w:
140+
warnings.simplefilter("always")
141+
# Import from the old deprecated location
142+
from torchao.dtypes.uintx.marlin_qqq_tensor import ( # noqa: F401
143+
MarlinQQQLayout,
144+
)
145+
146+
# Verify the deprecation warning was raised
147+
assert len(w) == 1
148+
assert issubclass(w[-1].category, DeprecationWarning)
149+
assert "torchao.dtypes.uintx.marlin_qqq_tensor is deprecated" in str(
150+
w[-1].message
151+
)
152+
assert "torchao.prototype.dtypes import" in str(w[-1].message)
153+
154+
135155
if __name__ == "__main__":
136156
run_tests()

torchao/_models/llama/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ def ffn_or_attn_only(mod, fqn):
460460
)
461461
if "marlin" in quantization:
462462
if "qqq" in quantization:
463-
from torchao.dtypes import MarlinQQQLayout
463+
from torchao.prototype.dtypes import MarlinQQQLayout
464464

465465
quantize_(
466466
model,

torchao/dtypes/affine_quantized_tensor_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,6 @@
3939
_linear_fp_act_uint4_weight_int8_zero_check,
4040
_linear_fp_act_uint4_weight_int8_zero_impl,
4141
)
42-
from torchao.dtypes.uintx.marlin_qqq_tensor import (
43-
_linear_int8_act_int4_weight_marlin_qqq_check,
44-
_linear_int8_act_int4_weight_marlin_qqq_impl,
45-
)
4642
from torchao.dtypes.uintx.marlin_sparse_layout import (
4743
_linear_fp_act_int4_weight_sparse_marlin_check,
4844
_linear_fp_act_int4_weight_sparse_marlin_impl,
@@ -94,6 +90,10 @@
9490
_linear_int8_act_int4_weight_cpu_check,
9591
_linear_int8_act_int4_weight_cpu_impl,
9692
)
93+
from torchao.prototype.dtypes.uintx.marlin_qqq_tensor import (
94+
_linear_int8_act_int4_weight_marlin_qqq_check,
95+
_linear_int8_act_int4_weight_marlin_qqq_impl,
96+
)
9797
from torchao.quantization.quant_primitives import (
9898
ZeroPointDomain,
9999
_dequantize_affine_no_zero_point,

torchao/dtypes/uintx/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
from torchao.prototype.dtypes.uintx.marlin_qqq_tensor import (
2+
MarlinQQQLayout,
3+
MarlinQQQTensor,
4+
to_marlinqqq_quantized_intx,
5+
)
6+
17
from .dyn_int8_act_int4_wei_cpu_layout import (
28
Int8DynamicActInt4WeightCPULayout,
39
)
@@ -7,11 +13,6 @@
713
from .int4_xpu_layout import (
814
Int4XPULayout,
915
)
10-
from .marlin_qqq_tensor import (
11-
MarlinQQQLayout,
12-
MarlinQQQTensor,
13-
to_marlinqqq_quantized_intx,
14-
)
1516
from .marlin_sparse_layout import (
1617
MarlinSparseLayout,
1718
)

0 commit comments

Comments
 (0)