Skip to content

Commit 0993770

Browse files
committed
add triton-based milestone from @PaliC
for some reason, these are not compatible with T4; which gives us a nice test-case/motivation for allowing the exclusion of certain GPUs from milestones
1 parent 4d88498 commit 0993770

File tree

6 files changed

+159
-8
lines changed

6 files changed

+159
-8
lines changed

examples/matmul_py/task.yml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,15 @@ files:
99

1010
milestones:
1111
- {
12-
"name": "pytorch reference",
13-
"source": "submission.py",
14-
"description": "PyTorch reference implementation as a performance baseline for matmul"
12+
name: "pytorch",
13+
source: "submission.py",
14+
description: "PyTorch reference implementation as a performance baseline for matmul"
15+
}
16+
- {
17+
name: "triton",
18+
source: "triton_ref.py",
19+
description: "Triton reference implementation as a performance baseline for matmul",
20+
exclude_gpus: ['T4']
1521
}
1622

1723
lang: "py"

examples/matmul_py/triton_ref.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#!POPCORN leaderboard matmul_py
2+
import triton
3+
import triton.language as tl
4+
import torch
5+
from task import input_t, output_t
6+
7+
8+
@triton.jit
9+
def matmul_kernel(
10+
# Pointers to matrices
11+
a_ptr, b_ptr, c_ptr,
12+
# Matrix dimensions
13+
M, N, K,
14+
# The stride variables represent how much to increase the ptr by when moving by 1
15+
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
16+
# by to get the element one row down (A has M rows).
17+
stride_am, stride_ak,
18+
stride_bk, stride_bn,
19+
stride_cm, stride_cn,
20+
# Meta-parameters
21+
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
22+
GROUP_SIZE_M: tl.constexpr,
23+
):
24+
"""Kernel for computing the matmul C = A x B.
25+
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
26+
"""
27+
# -----------------------------------------------------------
28+
# Map program ids `pid` to the block of C it should compute.
29+
# This is done in a grouped ordering to promote L2 cache hit rates.
30+
# See above `L2 Cache Optimizations` section for details.
31+
pid = tl.program_id(axis=0)
32+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
33+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
34+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
35+
group_id = pid // num_pid_in_group
36+
first_pid_m = group_id * GROUP_SIZE_M
37+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
38+
pid_m = first_pid_m + (pid % group_size_m)
39+
pid_n = (pid % num_pid_in_group) // group_size_m
40+
41+
# ----------------------------------------------------------
42+
# Create pointers for the first blocks of A and B.
43+
# We will advance this pointer as we move in the K direction
44+
# and accumulate
45+
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
46+
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
47+
# See above `Pointer Arithmetic` section for details
48+
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
49+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
50+
offs_k = tl.arange(0, BLOCK_SIZE_K)
51+
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
52+
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
53+
54+
# -----------------------------------------------------------
55+
# Iterate to compute a block of the C matrix.
56+
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
57+
# of fp32 values for higher precision.
58+
# `accumulator` will be converted back to fp16 after the loop.
59+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
60+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
61+
# Load the next block of A and B, generate a mask by checking the K dimension.
62+
# If it is out of bounds, set it to 0.
63+
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
64+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
65+
# We accumulate along the K dimension.
66+
accumulator += tl.dot(a, b)
67+
# Advance the ptrs to the next K block.
68+
a_ptrs += BLOCK_SIZE_K * stride_ak
69+
b_ptrs += BLOCK_SIZE_K * stride_bk
70+
# You can fuse arbitrary activation functions here
71+
# while the accumulator is still in FP32!
72+
c = accumulator.to(tl.float16)
73+
74+
# -----------------------------------------------------------
75+
# Write back the block of the output matrix C with masks.
76+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
77+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
78+
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
79+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
80+
tl.store(c_ptrs, c, mask=c_mask)
81+
82+
83+
def triton_matmul(a, b):
84+
# Check constraints.
85+
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
86+
assert a.is_contiguous(), "Matrix A must be contiguous"
87+
assert b.is_contiguous(), "Matrix B must be contiguous"
88+
M, K = a.shape
89+
K, N = b.shape
90+
# Allocate output.
91+
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
92+
# 1D launch kernel where each block gets its own program.
93+
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
94+
matmul_kernel[grid](
95+
a, b, c,
96+
M, N, K,
97+
a.stride(0), a.stride(1),
98+
b.stride(0), b.stride(1),
99+
c.stride(0), c.stride(1),
100+
BLOCK_SIZE_M=128, BLOCK_SIZE_N=128, BLOCK_SIZE_K=32,
101+
GROUP_SIZE_M=8,
102+
)
103+
return c
104+
105+
106+
def custom_kernel(data: input_t) -> output_t:
107+
a, b = data
108+
# Convert to torch tensors if they aren't already
109+
if not isinstance(a, torch.Tensor):
110+
a = torch.tensor(a, dtype=torch.float16).cuda()
111+
if not isinstance(b, torch.Tensor):
112+
b = torch.tensor(b, dtype=torch.float16).cuda()
113+
114+
# Ensure tensors are on GPU and contiguous
115+
if not a.is_cuda:
116+
a = a.cuda()
117+
if not b.is_cuda:
118+
b = b.cuda()
119+
120+
a = a.contiguous()
121+
b = b.contiguous()
122+
123+
# Use our custom Triton matmul
124+
result = triton_matmul(a, b)
125+
126+
# Convert back to the expected output format
127+
return result

src/kernelbot/cogs/admin_cog.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@ async def create_leaderboard_in_db(
381381
milestone.name,
382382
milestone.code,
383383
description=milestone.description,
384+
exclude_gpus=milestone.exclude_gpus,
384385
)
385386
except KernelBotError as e:
386387
await send_discord_message(
@@ -457,10 +458,19 @@ async def submit_milestone(milestone, gpu, reporter):
457458
if gpu in [r["runner"] for r in existing_runs]:
458459
await send_discord_message(
459460
interaction,
460-
f"Skipping {gpu}; milestone run already exists.",
461+
f"Skipping {gpu} for {milestone['name']}; milestone run already exists.",
461462
ephemeral=True,
462463
)
463464
continue
465+
466+
if gpu in milestone["exclude_gpus"]:
467+
await send_discord_message(
468+
interaction,
469+
f"Skipping {gpu} for {milestone['name']}; is excluded.",
470+
ephemeral=True,
471+
)
472+
continue
473+
464474
submit_tasks.append(
465475
submit_milestone(
466476
milestone,

src/libkernelbot/db_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class MilestoneItem(TypedDict):
6262
code: str
6363
description: str
6464
created_at: datetime.datetime
65+
exclude_gpus: list[str]
6566

6667

6768
__all__ = [LeaderboardItem, LeaderboardRankedEntry, RunItem, SubmissionItem, MilestoneItem]

src/libkernelbot/leaderboard_db.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -250,18 +250,23 @@ def create_milestone(
250250
name: str,
251251
code: str,
252252
description: str = None,
253+
exclude_gpus: list[str] = None,
253254
) -> int:
254255
"""Create a new milestone for a leaderboard"""
256+
if exclude_gpus is None:
257+
exclude = ""
258+
else:
259+
exclude = str.join(";", exclude_gpus)
255260
try:
256261
self.cursor.execute(
257262
"""
258263
INSERT INTO leaderboard.milestones (
259-
leaderboard_id, name, code, description
264+
leaderboard_id, name, code, description, exclude_gpus
260265
)
261-
VALUES (%s, %s, %s, %s)
266+
VALUES (%s, %s, %s, %s, %s)
262267
RETURNING id
263268
""",
264-
(leaderboard_id, name, code, description),
269+
(leaderboard_id, name, code, description, exclude),
265270
)
266271
milestone_id = self.cursor.fetchone()[0]
267272
self.connection.commit()
@@ -275,7 +280,7 @@ def get_leaderboard_milestones(self, leaderboard_id: int) -> "list[MilestoneItem
275280
"""Get all milestones for a leaderboard"""
276281
self.cursor.execute(
277282
"""
278-
SELECT id, name, code, description, created_at
283+
SELECT id, name, code, description, created_at, exclude_gpus
279284
FROM leaderboard.milestones
280285
WHERE leaderboard_id = %s
281286
ORDER BY created_at
@@ -289,6 +294,7 @@ def get_leaderboard_milestones(self, leaderboard_id: int) -> "list[MilestoneItem
289294
"code": row[2],
290295
"description": row[3],
291296
"created_at": row[4],
297+
"exclude_gpus": str.split(row[5], ";"),
292298
}
293299
for row in self.cursor.fetchall()
294300
]

src/libkernelbot/task.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class MilestoneData:
2929
name: str
3030
code: str
3131
description: str = ""
32+
exclude_gpus: list[str] = dataclasses.field(default_factory=list)
3233

3334

3435
TestCaseType = Dict[str, Union[int, str]]

0 commit comments

Comments
 (0)