Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions python/triton_kernels/bench/bench_batchnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import argparse
import time
import torch


def cuda_time(fn, iters=30, warmup=20):
if not torch.cuda.is_available():
raise RuntimeError("CUDA required for benchmarking")
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
# warmup
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start.record()
for _ in range(iters):
fn()
end.record()
torch.cuda.synchronize()
ms = start.elapsed_time(end)
return ms / iters


def bench_case(shape, dtype, training):
from triton_kernels.batchnorm import batchnorm_forward

device = "cuda"
x = torch.randn(*shape, device=device, dtype=dtype)
if x.ndim == 2:
C = x.shape[1]
else:
C = x.shape[1]
gamma = torch.randn(C, device=device, dtype=torch.float32)
beta = torch.randn(C, device=device, dtype=torch.float32)
running_mean = torch.zeros(C, device=device, dtype=torch.float32)
running_var = torch.ones(C, device=device, dtype=torch.float32)
eps = 1e-5
momentum = 0.1

def fn_triton():
y, m, v = batchnorm_forward(
x, gamma, beta, eps=eps, training=training,
running_mean=running_mean, running_var=running_var, momentum=momentum, layout="NCHW"
)
return y

def fn_torch():
y = torch.nn.functional.batch_norm(
x.float(),
None if training else running_mean,
None if training else running_var,
gamma, beta, training=training, momentum=momentum, eps=eps,
).to(x.dtype)
return y

t_triton = cuda_time(fn_triton)
t_torch = cuda_time(fn_torch)
return t_triton, t_torch


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--iters", type=int, default=30)
parser.add_argument("--warmup", type=int, default=20)
args = parser.parse_args()

torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision("high")

shapes = [
(64, 128),
(64, 128, 32, 32),
(32, 256, 64, 64),
]
dtypes = [torch.float32, torch.float16, torch.bfloat16]
modes = [True, False] # training, eval

header = ["shape", "dtype", "mode", "triton_ms", "torch_ms", "speedup(x)"]
print("\t".join(header))
for shape in shapes:
for dtype in dtypes:
for training in modes:
try:
t_tri, t_ref = bench_case(shape, dtype, training)
speedup = t_ref / max(t_tri, 1e-6)
print(f"{shape}\t{str(dtype).split('.')[-1]}\t{'train' if training else 'eval'}\t{t_tri:.3f}\t{t_ref:.3f}\t{speedup:.2f}")
except Exception as e:
print(f"{shape}\t{str(dtype).split('.')[-1]}\t{'train' if training else 'eval'}\tERROR\tERROR\t{e}")


if __name__ == "__main__":
main()


70 changes: 70 additions & 0 deletions python/triton_kernels/tests/test_batchnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import os
import pytest
import torch


@pytest.mark.parametrize("shape", [
(8, 16),
(64, 128),
(2, 8, 32, 32),
])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("training", [True, False])
def test_batchnorm_forward_matches_torch(shape, dtype, training):
if not torch.cuda.is_available():
pytest.skip("CUDA required")
device = "cuda"
x = torch.randn(*shape, device=device, dtype=dtype)
if x.ndim == 2:
N, C = x.shape
gamma = torch.randn(C, device=device, dtype=torch.float32)
beta = torch.randn(C, device=device, dtype=torch.float32)
else:
N, C, H, W = x.shape
gamma = torch.randn(C, device=device, dtype=torch.float32)
beta = torch.randn(C, device=device, dtype=torch.float32)

eps = 1e-5
momentum = 0.1
running_mean = torch.zeros(C, device=device, dtype=torch.float32)
running_var = torch.ones(C, device=device, dtype=torch.float32)

# reference
y_ref = torch.nn.functional.batch_norm(
x.float(),
running_mean.clone() if not training else None,
running_var.clone() if not training else None,
gamma, beta, training=training, momentum=momentum, eps=eps,
).to(dtype)

# under test
from triton_kernels.batchnorm import batchnorm_forward
y_tri, saved_mean, saved_var = batchnorm_forward(
x, gamma, beta, eps=eps, training=training,
running_mean=running_mean, running_var=running_var, momentum=momentum, layout="NCHW"
)

rtol = 1e-5 if dtype is torch.float32 else 3e-2
atol = 1e-6 if dtype is torch.float32 else 3e-3
torch.testing.assert_close(y_ref, y_tri.to(dtype), rtol=rtol, atol=atol)


@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_batchnorm_eps_and_identity(dtype):
if not torch.cuda.is_available():
pytest.skip("CUDA required")
device = "cuda"
x = torch.randn(4, 8, device=device, dtype=dtype)
C = x.shape[1]
gamma = torch.ones(C, device=device, dtype=torch.float32)
beta = torch.zeros(C, device=device, dtype=torch.float32)

for eps in (1e-5, 1e-3):
y_ref = torch.nn.functional.batch_norm(x.float(), None, None, gamma, beta, training=True, eps=eps).to(dtype)
from triton_kernels.batchnorm import batchnorm_forward
y_tri, m, v = batchnorm_forward(x, gamma, beta, eps=eps, training=True, layout="NCHW")
rtol = 1e-5 if dtype is torch.float32 else 3e-2
atol = 1e-6 if dtype is torch.float32 else 3e-3
torch.testing.assert_close(y_ref, y_tri.to(dtype), rtol=rtol, atol=atol)


3 changes: 3 additions & 0 deletions python/triton_kernels/triton_kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .batchnorm import batchnorm_forward # re-export for discoverability


179 changes: 179 additions & 0 deletions python/triton_kernels/triton_kernels/batchnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from typing import Optional, Tuple
import torch
import triton
import triton.language as tl


def batchnorm_forward(
x: torch.Tensor,
gamma: torch.Tensor,
beta: torch.Tensor,
eps: float = 1e-5,
training: bool = True,
running_mean: Optional[torch.Tensor] = None,
running_var: Optional[torch.Tensor] = None,
momentum: float = 0.1,
layout: str = "NCHW",
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
BatchNorm forward (NCHW only) using Triton kernels.

Args:
x: Input tensor of shape (N, C, H, W) or (N, C), CUDA device.
gamma: Scale parameters of shape (C,), fp32 recommended.
beta: Bias parameters of shape (C,), fp32 recommended.
eps: Small epsilon added to variance for numerical stability.
training: If True, compute batch statistics; else use running stats.
running_mean: If provided and training=False, used as mean (shape (C,)).
running_var: If provided and training=False, used as var (shape (C,)).
momentum: Placeholder for API symmetry; running stats are not updated
in this function. Callers can update externally if desired.
layout: Currently only "NCHW" supported.

Returns:
(y, saved_mean, saved_var):
y: Output tensor with same shape/dtype as x
saved_mean: Per-channel mean (fp32)
saved_var: Per-channel variance (population, fp32)

Notes:
- Statistics are accumulated in fp32; for half types accuracy is
generally sufficient and validated by tests.
- In eval mode, saved_mean/var mirror running_mean/var provided.
- This function does not update running statistics in-place.
"""
if not x.is_cuda:
raise ValueError("CUDA tensor required")
if x.ndim not in (2, 4):
raise ValueError(f"Unsupported ndim={x.ndim}; expected 2 or 4")
if layout not in ("NCHW",):
raise ValueError("Only NCHW layout supported in v1")
C = x.shape[1]
if gamma is None or beta is None:
raise ValueError("gamma and beta must be provided")
if gamma.numel() != C or beta.numel() != C:
raise ValueError("gamma/beta must have size equal to channel dimension")
if eps <= 0:
raise ValueError("eps must be positive")
if not training:
if running_mean is None or running_var is None:
raise ValueError("running_mean/var required for eval mode")
if running_mean.numel() != C or running_var.numel() != C:
raise ValueError("running stats must match channel dimension")

# Prepare contiguous channel-first view: (C, R)
if x.ndim == 2:
N = x.shape[0]
R = N
x_cf = x.transpose(0, 1).contiguous() # (C, N)
inv_perm_to_orig = (1, 0)
out_shape = (N, C)
else:
N, C_, H, W = x.shape
assert C_ == C
R = N * H * W
x_cf = x.contiguous().permute(1, 0, 2, 3).contiguous().view(C, R) # (C, R)
inv_perm_to_orig = (1, 0, 2, 3)
out_shape = (N, C, H, W)

device = x.device
x_dtype = x.dtype

# Stats (mean, var) in fp32
mean = torch.empty((C,), device=device, dtype=torch.float32)
var = torch.empty((C,), device=device, dtype=torch.float32)

if training:
# Launch stats kernel: per-channel Welford across R
BLOCK_R = 1024
grid = (C,)
_bn_stats_kernel[grid](
x_cf, mean, var,
C, R,
BLOCK_R=BLOCK_R,
)
saved_mean, saved_var = mean, var
else:
# Use provided running stats
saved_mean = running_mean.to(torch.float32, copy=True)
saved_var = running_var.to(torch.float32, copy=True)

# Precompute inv_std
inv_std = (saved_var + eps).rsqrt()

# Normalize
y_cf = torch.empty_like(x_cf, dtype=x_dtype)
BLOCK_R = 1024
grid = (C,)
_bn_norm_kernel[grid](
x_cf, y_cf,
saved_mean, inv_std,
gamma.to(torch.float32), beta.to(torch.float32),
C, R,
BLOCK_R=BLOCK_R,
)

# Restore original layout
if x.ndim == 2:
y = y_cf.transpose(0, 1).contiguous().view(out_shape)
else:
y = y_cf.view(C, N, H, W).permute(1, 0, 2, 3).contiguous()

return y, saved_mean, saved_var


@triton.jit
def _bn_stats_kernel(x_cf, mean_out, var_out, C, R, BLOCK_R: tl.constexpr):
c = tl.program_id(0)
if c >= C:
return
# Pointers
x_ptr = x_cf + c * R
# Accumulators in fp32
s = 0.0
s2 = 0.0
cnt = 0.0

offs = tl.arange(0, BLOCK_R)
r0 = 0
while r0 < R:
idx = r0 + offs
mask = idx < R
vals = tl.load(x_ptr + idx, mask=mask, other=0).to(tl.float32)
s += tl.sum(vals, axis=0)
s2 += tl.sum(vals * vals, axis=0)
num = tl.sum(mask, axis=0)
cnt += num.to(tl.float32)
r0 += BLOCK_R

# Finalize (population variance)
mean = tl.where(cnt > 0.0, s / cnt, 0.0)
var = tl.where(cnt > 0.0, s2 / cnt - mean * mean, 0.0)
tl.store(mean_out + c, mean)
tl.store(var_out + c, var)


@triton.jit
def _bn_norm_kernel(x_cf, y_cf, mean, inv_std, gamma, beta, C, R, BLOCK_R: tl.constexpr):
c = tl.program_id(0)
if c >= C:
return
x_ptr = x_cf + c * R
y_ptr = y_cf + c * R
m = tl.load(mean + c)
istd = tl.load(inv_std + c)
g = tl.load(gamma + c)
b = tl.load(beta + c)

offs = tl.arange(0, BLOCK_R)
for r0 in range(0, R, BLOCK_R):
idx = r0 + offs
mask = idx < R
x = tl.load(x_ptr + idx, mask=mask, other=0)
x_f = x.to(tl.float32)
y = (x_f - m) * istd
y = y * g + b
tl.store(y_ptr + idx, y, mask=mask)