Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# isort: off
# fmt: off
import pytest
import types

import torch

import triton_kernels.matmul_ogs_details.opt_flags as opt_flags


class _DummyPrecisionConfig:
def __init__(self):
self.weight_scale = None
self.max_num_imprecise_acc = None
self.act_scale = None
self.out_scale = None
self.enforce_bitwise_invariance = False


def _stub_cuda_props(*_args, **_kwargs):
return types.SimpleNamespace(multi_processor_count=16)


def setup_amd(monkeypatch):
monkeypatch.setattr(opt_flags, "get_cdna_version", lambda: 3)
monkeypatch.setattr(opt_flags.torch.cuda, "get_device_properties", _stub_cuda_props)
monkeypatch.setattr(
opt_flags.opt_flags_amd,
"compute_block_nk",
lambda *args, **kwargs: (64, 32),
)

fake_target = types.SimpleNamespace(backend="hip")
monkeypatch.setattr(
"triton.runtime.driver.active.get_current_target",
lambda: fake_target,
)


def setup_nvidia(monkeypatch):
monkeypatch.setattr(opt_flags.torch.cuda, "get_device_properties", _stub_cuda_props)
monkeypatch.setattr(opt_flags.torch.cuda, "get_device_capability", lambda: (9, 0))
monkeypatch.setattr(
opt_flags.opt_flags_nvidia,
"compute_block_n",
lambda n, arch, precision_config: (64, 32),
)
monkeypatch.setattr(
opt_flags.opt_flags_nvidia,
"compute_grid_size",
lambda routing_data, batch_size, m, n, block_m, block_n: 4,
)
monkeypatch.setattr(
opt_flags.opt_flags_nvidia,
"compute_block_k",
lambda m, k, is_persistent, lhs_dtype, rhs_dtype, precision_config, has_y_acc_in: 32,
)
monkeypatch.setattr(
opt_flags.opt_flags_nvidia,
"compute_split_k",
lambda block_k, k, estimated_actual_grid_size: 1,
)
monkeypatch.setattr(
opt_flags.opt_flags_nvidia,
"compute_num_stages",
lambda *args, **kwargs: 2,
)
monkeypatch.setattr(
opt_flags.opt_flags_nvidia,
"compute_num_warps",
lambda block_m, block_n, is_persistent, precision_config: 4,
)

fake_target = types.SimpleNamespace(backend="cuda")
monkeypatch.setattr(
"triton.runtime.driver.active.get_current_target",
lambda: fake_target,
)


def test_make_default_opt_flags_amd_split_k_constraint(monkeypatch):
setup_amd(monkeypatch)

precision_config = _DummyPrecisionConfig()
flags = opt_flags.make_default_opt_flags_amd(
torch.float16,
torch.float16,
torch.float16,
precision_config,
2,
128,
64,
32,
None,
False,
False,
False,
0,
False,
False,
{"split_k": 5},
)

assert flags.split_k == 5


def test_make_default_opt_flags_nvidia_split_k_constraint(monkeypatch):
setup_nvidia(monkeypatch)

precision_config = _DummyPrecisionConfig()
flags = opt_flags.make_default_opt_flags_nvidia(
torch.float16,
torch.float16,
torch.float16,
precision_config,
4,
256,
128,
64,
None,
False,
False,
False,
0,
False,
False,
{"split_k": 3},
)

assert flags.split_k == 3

def test_max_allowable_mn_and_split_k_constraints(monkeypatch):
setup_nvidia(monkeypatch)

opt_flags._opt_flags = None
opt_flags.reset_opt_flags_constraints()
opt_flags.update_opt_flags_constraints(
{
"max_allowable_mn": 256,
# Without split_k, this should raise an error
}
)

with pytest.raises(opt_flags.InapplicableConstraint):
opt_flags.make_opt_flags(
torch.float16,
torch.float16,
torch.float16,
_DummyPrecisionConfig(),
1,
256,
256,
256,
None,
False,
False,
False,
0,
False,
None,
)

def test_max_allowable_mn(monkeypatch):
setup_nvidia(monkeypatch)

batch_size, m, n, k = 1, 256, 256, 256

def get_flags(split_k, max_mn):
opt_flags._opt_flags = None
opt_flags.reset_opt_flags_constraints()
opt_flags.update_opt_flags_constraints(
{
"split_k": split_k,
"max_allowable_mn": max_mn,
}
)
return opt_flags.make_opt_flags(
torch.float16,
torch.float16,
torch.float16,
_DummyPrecisionConfig(),
batch_size,
m,
n,
k,
None,
False,
False,
False,
0,
False,
None,
)

split_k = 6
# Allowable mn is less than actual mn, so split_k should be set to 1
max_mn = (m * n) // 2
flags = get_flags(split_k, max_mn)
assert flags.split_k == 1

split_k = 6
# Allowable mn is more than actual mn, so split_k should be unchanged
max_mn = (m * n) * 2
flags = get_flags(split_k, max_mn)
assert flags.split_k == split_k
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,22 @@ def __post_init__(self):
raise ValueError("Not supported")


def max_allowable_mn(
max_mn: int,
m: int,
n: int,
split_k: int,
) -> int:
return 1 if m * n >= max_mn else split_k


def all_constraints_satisfied(opt_flags: OptFlags, constraints: dict) -> bool:
_split_k_constraints = ['split_k', 'max_allowable_mn']
assert all(getattr(opt_flags, ck) == cv for ck, cv in constraints.items() if cv is not None and ck not in _split_k_constraints)
if constraints.get('split_k') and not constraints.get('max_allowable_mn'):
assert opt_flags.split_k == constraints['split_k']


def make_default_opt_flags_amd(
out_dtype,
lhs_dtype,
Expand All @@ -51,7 +67,7 @@ def make_default_opt_flags_amd(
has_y_acc_in,
constraints,
):
constraints_supported = ["block_m", "block_n", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile"]
constraints_supported = ["block_m", "block_n", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile", "max_allowable_mn"]
assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
# tokens per expert
if routing_data is None:
Expand Down Expand Up @@ -90,7 +106,9 @@ def make_default_opt_flags_amd(
)
is_persistent = constraints.get("is_persistent", False)
# split_k:
if constraints.get("split_k", None) is not None:
if constraints.get("max_allowable_mn", 0) > 0 and constraints.get("split_k") is not None:
split_k = max_allowable_mn(constraints["max_allowable_mn"], m, n, constraints.get("split_k"))
elif constraints.get("split_k", None) is not None:
split_k = constraints["split_k"]
elif is_persistent or enforce_bitwise_invariance:
split_k = 1
Expand Down Expand Up @@ -149,7 +167,7 @@ def replace_with_valid_constraint(k: str, v):
target_kernel_kwargs=target_kernel_kwargs,
)
# check constraints
assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}"
all_constraints_satisfied(ret, constraints)
return ret

def make_default_opt_flags_nvidia(
Expand All @@ -170,7 +188,7 @@ def make_default_opt_flags_nvidia(
has_y_acc_in,
constraints,
):
constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages", "idle_sms"]
constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages", "idle_sms", "max_allowable_mn"]
assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
# tokens per expert
if routing_data is None or batch_size > 1:
Expand Down Expand Up @@ -220,7 +238,9 @@ def make_default_opt_flags_nvidia(
# TODO: swizzle the HBM layout of the weights instead
block_n, block_k = block_k, block_n
# split_k
if constraints.get("split_k", None) is not None:
if constraints.get("max_allowable_mn", 0) > 0 and constraints.get("split_k") is not None:
split_k = max_allowable_mn(constraints["max_allowable_mn"], m, n, constraints.get("split_k"))
elif constraints.get("split_k", None) is not None:
split_k = constraints["split_k"]
elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None:
split_k = 1
Expand Down Expand Up @@ -283,7 +303,7 @@ def make_default_opt_flags_nvidia(
idle_sms=constraints.get("idle_sms", 0),
)
# check constraints
assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}"
all_constraints_satisfied(ret, constraints)
return ret

# --------------
Expand Down Expand Up @@ -331,6 +351,9 @@ def make_opt_flags(
raise InapplicableConstraint("cannot enforce `is_persistent=True` constraint")
if _opt_flags_constraints.get("fused_scatter", False) and not can_use_fused_scatter:
raise InapplicableConstraint("cannot enforce `fused_scatter=True` constraint")
if _opt_flags_constraints.get("max_allowable_mn"):
if not _opt_flags_constraints.get("split_k"):
raise InapplicableConstraint("split_k also needs to be provided with max_allowable_mn")
enforce_bitwise_invariance = precision_config.enforce_bitwise_invariance
if _opt_flags is not None:
assert not _opt_flags_constraints
Expand Down
Loading