@@ -247,7 +247,6 @@ def _matmul_ogs(
247
247
248
248
# TODO: refactor if/else when triton front end improves
249
249
if is_w_microscaled :
250
- tl .static_assert (not EXPT_IS_INNER , "Not supported yet" )
251
250
WMxScale += expt_id * stride_w_mx_e
252
251
253
252
if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE" :
@@ -281,7 +280,8 @@ def _matmul_ogs(
281
280
offs_n_scale = (pid_n * SCALE_BLOCK_N + tl .arange (0 , SCALE_BLOCK_N )) % N
282
281
offs_n_scale = tl .max_contiguous (tl .multiple_of (offs_n_scale , SCALE_BLOCK_N ), SCALE_BLOCK_N )
283
282
# K dimension must be the last dimension for the scales
284
- offs_k_scale = PACKED_MX_BLOCK * pid_k + tl .arange (0 , PACKED_MX_BLOCK )
283
+ tl .static_assert (not EXPT_IS_INNER or W_IS_PADDED )
284
+ offs_k_scale = off_k_x // BLOCK_K * PACKED_MX_BLOCK + tl .arange (0 , PACKED_MX_BLOCK )
285
285
WMxScalePtrs = WMxScale + offs_k_scale .to (index_type )[None , :] * stride_scale_k + offs_n_scale .to (index_type )[:, None ] * stride_w_mx_n
286
286
else :
287
287
WMxScalePtrs = None
@@ -295,7 +295,7 @@ def _matmul_ogs(
295
295
XMxScale += start_z .to (index_type ) * stride_x_mx_z
296
296
if GatherIndx is None :
297
297
XMxScale += start_m * stride_x_mx_m
298
- offs_x_k_scale = MX_SCALE_BLOCK_K * pid_k + tl .arange (0 , MX_SCALE_BLOCK_K )
298
+ offs_x_k_scale = off_k_x // MXFP_BLOCK_SIZE + tl .arange (0 , MX_SCALE_BLOCK_K )
299
299
XMxScalePtrs = XMxScale + offs_x_m .to (index_type )[:, None ] * stride_x_mx_m + offs_x_k_scale .to (index_type )[None , :] * stride_x_mx_k
300
300
else :
301
301
XMxScalePtrs = None
0 commit comments