diff --git a/python/triton_kernels/tests/test_matmul_details/test_opt_flags_split_k.py b/python/triton_kernels/tests/test_matmul_details/test_opt_flags_split_k.py new file mode 100644 index 000000000000..d26a81ab3641 --- /dev/null +++ b/python/triton_kernels/tests/test_matmul_details/test_opt_flags_split_k.py @@ -0,0 +1,205 @@ +# isort: off +# fmt: off +import pytest +import types + +import torch + +import triton_kernels.matmul_ogs_details.opt_flags as opt_flags + + +class _DummyPrecisionConfig: + def __init__(self): + self.weight_scale = None + self.max_num_imprecise_acc = None + self.act_scale = None + self.out_scale = None + self.enforce_bitwise_invariance = False + + +def _stub_cuda_props(*_args, **_kwargs): + return types.SimpleNamespace(multi_processor_count=16) + + +def setup_amd(monkeypatch): + monkeypatch.setattr(opt_flags, "get_cdna_version", lambda: 3) + monkeypatch.setattr(opt_flags.torch.cuda, "get_device_properties", _stub_cuda_props) + monkeypatch.setattr( + opt_flags.opt_flags_amd, + "compute_block_nk", + lambda *args, **kwargs: (64, 32), + ) + + fake_target = types.SimpleNamespace(backend="hip") + monkeypatch.setattr( + "triton.runtime.driver.active.get_current_target", + lambda: fake_target, + ) + + +def setup_nvidia(monkeypatch): + monkeypatch.setattr(opt_flags.torch.cuda, "get_device_properties", _stub_cuda_props) + monkeypatch.setattr(opt_flags.torch.cuda, "get_device_capability", lambda: (9, 0)) + monkeypatch.setattr( + opt_flags.opt_flags_nvidia, + "compute_block_n", + lambda n, arch, precision_config: (64, 32), + ) + monkeypatch.setattr( + opt_flags.opt_flags_nvidia, + "compute_grid_size", + lambda routing_data, batch_size, m, n, block_m, block_n: 4, + ) + monkeypatch.setattr( + opt_flags.opt_flags_nvidia, + "compute_block_k", + lambda m, k, is_persistent, lhs_dtype, rhs_dtype, precision_config, has_y_acc_in: 32, + ) + monkeypatch.setattr( + opt_flags.opt_flags_nvidia, + "compute_split_k", + lambda block_k, k, estimated_actual_grid_size: 1, + ) + monkeypatch.setattr( + opt_flags.opt_flags_nvidia, + "compute_num_stages", + lambda *args, **kwargs: 2, + ) + monkeypatch.setattr( + opt_flags.opt_flags_nvidia, + "compute_num_warps", + lambda block_m, block_n, is_persistent, precision_config: 4, + ) + + fake_target = types.SimpleNamespace(backend="cuda") + monkeypatch.setattr( + "triton.runtime.driver.active.get_current_target", + lambda: fake_target, + ) + + +def test_make_default_opt_flags_amd_split_k_constraint(monkeypatch): + setup_amd(monkeypatch) + + precision_config = _DummyPrecisionConfig() + flags = opt_flags.make_default_opt_flags_amd( + torch.float16, + torch.float16, + torch.float16, + precision_config, + 2, + 128, + 64, + 32, + None, + False, + False, + False, + 0, + False, + False, + {"split_k": 5}, + ) + + assert flags.split_k == 5 + + +def test_make_default_opt_flags_nvidia_split_k_constraint(monkeypatch): + setup_nvidia(monkeypatch) + + precision_config = _DummyPrecisionConfig() + flags = opt_flags.make_default_opt_flags_nvidia( + torch.float16, + torch.float16, + torch.float16, + precision_config, + 4, + 256, + 128, + 64, + None, + False, + False, + False, + 0, + False, + False, + {"split_k": 3}, + ) + + assert flags.split_k == 3 + +def test_max_allowable_mn_and_split_k_constraints(monkeypatch): + setup_nvidia(monkeypatch) + + opt_flags._opt_flags = None + opt_flags.reset_opt_flags_constraints() + opt_flags.update_opt_flags_constraints( + { + "max_allowable_mn": 256, + # Without split_k, this should raise an error + } + ) + + with pytest.raises(opt_flags.InapplicableConstraint): + opt_flags.make_opt_flags( + torch.float16, + torch.float16, + torch.float16, + _DummyPrecisionConfig(), + 1, + 256, + 256, + 256, + None, + False, + False, + False, + 0, + False, + None, + ) + +def test_max_allowable_mn(monkeypatch): + setup_nvidia(monkeypatch) + + batch_size, m, n, k = 1, 256, 256, 256 + + def get_flags(split_k, max_mn): + opt_flags._opt_flags = None + opt_flags.reset_opt_flags_constraints() + opt_flags.update_opt_flags_constraints( + { + "split_k": split_k, + "max_allowable_mn": max_mn, + } + ) + return opt_flags.make_opt_flags( + torch.float16, + torch.float16, + torch.float16, + _DummyPrecisionConfig(), + batch_size, + m, + n, + k, + None, + False, + False, + False, + 0, + False, + None, + ) + + split_k = 6 + # Allowable mn is less than actual mn, so split_k should be set to 1 + max_mn = (m * n) // 2 + flags = get_flags(split_k, max_mn) + assert flags.split_k == 1 + + split_k = 6 + # Allowable mn is more than actual mn, so split_k should be unchanged + max_mn = (m * n) * 2 + flags = get_flags(split_k, max_mn) + assert flags.split_k == split_k diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py index 9f39a9464c97..72e0998f5c5f 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py @@ -33,6 +33,22 @@ def __post_init__(self): raise ValueError("Not supported") +def max_allowable_mn( + max_mn: int, + m: int, + n: int, + split_k: int, + ) -> int: + return 1 if m * n >= max_mn else split_k + + +def all_constraints_satisfied(opt_flags: OptFlags, constraints: dict) -> bool: + _split_k_constraints = ['split_k', 'max_allowable_mn'] + assert all(getattr(opt_flags, ck) == cv for ck, cv in constraints.items() if cv is not None and ck not in _split_k_constraints) + if constraints.get('split_k') and not constraints.get('max_allowable_mn'): + assert opt_flags.split_k == constraints['split_k'] + + def make_default_opt_flags_amd( out_dtype, lhs_dtype, @@ -51,7 +67,7 @@ def make_default_opt_flags_amd( has_y_acc_in, constraints, ): - constraints_supported = ["block_m", "block_n", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile"] + constraints_supported = ["block_m", "block_n", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile", "max_allowable_mn"] assert not any([c not in constraints_supported for c in constraints]), constraints.keys() # tokens per expert if routing_data is None: @@ -90,7 +106,9 @@ def make_default_opt_flags_amd( ) is_persistent = constraints.get("is_persistent", False) # split_k: - if constraints.get("split_k", None) is not None: + if constraints.get("max_allowable_mn", 0) > 0 and constraints.get("split_k") is not None: + split_k = max_allowable_mn(constraints["max_allowable_mn"], m, n, constraints.get("split_k")) + elif constraints.get("split_k", None) is not None: split_k = constraints["split_k"] elif is_persistent or enforce_bitwise_invariance: split_k = 1 @@ -149,7 +167,7 @@ def replace_with_valid_constraint(k: str, v): target_kernel_kwargs=target_kernel_kwargs, ) # check constraints - assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}" + all_constraints_satisfied(ret, constraints) return ret def make_default_opt_flags_nvidia( @@ -170,7 +188,7 @@ def make_default_opt_flags_nvidia( has_y_acc_in, constraints, ): - constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages", "idle_sms"] + constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages", "idle_sms", "max_allowable_mn"] assert not any([c not in constraints_supported for c in constraints]), constraints.keys() # tokens per expert if routing_data is None or batch_size > 1: @@ -220,7 +238,9 @@ def make_default_opt_flags_nvidia( # TODO: swizzle the HBM layout of the weights instead block_n, block_k = block_k, block_n # split_k - if constraints.get("split_k", None) is not None: + if constraints.get("max_allowable_mn", 0) > 0 and constraints.get("split_k") is not None: + split_k = max_allowable_mn(constraints["max_allowable_mn"], m, n, constraints.get("split_k")) + elif constraints.get("split_k", None) is not None: split_k = constraints["split_k"] elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None: split_k = 1 @@ -283,7 +303,7 @@ def make_default_opt_flags_nvidia( idle_sms=constraints.get("idle_sms", 0), ) # check constraints - assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}" + all_constraints_satisfied(ret, constraints) return ret # -------------- @@ -331,6 +351,9 @@ def make_opt_flags( raise InapplicableConstraint("cannot enforce `is_persistent=True` constraint") if _opt_flags_constraints.get("fused_scatter", False) and not can_use_fused_scatter: raise InapplicableConstraint("cannot enforce `fused_scatter=True` constraint") + if _opt_flags_constraints.get("max_allowable_mn"): + if not _opt_flags_constraints.get("split_k"): + raise InapplicableConstraint("split_k also needs to be provided with max_allowable_mn") enforce_bitwise_invariance = precision_config.enforce_bitwise_invariance if _opt_flags is not None: assert not _opt_flags_constraints