@@ -102,6 +102,9 @@ def _mx_inference_linear_transform(
102102 module : torch .nn .Module , config : MXFPInferenceConfig
103103):
104104 weight = module .weight
105+ is_swizzled_scales = True
106+ if "xpu" in weight .device .type :
107+ is_swizzled_scales = False
105108
106109 assert weight .dtype == torch .bfloat16 , (
107110 f"Only supporting bf16 out dtype for now, got { weight .dtype } "
@@ -111,7 +114,7 @@ def _mx_inference_linear_transform(
111114 block_size = config .block_size ,
112115 gemm_kernel_choice = config .gemm_kernel_choice ,
113116 pack_fp6 = False ,
114- is_swizzled_scales = True ,
117+ is_swizzled_scales = is_swizzled_scales ,
115118 )
116119
117120 # Convert weight to MX Tensor
@@ -122,7 +125,7 @@ def _mx_inference_linear_transform(
122125 gemm_kernel_choice = config .gemm_kernel_choice ,
123126 pack_fp6 = False , # TODO
124127 act_quant_kwargs = act_quant_kwargs ,
125- is_swizzled_scales = True ,
128+ is_swizzled_scales = is_swizzled_scales ,
126129 )
127130
128131 module .weight = torch .nn .Parameter (quantized_weight , requires_grad = False )
0 commit comments