From 8bf731e60187cbbce8ba9ca006b204a118d5092f Mon Sep 17 00:00:00 2001 From: jongsoo-openai Date: Mon, 6 Oct 2025 18:51:19 -0700 Subject: [PATCH 1/4] [mxfp] support EXPT_IS_INNER in non-persistent matmul --- .../triton_kernels/matmul_ogs_details/_matmul_ogs.py | 6 +++--- .../triton_kernels/matmul_ogs_details/_p_matmul_ogs.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py index 21884a1dfb21..f670c5da462e 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py @@ -247,7 +247,6 @@ def _matmul_ogs( # TODO: refactor if/else when triton front end improves if is_w_microscaled: - tl.static_assert(not EXPT_IS_INNER, "Not supported yet") WMxScale += expt_id * stride_w_mx_e if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE": @@ -281,7 +280,8 @@ def _matmul_ogs( offs_n_scale = (pid_n * SCALE_BLOCK_N + tl.arange(0, SCALE_BLOCK_N)) % N offs_n_scale = tl.max_contiguous(tl.multiple_of(offs_n_scale, SCALE_BLOCK_N), SCALE_BLOCK_N) # K dimension must be the last dimension for the scales - offs_k_scale = PACKED_MX_BLOCK * pid_k + tl.arange(0, PACKED_MX_BLOCK) + tl.static_assert(not EXPT_IS_INNER or W_IS_PADDED) + offs_k_scale = off_k_x // BLOCK_K * PACKED_MX_BLOCK + tl.arange(0, PACKED_MX_BLOCK) WMxScalePtrs = WMxScale + offs_k_scale.to(index_type)[None, :] * stride_scale_k + offs_n_scale.to(index_type)[:, None] * stride_w_mx_n else: WMxScalePtrs = None @@ -295,7 +295,7 @@ def _matmul_ogs( XMxScale += start_z.to(index_type) * stride_x_mx_z if GatherIndx is None: XMxScale += start_m * stride_x_mx_m - offs_x_k_scale = MX_SCALE_BLOCK_K * pid_k + tl.arange(0, MX_SCALE_BLOCK_K) + offs_x_k_scale = off_k_x // MXFP_BLOCK_SIZE + tl.arange(0, MX_SCALE_BLOCK_K) XMxScalePtrs = XMxScale + offs_x_m.to(index_type)[:, None] * stride_x_mx_m + offs_x_k_scale.to(index_type)[None, :] * stride_x_mx_k else: XMxScalePtrs = None diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py index 6503b7c8a5d1..865de1477469 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py @@ -249,7 +249,7 @@ def _p_matmul_ogs( XMxScalePtrs = XMxScale + start_z.to(index_type) * stride_x_mx_z if GatherIndx is None: XMxScalePtrs += start_m * stride_x_mx_m - offs_k_scale = MX_SCALE_BLOCK_K * pid_k + tl.arange(0, MX_SCALE_BLOCK_K) + offs_k_scale = off_k_x0 // MXFP_BLOCK_SIZE + tl.arange(0, MX_SCALE_BLOCK_K) XMxScalePtrs += (offs_x_m if USE_GATHER_TMA else offs_m).to(index_type)[:, None] * stride_x_mx_m XMxScalePtrs += offs_k_scale.to(index_type)[None, :] * stride_x_mx_k else: From 052a2c971cad0df10eb04c32726d67caddfd8265 Mon Sep 17 00:00:00 2001 From: jongsoo-openai Date: Sun, 19 Oct 2025 22:18:21 -0700 Subject: [PATCH 2/4] fix bugs; add tests --- python/triton_kernels/tests/test_matmul.py | 26 +++++++++++---- .../triton_kernels/matmul_ogs.py | 16 +++++++--- .../matmul_ogs_details/_common.py | 11 ++++--- .../matmul_ogs_details/_matmul_ogs.py | 4 +-- .../matmul_ogs_details/_p_matmul_ogs.py | 1 - .../matmul_ogs_details/opt_flags.py | 7 ++-- .../opt_flags_details/opt_flags_nvidia.py | 4 +-- .../triton_kernels/triton_kernels/tensor.py | 27 +++++++++------- .../layout_details/blackwell_scale.py | 32 +++++++++++++------ 9 files changed, 81 insertions(+), 47 deletions(-) diff --git a/python/triton_kernels/tests/test_matmul.py b/python/triton_kernels/tests/test_matmul.py index 7a1922885cd7..73cb38f680aa 100644 --- a/python/triton_kernels/tests/test_matmul.py +++ b/python/triton_kernels/tests/test_matmul.py @@ -365,8 +365,16 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o pytest.skip("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles).") expt_is_inner = (inner_expt_opt is not None) - if expt_is_inner and (mode != "ragged" or "mx" in act_dtype_str or "mx" in weight_dtype_str): - pytest.skip("Not supported yet") + if expt_is_inner: + if mode != "ragged": + pytest.skip("inner_expt_opt only meaningful with ragged") + if "mx" in act_dtype_str and inner_expt_opt != "pad_x": + pytest.skip("inner_expt_opt and act mx only supported with pad_x") + if "mx" in weight_dtype_str: + if inner_expt_opt != "pad_w": + pytest.skip("inner_expt_opt and weight mx only supported with pad_w") + if is_persistent and not hbm_swizzling: + pytest.skip("FIXME: Fatal Python error: Aborted") # launch metadata for batched / mx types may not work yet. torch.manual_seed(0) @@ -398,6 +406,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o opt_flags.update_opt_flags_constraints(constraints) weight_mxfp = weight_dtype_str.startswith("mx") + weight_mxfp4 = weight_mxfp and "float4" in weight_dtype_str if weight_mxfp: weight_dtype_str = weight_dtype_str[2:] act_mxfp8 = act_dtype_str.startswith("mx") @@ -421,6 +430,9 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o rdata = gindx = sindx = None padding_block_k = 32 + if hbm_swizzling and is_persistent and torch.cuda.get_device_capability()[0] >= 10: + # Blackwell scale swizzling constraint + padding_block_k = 128 x_tri, w_tri, bias_tri, gs0_tri, gs1_tri = init_compute_data(m, n, k, rdata, gindx, sindx, n_expts_tot, n_expts_act, mode, torch.bfloat16 if act_mxfp8 else act_dtype, # torch.bfloat16 if weight_mxfp else weight_dtype, @@ -456,7 +468,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o # compute layouts w_layout, w_layout_opts = layout.StridedLayout, dict() w_scale_layout, w_scale_layout_opts = layout.StridedLayout, dict() - if hbm_swizzling and "float4" in weight_dtype_str: + if hbm_swizzling and weight_mxfp4: w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=mx_axis) w_scale_layout, w_scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout( mx_axis=mx_axis, num_warps=8) @@ -465,7 +477,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o if colmajor_mxfp_weight: w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis) w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=mx_axis) - w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype + w_tri_dtype = FP4 if weight_mxfp4 else weight_dtype w_tri = wrap_torch_tensor(w_tri, w_tri_dtype) w_scale_tri = wrap_torch_tensor(w_scale_tri) # convert layouts @@ -567,8 +579,8 @@ def _pad_and_block(x: torch.Tensor) -> torch.Tensor: tri_y = matmul_ogs(x_tri, w_tri, bias_tri, rdata, gindx, sindx, precision_opt, gammas=gs1_ref, epilogue=epilogue, y=y_tri_in, inner_routing_data=inner_routing_data) - except (opt_flags.InapplicableConstraint, NotImplementedError): - pytest.skip("inapplicable opt_flags constraint") + except (opt_flags.InapplicableConstraint, NotImplementedError) as e: + pytest.skip(f"inapplicable opt_flags constraint {e}") if y_tri_in is not None: assert tri_y.data_ptr() == y_tri_in.data_ptr() assert tri_y.shape == y_tri_in.shape @@ -601,7 +613,7 @@ def scale(val, scal): ref_y = upcast_from_mxfp_torch(ref_y_quant, ref_y_scale, target_dtype=ref_y.dtype, axis=-1) maxtol = 4e-1 rmstol = 4e-2 - elif weight_mxfp and "float4_e2m1" in weight_dtype_str: + elif weight_mxfp4: if act_is_float8: maxtol = 8e-2 else: diff --git a/python/triton_kernels/triton_kernels/matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul_ogs.py index 685b44323343..bfefbe5f7f0e 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs.py @@ -644,11 +644,20 @@ def matmul_ogs(x, w, bias, y_tensor_or_tma = y_storage.make_tma(y_tma_block_size, y_tma_mode) if y_has_tma else y_storage.data # create tma descriptor for w w_has_tma = opt_flags.is_persistent - w_tensor_or_tma = w_storage.make_tma([1, opt_flags.block_k, opt_flags.block_n], "dense") if w_has_tma else w_storage.data + w_tma_block_size = [1, opt_flags.block_k, opt_flags.block_n] + w_tensor_or_tma = w_storage.make_tma(w_tma_block_size, "dense") if w_has_tma else w_storage.data # create tma descriptor for w_scale - w_scale_tensor_or_tma = w_scale w_scale_has_tma = opt_flags.is_persistent and w_scale is not None - w_scale_tensor_or_tma = w_scale.storage.make_tma([opt_flags.block_n, opt_flags.block_k], "dense") if w_scale_has_tma else w_scale + w_transpose = w_storage.data.stride()[-2] == 1 + if w_scale_has_tma: + w_scale_storage = w_scale.storage + w_scale_tma_block_size = [opt_flags.block_n, opt_flags.block_k] if w_transpose else [opt_flags.block_k, opt_flags.block_n] + if isinstance(w_scale.storage.layout, StridedLayout): + w_scale_storage = _canonicalize_storage(w_scale.storage, 3, None) + w_scale_tma_block_size = [1] + w_scale_tma_block_size + w_scale_tensor_or_tma = w_scale_storage.make_tma(w_scale_tma_block_size, "dense") + else: + w_scale_tensor_or_tma = w_scale # canonicalize strides x_strides = [0]*(3 - x_storage.data.ndim) + list(x_storage.data.stride()) x_scale_strides = x_scale.stride() if x_has_mx else (None, None, None) @@ -663,7 +672,6 @@ def matmul_ogs(x, w, bias, # (i.e. col-wise). Since this matters when w_has_mx is True and w_transpose # is True the fast code path, stride(-2) == 1 takes precedence, e.g., vs. # w_transpose = w_storage.data.stride()[-1] != 1 - w_transpose = w_storage.data.stride()[-2] == 1 fused_comm_kwargs = { "pYPtrs": fused_comm.out_handles, "ScatterShardIndx": fused_comm.scatter_shard_indx, diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py index d411d3255c84..37edf93b55af 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py @@ -98,12 +98,13 @@ def _load_tile_attrs( tl.static_assert(M is not None) expt_id, pid_z, pid_z_out, start_m, block_id, eM = 0, 0, pid_e, 0, pid_m, M k_tiles = tl.cdiv(tl.load(ExptHist + pid_e), BLOCK_K) - padded_start_off = tl.load(ExptTileOffs + pid_e) * BLOCK_K + padded_start_off_raw = tl.load(ExptTileOffs + pid_e) + padded_start_off = padded_start_off_raw * BLOCK_K unpadded_start_off = tl.load(ExptOffs + pid_e) off_k_x = padded_start_off if X_IS_PADDED else unpadded_start_off # K_W is only used for non-TMA kernel (W bound is handled by TMA on TMA kernel). if W_IS_PADDED: - off_k_w = padded_start_off + off_k_w = padded_start_off_raw * PACKED_BLOCK_K_W K_W = tl.load(ExptTileOffs + pid_e + 1) * BLOCK_K else: off_k_w = unpadded_start_off @@ -147,7 +148,6 @@ def _load_tile_attrs( def make_matmul_repr(base_name, order): - def matmul_repr(specialization): signature = specialization.signature constants = specialization.constants @@ -266,5 +266,6 @@ def matmul_launch_metadata(grid, kernel, args): @triton.jit def threadfence_system(): - tl.inline_asm_elementwise("mov.u32 $0, 0x0; fence.sc.sys;", args=(), dtype=(tl.int32, ), is_pure=False, pack=1, - constraints="=r") + tl.inline_asm_elementwise( + "mov.u32 $0, 0x0; fence.sc.sys;", args=(), dtype=(tl.int32,), is_pure=False, pack=1, constraints="=r" + ) diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py index f670c5da462e..5b671cb11bfa 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py @@ -131,7 +131,7 @@ def _matmul_ogs( tl.static_assert(w_type == tl.uint8 or (w_type == tl.float8e4nv or w_type == tl.float8e5), "mx_weight_ptr must be uint8 or fp8") tl.static_assert(WMxScale.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8") - tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR") + tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, f"{BLOCK_K=} must be a multiple of {MX_PACK_DIVISOR=}") tl.static_assert(SWIZZLE_MX_VALUE == "HOPPER_VALUE" or SWIZZLE_MX_VALUE is None, "Only Hopper swizzling is supported for values") # TODO: refactor if/else when triton front end improves @@ -281,7 +281,7 @@ def _matmul_ogs( offs_n_scale = tl.max_contiguous(tl.multiple_of(offs_n_scale, SCALE_BLOCK_N), SCALE_BLOCK_N) # K dimension must be the last dimension for the scales tl.static_assert(not EXPT_IS_INNER or W_IS_PADDED) - offs_k_scale = off_k_x // BLOCK_K * PACKED_MX_BLOCK + tl.arange(0, PACKED_MX_BLOCK) + offs_k_scale = off_k_w // PACKED_BLOCK_K_W * PACKED_MX_BLOCK + tl.arange(0, PACKED_MX_BLOCK) WMxScalePtrs = WMxScale + offs_k_scale.to(index_type)[None, :] * stride_scale_k + offs_n_scale.to(index_type)[:, None] * stride_w_mx_n else: WMxScalePtrs = None diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py index 865de1477469..fe7af378e6c5 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py @@ -125,7 +125,6 @@ def _p_matmul_ogs( tl.static_assert(get_dtype(WMxScale) == tl.uint8, "mx_scale_ptr must be uint8") tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR") tl.static_assert(SWIZZLE_MX_SCALE == "BLACKWELL_SCALE" or SWIZZLE_MX_SCALE is None, "Only Blackwell swizzling is supported for scales") - tl.static_assert(not EXPT_IS_INNER, "Not supported yet") # We have pack 2 fp4 values in a byte W_PACK_DIVISOR: tl.constexpr = 2 if w_type == tl.uint8 else 1 diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py index 72e0998f5c5f..7c3e5c0c5da2 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py @@ -229,14 +229,13 @@ def make_default_opt_flags_nvidia( is_persistent = False block_n = block_n_tma if is_persistent else block_n # block k - if constraints.get("block_k", None) is not None: - block_k = constraints["block_k"] - else: - block_k = opt_flags_nvidia.compute_block_k(m, k, is_persistent, lhs_dtype, rhs_dtype, precision_config, has_y_acc_in) + block_k = opt_flags_nvidia.compute_block_k(m, k, is_persistent, lhs_dtype, rhs_dtype, precision_config, has_y_acc_in) if block_n == 256 and block_k == 128 and block_m <= 64 and is_persistent and rhs_dtype == FP4 and k >= 4096 and tokens_per_expt > 1: # Swap block_n and block_k for mxfp4 weights so that block_k is a full cacheline, so long as K is sufficiently large. # TODO: swizzle the HBM layout of the weights instead block_n, block_k = block_k, block_n + if constraints.get("block_k", None) is not None: + block_k = constraints["block_k"] # split_k if constraints.get("max_allowable_mn", 0) > 0 and constraints.get("split_k") is not None: split_k = max_allowable_mn(constraints["max_allowable_mn"], m, n, constraints.get("split_k")) diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py index a9964a625a82..064c36ef599e 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py @@ -1,9 +1,9 @@ import torch import triton from triton_kernels import target_info -from triton_kernels.tensor import get_layout, bitwidth, FP4 -from triton_kernels.tensor_details.layout import HopperAmpereMXScaleLayout from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE +from triton_kernels.tensor import FP4, bitwidth, get_layout +from triton_kernels.tensor_details.layout import HopperAmpereMXScaleLayout def compute_grid_size(routing_data, batch_size, m, n, block_m, block_n): diff --git a/python/triton_kernels/triton_kernels/tensor.py b/python/triton_kernels/triton_kernels/tensor.py index 64dbc73b0545..9a78f1b2ef96 100644 --- a/python/triton_kernels/triton_kernels/tensor.py +++ b/python/triton_kernels/triton_kernels/tensor.py @@ -2,12 +2,13 @@ from typing import Type import torch -from triton.tools.tensor_descriptor import TensorDescriptor from triton.tools.ragged_tma import create_ragged_descriptor +from triton.tools.tensor_descriptor import TensorDescriptor + from .target_info import cuda_capability_geq -from .tensor_details.layout import Layout, StridedLayout -from .tensor_details import ragged_tensor as ragged_tensor_details from .tensor_details import bitmatrix as bitmatrix_details +from .tensor_details import ragged_tensor as ragged_tensor_details +from .tensor_details.layout import BlackwellMXValueLayout, Layout, StridedLayout from .tensor_details.ragged_tensor import RaggedTensorMetadata @@ -46,26 +47,28 @@ def is_tma_compliant(self): compliant = [strides[i] * bitwidth % 128 == 0 for i in range(ndim) if i != major_dim] return all(compliant) - def make_dense_tma(self, block_shape, transpose=False): + def make_dense_tma(self, block_shape): strides = list(self.data.stride()) shape = list(self.data.shape) - transpose = self.data.stride()[-1] != 1 + transpose = strides[-1] != 1 if transpose: block_shape = block_shape[:-2] + [block_shape[-1], block_shape[-2]] shape = shape[:-2] + [shape[-1], shape[-2]] strides = strides[:-2] + [strides[-1], strides[-2]] - if self.data.dtype == torch.uint8 and self.layout.name == "BLACKWELL_VALUE": + if self.data.dtype == torch.uint8 and (self.layout.name is None or "_SCALE" not in self.layout.name): indx = strides.index(1) block_shape[indx] = block_shape[indx] // 2 - if shape[-1] % 128 != 0: - raise ValueError("inner shape need to be multiple of 128 for " - "mxfp4 (CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B) TMAs.") + if isinstance(self.layout, BlackwellMXValueLayout): + if shape[-1] % 128 != 0: + raise ValueError( + "inner shape need to be multiple of 128 for mxfp4 (CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B) TMAs." + ) block_shape = self.layout.swizzle_block_shape(block_shape) return TensorDescriptor(self.data, shape, strides, block_shape) - def make_tma(self, block_shape, mode, transpose=False): + def make_tma(self, block_shape, mode): if mode in ["dense", "gather", "scatter"]: - return self.make_dense_tma(block_shape, transpose) + return self.make_dense_tma(block_shape) assert mode == "ragged" ragged_dim = len(self.data.shape) - 2 return create_ragged_descriptor(self.data, block_shape, ragged_dim=ragged_dim) @@ -176,7 +179,6 @@ def size(self, i=None): # ---------------------------------------------------------------------------- # @dataclass class Bitmatrix(Tensor): - def __post_init__(self): assert self.dtype == BIT super().__post_init__() @@ -195,6 +197,7 @@ class RaggedTensor: A ragged `tensor` is a collection of 2D tensors that share the same number of columns. Each tensor in this collection is called a `slice`. """ + # slice_sizes[i] is the number of rows in slice `i` slice_sizes: torch.Tensor # ragged tensors are stored in memory as (potentially padded) 2D tensors of shape diff --git a/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py b/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py index 7df29947fcec..e01e321bd1e6 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py +++ b/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py @@ -1,7 +1,9 @@ import math + +import torch import triton import triton.language as tl -import torch + from .base import Layout SWIZZLE_ALIGN_INNER = tl.constexpr(8) @@ -14,7 +16,11 @@ class BlackwellMXScaleLayout(Layout): def __init__(self, shape) -> None: super().__init__(shape) - *self.leading_shape, self.K, self.N, = shape + ( + *self.leading_shape, + self.K, + self.N, + ) = shape self.B = math.prod(self.leading_shape) self.ALIGN_K = 8 self.ALIGN_N = 128 @@ -25,30 +31,36 @@ def __init__(self, shape) -> None: def swizzle_data(self, data): data = torch.nn.functional.pad(data, (0, self.N_pad - self.N, 0, self.K_pad - self.K)) data = data.transpose(-1, -2).contiguous() - data = data.reshape(self.B, self.N_pad // self.ALIGN_N, self.ALIGN_N // 32, 32, self.K_pad // self.SWIZZLE_K, - self.SWIZZLE_K) + data = data.reshape( + self.B, self.N_pad // self.ALIGN_N, self.ALIGN_N // 32, 32, self.K_pad // self.SWIZZLE_K, self.SWIZZLE_K + ) data = data.transpose(2, 4).contiguous() data = data.view(1, self.B * self.N_pad // 128, self.K_pad // 4, 2, 256) return data def unswizzle_data(self, data): - data = data.reshape(self.B, self.N_pad // self.ALIGN_N, self.K_pad // self.SWIZZLE_K, 32, self.ALIGN_N // 32, - self.SWIZZLE_K) + data = data.reshape( + self.B, self.N_pad // self.ALIGN_N, self.K_pad // self.SWIZZLE_K, 32, self.ALIGN_N // 32, self.SWIZZLE_K + ) data = data.transpose(2, 4) data = data.reshape(*self.leading_shape, self.N_pad, self.K_pad) data = data.transpose(-1, -2) - return data[..., :self.K, :self.N] + return data[..., : self.K, : self.N] def swizzle_block_shape(self, block_shape): MX_PACK_DIVISOR = 32 MX_SCALE_BLOCK_K = block_shape[1] // MX_PACK_DIVISOR + assert block_shape[0] >= 128, f"{block_shape[0]=} must be >= 128" return [1, block_shape[0] // 128, MX_SCALE_BLOCK_K // 4, 2, 256] @triton.jit -def unswizzle_mx_scale_bw(x, SIZE_OUTER: tl.constexpr = SWIZZLE_SIZE_OUTER, - SIZE_INNER: tl.constexpr = SWIZZLE_SIZE_INNER, - ALIGN_INNER: tl.constexpr = SWIZZLE_ALIGN_INNER): +def unswizzle_mx_scale_bw( + x, + SIZE_OUTER: tl.constexpr = SWIZZLE_SIZE_OUTER, + SIZE_INNER: tl.constexpr = SWIZZLE_SIZE_INNER, + ALIGN_INNER: tl.constexpr = SWIZZLE_ALIGN_INNER, +): shape_0: tl.constexpr = x.shape[0] shape_1: tl.constexpr = x.shape[1] tl.static_assert(shape_1 % SIZE_OUTER == 0) From 3f9220e94a855c9ad721eccd5cf9f40348bdfa28 Mon Sep 17 00:00:00 2001 From: jongsoo-openai Date: Sun, 19 Oct 2025 22:59:33 -0700 Subject: [PATCH 3/4] yapf --- .../triton_kernels/matmul_ogs_details/_common.py | 6 +++--- python/triton_kernels/triton_kernels/tensor.py | 1 + .../tensor_details/layout_details/blackwell_scale.py | 12 +++++------- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py index 37edf93b55af..d0149bbfb592 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py @@ -148,6 +148,7 @@ def _load_tile_attrs( def make_matmul_repr(base_name, order): + def matmul_repr(specialization): signature = specialization.signature constants = specialization.constants @@ -266,6 +267,5 @@ def matmul_launch_metadata(grid, kernel, args): @triton.jit def threadfence_system(): - tl.inline_asm_elementwise( - "mov.u32 $0, 0x0; fence.sc.sys;", args=(), dtype=(tl.int32,), is_pure=False, pack=1, constraints="=r" - ) + tl.inline_asm_elementwise("mov.u32 $0, 0x0; fence.sc.sys;", args=(), dtype=(tl.int32, ), is_pure=False, pack=1, + constraints="=r") diff --git a/python/triton_kernels/triton_kernels/tensor.py b/python/triton_kernels/triton_kernels/tensor.py index 9a78f1b2ef96..48e762a69b88 100644 --- a/python/triton_kernels/triton_kernels/tensor.py +++ b/python/triton_kernels/triton_kernels/tensor.py @@ -179,6 +179,7 @@ def size(self, i=None): # ---------------------------------------------------------------------------- # @dataclass class Bitmatrix(Tensor): + def __post_init__(self): assert self.dtype == BIT super().__post_init__() diff --git a/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py b/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py index e01e321bd1e6..ec2637c75013 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py +++ b/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py @@ -31,21 +31,19 @@ def __init__(self, shape) -> None: def swizzle_data(self, data): data = torch.nn.functional.pad(data, (0, self.N_pad - self.N, 0, self.K_pad - self.K)) data = data.transpose(-1, -2).contiguous() - data = data.reshape( - self.B, self.N_pad // self.ALIGN_N, self.ALIGN_N // 32, 32, self.K_pad // self.SWIZZLE_K, self.SWIZZLE_K - ) + data = data.reshape(self.B, self.N_pad // self.ALIGN_N, self.ALIGN_N // 32, 32, self.K_pad // self.SWIZZLE_K, + self.SWIZZLE_K) data = data.transpose(2, 4).contiguous() data = data.view(1, self.B * self.N_pad // 128, self.K_pad // 4, 2, 256) return data def unswizzle_data(self, data): - data = data.reshape( - self.B, self.N_pad // self.ALIGN_N, self.K_pad // self.SWIZZLE_K, 32, self.ALIGN_N // 32, self.SWIZZLE_K - ) + data = data.reshape(self.B, self.N_pad // self.ALIGN_N, self.K_pad // self.SWIZZLE_K, 32, self.ALIGN_N // 32, + self.SWIZZLE_K) data = data.transpose(2, 4) data = data.reshape(*self.leading_shape, self.N_pad, self.K_pad) data = data.transpose(-1, -2) - return data[..., : self.K, : self.N] + return data[..., :self.K, :self.N] def swizzle_block_shape(self, block_shape): MX_PACK_DIVISOR = 32 From cac7ad5e5f19066aa38a503575e6dc5836cc6a14 Mon Sep 17 00:00:00 2001 From: jongsoo-openai Date: Mon, 20 Oct 2025 07:17:58 -0700 Subject: [PATCH 4/4] more fixes --- python/triton_kernels/tests/test_matmul.py | 12 +++++++++--- python/triton_kernels/triton_kernels/matmul_ogs.py | 3 +-- .../triton_kernels/matmul_ogs_details/_common.py | 2 +- .../opt_flags_details/opt_flags_nvidia.py | 7 +++++-- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/python/triton_kernels/tests/test_matmul.py b/python/triton_kernels/tests/test_matmul.py index 73cb38f680aa..5a6e64517f00 100644 --- a/python/triton_kernels/tests/test_matmul.py +++ b/python/triton_kernels/tests/test_matmul.py @@ -375,6 +375,8 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o pytest.skip("inner_expt_opt and weight mx only supported with pad_w") if is_persistent and not hbm_swizzling: pytest.skip("FIXME: Fatal Python error: Aborted") + if is_hip() and act_dtype_str == "bfloat16": + pytest.skip("FIXME: failed to translate module to LLVM IR") # launch metadata for batched / mx types may not work yet. torch.manual_seed(0) @@ -430,9 +432,13 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o rdata = gindx = sindx = None padding_block_k = 32 - if hbm_swizzling and is_persistent and torch.cuda.get_device_capability()[0] >= 10: - # Blackwell scale swizzling constraint - padding_block_k = 128 + if hbm_swizzling: + if torch.cuda.get_device_capability()[0] >= 10: + # Blackwell scale swizzling constraint + # https://github.com/triton-lang/triton/blob/814b862166c756d9f33238844f4ac047e0243388/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py#L45 + padding_block_k = 128 + elif not is_persistent: + padding_block_k = 64 x_tri, w_tri, bias_tri, gs0_tri, gs1_tri = init_compute_data(m, n, k, rdata, gindx, sindx, n_expts_tot, n_expts_act, mode, torch.bfloat16 if act_mxfp8 else act_dtype, # torch.bfloat16 if weight_mxfp else weight_dtype, diff --git a/python/triton_kernels/triton_kernels/matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul_ogs.py index bfefbe5f7f0e..5f5a3bf4334f 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs.py @@ -644,8 +644,7 @@ def matmul_ogs(x, w, bias, y_tensor_or_tma = y_storage.make_tma(y_tma_block_size, y_tma_mode) if y_has_tma else y_storage.data # create tma descriptor for w w_has_tma = opt_flags.is_persistent - w_tma_block_size = [1, opt_flags.block_k, opt_flags.block_n] - w_tensor_or_tma = w_storage.make_tma(w_tma_block_size, "dense") if w_has_tma else w_storage.data + w_tensor_or_tma = w_storage.make_tma([1, opt_flags.block_k, opt_flags.block_n], "dense") if w_has_tma else w_storage.data # create tma descriptor for w_scale w_scale_has_tma = opt_flags.is_persistent and w_scale is not None w_transpose = w_storage.data.stride()[-2] == 1 diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py index d0149bbfb592..1d71b9bc272d 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py @@ -105,7 +105,7 @@ def _load_tile_attrs( # K_W is only used for non-TMA kernel (W bound is handled by TMA on TMA kernel). if W_IS_PADDED: off_k_w = padded_start_off_raw * PACKED_BLOCK_K_W - K_W = tl.load(ExptTileOffs + pid_e + 1) * BLOCK_K + K_W = tl.load(ExptTileOffs + pid_e + 1) * PACKED_BLOCK_K_W else: off_k_w = unpadded_start_off K_W = tl.load(ExptOffs + pid_e + 1) diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py index 064c36ef599e..4b41d4aae686 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py @@ -18,8 +18,11 @@ def compute_grid_size(routing_data, batch_size, m, n, block_m, block_n): def compute_block_n(n: int, arch, precision_config): # block_n: layout = get_layout(precision_config.weight_scale) - if isinstance(layout, HopperAmpereMXScaleLayout) and layout.num_warps == 4: - return 128, 128 + if isinstance(layout, HopperAmpereMXScaleLayout): + if layout.num_warps in [4, 8]: + # https://github.com/triton-lang/triton/blob/814b862166c756d9f33238844f4ac047e0243388/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py#L265 + block_n = 2 * layout.num_warps * 2 * 8 + return block_n, block_n elif precision_config.max_num_imprecise_acc is None and n > 128: return 256, 256 else: