Skip to content

Commit 64f4151

Browse files
committed
Support mx_tensor and enable it's test on Intel GPU
1 parent 9231d4f commit 64f4151

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

torchao/prototype/mx_formats/inference_workflow.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ def _mx_inference_linear_transform(
102102
module: torch.nn.Module, config: MXFPInferenceConfig
103103
):
104104
weight = module.weight
105+
is_swizzled_scales = True
106+
if "xpu" in weight.device.type:
107+
is_swizzled_scales = False
105108

106109
assert weight.dtype == torch.bfloat16, (
107110
f"Only supporting bf16 out dtype for now, got {weight.dtype}"
@@ -111,7 +114,7 @@ def _mx_inference_linear_transform(
111114
block_size=config.block_size,
112115
gemm_kernel_choice=config.gemm_kernel_choice,
113116
pack_fp6=False,
114-
is_swizzled_scales=True,
117+
is_swizzled_scales=is_swizzled_scales,
115118
)
116119

117120
# Convert weight to MX Tensor
@@ -122,7 +125,7 @@ def _mx_inference_linear_transform(
122125
gemm_kernel_choice=config.gemm_kernel_choice,
123126
pack_fp6=False, # TODO
124127
act_quant_kwargs=act_quant_kwargs,
125-
is_swizzled_scales=True,
128+
is_swizzled_scales=is_swizzled_scales,
126129
)
127130

128131
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)

0 commit comments

Comments
 (0)