Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ __pycache__/
.pytest_cache
**/.cache
**/meta-llama/**/*
*.egg-info
2 changes: 1 addition & 1 deletion benchmarking/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,5 @@ def get_benchmark_vals(cls):
@classmethod
def get_profiling_data(cls):
if cls._instance and cls._instance.profiler:
return self.profiler.key_averages()
return cls._instance.profiler.key_averages()
return None
3 changes: 3 additions & 0 deletions kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# from .conv import _conv, conv
from . import blocksparse
from .batchnorm import _batchnorm, batchnorm
from .cross_entropy import _cross_entropy, cross_entropy
from .flash_attention import attention
from .matmul import _matmul, get_higher_dtype, matmul

__all__ = [
"blocksparse",
"_batchnorm",
"batchnorm",
"_cross_entropy",
"cross_entropy",
"_matmul",
Expand Down
172 changes: 172 additions & 0 deletions kernels/batchnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import torch
import triton
import triton.language as tl
from benchmarking import Profiler

@triton.autotune(
configs=[
triton.Config({'BLOCK_N': 256, 'BLOCK_C': 128}, num_warps=4, num_stages=2),
triton.Config({'BLOCK_N': 512, 'BLOCK_C': 128}, num_warps=4, num_stages=2),
triton.Config({'BLOCK_N': 256, 'BLOCK_C': 256}, num_warps=8, num_stages=2),
triton.Config({'BLOCK_N': 128, 'BLOCK_C': 256}, num_warps=4, num_stages=2),
triton.Config({'BLOCK_N': 512, 'BLOCK_C': 256}, num_warps=8, num_stages=3),
],
key=['N', 'C', 'stride_in_c'],
)
@triton.jit
def bn_reduce_kernel(x_ptr, sum_ptr, sumsq_ptr, N, C, stride_in_n, stride_in_c,
BLOCK_N: tl.constexpr, BLOCK_C: tl.constexpr):
pid_n = tl.program_id(0)
pid_c = tl.program_id(1)

offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_c = pid_c * BLOCK_C + tl.arange(0, BLOCK_C)

mask_n = offs_n < N
mask_c = offs_c < C

x_ptrs = x_ptr + offs_n[:, None] * stride_in_n + offs_c[None, :] * stride_in_c
mask_2d = mask_n[:, None] & mask_c[None, :]

x = tl.load(x_ptrs, mask=mask_2d, other=0).to(tl.float32)
psum = tl.sum(x, axis=0)
psumsq = tl.sum(x * x, axis=0)

tl.atomic_add(sum_ptr + offs_c, psum, mask=mask_c)
tl.atomic_add(sumsq_ptr + offs_c, psumsq, mask=mask_c)

@triton.autotune(
configs=[
triton.Config({'BLOCK_N': 256, 'BLOCK_C': 128}, num_warps=4, num_stages=2),
triton.Config({'BLOCK_N': 512, 'BLOCK_C': 128}, num_warps=4, num_stages=2),
triton.Config({'BLOCK_N': 256, 'BLOCK_C': 256}, num_warps=8, num_stages=2),
triton.Config({'BLOCK_N': 128, 'BLOCK_C': 256}, num_warps=4, num_stages=2),
triton.Config({'BLOCK_N': 512, 'BLOCK_C': 256}, num_warps=8, num_stages=3),
],
key=['N', 'C', 'stride_in_c'],
)
@triton.jit
def bn_norm_kernel(x_ptr, gamma_ptr, beta_ptr, y_ptr, sum_ptr, sumsq_ptr,
N, C, eps, stride_in_n, stride_in_c, stride_out_n, stride_out_c,
BLOCK_N: tl.constexpr, BLOCK_C: tl.constexpr):
pid_n = tl.program_id(0)
pid_c = tl.program_id(1)

offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_c = pid_c * BLOCK_C + tl.arange(0, BLOCK_C)
mask_n = offs_n < N
mask_c = offs_c < C

s = tl.load(sum_ptr + offs_c, mask=mask_c, other=0.).to(tl.float32)
ss = tl.load(sumsq_ptr + offs_c, mask=mask_c, other=0.).to(tl.float32)
Nf = tl.full((), N, tl.float32)

mean = s / Nf
var = ss / Nf - mean * mean
inv_std = tl.rsqrt(var + eps)

gamma = tl.load(gamma_ptr + offs_c, mask=mask_c, other=1.).to(tl.float32)
beta = tl.load(beta_ptr + offs_c, mask=mask_c, other=0.).to(tl.float32)

scale = gamma * inv_std
shift = beta - mean * scale

x_ptrs = x_ptr + offs_n[:, None] * stride_in_n + offs_c[None, :] * stride_in_c
y_ptrs = y_ptr + offs_n[:, None] * stride_out_n + offs_c[None, :] * stride_out_c
mask_2d = mask_n[:, None] & mask_c[None, :]

x = tl.load(x_ptrs, mask=mask_2d, other=0.).to(tl.float32)
y = x * scale[None, :] + shift[None, :]
tl.store(y_ptrs, y, mask=mask_2d)

@triton.autotune(
configs=[
triton.Config({'BLOCK_C': 64, 'BLOCK_N': 128}, num_warps=2, num_stages=3),
triton.Config({'BLOCK_C': 128, 'BLOCK_N': 128}, num_warps=4, num_stages=3),
triton.Config({'BLOCK_C': 128, 'BLOCK_N': 256}, num_warps=4, num_stages=4),
triton.Config({'BLOCK_C': 256, 'BLOCK_N': 128}, num_warps=8, num_stages=4),
],
key=['N', 'C', 'stride_in_c'],
)
@triton.jit
def bn_fused_kernel(x_ptr, y_ptr, gamma_ptr, beta_ptr, N, C, eps,
stride_in_n, stride_in_c, stride_out_n, stride_out_c,
BLOCK_C: tl.constexpr, BLOCK_N: tl.constexpr):
pid_c = tl.program_id(0)
offs_c = pid_c * BLOCK_C + tl.arange(0, BLOCK_C)
mask_c = offs_c < C

gamma = tl.load(gamma_ptr + offs_c, mask=mask_c, other=1.).to(tl.float32)
beta = tl.load(beta_ptr + offs_c, mask=mask_c, other=0.).to(tl.float32)

s = tl.zeros([BLOCK_C], dtype=tl.float32)
ss = tl.zeros([BLOCK_C], dtype=tl.float32)

tl.max_contiguous(offs_c, BLOCK_C)

for n0 in range(0, N, BLOCK_N):
offs_n = n0 + tl.arange(0, BLOCK_N)
mask_n = offs_n < N
x_ptrs = x_ptr + offs_n[:, None] * stride_in_n + offs_c[None, :] * stride_in_c
m = mask_n[:, None] & mask_c[None, :]
x = tl.load(x_ptrs, mask=m, other=0.).to(tl.float32)
s += tl.sum(x, axis=0)
ss += tl.sum(x * x, axis=0)

Nf = tl.full((), N, tl.float32)
mean = s / Nf
var = ss / Nf - mean * mean
inv_std = tl.rsqrt(var + eps)
scale = gamma * inv_std
shift = beta - mean * scale

for n0 in range(0, N, BLOCK_N):
offs_n = n0 + tl.arange(0, BLOCK_N)
mask_n = offs_n < N
x_ptrs = x_ptr + offs_n[:, None] * stride_in_n + offs_c[None, :] * stride_in_c
y_ptrs = y_ptr + offs_n[:, None] * stride_out_n + offs_c[None, :] * stride_out_c
m = mask_n[:, None] & mask_c[None, :]
x = tl.load(x_ptrs, mask=m, other=0.).to(tl.float32)
y = x * scale[None, :] + shift[None, :]
tl.store(y_ptrs, y, mask=m)

def _batchnorm(input: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor,
output: torch.Tensor, N: int, C: int, eps: float):
assert input.is_cuda and output.is_cuda and gamma.is_cuda and beta.is_cuda
assert input.dtype in (torch.float16, torch.bfloat16, torch.float32)

problem_size = N * C
use_fused = problem_size < (1 << 24)

if use_fused:
def grid(meta):
return (triton.cdiv(C, meta['BLOCK_C']),)
bn_fused_kernel[grid](
input, output, gamma, beta,
N, C, eps,
input.stride(0), input.stride(1),
output.stride(0), output.stride(1),
)
return

sum_buf = torch.zeros(C, device=input.device, dtype=torch.float32)
sumsq_buf = torch.zeros(C, device=input.device, dtype=torch.float32)

def grid(meta):
return (triton.cdiv(N, meta['BLOCK_N']), triton.cdiv(C, meta['BLOCK_C']))

bn_reduce_kernel[grid](input, sum_buf, sumsq_buf,
N, C, input.stride(0), input.stride(1))

bn_norm_kernel[grid](input, gamma, beta, output,
sum_buf, sumsq_buf,
N, C, eps,
input.stride(0), input.stride(1),
output.stride(0), output.stride(1))

@Profiler.profiling_decorator("batchnorm")
def batchnorm(input: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor, eps: float = 1e-5):
N, C = input.shape
output = torch.empty_like(input)
_batchnorm(input, gamma, beta, output, N, C, eps)
return output
43 changes: 39 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from models.llama import llama_example_chat_completion, llama_example_text_completion
from benchmarking import Profiler, compare_benchmarks
import kernels
import pprint


Expand All @@ -31,8 +32,8 @@ def main(operation: str, profile=False, benchmark=False, **kwargs):
p = Profiler(profile, benchmark)
torch.cuda.empty_cache()
runner(operation, kwargs)
benchmarks["triton"] = Profiler.get_benchmark_vals()
profiles["triton"] = Profiler.get_profiling_data()
benchmarks["non_triton"] = Profiler.get_benchmark_vals()
profiles["non_triton"] = Profiler.get_profiling_data()
Profiler.reset()
p = Profiler(profile, benchmark)

Expand All @@ -42,8 +43,8 @@ def main(operation: str, profile=False, benchmark=False, **kwargs):
p = Profiler(profile, benchmark)
torch.cuda.empty_cache()
runner(operation, kwargs)
benchmarks["non_triton"] = Profiler.get_benchmark_vals()
profiles["non_triton"] = Profiler.get_profiling_data()
benchmarks["triton"] = Profiler.get_benchmark_vals()
profiles["triton"] = Profiler.get_profiling_data()
elif profile:
runner(operation, kwargs)
data = Profiler.get_profiling_data()
Expand All @@ -66,11 +67,45 @@ def main(operation: str, profile=False, benchmark=False, **kwargs):
print("\n==================================\n")


@Profiler.profiling_decorator("batchnorm_benchmark")
def batchnorm_benchmark(batch_size=512, channels=2048, use_triton=True, suppress_prints=False):
"""Benchmark batchnorm implementation"""
if not suppress_prints:
print(f"Running BatchNorm benchmark: {batch_size}x{channels}, use_triton={use_triton}")

device = 'cuda'
eps = 1e-5

# Create test data
input_tensor = torch.randn(batch_size, channels, device=device, dtype=torch.float32)
gamma = torch.ones(channels, device=device, dtype=torch.float32)
beta = torch.zeros(channels, device=device, dtype=torch.float32)

if use_triton:
# Use custom Triton implementation
output = kernels.batchnorm(input_tensor, gamma, beta, eps)
else:
# Use PyTorch baseline
bn = torch.nn.BatchNorm1d(channels, eps=eps, affine=True, track_running_stats=False, device=device)
bn.weight.data = gamma
bn.bias.data = beta
bn.train()
with torch.no_grad():
output = bn(input_tensor)

if not suppress_prints:
print(f"Output shape: {output.shape}, dtype: {output.dtype}")

return output


def runner(operation: str, kwargs):
if operation == "llama_chat_completion":
llama_example_chat_completion(**kwargs)
elif operation == "llama_text_completion":
llama_example_text_completion(**kwargs)
elif operation == "batchnorm_benchmark":
batchnorm_benchmark(**kwargs)
else:
raise ValueError(f"Unknown operation: {operation}")

Expand Down
30 changes: 30 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
[build-system]
requires = ["setuptools>=64", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "triton-kernels"
version = "0.1.0"
description = "High-performance Triton kernels for machine learning"
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"torch",
"triton",
"numpy",
"pandas>=2.0.3",
"fire>=0.7.1",
"fairscale>=0.4.13",
"tiktoken>=0.7.0",
]

[project.optional-dependencies]
test = ["pytest", "pytest-xdist"]
dev = ["pytest", "pytest-xdist", "black", "ruff"]

[tool.setuptools.packages.find]
include = ["kernels*"]
exclude = ["test*", "benchmarking*", "models*"]

[tool.setuptools.package-dir]
"" = "."
87 changes: 87 additions & 0 deletions test/test_batchnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import pytest
import torch
import kernels
from kernels import batchnorm, _batchnorm

def create_test_data(batch_size: int, channels: int, device: str = 'cuda'):
input_tensor = torch.randn(batch_size, channels, device=device, dtype=torch.float32)
gamma = torch.ones(channels, device=device, dtype=torch.float32)
beta = torch.zeros(channels, device=device, dtype=torch.float32)
return input_tensor, gamma, beta

@pytest.mark.parametrize("batch_size,channels", [
(32, 128),
(64, 256),
(128, 512),
(256, 1024),
(512, 2048),
])
def test_batchnorm_correctness(batch_size, channels):
eps = 1e-5
input_tensor, gamma, beta = create_test_data(batch_size, channels)

output_custom = batchnorm(input_tensor, gamma, beta, eps)

bn = torch.nn.BatchNorm1d(channels, eps=eps, affine=True, track_running_stats=False, device=input_tensor.device)
bn.weight.data = gamma
bn.bias.data = beta
bn.train()
with torch.no_grad():
output_pytorch = bn(input_tensor)

max_diff = torch.max(torch.abs(output_custom - output_pytorch)).item()
rel_error = max_diff / (torch.max(torch.abs(output_pytorch)).item() + 1e-8)

assert max_diff < 1e-3 or rel_error < 1e-3, f"Max diff: {max_diff:.2e}, Rel error: {rel_error:.2e}"

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
def test_batchnorm_dtypes(dtype):
batch_size, channels = 128, 256
eps = 1e-5

input_tensor = torch.randn(batch_size, channels, device='cuda', dtype=dtype)
gamma = torch.ones(channels, device='cuda', dtype=dtype)
beta = torch.zeros(channels, device='cuda', dtype=dtype)

output = batchnorm(input_tensor, gamma, beta, eps)
assert output.dtype == dtype
assert output.shape == input_tensor.shape

def test_batchnorm_backward_compatibility():
batch_size, channels = 64, 128
eps = 1e-5
input_tensor, gamma, beta = create_test_data(batch_size, channels)

N, C = input_tensor.shape
output = torch.empty_like(input_tensor)
_batchnorm(input_tensor, gamma, beta, output, N, C, eps)

bn = torch.nn.BatchNorm1d(C, eps=eps, affine=True, track_running_stats=False, device=input_tensor.device)
bn.weight.data = gamma
bn.bias.data = beta
bn.train()
with torch.no_grad():
expected = bn(input_tensor)

torch.testing.assert_close(output, expected, atol=1e-3, rtol=1e-3)

@pytest.mark.parametrize("use_small_problem", [True, False])
def test_fused_vs_two_pass(use_small_problem):
if use_small_problem:
batch_size, channels = 32, 64
else:
batch_size, channels = 8192, 1024

eps = 1e-5
input_tensor, gamma, beta = create_test_data(batch_size, channels)

output = batchnorm(input_tensor, gamma, beta, eps)

bn = torch.nn.BatchNorm1d(channels, eps=eps, affine=True, track_running_stats=False, device=input_tensor.device)
bn.weight.data = gamma
bn.bias.data = beta
bn.train()
with torch.no_grad():
expected = bn(input_tensor)

torch.testing.assert_close(output, expected, atol=1e-3, rtol=1e-3)