From 10e7a7ba16c9cd2d92332e744bc04e404870f34a Mon Sep 17 00:00:00 2001 From: Afroz Mohiuddin Date: Wed, 8 Oct 2025 08:37:56 -0700 Subject: [PATCH 01/13] Update constraints dict to accept callable for split_k and invoke that callable to determine split_k. --- .../matmul_ogs_details/opt_flags.py | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py index 9f39a9464c97..467c00e02001 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py @@ -8,6 +8,10 @@ import torch from .opt_flags_details import opt_flags_amd, opt_flags_nvidia from triton_kernels.tensor import bitwidth +from typing import Callable + +# Function type: takes four ints (batch_size, m, n, k) and a output dtype, returns an int +CallableSplitK = Callable[[int, int, int, int, torch.dtype], int] @dataclass @@ -33,6 +37,19 @@ def __post_init__(self): raise ValueError("Not supported") +def get_split_k_from_constraints( + constraints_split_k: int | CallableSplitK, + batch_size: int, + m: int, + n: int, + k: int, + out_dtype: torch.dtype) -> int: + if isinstance(constraints_split_k, int): + return constraints_split_k + + return constraints_split_k(batch_size, m, n, k, out_dtype) + + def make_default_opt_flags_amd( out_dtype, lhs_dtype, @@ -91,7 +108,8 @@ def make_default_opt_flags_amd( is_persistent = constraints.get("is_persistent", False) # split_k: if constraints.get("split_k", None) is not None: - split_k = constraints["split_k"] + split_k = get_split_k_from_constraints( + constraints["split_k"], batch_size, m, n, k, out_dtype) elif is_persistent or enforce_bitwise_invariance: split_k = 1 else: @@ -221,7 +239,8 @@ def make_default_opt_flags_nvidia( block_n, block_k = block_k, block_n # split_k if constraints.get("split_k", None) is not None: - split_k = constraints["split_k"] + split_k = get_split_k_from_constraints( + constraints["split_k"], batch_size, m, n, k, out_dtype) elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None: split_k = 1 else: @@ -293,7 +312,7 @@ def make_default_opt_flags_nvidia( _opt_flags_constraints: dict = dict() _opt_flags: OptFlags | None = None -def update_opt_flags_constraints(constraints: dict[str, int]): +def update_opt_flags_constraints(constraints: dict[str, int | Callable]): global _opt_flags_constraints _opt_flags_constraints.update(constraints) From 9f20705f2d1641ec295a7555fc14a9a32a2294c2 Mon Sep 17 00:00:00 2001 From: Afroz Mohiuddin Date: Wed, 8 Oct 2025 08:45:41 -0700 Subject: [PATCH 02/13] Update constraints dict to accept callable for split_k and invoke that callable to determine split_k. Update constraints dict to accept callable for split_k and invoke that callable to determine split_k. --- .../matmul_ogs_details/opt_flags.py | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py index 9f39a9464c97..467c00e02001 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py @@ -8,6 +8,10 @@ import torch from .opt_flags_details import opt_flags_amd, opt_flags_nvidia from triton_kernels.tensor import bitwidth +from typing import Callable + +# Function type: takes four ints (batch_size, m, n, k) and a output dtype, returns an int +CallableSplitK = Callable[[int, int, int, int, torch.dtype], int] @dataclass @@ -33,6 +37,19 @@ def __post_init__(self): raise ValueError("Not supported") +def get_split_k_from_constraints( + constraints_split_k: int | CallableSplitK, + batch_size: int, + m: int, + n: int, + k: int, + out_dtype: torch.dtype) -> int: + if isinstance(constraints_split_k, int): + return constraints_split_k + + return constraints_split_k(batch_size, m, n, k, out_dtype) + + def make_default_opt_flags_amd( out_dtype, lhs_dtype, @@ -91,7 +108,8 @@ def make_default_opt_flags_amd( is_persistent = constraints.get("is_persistent", False) # split_k: if constraints.get("split_k", None) is not None: - split_k = constraints["split_k"] + split_k = get_split_k_from_constraints( + constraints["split_k"], batch_size, m, n, k, out_dtype) elif is_persistent or enforce_bitwise_invariance: split_k = 1 else: @@ -221,7 +239,8 @@ def make_default_opt_flags_nvidia( block_n, block_k = block_k, block_n # split_k if constraints.get("split_k", None) is not None: - split_k = constraints["split_k"] + split_k = get_split_k_from_constraints( + constraints["split_k"], batch_size, m, n, k, out_dtype) elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None: split_k = 1 else: @@ -293,7 +312,7 @@ def make_default_opt_flags_nvidia( _opt_flags_constraints: dict = dict() _opt_flags: OptFlags | None = None -def update_opt_flags_constraints(constraints: dict[str, int]): +def update_opt_flags_constraints(constraints: dict[str, int | Callable]): global _opt_flags_constraints _opt_flags_constraints.update(constraints) From d2efa7c82939e609e8b79bbfc40803838c4229a2 Mon Sep 17 00:00:00 2001 From: Afroz Mohiuddin Date: Wed, 8 Oct 2025 10:17:39 -0700 Subject: [PATCH 03/13] Make notes that one needs to update constraint checking --- .../triton_kernels/matmul_ogs_details/opt_flags.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py index 467c00e02001..a100dee6340b 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py @@ -167,6 +167,7 @@ def replace_with_valid_constraint(k: str, v): target_kernel_kwargs=target_kernel_kwargs, ) # check constraints + # TODO(afroz): Update this later. assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}" return ret @@ -302,6 +303,7 @@ def make_default_opt_flags_nvidia( idle_sms=constraints.get("idle_sms", 0), ) # check constraints + # TODO(afroz): Update this later. assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}" return ret From e89313e9e05f3b37d9ba20c1e7f1c477ce5e75c5 Mon Sep 17 00:00:00 2001 From: Afroz Mohiuddin Date: Wed, 8 Oct 2025 12:44:24 -0700 Subject: [PATCH 04/13] Add tests --- .../tests/test_opt_flags_split_k.py | 226 ++++++++++++++++++ 1 file changed, 226 insertions(+) create mode 100644 python/triton_kernels/tests/test_opt_flags_split_k.py diff --git a/python/triton_kernels/tests/test_opt_flags_split_k.py b/python/triton_kernels/tests/test_opt_flags_split_k.py new file mode 100644 index 000000000000..e6c63c23cf63 --- /dev/null +++ b/python/triton_kernels/tests/test_opt_flags_split_k.py @@ -0,0 +1,226 @@ +# isort: off +# fmt: off +import types +from typing import Callable +import pytest + +torch = pytest.importorskip("torch") + +import triton_kernels.matmul_ogs_details.opt_flags as opt_flags + + +class _DummyPrecisionConfig: + def __init__(self): + self.weight_scale = None + self.max_num_imprecise_acc = None + self.act_scale = None + self.out_scale = None + self.enforce_bitwise_invariance = False + + +def _stub_cuda_props(*_args, **_kwargs): + return types.SimpleNamespace(multi_processor_count=16) + + +def setup_amd(monkeypatch): + monkeypatch.setattr(opt_flags, "get_cdna_version", lambda: 3) + monkeypatch.setattr(opt_flags.torch.cuda, "get_device_properties", _stub_cuda_props) + monkeypatch.setattr( + opt_flags.opt_flags_amd, + "compute_block_nk", + lambda *args, **kwargs: (64, 32), + ) + + +def setup_nvidia(monkeypatch): + monkeypatch.setattr(opt_flags.torch.cuda, "get_device_properties", _stub_cuda_props) + monkeypatch.setattr(opt_flags.torch.cuda, "get_device_capability", lambda: (9, 0)) + monkeypatch.setattr( + opt_flags.opt_flags_nvidia, + "compute_block_n", + lambda n, arch, precision_config: (64, 32), + ) + monkeypatch.setattr( + opt_flags.opt_flags_nvidia, + "compute_grid_size", + lambda routing_data, batch_size, m, n, block_m, block_n: 4, + ) + monkeypatch.setattr( + opt_flags.opt_flags_nvidia, + "compute_block_k", + lambda m, k, is_persistent, lhs_dtype, rhs_dtype, precision_config, has_y_acc_in: 32, + ) + monkeypatch.setattr( + opt_flags.opt_flags_nvidia, + "compute_split_k", + lambda block_k, k, estimated_actual_grid_size: 1, + ) + monkeypatch.setattr( + opt_flags.opt_flags_nvidia, + "compute_num_stages", + lambda *args, **kwargs: 2, + ) + monkeypatch.setattr( + opt_flags.opt_flags_nvidia, + "compute_num_warps", + lambda block_m, block_n, is_persistent, precision_config: 4, + ) + + +def make_split_k_limiter( + max_size_bytes: float, + max_split_k: int, +) -> Callable[[int, int, int, int, torch.dtype], int]: + """Create a ki_split_k callback that respects a memory ceiling and max_split_k. + + Args: + max_size_bytes: Maximum intermediate size in bytes. + max_split_k: Maximum allowable split_k value. + + Returns: + A callable that computes the maximum split_k that keeps the + intermediate matrix ``split_k * b * m * n`` of the provided dtype under the + size limit. The value is clamped between 1 and ``max_split_k`` for positive shapes and + raises ``ValueError`` for non-positive arguments or invalid dtypes. + """ + + if max_size_bytes <= 0: + raise ValueError("max_size_bytes must be positive") + if max_split_k < 1: + raise ValueError("max_split_k must be at least 1") + + def _limit_split_k(b: int, m: int, n: int, k: int, dtype: torch.dtype) -> int: + del k # unused but kept for signature compatibility + elem_size = torch.empty((), dtype=dtype).element_size() + bytes_per_split = b * m * n * elem_size + + if bytes_per_split <= 0: + raise ValueError( + "Invalid arguments: " + f"{bytes_per_split=} = {b=} * {m=} * {n=} * size(dtype)={elem_size}" + ) + + max_split = int(max_size_bytes // bytes_per_split) + return min(max_split_k, max(1, max_split)) + + return _limit_split_k + + +def test_make_default_opt_flags_amd_split_k_callable(monkeypatch): + setup_amd(monkeypatch) + + captured_args = {} + + def split_k_callable(batch_size, m, n, k, out_dtype): + captured_args["value"] = (batch_size, m, n, k, out_dtype) + return 5 + + precision_config = _DummyPrecisionConfig() + flags = opt_flags.make_default_opt_flags_amd( + torch.float16, + torch.float16, + torch.float16, + precision_config, + 2, + 128, + 64, + 32, + None, + False, + False, + False, + 0, + False, + False, + {"split_k": split_k_callable}, + ) + + assert flags.split_k == 5 + assert captured_args["value"] == (2, 128, 64, 32, torch.float16) + + +def test_make_default_opt_flags_nvidia_split_k_callable(monkeypatch): + setup_nvidia(monkeypatch) + + captured_args = {} + + def split_k_callable(batch_size, m, n, k, out_dtype): + captured_args["value"] = (batch_size, m, n, k, out_dtype) + return 3 + + precision_config = _DummyPrecisionConfig() + flags = opt_flags.make_default_opt_flags_nvidia( + torch.float16, + torch.float16, + torch.float16, + precision_config, + 4, + 256, + 128, + 64, + None, + False, + False, + False, + 0, + False, + False, + {"split_k": split_k_callable}, + ) + + assert flags.split_k == 3 + assert captured_args["value"] == (4, 256, 128, 64, torch.float16) + + +def test_split_k_callable_with_max_size_callable(monkeypatch): + setup_nvidia(monkeypatch) + + batch_size, m, n, k = 4, 256, 128, 64 + bytes_float16 = 2 + intermediate_size = batch_size * m * n * bytes_float16 + + def get_flags(_split_k_callable): + + return opt_flags.make_default_opt_flags_nvidia( + torch.float16, + torch.float16, + torch.float16, + _DummyPrecisionConfig(), + batch_size, + m, + n, + k, + None, + False, + False, + False, + 0, + False, + False, + { "split_k": _split_k_callable}, + ) + + # Test with a very small allowance that only allows split_k=allowance + allowance = 2 + max_allowable_split_k = 4 + split_k_callable = make_split_k_limiter(allowance * intermediate_size, max_allowable_split_k) + flags = get_flags(split_k_callable) + + assert flags.split_k == allowance + + # With a larger allowance, we should bump against the max allowable split_k + allowance = 8 + max_allowable_split_k = 4 + split_k_callable = make_split_k_limiter(allowance * intermediate_size, max_allowable_split_k) + flags = get_flags(split_k_callable) + + assert flags.split_k == max_allowable_split_k + + # If we bump up the max_allowable_split_k, we should get the allowance + allowance = 8 + max_allowable_split_k = 8 + split_k_callable = make_split_k_limiter(allowance * intermediate_size, max_allowable_split_k) + flags = get_flags(split_k_callable) + + assert flags.split_k == max_allowable_split_k + From fa3ef72c88560c0eb054ad21088bcad05ada3452 Mon Sep 17 00:00:00 2001 From: Afroz Mohiuddin Date: Thu, 9 Oct 2025 11:41:38 -0700 Subject: [PATCH 05/13] Rework split_k callable to dynamic_split_k instead and add tests --- .../tests/test_opt_flags_split_k.py | 136 +++++++----------- .../matmul_ogs_details/opt_flags.py | 93 +++++++++--- 2 files changed, 123 insertions(+), 106 deletions(-) diff --git a/python/triton_kernels/tests/test_opt_flags_split_k.py b/python/triton_kernels/tests/test_opt_flags_split_k.py index e6c63c23cf63..1eb413c30ee2 100644 --- a/python/triton_kernels/tests/test_opt_flags_split_k.py +++ b/python/triton_kernels/tests/test_opt_flags_split_k.py @@ -67,54 +67,9 @@ def setup_nvidia(monkeypatch): ) -def make_split_k_limiter( - max_size_bytes: float, - max_split_k: int, -) -> Callable[[int, int, int, int, torch.dtype], int]: - """Create a ki_split_k callback that respects a memory ceiling and max_split_k. - - Args: - max_size_bytes: Maximum intermediate size in bytes. - max_split_k: Maximum allowable split_k value. - - Returns: - A callable that computes the maximum split_k that keeps the - intermediate matrix ``split_k * b * m * n`` of the provided dtype under the - size limit. The value is clamped between 1 and ``max_split_k`` for positive shapes and - raises ``ValueError`` for non-positive arguments or invalid dtypes. - """ - - if max_size_bytes <= 0: - raise ValueError("max_size_bytes must be positive") - if max_split_k < 1: - raise ValueError("max_split_k must be at least 1") - - def _limit_split_k(b: int, m: int, n: int, k: int, dtype: torch.dtype) -> int: - del k # unused but kept for signature compatibility - elem_size = torch.empty((), dtype=dtype).element_size() - bytes_per_split = b * m * n * elem_size - - if bytes_per_split <= 0: - raise ValueError( - "Invalid arguments: " - f"{bytes_per_split=} = {b=} * {m=} * {n=} * size(dtype)={elem_size}" - ) - - max_split = int(max_size_bytes // bytes_per_split) - return min(max_split_k, max(1, max_split)) - - return _limit_split_k - - -def test_make_default_opt_flags_amd_split_k_callable(monkeypatch): +def test_make_default_opt_flags_amd_split_k_constraint(monkeypatch): setup_amd(monkeypatch) - captured_args = {} - - def split_k_callable(batch_size, m, n, k, out_dtype): - captured_args["value"] = (batch_size, m, n, k, out_dtype) - return 5 - precision_config = _DummyPrecisionConfig() flags = opt_flags.make_default_opt_flags_amd( torch.float16, @@ -132,22 +87,15 @@ def split_k_callable(batch_size, m, n, k, out_dtype): 0, False, False, - {"split_k": split_k_callable}, + {"split_k": 5}, ) assert flags.split_k == 5 - assert captured_args["value"] == (2, 128, 64, 32, torch.float16) -def test_make_default_opt_flags_nvidia_split_k_callable(monkeypatch): +def test_make_default_opt_flags_nvidia_split_k_constraint(monkeypatch): setup_nvidia(monkeypatch) - captured_args = {} - - def split_k_callable(batch_size, m, n, k, out_dtype): - captured_args["value"] = (batch_size, m, n, k, out_dtype) - return 3 - precision_config = _DummyPrecisionConfig() flags = opt_flags.make_default_opt_flags_nvidia( torch.float16, @@ -165,22 +113,23 @@ def split_k_callable(batch_size, m, n, k, out_dtype): 0, False, False, - {"split_k": split_k_callable}, + {"split_k": 3}, ) assert flags.split_k == 3 - assert captured_args["value"] == (4, 256, 128, 64, torch.float16) -def test_split_k_callable_with_max_size_callable(monkeypatch): +def test_dynamic_split_k(monkeypatch): setup_nvidia(monkeypatch) - batch_size, m, n, k = 4, 256, 128, 64 + batch_size, m, n = 4, 256, 128 + k = (2**6) * 3 + bytes_float16 = 2 intermediate_size = batch_size * m * n * bytes_float16 - def get_flags(_split_k_callable): - + def get_flags(split_k, dynamic_split_k_max_size_bytes, dynamic_split_k_max_split_k): + dynamic_split_k = dynamic_split_k_max_size_bytes is not None return opt_flags.make_default_opt_flags_nvidia( torch.float16, torch.float16, @@ -197,30 +146,47 @@ def get_flags(_split_k_callable): 0, False, False, - { "split_k": _split_k_callable}, + { + "split_k": split_k, + "dynamic_split_k": dynamic_split_k, + "dynamic_split_k_max_size_bytes": dynamic_split_k_max_size_bytes, + "dynamic_split_k_max_split_k": dynamic_split_k_max_split_k, + }, ) - # Test with a very small allowance that only allows split_k=allowance - allowance = 2 - max_allowable_split_k = 4 - split_k_callable = make_split_k_limiter(allowance * intermediate_size, max_allowable_split_k) - flags = get_flags(split_k_callable) - - assert flags.split_k == allowance - - # With a larger allowance, we should bump against the max allowable split_k - allowance = 8 - max_allowable_split_k = 4 - split_k_callable = make_split_k_limiter(allowance * intermediate_size, max_allowable_split_k) - flags = get_flags(split_k_callable) - - assert flags.split_k == max_allowable_split_k - - # If we bump up the max_allowable_split_k, we should get the allowance - allowance = 8 - max_allowable_split_k = 8 - split_k_callable = make_split_k_limiter(allowance * intermediate_size, max_allowable_split_k) - flags = get_flags(split_k_callable) - - assert flags.split_k == max_allowable_split_k + # If `dynamic_split_k` is not specified, we get the specified split_k + for split_k in [1, 2, 4, 8]: + flags = get_flags(split_k, None, None) + assert flags.split_k == split_k + + # If `dynamic_split_k` is specified, then it is computed, and the specified split_k is ignored. + possible_splits = 6 + # So 6 splits are possible + assert k % possible_splits == 0 + allowance = possible_splits * intermediate_size + given_split_k = 3 + flags = get_flags(given_split_k, allowance, None) + assert flags.split_k == possible_splits + + # If we specify a max split size in the above scenario, it is respected, even though more splits are possible. + max_split_k = 4 + flags = get_flags(given_split_k, allowance, max_split_k) + assert flags.split_k == max_split_k + + # When the allowance is low enough, no splits are possible. + allowance = intermediate_size + flags = get_flags(given_split_k, allowance, max_split_k) + assert flags.split_k == 1 + + # Extreme case, split_k = k + allowance = k * intermediate_size + flags = get_flags(given_split_k, allowance, None) + assert flags.split_k == k + + # Split k doesn't need to be a divisor of k + non_divisor_k = 5 + assert k % non_divisor_k != 0 + allowance = non_divisor_k * intermediate_size + flags = get_flags(None, allowance, None) + assert flags.split_k == non_divisor_k diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py index a100dee6340b..93a6a71a38f2 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py @@ -37,17 +37,41 @@ def __post_init__(self): raise ValueError("Not supported") -def get_split_k_from_constraints( - constraints_split_k: int | CallableSplitK, - batch_size: int, - m: int, - n: int, - k: int, - out_dtype: torch.dtype) -> int: - if isinstance(constraints_split_k, int): - return constraints_split_k +# NOTE(afroz): We need bitwise identical results with variable batch sizes, so we shouldn't take +# batch size into account when we do this. Rework accordingly. +def dynamic_split_k( + max_size_bytes: float, + batch_size: int, + m: int, + n: int, + k: int, + output_dtype: torch.dtype, + max_split_k: int | None = None, + ) -> int: + """Returns split_k value respecting max_size_bytes and optionally max_split_k constraints.""" - return constraints_split_k(batch_size, m, n, k, out_dtype) + elem_size = torch.empty((), dtype=output_dtype).element_size() + bytes_per_split = batch_size * m * n * elem_size + + # max_split can only be as high as the allowance from max_size_bytes + max_split = int(max_size_bytes // bytes_per_split) + + # max_split can only be as high as the allowable max_split_k, if specified. + if max_split_k is not None: + max_split = min(max_split_k, max(1, max_split)) + + # NOTE: max_split doesn't need to divide k + # while k % max_split != 0 and max_split > 1: + # max_split -= 1 + + return max_split + + +def all_constraints_satisfied(opt_flags: OptFlags, constraints: dict) -> bool: + _split_k_constraints = ['split_k', 'dynamic_split_k', 'dynamic_split_k_max_size_bytes', 'dynamic_split_k_max_split_k'] + assert all(getattr(opt_flags, ck) == cv for ck, cv in constraints.items() if cv is not None and ck not in _split_k_constraints) + if constraints.get('split_k') and not constraints.get('dynamic_split_k'): + assert opt_flags.split_k == constraints['split_k'] def make_default_opt_flags_amd( @@ -68,7 +92,8 @@ def make_default_opt_flags_amd( has_y_acc_in, constraints, ): - constraints_supported = ["block_m", "block_n", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile"] + constraints_supported = ["block_m", "block_n", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile", + "dynamic_split_k", "dynamic_split_k_max_size_bytes", "dynamic_split_k_max_split_k"] assert not any([c not in constraints_supported for c in constraints]), constraints.keys() # tokens per expert if routing_data is None: @@ -107,9 +132,18 @@ def make_default_opt_flags_amd( ) is_persistent = constraints.get("is_persistent", False) # split_k: - if constraints.get("split_k", None) is not None: - split_k = get_split_k_from_constraints( - constraints["split_k"], batch_size, m, n, k, out_dtype) + if constraints.get("dynamic_split_k", False): + split_k = dynamic_split_k( + constraints["dynamic_split_k_max_size_bytes"], + batch_size, + m, + n, + k, + out_dtype, + constraints.get("dynamic_split_k_max_split_k"), + ) + elif constraints.get("split_k", None) is not None: + split_k = constraints["split_k"] elif is_persistent or enforce_bitwise_invariance: split_k = 1 else: @@ -168,7 +202,8 @@ def replace_with_valid_constraint(k: str, v): ) # check constraints # TODO(afroz): Update this later. - assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}" + # assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}" + all_constraints_satisfied(ret, constraints) return ret def make_default_opt_flags_nvidia( @@ -189,7 +224,8 @@ def make_default_opt_flags_nvidia( has_y_acc_in, constraints, ): - constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages", "idle_sms"] + constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages", + "idle_sms", "dynamic_split_k", "dynamic_split_k_max_size_bytes", "dynamic_split_k_max_split_k"] assert not any([c not in constraints_supported for c in constraints]), constraints.keys() # tokens per expert if routing_data is None or batch_size > 1: @@ -239,9 +275,18 @@ def make_default_opt_flags_nvidia( # TODO: swizzle the HBM layout of the weights instead block_n, block_k = block_k, block_n # split_k - if constraints.get("split_k", None) is not None: - split_k = get_split_k_from_constraints( - constraints["split_k"], batch_size, m, n, k, out_dtype) + if constraints.get("dynamic_split_k", False): + split_k = dynamic_split_k( + constraints["dynamic_split_k_max_size_bytes"], + batch_size, + m, + n, + k, + out_dtype, + constraints.get("dynamic_split_k_max_split_k"), + ) + elif constraints.get("split_k", None) is not None: + split_k = constraints["split_k"] elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None: split_k = 1 else: @@ -303,8 +348,7 @@ def make_default_opt_flags_nvidia( idle_sms=constraints.get("idle_sms", 0), ) # check constraints - # TODO(afroz): Update this later. - assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}" + all_constraints_satisfied(ret, constraints) return ret # -------------- @@ -352,6 +396,13 @@ def make_opt_flags( raise InapplicableConstraint("cannot enforce `is_persistent=True` constraint") if _opt_flags_constraints.get("fused_scatter", False) and not can_use_fused_scatter: raise InapplicableConstraint("cannot enforce `fused_scatter=True` constraint") + if _opt_flags_constraints.get("dynamic_split_k"): + # Ensure dynamic_split_k_max_size_bytes + if _opt_flags_constraints.get("dynamic_split_k_max_size_bytes", 0) <= 0: + raise InapplicableConstraint("dynamic_split_k_max_size_bytes must be > 0.") + # dynamic_split_k_max_split_k - If specified, must be at least 1 + if "dynamic_split_k_max_split_k" in _opt_flags_constraints and _opt_flags_constraints["dynamic_split_k_max_split_k"] < 1: + raise InapplicableConstraint("dynamic_split_k_max_split_k must be at least 1 if specified") enforce_bitwise_invariance = precision_config.enforce_bitwise_invariance if _opt_flags is not None: assert not _opt_flags_constraints From 04cc7d26783a06c4b34b16f82c66aec29a4d500d Mon Sep 17 00:00:00 2001 From: Afroz Mohiuddin Date: Thu, 9 Oct 2025 11:54:06 -0700 Subject: [PATCH 06/13] Change dynamic_split_k to only return 1 for the cursed shape, else 4 --- .../matmul_ogs_details/opt_flags.py | 46 +++++++++++++------ 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py index 93a6a71a38f2..d3b76062cd18 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py @@ -39,6 +39,34 @@ def __post_init__(self): # NOTE(afroz): We need bitwise identical results with variable batch sizes, so we shouldn't take # batch size into account when we do this. Rework accordingly. +# def dynamic_split_k( +# max_size_bytes: float, +# batch_size: int, +# m: int, +# n: int, +# k: int, +# output_dtype: torch.dtype, +# max_split_k: int | None = None, +# ) -> int: +# """Returns split_k value respecting max_size_bytes and optionally max_split_k constraints.""" + +# elem_size = torch.empty((), dtype=output_dtype).element_size() +# bytes_per_split = batch_size * m * n * elem_size + +# # max_split can only be as high as the allowance from max_size_bytes +# max_split = int(max_size_bytes // bytes_per_split) + +# # max_split can only be as high as the allowable max_split_k, if specified. +# if max_split_k is not None: +# max_split = min(max_split_k, max(1, max_split)) + +# # NOTE: max_split doesn't need to divide k +# # while k % max_split != 0 and max_split > 1: +# # max_split -= 1 + +# return max_split + +# Return max_split_k except the one shape that blows up! def dynamic_split_k( max_size_bytes: float, batch_size: int, @@ -48,23 +76,15 @@ def dynamic_split_k( output_dtype: torch.dtype, max_split_k: int | None = None, ) -> int: - """Returns split_k value respecting max_size_bytes and optionally max_split_k constraints.""" + """Return max_split_k except the one shape that blows up!""" - elem_size = torch.empty((), dtype=output_dtype).element_size() - bytes_per_split = batch_size * m * n * elem_size + del output_dtype, batch_size, k, max_size_bytes - # max_split can only be as high as the allowance from max_size_bytes - max_split = int(max_size_bytes // bytes_per_split) + if m * n == (128 * 1024) * (201088 // 8): + return 1 - # max_split can only be as high as the allowable max_split_k, if specified. - if max_split_k is not None: - max_split = min(max_split_k, max(1, max_split)) + return max_split_k or 4 - # NOTE: max_split doesn't need to divide k - # while k % max_split != 0 and max_split > 1: - # max_split -= 1 - - return max_split def all_constraints_satisfied(opt_flags: OptFlags, constraints: dict) -> bool: From 206ee736cf0c1a8afa1d2113cddc54291b464b22 Mon Sep 17 00:00:00 2001 From: Afroz Mohiuddin Date: Fri, 10 Oct 2025 09:56:15 -0700 Subject: [PATCH 07/13] Function that checks shape only on m * n, doesn't involve b --- .../tests/test_opt_flags_split_k.py | 53 +++-------- .../matmul_ogs_details/opt_flags.py | 93 ++++--------------- 2 files changed, 28 insertions(+), 118 deletions(-) diff --git a/python/triton_kernels/tests/test_opt_flags_split_k.py b/python/triton_kernels/tests/test_opt_flags_split_k.py index 1eb413c30ee2..a85480c1513c 100644 --- a/python/triton_kernels/tests/test_opt_flags_split_k.py +++ b/python/triton_kernels/tests/test_opt_flags_split_k.py @@ -119,17 +119,15 @@ def test_make_default_opt_flags_nvidia_split_k_constraint(monkeypatch): assert flags.split_k == 3 -def test_dynamic_split_k(monkeypatch): +def test_max_allowable_mn(monkeypatch): setup_nvidia(monkeypatch) - batch_size, m, n = 4, 256, 128 - k = (2**6) * 3 + batch_size, m, n, k = 1, 256, 256, 256 bytes_float16 = 2 intermediate_size = batch_size * m * n * bytes_float16 - def get_flags(split_k, dynamic_split_k_max_size_bytes, dynamic_split_k_max_split_k): - dynamic_split_k = dynamic_split_k_max_size_bytes is not None + def get_flags(split_k, max_mn): return opt_flags.make_default_opt_flags_nvidia( torch.float16, torch.float16, @@ -148,45 +146,16 @@ def get_flags(split_k, dynamic_split_k_max_size_bytes, dynamic_split_k_max_split False, { "split_k": split_k, - "dynamic_split_k": dynamic_split_k, - "dynamic_split_k_max_size_bytes": dynamic_split_k_max_size_bytes, - "dynamic_split_k_max_split_k": dynamic_split_k_max_split_k, + "max_allowable_mn": max_mn, }, ) - # If `dynamic_split_k` is not specified, we get the specified split_k - for split_k in [1, 2, 4, 8]: - flags = get_flags(split_k, None, None) - assert flags.split_k == split_k - - # If `dynamic_split_k` is specified, then it is computed, and the specified split_k is ignored. - possible_splits = 6 - # So 6 splits are possible - assert k % possible_splits == 0 - allowance = possible_splits * intermediate_size - given_split_k = 3 - flags = get_flags(given_split_k, allowance, None) - assert flags.split_k == possible_splits - - # If we specify a max split size in the above scenario, it is respected, even though more splits are possible. - max_split_k = 4 - flags = get_flags(given_split_k, allowance, max_split_k) - assert flags.split_k == max_split_k - - # When the allowance is low enough, no splits are possible. - allowance = intermediate_size - flags = get_flags(given_split_k, allowance, max_split_k) + split_k = 6 + max_mn = (m * n) // 2 + flags = get_flags(split_k, max_mn) assert flags.split_k == 1 - # Extreme case, split_k = k - allowance = k * intermediate_size - flags = get_flags(given_split_k, allowance, None) - assert flags.split_k == k - - # Split k doesn't need to be a divisor of k - non_divisor_k = 5 - assert k % non_divisor_k != 0 - allowance = non_divisor_k * intermediate_size - flags = get_flags(None, allowance, None) - assert flags.split_k == non_divisor_k - + split_k = 6 + max_mn = (m * n) * 2 + flags = get_flags(split_k, max_mn) + assert flags.split_k == split_k diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py index d3b76062cd18..747cc7a46aee 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py @@ -37,60 +37,22 @@ def __post_init__(self): raise ValueError("Not supported") -# NOTE(afroz): We need bitwise identical results with variable batch sizes, so we shouldn't take -# batch size into account when we do this. Rework accordingly. -# def dynamic_split_k( -# max_size_bytes: float, -# batch_size: int, -# m: int, -# n: int, -# k: int, -# output_dtype: torch.dtype, -# max_split_k: int | None = None, -# ) -> int: -# """Returns split_k value respecting max_size_bytes and optionally max_split_k constraints.""" - -# elem_size = torch.empty((), dtype=output_dtype).element_size() -# bytes_per_split = batch_size * m * n * elem_size - -# # max_split can only be as high as the allowance from max_size_bytes -# max_split = int(max_size_bytes // bytes_per_split) - -# # max_split can only be as high as the allowable max_split_k, if specified. -# if max_split_k is not None: -# max_split = min(max_split_k, max(1, max_split)) - -# # NOTE: max_split doesn't need to divide k -# # while k % max_split != 0 and max_split > 1: -# # max_split -= 1 - -# return max_split - -# Return max_split_k except the one shape that blows up! -def dynamic_split_k( - max_size_bytes: float, - batch_size: int, +def max_allowable_mn( + max_mn: int, m: int, n: int, - k: int, - output_dtype: torch.dtype, - max_split_k: int | None = None, + split_k: int, ) -> int: - """Return max_split_k except the one shape that blows up!""" - - del output_dtype, batch_size, k, max_size_bytes - - if m * n == (128 * 1024) * (201088 // 8): + if m * n >= max_mn: return 1 - return max_split_k or 4 - + return split_k def all_constraints_satisfied(opt_flags: OptFlags, constraints: dict) -> bool: - _split_k_constraints = ['split_k', 'dynamic_split_k', 'dynamic_split_k_max_size_bytes', 'dynamic_split_k_max_split_k'] + _split_k_constraints = ['split_k', 'max_allowable_mn'] assert all(getattr(opt_flags, ck) == cv for ck, cv in constraints.items() if cv is not None and ck not in _split_k_constraints) - if constraints.get('split_k') and not constraints.get('dynamic_split_k'): + if constraints.get('split_k') and not constraints.get('max_allowable_mn'): assert opt_flags.split_k == constraints['split_k'] @@ -112,8 +74,7 @@ def make_default_opt_flags_amd( has_y_acc_in, constraints, ): - constraints_supported = ["block_m", "block_n", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile", - "dynamic_split_k", "dynamic_split_k_max_size_bytes", "dynamic_split_k_max_split_k"] + constraints_supported = ["block_m", "block_n", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile", "max_allowable_mn"] assert not any([c not in constraints_supported for c in constraints]), constraints.keys() # tokens per expert if routing_data is None: @@ -152,16 +113,8 @@ def make_default_opt_flags_amd( ) is_persistent = constraints.get("is_persistent", False) # split_k: - if constraints.get("dynamic_split_k", False): - split_k = dynamic_split_k( - constraints["dynamic_split_k_max_size_bytes"], - batch_size, - m, - n, - k, - out_dtype, - constraints.get("dynamic_split_k_max_split_k"), - ) + if constraints.get("max_allowable_mn", 0) > 0 and constraints.get("split_k") is not None: + split_k = max_allowable_mn(constraints["max_allowable_mn"], m, n, constraints.get("split_k")) elif constraints.get("split_k", None) is not None: split_k = constraints["split_k"] elif is_persistent or enforce_bitwise_invariance: @@ -245,7 +198,7 @@ def make_default_opt_flags_nvidia( constraints, ): constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages", - "idle_sms", "dynamic_split_k", "dynamic_split_k_max_size_bytes", "dynamic_split_k_max_split_k"] + "idle_sms", "max_allowable_mn"] assert not any([c not in constraints_supported for c in constraints]), constraints.keys() # tokens per expert if routing_data is None or batch_size > 1: @@ -295,16 +248,8 @@ def make_default_opt_flags_nvidia( # TODO: swizzle the HBM layout of the weights instead block_n, block_k = block_k, block_n # split_k - if constraints.get("dynamic_split_k", False): - split_k = dynamic_split_k( - constraints["dynamic_split_k_max_size_bytes"], - batch_size, - m, - n, - k, - out_dtype, - constraints.get("dynamic_split_k_max_split_k"), - ) + if constraints.get("max_allowable_mn", 0) > 0 and constraints.get("split_k") is not None: + split_k = max_allowable_mn(constraints["max_allowable_mn"], m, n, constraints.get("split_k")) elif constraints.get("split_k", None) is not None: split_k = constraints["split_k"] elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None: @@ -378,7 +323,7 @@ def make_default_opt_flags_nvidia( _opt_flags_constraints: dict = dict() _opt_flags: OptFlags | None = None -def update_opt_flags_constraints(constraints: dict[str, int | Callable]): +def update_opt_flags_constraints(constraints: dict[str, int]): global _opt_flags_constraints _opt_flags_constraints.update(constraints) @@ -416,13 +361,9 @@ def make_opt_flags( raise InapplicableConstraint("cannot enforce `is_persistent=True` constraint") if _opt_flags_constraints.get("fused_scatter", False) and not can_use_fused_scatter: raise InapplicableConstraint("cannot enforce `fused_scatter=True` constraint") - if _opt_flags_constraints.get("dynamic_split_k"): - # Ensure dynamic_split_k_max_size_bytes - if _opt_flags_constraints.get("dynamic_split_k_max_size_bytes", 0) <= 0: - raise InapplicableConstraint("dynamic_split_k_max_size_bytes must be > 0.") - # dynamic_split_k_max_split_k - If specified, must be at least 1 - if "dynamic_split_k_max_split_k" in _opt_flags_constraints and _opt_flags_constraints["dynamic_split_k_max_split_k"] < 1: - raise InapplicableConstraint("dynamic_split_k_max_split_k must be at least 1 if specified") + if _opt_flags_constraints.get("max_allowable_mn"): + if not _opt_flags_constraints.get("split_k"): + raise InapplicableConstraint("split_k also needs to be provided with max_allowable_mn") enforce_bitwise_invariance = precision_config.enforce_bitwise_invariance if _opt_flags is not None: assert not _opt_flags_constraints From 6967d01979b78d851cc6e73fcfca9aa0e9a4aa5f Mon Sep 17 00:00:00 2001 From: Afroz Mohiuddin Date: Fri, 10 Oct 2025 10:05:09 -0700 Subject: [PATCH 08/13] . --- .../triton_kernels/matmul_ogs_details/opt_flags.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py index 747cc7a46aee..971b1cc0d741 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py @@ -8,10 +8,6 @@ import torch from .opt_flags_details import opt_flags_amd, opt_flags_nvidia from triton_kernels.tensor import bitwidth -from typing import Callable - -# Function type: takes four ints (batch_size, m, n, k) and a output dtype, returns an int -CallableSplitK = Callable[[int, int, int, int, torch.dtype], int] @dataclass @@ -174,8 +170,6 @@ def replace_with_valid_constraint(k: str, v): target_kernel_kwargs=target_kernel_kwargs, ) # check constraints - # TODO(afroz): Update this later. - # assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}" all_constraints_satisfied(ret, constraints) return ret From 5276531468262bc4791e311ab9d4f8591b018950 Mon Sep 17 00:00:00 2001 From: Afroz Mohiuddin Date: Fri, 10 Oct 2025 10:06:55 -0700 Subject: [PATCH 09/13] . --- .../triton_kernels/matmul_ogs_details/opt_flags.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py index 971b1cc0d741..c3fd88db9404 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py @@ -191,8 +191,7 @@ def make_default_opt_flags_nvidia( has_y_acc_in, constraints, ): - constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages", - "idle_sms", "max_allowable_mn"] + constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages", "idle_sms", "max_allowable_mn"] assert not any([c not in constraints_supported for c in constraints]), constraints.keys() # tokens per expert if routing_data is None or batch_size > 1: From f5b5ebbad657820c8b9dda3ea5d13e48ac730533 Mon Sep 17 00:00:00 2001 From: Afroz Mohiuddin Date: Fri, 10 Oct 2025 11:02:04 -0700 Subject: [PATCH 10/13] More and better tests --- .../tests/test_opt_flags_split_k.py | 71 +++++++++++++++---- .../matmul_ogs_details/opt_flags.py | 5 +- 2 files changed, 59 insertions(+), 17 deletions(-) diff --git a/python/triton_kernels/tests/test_opt_flags_split_k.py b/python/triton_kernels/tests/test_opt_flags_split_k.py index a85480c1513c..2eb2db977023 100644 --- a/python/triton_kernels/tests/test_opt_flags_split_k.py +++ b/python/triton_kernels/tests/test_opt_flags_split_k.py @@ -1,10 +1,10 @@ # isort: off # fmt: off -import types -from typing import Callable import pytest +import types -torch = pytest.importorskip("torch") +import torch +import triton import triton_kernels.matmul_ogs_details.opt_flags as opt_flags @@ -31,6 +31,12 @@ def setup_amd(monkeypatch): lambda *args, **kwargs: (64, 32), ) + fake_target = types.SimpleNamespace(backend="hip") + monkeypatch.setattr( + "triton.runtime.driver.active.get_current_target", + lambda: fake_target, + ) + def setup_nvidia(monkeypatch): monkeypatch.setattr(opt_flags.torch.cuda, "get_device_properties", _stub_cuda_props) @@ -66,6 +72,12 @@ def setup_nvidia(monkeypatch): lambda block_m, block_n, is_persistent, precision_config: 4, ) + fake_target = types.SimpleNamespace(backend="cuda") + monkeypatch.setattr( + "triton.runtime.driver.active.get_current_target", + lambda: fake_target, + ) + def test_make_default_opt_flags_amd_split_k_constraint(monkeypatch): setup_amd(monkeypatch) @@ -118,17 +130,52 @@ def test_make_default_opt_flags_nvidia_split_k_constraint(monkeypatch): assert flags.split_k == 3 +def test_max_allowable_mn_and_split_k_constraints(monkeypatch): + setup_nvidia(monkeypatch) + + opt_flags._opt_flags = None + opt_flags.reset_opt_flags_constraints() + opt_flags.update_opt_flags_constraints( + { + "max_allowable_mn": 256, + # Without split_k, this should raise an error + } + ) + + with pytest.raises(opt_flags.InapplicableConstraint): + opt_flags.make_opt_flags( + torch.float16, + torch.float16, + torch.float16, + _DummyPrecisionConfig(), + 1, + 256, + 256, + 256, + None, + False, + False, + False, + 0, + False, + None, + ) def test_max_allowable_mn(monkeypatch): setup_nvidia(monkeypatch) batch_size, m, n, k = 1, 256, 256, 256 - bytes_float16 = 2 - intermediate_size = batch_size * m * n * bytes_float16 - def get_flags(split_k, max_mn): - return opt_flags.make_default_opt_flags_nvidia( + opt_flags._opt_flags = None + opt_flags.reset_opt_flags_constraints() + opt_flags.update_opt_flags_constraints( + { + "split_k": split_k, + "max_allowable_mn": max_mn, + } + ) + return opt_flags.make_opt_flags( torch.float16, torch.float16, torch.float16, @@ -143,19 +190,17 @@ def get_flags(split_k, max_mn): False, 0, False, - False, - { - "split_k": split_k, - "max_allowable_mn": max_mn, - }, + None, ) split_k = 6 + # Allowable mn is less than actual mn, so split_k should be set to 1 max_mn = (m * n) // 2 flags = get_flags(split_k, max_mn) assert flags.split_k == 1 split_k = 6 + # Allowable mn is more than actual mn, so split_k should be unchanged max_mn = (m * n) * 2 flags = get_flags(split_k, max_mn) - assert flags.split_k == split_k + assert flags.split_k == split_k \ No newline at end of file diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py index c3fd88db9404..72e0998f5c5f 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py @@ -39,10 +39,7 @@ def max_allowable_mn( n: int, split_k: int, ) -> int: - if m * n >= max_mn: - return 1 - - return split_k + return 1 if m * n >= max_mn else split_k def all_constraints_satisfied(opt_flags: OptFlags, constraints: dict) -> bool: From 7fc4959bdecdfea1b9e082a3a04e70a65a389e54 Mon Sep 17 00:00:00 2001 From: Afroz Mohiuddin Date: Fri, 10 Oct 2025 11:23:38 -0700 Subject: [PATCH 11/13] . --- python/triton_kernels/tests/test_opt_flags_split_k.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/triton_kernels/tests/test_opt_flags_split_k.py b/python/triton_kernels/tests/test_opt_flags_split_k.py index 2eb2db977023..95906c0ce520 100644 --- a/python/triton_kernels/tests/test_opt_flags_split_k.py +++ b/python/triton_kernels/tests/test_opt_flags_split_k.py @@ -203,4 +203,4 @@ def get_flags(split_k, max_mn): # Allowable mn is more than actual mn, so split_k should be unchanged max_mn = (m * n) * 2 flags = get_flags(split_k, max_mn) - assert flags.split_k == split_k \ No newline at end of file + assert flags.split_k == split_k From 899749f3ac5212b7c4793ebf822ed6f45cad8b48 Mon Sep 17 00:00:00 2001 From: Afroz Mohiuddin Date: Fri, 10 Oct 2025 11:31:21 -0700 Subject: [PATCH 12/13] , --- python/triton_kernels/tests/test_opt_flags_split_k.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/triton_kernels/tests/test_opt_flags_split_k.py b/python/triton_kernels/tests/test_opt_flags_split_k.py index 95906c0ce520..d26a81ab3641 100644 --- a/python/triton_kernels/tests/test_opt_flags_split_k.py +++ b/python/triton_kernels/tests/test_opt_flags_split_k.py @@ -4,7 +4,6 @@ import types import torch -import triton import triton_kernels.matmul_ogs_details.opt_flags as opt_flags From 30f4e69f216577e001c4aba18ac5cd060ddcc8b7 Mon Sep 17 00:00:00 2001 From: Afroz Mohiuddin Date: Fri, 10 Oct 2025 13:05:09 -0700 Subject: [PATCH 13/13] Move test_opt_flags_split_k.py into a new test_matmul_details folder --- .../tests/{ => test_matmul_details}/test_opt_flags_split_k.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename python/triton_kernels/tests/{ => test_matmul_details}/test_opt_flags_split_k.py (100%) diff --git a/python/triton_kernels/tests/test_opt_flags_split_k.py b/python/triton_kernels/tests/test_matmul_details/test_opt_flags_split_k.py similarity index 100% rename from python/triton_kernels/tests/test_opt_flags_split_k.py rename to python/triton_kernels/tests/test_matmul_details/test_opt_flags_split_k.py