-
Notifications
You must be signed in to change notification settings - Fork 40
Open
Labels
Description
It is a common task to compare Tritonbench's result with other benchmarks.
When there are discrepancies between two benchmarks, we often need a script to strip Tritonbench's framework code to get a minimal reproduction of the benchmark. That is to say, we need to generate a script that is able to run a kernel without any dependency from the "tritonbench.*" package. Usually, running with the first input data point is sufficient.
For example, for the softmax operator with triton_softmax
backend, we can have such a script:
from typing import Generator, List
import torch
import triton
import triton.language as tl
dtype = torch.bfloat16
device = "cuda"
def get_input_iter():
M = 4096
shapes = [(M, 128 * i) for i in range(2, 100)]
for M, N in shapes:
yield (torch.randn([M, N], dtype=dtype, device=device),)
@triton.jit
def softmax_kernel(
output_ptr,
input_ptr,
input_row_stride,
output_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
# The rows of the softmax are independent, so we parallelize across those
row_idx = tl.program_id(0)
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float("inf"))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
def triton_softmax(x):
n_rows, n_cols = x.shape
# The block size is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 4
if BLOCK_SIZE >= 2048:
num_warps = 8
if BLOCK_SIZE >= 4096:
num_warps = 16
# Allocate output
y = torch.empty_like(x)
# Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o
# f the input matrix
def _inner():
Operator.softmax_kernel[(n_rows,)](
y,
x,
x.stride(0),
y.stride(0),
n_cols,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
)
return y
return _inner
def run():
input_iter = get_input_iter()
first_input = next(input_iter)
fn = triton_softmax(first_input)
results = triton.testing.do_bench(fn)
print(f"Benchmarking results: {results}")
if __name__ == "__main__":
run()