diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py index ba634857ebe6..ff51b844f40a 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py @@ -247,7 +247,6 @@ def _matmul_ogs( # TODO: refactor if/else when triton front end improves if is_w_microscaled: - tl.static_assert(not EXPT_IS_INNER, "Not supported yet") WMxScale += expt_id * stride_w_mx_e if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE": @@ -281,7 +280,8 @@ def _matmul_ogs( offs_n_scale = (pid_n * SCALE_BLOCK_N + tl.arange(0, SCALE_BLOCK_N)) % N offs_n_scale = tl.max_contiguous(tl.multiple_of(offs_n_scale, SCALE_BLOCK_N), SCALE_BLOCK_N) # K dimension must be the last dimension for the scales - offs_k_scale = PACKED_MX_BLOCK * pid_k + tl.arange(0, PACKED_MX_BLOCK) + tl.static_assert(not EXPT_IS_INNER or W_IS_PADDED) + offs_k_scale = off_k_x // BLOCK_K * PACKED_MX_BLOCK + tl.arange(0, PACKED_MX_BLOCK) WMxScalePtrs = WMxScale + offs_k_scale.to(index_type)[None, :] * stride_scale_k + offs_n_scale.to(index_type)[:, None] * stride_w_mx_n else: WMxScalePtrs = None @@ -295,7 +295,7 @@ def _matmul_ogs( XMxScale += start_z.to(index_type) * stride_x_mx_z if GatherIndx is None: XMxScale += start_m * stride_x_mx_m - offs_x_k_scale = MX_SCALE_BLOCK_K * pid_k + tl.arange(0, MX_SCALE_BLOCK_K) + offs_x_k_scale = off_k_x // MXFP_BLOCK_SIZE + tl.arange(0, MX_SCALE_BLOCK_K) XMxScalePtrs = XMxScale + offs_x_m.to(index_type)[:, None] * stride_x_mx_m + offs_x_k_scale.to(index_type)[None, :] * stride_x_mx_k else: XMxScalePtrs = None diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py index 0dbbb60af6eb..6a83f0af9b10 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py @@ -249,7 +249,7 @@ def _p_matmul_ogs( XMxScalePtrs = XMxScale + start_z.to(index_type) * stride_x_mx_z if GatherIndx is None: XMxScalePtrs += start_m * stride_x_mx_m - offs_k_scale = MX_SCALE_BLOCK_K * pid_k + tl.arange(0, MX_SCALE_BLOCK_K) + offs_k_scale = off_k_x0 // MXFP_BLOCK_SIZE + tl.arange(0, MX_SCALE_BLOCK_K) XMxScalePtrs += (offs_x_m if USE_GATHER_TMA else offs_m).to(index_type)[:, None] * stride_x_mx_m XMxScalePtrs += offs_k_scale.to(index_type)[None, :] * stride_x_mx_k else: