Skip to content

Commit 4e0d041

Browse files
authored
[TRITON_KERNELS] constraint on split_k on m * n (#8404)
Update constraints dict to accept callable for split_k and invoke that callable to determine split_k. Spoke to Philippe Tillet offline, there are cases where we need to decide `split_k` dynamically based on the shapes of the operands. Notably, in our case we want batch size invariance, but want to restrict the amount of memory that's otherwise allocated, i.e. `[split_k, m, n]` allocation for the scratch memory to do the `bmk,kn->bmn` multiplication, so pass that as an additional constraint `max_allowable_mn` which drop down to split_k = 1 if `m * n > max_allowable_mn`. <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [] This PR does not need a test because `ADDED`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent faa5033 commit 4e0d041

File tree

2 files changed

+234
-6
lines changed

2 files changed

+234
-6
lines changed
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# isort: off
2+
# fmt: off
3+
import pytest
4+
import types
5+
6+
import torch
7+
8+
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
9+
10+
11+
class _DummyPrecisionConfig:
12+
def __init__(self):
13+
self.weight_scale = None
14+
self.max_num_imprecise_acc = None
15+
self.act_scale = None
16+
self.out_scale = None
17+
self.enforce_bitwise_invariance = False
18+
19+
20+
def _stub_cuda_props(*_args, **_kwargs):
21+
return types.SimpleNamespace(multi_processor_count=16)
22+
23+
24+
def setup_amd(monkeypatch):
25+
monkeypatch.setattr(opt_flags, "get_cdna_version", lambda: 3)
26+
monkeypatch.setattr(opt_flags.torch.cuda, "get_device_properties", _stub_cuda_props)
27+
monkeypatch.setattr(
28+
opt_flags.opt_flags_amd,
29+
"compute_block_nk",
30+
lambda *args, **kwargs: (64, 32),
31+
)
32+
33+
fake_target = types.SimpleNamespace(backend="hip")
34+
monkeypatch.setattr(
35+
"triton.runtime.driver.active.get_current_target",
36+
lambda: fake_target,
37+
)
38+
39+
40+
def setup_nvidia(monkeypatch):
41+
monkeypatch.setattr(opt_flags.torch.cuda, "get_device_properties", _stub_cuda_props)
42+
monkeypatch.setattr(opt_flags.torch.cuda, "get_device_capability", lambda: (9, 0))
43+
monkeypatch.setattr(
44+
opt_flags.opt_flags_nvidia,
45+
"compute_block_n",
46+
lambda n, arch, precision_config: (64, 32),
47+
)
48+
monkeypatch.setattr(
49+
opt_flags.opt_flags_nvidia,
50+
"compute_grid_size",
51+
lambda routing_data, batch_size, m, n, block_m, block_n: 4,
52+
)
53+
monkeypatch.setattr(
54+
opt_flags.opt_flags_nvidia,
55+
"compute_block_k",
56+
lambda m, k, is_persistent, lhs_dtype, rhs_dtype, precision_config, has_y_acc_in: 32,
57+
)
58+
monkeypatch.setattr(
59+
opt_flags.opt_flags_nvidia,
60+
"compute_split_k",
61+
lambda block_k, k, estimated_actual_grid_size: 1,
62+
)
63+
monkeypatch.setattr(
64+
opt_flags.opt_flags_nvidia,
65+
"compute_num_stages",
66+
lambda *args, **kwargs: 2,
67+
)
68+
monkeypatch.setattr(
69+
opt_flags.opt_flags_nvidia,
70+
"compute_num_warps",
71+
lambda block_m, block_n, is_persistent, precision_config: 4,
72+
)
73+
74+
fake_target = types.SimpleNamespace(backend="cuda")
75+
monkeypatch.setattr(
76+
"triton.runtime.driver.active.get_current_target",
77+
lambda: fake_target,
78+
)
79+
80+
81+
def test_make_default_opt_flags_amd_split_k_constraint(monkeypatch):
82+
setup_amd(monkeypatch)
83+
84+
precision_config = _DummyPrecisionConfig()
85+
flags = opt_flags.make_default_opt_flags_amd(
86+
torch.float16,
87+
torch.float16,
88+
torch.float16,
89+
precision_config,
90+
2,
91+
128,
92+
64,
93+
32,
94+
None,
95+
False,
96+
False,
97+
False,
98+
0,
99+
False,
100+
False,
101+
{"split_k": 5},
102+
)
103+
104+
assert flags.split_k == 5
105+
106+
107+
def test_make_default_opt_flags_nvidia_split_k_constraint(monkeypatch):
108+
setup_nvidia(monkeypatch)
109+
110+
precision_config = _DummyPrecisionConfig()
111+
flags = opt_flags.make_default_opt_flags_nvidia(
112+
torch.float16,
113+
torch.float16,
114+
torch.float16,
115+
precision_config,
116+
4,
117+
256,
118+
128,
119+
64,
120+
None,
121+
False,
122+
False,
123+
False,
124+
0,
125+
False,
126+
False,
127+
{"split_k": 3},
128+
)
129+
130+
assert flags.split_k == 3
131+
132+
def test_max_allowable_mn_and_split_k_constraints(monkeypatch):
133+
setup_nvidia(monkeypatch)
134+
135+
opt_flags._opt_flags = None
136+
opt_flags.reset_opt_flags_constraints()
137+
opt_flags.update_opt_flags_constraints(
138+
{
139+
"max_allowable_mn": 256,
140+
# Without split_k, this should raise an error
141+
}
142+
)
143+
144+
with pytest.raises(opt_flags.InapplicableConstraint):
145+
opt_flags.make_opt_flags(
146+
torch.float16,
147+
torch.float16,
148+
torch.float16,
149+
_DummyPrecisionConfig(),
150+
1,
151+
256,
152+
256,
153+
256,
154+
None,
155+
False,
156+
False,
157+
False,
158+
0,
159+
False,
160+
None,
161+
)
162+
163+
def test_max_allowable_mn(monkeypatch):
164+
setup_nvidia(monkeypatch)
165+
166+
batch_size, m, n, k = 1, 256, 256, 256
167+
168+
def get_flags(split_k, max_mn):
169+
opt_flags._opt_flags = None
170+
opt_flags.reset_opt_flags_constraints()
171+
opt_flags.update_opt_flags_constraints(
172+
{
173+
"split_k": split_k,
174+
"max_allowable_mn": max_mn,
175+
}
176+
)
177+
return opt_flags.make_opt_flags(
178+
torch.float16,
179+
torch.float16,
180+
torch.float16,
181+
_DummyPrecisionConfig(),
182+
batch_size,
183+
m,
184+
n,
185+
k,
186+
None,
187+
False,
188+
False,
189+
False,
190+
0,
191+
False,
192+
None,
193+
)
194+
195+
split_k = 6
196+
# Allowable mn is less than actual mn, so split_k should be set to 1
197+
max_mn = (m * n) // 2
198+
flags = get_flags(split_k, max_mn)
199+
assert flags.split_k == 1
200+
201+
split_k = 6
202+
# Allowable mn is more than actual mn, so split_k should be unchanged
203+
max_mn = (m * n) * 2
204+
flags = get_flags(split_k, max_mn)
205+
assert flags.split_k == split_k

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,22 @@ def __post_init__(self):
3333
raise ValueError("Not supported")
3434

3535

36+
def max_allowable_mn(
37+
max_mn: int,
38+
m: int,
39+
n: int,
40+
split_k: int,
41+
) -> int:
42+
return 1 if m * n >= max_mn else split_k
43+
44+
45+
def all_constraints_satisfied(opt_flags: OptFlags, constraints: dict) -> bool:
46+
_split_k_constraints = ['split_k', 'max_allowable_mn']
47+
assert all(getattr(opt_flags, ck) == cv for ck, cv in constraints.items() if cv is not None and ck not in _split_k_constraints)
48+
if constraints.get('split_k') and not constraints.get('max_allowable_mn'):
49+
assert opt_flags.split_k == constraints['split_k']
50+
51+
3652
def make_default_opt_flags_amd(
3753
out_dtype,
3854
lhs_dtype,
@@ -51,7 +67,7 @@ def make_default_opt_flags_amd(
5167
has_y_acc_in,
5268
constraints,
5369
):
54-
constraints_supported = ["block_m", "block_n", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile"]
70+
constraints_supported = ["block_m", "block_n", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile", "max_allowable_mn"]
5571
assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
5672
# tokens per expert
5773
if routing_data is None:
@@ -90,7 +106,9 @@ def make_default_opt_flags_amd(
90106
)
91107
is_persistent = constraints.get("is_persistent", False)
92108
# split_k:
93-
if constraints.get("split_k", None) is not None:
109+
if constraints.get("max_allowable_mn", 0) > 0 and constraints.get("split_k") is not None:
110+
split_k = max_allowable_mn(constraints["max_allowable_mn"], m, n, constraints.get("split_k"))
111+
elif constraints.get("split_k", None) is not None:
94112
split_k = constraints["split_k"]
95113
elif is_persistent or enforce_bitwise_invariance:
96114
split_k = 1
@@ -149,7 +167,7 @@ def replace_with_valid_constraint(k: str, v):
149167
target_kernel_kwargs=target_kernel_kwargs,
150168
)
151169
# check constraints
152-
assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}"
170+
all_constraints_satisfied(ret, constraints)
153171
return ret
154172

155173
def make_default_opt_flags_nvidia(
@@ -170,7 +188,7 @@ def make_default_opt_flags_nvidia(
170188
has_y_acc_in,
171189
constraints,
172190
):
173-
constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages", "idle_sms"]
191+
constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages", "idle_sms", "max_allowable_mn"]
174192
assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
175193
# tokens per expert
176194
if routing_data is None or batch_size > 1:
@@ -220,7 +238,9 @@ def make_default_opt_flags_nvidia(
220238
# TODO: swizzle the HBM layout of the weights instead
221239
block_n, block_k = block_k, block_n
222240
# split_k
223-
if constraints.get("split_k", None) is not None:
241+
if constraints.get("max_allowable_mn", 0) > 0 and constraints.get("split_k") is not None:
242+
split_k = max_allowable_mn(constraints["max_allowable_mn"], m, n, constraints.get("split_k"))
243+
elif constraints.get("split_k", None) is not None:
224244
split_k = constraints["split_k"]
225245
elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None:
226246
split_k = 1
@@ -283,7 +303,7 @@ def make_default_opt_flags_nvidia(
283303
idle_sms=constraints.get("idle_sms", 0),
284304
)
285305
# check constraints
286-
assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}"
306+
all_constraints_satisfied(ret, constraints)
287307
return ret
288308

289309
# --------------
@@ -331,6 +351,9 @@ def make_opt_flags(
331351
raise InapplicableConstraint("cannot enforce `is_persistent=True` constraint")
332352
if _opt_flags_constraints.get("fused_scatter", False) and not can_use_fused_scatter:
333353
raise InapplicableConstraint("cannot enforce `fused_scatter=True` constraint")
354+
if _opt_flags_constraints.get("max_allowable_mn"):
355+
if not _opt_flags_constraints.get("split_k"):
356+
raise InapplicableConstraint("split_k also needs to be provided with max_allowable_mn")
334357
enforce_bitwise_invariance = precision_config.enforce_bitwise_invariance
335358
if _opt_flags is not None:
336359
assert not _opt_flags_constraints

0 commit comments

Comments
 (0)