Skip to content

Commit dc92bb1

Browse files
committed
[wip] float8 rowwise quant along row 1 of tensor rank 2
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: d3d5c06 ghstack-comment-id: 3497584430 Pull-Request: #3303
1 parent 6815e57 commit dc92bb1

File tree

6 files changed

+91
-11
lines changed

6 files changed

+91
-11
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,44 @@ def forward(self, x):
466466
sqnr = compute_error(original, quantized)
467467
self.assertTrue(sqnr > 20)
468468

469+
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
470+
@unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai")
471+
def test_bmm_weight_in_bkn_layout(self):
472+
# Tests rowwise quantization of a 3d weight stored with shape (B, K, N)
473+
# and contigous with that shape. Since the `K` dimension is not last, we
474+
# need to specify granularity with `PerRow(1)`.
475+
476+
# only support per row quantization
477+
granularity = [PerRow(), PerRow(1)]
478+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
479+
480+
class Model(torch.nn.Module):
481+
def __init__(self, weight):
482+
super().__init__()
483+
self.weight = weight
484+
485+
def forward(self, x):
486+
return torch.bmm(x, self.weight)
487+
488+
dtype = torch.bfloat16
489+
device = "cuda"
490+
491+
B, M, K, N = 10, 32, 128, 256
492+
493+
input = torch.randn(B, M, K, dtype=dtype, device=device)
494+
weight = torch.randn(B, K, N, dtype=dtype, device=device)
495+
m = Model(weight).eval()
496+
original = m(input)
497+
quantize_(m, config, filter_fn=lambda x, fqn: True)
498+
499+
assert m.weight.scale.shape == (B, 1, N), (
500+
f"unexpected scale shape {m.weight.scale.shape}"
501+
)
502+
503+
quantized = m(input)
504+
sqnr = compute_error(original, quantized)
505+
self.assertTrue(sqnr > 20)
506+
469507
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
470508
@common_utils.parametrize(
471509
"sizes",

test/quantization/test_quant_primitives.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212

13+
from torchao.quantization.granularity import PerRow
1314
from torchao.quantization.quant_primitives import (
1415
MappingType,
1516
ZeroPointDomain,
@@ -27,6 +28,7 @@
2728
# TODO: remove test for utils?
2829
from torchao.quantization.utils import (
2930
_quantize_activation_per_token_absmax,
31+
get_block_size,
3032
get_group_qparams_symmetric,
3133
groupwise_affine_dequantize_tensor_from_qparams,
3234
groupwise_affine_quantize_tensor_from_qparams,
@@ -844,6 +846,29 @@ def test_float8_blockwise_scaling(self):
844846
torch.testing.assert_close(scale, ref_scale, atol=0, rtol=0)
845847
torch.testing.assert_close(data.float(), ref_data.float(), atol=0, rtol=0)
846848

849+
def test_float8_rowwise_scaling_3d_weight_axis_1(self):
850+
"""
851+
Test scaling a weight with shape (B, K, N) and row-major memory layout
852+
across the K dimension.
853+
"""
854+
855+
B, K, N = 8, 16, 32
856+
hp_tensor = torch.randn(B, K, N, dtype=torch.float)
857+
858+
granularity = PerRow(1)
859+
block_size = get_block_size(hp_tensor.shape, granularity)
860+
scale = _choose_scale_float8(
861+
hp_tensor,
862+
float8_dtype=torch.float8_e4m3fn,
863+
block_size=block_size,
864+
hp_value_lb=None,
865+
hp_value_ub=None,
866+
)
867+
data = _quantize_affine_float8(hp_tensor, scale, torch.float8_e4m3fn)
868+
869+
assert scale.shape == (B, 1, N)
870+
assert data.shape == (B, K, N)
871+
847872

848873
if __name__ == "__main__":
849874
unittest.main()

torchao/quantization/granularity.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,14 @@ class PerAxis(Granularity):
3939
This granularity type calculates different quantization parameters
4040
along a specified axis of the tensor.
4141
42-
For example if the input tensor is shape [8, 16] and axis=0, then
43-
the quantization parameters are calculated for each row of the tensor.
44-
Giving a total of 8 quantization parameters.
42+
Examples:
43+
* input_tensor shape [A, B], axis 0 -> scale_shape [A, 1]
44+
* input_tensor shape [A, B], axis 1 -> scale_shape [1, B]
45+
* input_tensor shape [A, B, C], axis 1 -> scale_shape [1, B, 1]
4546
4647
Attributes:
47-
axis (int): The axis along which reduction is performed.
48+
axis (int): The axis which is kept, reduction is performed across all
49+
the other axes
4850
"""
4951

5052
axis: int
@@ -76,12 +78,19 @@ class PerRow(Granularity):
7678
"""
7779
Represents row-wise granularity in quantization.
7880
79-
This is a special case of per-axis quantization and is unique to Float8 matmuls
80-
where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight
81-
is quantized with a block_size of (1, weight.shape[1]).
81+
Examples:
82+
* input_tensor shape [A, B], dim 0 -> scale_shape [1, B]
83+
* input_tensor shape [A, B], dim 1 -> scale_shape [A, 1]
84+
* input_tensor shape [A, B], dim -1 -> scale_shape [A, 1]
85+
* input_tensor shape [A, B, C], dim 1 -> scale_shape [A, 1, C]
86+
87+
Attributes:
88+
dim (int): The dim which is reduced across, all other dims are kept
8289
"""
8390

84-
pass
91+
# TODO(before land): any BC concerns with loading old checkpoints
92+
# serialized without this arg? investigate this
93+
dim: int = -1
8594

8695

8796
@dataclass(frozen=True)

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@ def from_hp(
180180
and _is_fbgemm_gpu_genai_available()
181181
and is_sm_at_least_90()
182182
and isinstance(granularity, PerRow)
183+
# fbgemm path only supports quantizing along the last dim
184+
and granularity.dim in (-1, len(hp_tensor.shape) - 1)
183185
and float8_dtype == torch.float8_e4m3fn
184186
and hp_value_lb is None
185187
):
@@ -438,7 +440,7 @@ def _(func, types, args, kwargs):
438440

439441
res = torch.ops.fbgemm.f8f8bf16_rowwise_batched(
440442
a_data,
441-
b_data.transpose(-2, -1),
443+
b_data.transpose(-2, -1).contiguous(),
442444
a_scale,
443445
b_scale.transpose(-2, -1),
444446
b_scale,

torchao/quantization/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,8 +723,12 @@ def get_block_size(
723723
f"Not all shapes in input shape {input_shape} are divisible by block size {block_size}"
724724
)
725725
return block_size
726-
elif isinstance(granularity, (PerRow, PerToken)):
726+
elif isinstance(granularity, PerToken):
727727
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
728+
elif isinstance(granularity, PerRow):
729+
block_size = [1] * len(input_shape)
730+
block_size[granularity.dim] = input_shape[granularity.dim]
731+
return tuple(block_size)
728732
elif isinstance(granularity, PerGroup):
729733
assert input_shape[-1] % granularity.group_size == 0, (
730734
f"Last dimension of input {input_shape[-1]} is not divisible by group size {granularity.group_size}"

torchao/testing/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,9 @@ def _test_slice_and_copy_similar_to_vllm(self, config: AOBaseConfig):
444444
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
445445
# making the weight different
446446
dummy_l.weight = torch.nn.Parameter(
447-
dummy_l.weight + 2 * torch.randn(1024, 1024, device=device, dtype=dtype),
447+
dummy_l.weight
448+
+ 1.0
449+
+ 2 * torch.randn(1024, 1024, device=device, dtype=dtype),
448450
requires_grad=False,
449451
)
450452
quantize_(dummy_l, config)

0 commit comments

Comments
 (0)