Skip to content

Commit 50be103

Browse files
[mxfp] support EXPT_IS_INNER in non-persistent matmul
1 parent 0173f75 commit 50be103

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,6 @@ def _matmul_ogs(
247247

248248
# TODO: refactor if/else when triton front end improves
249249
if is_w_microscaled:
250-
tl.static_assert(not EXPT_IS_INNER, "Not supported yet")
251250
WMxScale += expt_id * stride_w_mx_e
252251

253252
if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
@@ -281,7 +280,8 @@ def _matmul_ogs(
281280
offs_n_scale = (pid_n * SCALE_BLOCK_N + tl.arange(0, SCALE_BLOCK_N)) % N
282281
offs_n_scale = tl.max_contiguous(tl.multiple_of(offs_n_scale, SCALE_BLOCK_N), SCALE_BLOCK_N)
283282
# K dimension must be the last dimension for the scales
284-
offs_k_scale = PACKED_MX_BLOCK * pid_k + tl.arange(0, PACKED_MX_BLOCK)
283+
tl.static_assert(not EXPT_IS_INNER or W_IS_PADDED)
284+
offs_k_scale = off_k_x // BLOCK_K * PACKED_MX_BLOCK + tl.arange(0, PACKED_MX_BLOCK)
285285
WMxScalePtrs = WMxScale + offs_k_scale.to(index_type)[None, :] * stride_scale_k + offs_n_scale.to(index_type)[:, None] * stride_w_mx_n
286286
else:
287287
WMxScalePtrs = None
@@ -295,7 +295,7 @@ def _matmul_ogs(
295295
XMxScale += start_z.to(index_type) * stride_x_mx_z
296296
if GatherIndx is None:
297297
XMxScale += start_m * stride_x_mx_m
298-
offs_x_k_scale = MX_SCALE_BLOCK_K * pid_k + tl.arange(0, MX_SCALE_BLOCK_K)
298+
offs_x_k_scale = off_k_x // MXFP_BLOCK_SIZE + tl.arange(0, MX_SCALE_BLOCK_K)
299299
XMxScalePtrs = XMxScale + offs_x_m.to(index_type)[:, None] * stride_x_mx_m + offs_x_k_scale.to(index_type)[None, :] * stride_x_mx_k
300300
else:
301301
XMxScalePtrs = None

python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def _p_matmul_ogs(
249249
XMxScalePtrs = XMxScale + start_z.to(index_type) * stride_x_mx_z
250250
if GatherIndx is None:
251251
XMxScalePtrs += start_m * stride_x_mx_m
252-
offs_k_scale = MX_SCALE_BLOCK_K * pid_k + tl.arange(0, MX_SCALE_BLOCK_K)
252+
offs_k_scale = off_k_x0 // MXFP_BLOCK_SIZE + tl.arange(0, MX_SCALE_BLOCK_K)
253253
XMxScalePtrs += (offs_x_m if USE_GATHER_TMA else offs_m).to(index_type)[:, None] * stride_x_mx_m
254254
XMxScalePtrs += offs_k_scale.to(index_type)[None, :] * stride_x_mx_k
255255
else:

0 commit comments

Comments
 (0)