Skip to content

[BE] Tool to generate minimal reproduction script #351

@xuzhao9

Description

@xuzhao9

It is a common task to compare Tritonbench's result with other benchmarks.

When there are discrepancies between two benchmarks, we often need a script to strip Tritonbench's framework code to get a minimal reproduction of the benchmark. That is to say, we need to generate a script that is able to run a kernel without any dependency from the "tritonbench.*" package. Usually, running with the first input data point is sufficient.

For example, for the softmax operator with triton_softmax backend, we can have such a script:

from typing import Generator, List

import torch
import triton
import triton.language as tl

dtype = torch.bfloat16
device = "cuda"

def get_input_iter():
    M = 4096
    shapes = [(M, 128 * i) for i in range(2, 100)]
    for M, N in shapes:
        yield (torch.randn([M, N], dtype=dtype, device=device),)

@triton.jit
def softmax_kernel(
    output_ptr,
    input_ptr,
    input_row_stride,
    output_row_stride,
    n_cols,
    BLOCK_SIZE: tl.constexpr,
):
    # The rows of the softmax are independent, so we parallelize across those
    row_idx = tl.program_id(0)
    # The stride represents how much we need to increase the pointer to advance 1 row
    row_start_ptr = input_ptr + row_idx * input_row_stride
    # The block size is the next power of two greater than n_cols, so we can fit each
    # row in a single block
    col_offsets = tl.arange(0, BLOCK_SIZE)
    input_ptrs = row_start_ptr + col_offsets
    # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
    row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float("inf"))
    # Subtract maximum for numerical stability
    row_minus_max = row - tl.max(row, axis=0)
    # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
    numerator = tl.exp(row_minus_max)
    denominator = tl.sum(numerator, axis=0)
    softmax_output = numerator / denominator
    # Write back output to DRAM
    output_row_start_ptr = output_ptr + row_idx * output_row_stride
    output_ptrs = output_row_start_ptr + col_offsets
    tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)


def triton_softmax(x):
    n_rows, n_cols = x.shape
    # The block size is the smallest power of two greater than the number of columns in `x`
    BLOCK_SIZE = triton.next_power_of_2(n_cols)
    # Another trick we can use is to ask the compiler to use more threads per row by
    # increasing the number of warps (`num_warps`) over which each row is distributed.
    # You will see in the next tutorial how to auto-tune this value in a more natural
    # way so you don't have to come up with manual heuristics yourself.
    num_warps = 4
    if BLOCK_SIZE >= 2048:
        num_warps = 8
    if BLOCK_SIZE >= 4096:
        num_warps = 16
    # Allocate output
    y = torch.empty_like(x)

    # Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o
    # f the input matrix
    def _inner():
        Operator.softmax_kernel[(n_rows,)](
            y,
            x,
            x.stride(0),
            y.stride(0),
            n_cols,
            num_warps=num_warps,
            BLOCK_SIZE=BLOCK_SIZE,
        )
        return y

    return _inner

def run():
    input_iter = get_input_iter()
    first_input = next(input_iter)
    fn = triton_softmax(first_input)
    results = triton.testing.do_bench(fn)
    print(f"Benchmarking results: {results}")



if __name__ == "__main__":
    run()


Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions