Skip to content

Commit 235cad9

Browse files
committed
Implement aten.add for IntxUnpackedToInt8Tensor
1 parent a257166 commit 235cad9

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,25 @@ def test_embedding(self):
5050
error = compute_error(original, quantized)
5151
self.assertTrue(error > 20)
5252

53+
def test_add(self):
54+
dtype = torch.bfloat16
55+
device = "cpu"
56+
a = torch.randint(low=0, high=128, size=(10,), device=device)
57+
a_orig = a.clone()
58+
b = torch.randint(low=0, high=128, size=(10,), device=device)
59+
sum = a + b
60+
61+
quantize_(a, self.config)
62+
a_quant_sum = a + b
63+
64+
quantize(b, self.config)
65+
b_quant_sum = a_orig + b
66+
a_b_quant_sum = a + b
67+
68+
for quantized_sum in [a_quant_sum, b_quant_sum, a_b_quant_sum]:
69+
error = compute_error(original, quantized_sum)
70+
self.assertTrue(error > 20)
71+
5372
def test_linear(self):
5473
dtype = torch.bfloat16
5574
device = "cpu"

torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,19 @@ def _(func, types, args, kwargs):
355355
return torch.nn.functional.embedding(indices, weight_tensor, **kwargs)
356356

357357

358+
@implements(aten.add.Tensor)
359+
def _(func, types, args, kwargs):
360+
assert len(args) == 2
361+
t1, t2 = args[0], args[1]
362+
if isinstance(t1, IntxUnpackedToInt8Tensor):
363+
assert t1.activation_quantization is None
364+
t1 = t1.dequantize()
365+
if isinstance(t2, IntxUnpackedToInt8Tensor):
366+
assert t2.activation_quantization is None
367+
t2 = t2.dequantize()
368+
return t1 + t2
369+
370+
358371
@implements(aten.slice.Tensor)
359372
def _(func, types, args, kwargs):
360373
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])

0 commit comments

Comments
 (0)