diff --git a/python/triton_kernels/bench/bench_batchnorm.py b/python/triton_kernels/bench/bench_batchnorm.py new file mode 100644 index 000000000000..15417ae4ff1b --- /dev/null +++ b/python/triton_kernels/bench/bench_batchnorm.py @@ -0,0 +1,96 @@ +import argparse +import time +import torch + + +def cuda_time(fn, iters=30, warmup=20): + if not torch.cuda.is_available(): + raise RuntimeError("CUDA required for benchmarking") + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + # warmup + for _ in range(warmup): + fn() + torch.cuda.synchronize() + start.record() + for _ in range(iters): + fn() + end.record() + torch.cuda.synchronize() + ms = start.elapsed_time(end) + return ms / iters + + +def bench_case(shape, dtype, training): + from triton_kernels.batchnorm import batchnorm_forward + + device = "cuda" + x = torch.randn(*shape, device=device, dtype=dtype) + if x.ndim == 2: + C = x.shape[1] + else: + C = x.shape[1] + gamma = torch.randn(C, device=device, dtype=torch.float32) + beta = torch.randn(C, device=device, dtype=torch.float32) + running_mean = torch.zeros(C, device=device, dtype=torch.float32) + running_var = torch.ones(C, device=device, dtype=torch.float32) + eps = 1e-5 + momentum = 0.1 + + def fn_triton(): + y, m, v = batchnorm_forward( + x, gamma, beta, eps=eps, training=training, + running_mean=running_mean, running_var=running_var, momentum=momentum, layout="NCHW" + ) + return y + + def fn_torch(): + y = torch.nn.functional.batch_norm( + x.float(), + None if training else running_mean, + None if training else running_var, + gamma, beta, training=training, momentum=momentum, eps=eps, + ).to(x.dtype) + return y + + t_triton = cuda_time(fn_triton) + t_torch = cuda_time(fn_torch) + return t_triton, t_torch + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--iters", type=int, default=30) + parser.add_argument("--warmup", type=int, default=20) + args = parser.parse_args() + + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + torch.backends.cuda.matmul.allow_tf32 = True + torch.set_float32_matmul_precision("high") + + shapes = [ + (64, 128), + (64, 128, 32, 32), + (32, 256, 64, 64), + ] + dtypes = [torch.float32, torch.float16, torch.bfloat16] + modes = [True, False] # training, eval + + header = ["shape", "dtype", "mode", "triton_ms", "torch_ms", "speedup(x)"] + print("\t".join(header)) + for shape in shapes: + for dtype in dtypes: + for training in modes: + try: + t_tri, t_ref = bench_case(shape, dtype, training) + speedup = t_ref / max(t_tri, 1e-6) + print(f"{shape}\t{str(dtype).split('.')[-1]}\t{'train' if training else 'eval'}\t{t_tri:.3f}\t{t_ref:.3f}\t{speedup:.2f}") + except Exception as e: + print(f"{shape}\t{str(dtype).split('.')[-1]}\t{'train' if training else 'eval'}\tERROR\tERROR\t{e}") + + +if __name__ == "__main__": + main() + + diff --git a/python/triton_kernels/tests/test_batchnorm.py b/python/triton_kernels/tests/test_batchnorm.py new file mode 100644 index 000000000000..dd4956d1c430 --- /dev/null +++ b/python/triton_kernels/tests/test_batchnorm.py @@ -0,0 +1,70 @@ +import os +import pytest +import torch + + +@pytest.mark.parametrize("shape", [ + (8, 16), + (64, 128), + (2, 8, 32, 32), +]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("training", [True, False]) +def test_batchnorm_forward_matches_torch(shape, dtype, training): + if not torch.cuda.is_available(): + pytest.skip("CUDA required") + device = "cuda" + x = torch.randn(*shape, device=device, dtype=dtype) + if x.ndim == 2: + N, C = x.shape + gamma = torch.randn(C, device=device, dtype=torch.float32) + beta = torch.randn(C, device=device, dtype=torch.float32) + else: + N, C, H, W = x.shape + gamma = torch.randn(C, device=device, dtype=torch.float32) + beta = torch.randn(C, device=device, dtype=torch.float32) + + eps = 1e-5 + momentum = 0.1 + running_mean = torch.zeros(C, device=device, dtype=torch.float32) + running_var = torch.ones(C, device=device, dtype=torch.float32) + + # reference + y_ref = torch.nn.functional.batch_norm( + x.float(), + running_mean.clone() if not training else None, + running_var.clone() if not training else None, + gamma, beta, training=training, momentum=momentum, eps=eps, + ).to(dtype) + + # under test + from triton_kernels.batchnorm import batchnorm_forward + y_tri, saved_mean, saved_var = batchnorm_forward( + x, gamma, beta, eps=eps, training=training, + running_mean=running_mean, running_var=running_var, momentum=momentum, layout="NCHW" + ) + + rtol = 1e-5 if dtype is torch.float32 else 3e-2 + atol = 1e-6 if dtype is torch.float32 else 3e-3 + torch.testing.assert_close(y_ref, y_tri.to(dtype), rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_batchnorm_eps_and_identity(dtype): + if not torch.cuda.is_available(): + pytest.skip("CUDA required") + device = "cuda" + x = torch.randn(4, 8, device=device, dtype=dtype) + C = x.shape[1] + gamma = torch.ones(C, device=device, dtype=torch.float32) + beta = torch.zeros(C, device=device, dtype=torch.float32) + + for eps in (1e-5, 1e-3): + y_ref = torch.nn.functional.batch_norm(x.float(), None, None, gamma, beta, training=True, eps=eps).to(dtype) + from triton_kernels.batchnorm import batchnorm_forward + y_tri, m, v = batchnorm_forward(x, gamma, beta, eps=eps, training=True, layout="NCHW") + rtol = 1e-5 if dtype is torch.float32 else 3e-2 + atol = 1e-6 if dtype is torch.float32 else 3e-3 + torch.testing.assert_close(y_ref, y_tri.to(dtype), rtol=rtol, atol=atol) + + diff --git a/python/triton_kernels/triton_kernels/__init__.py b/python/triton_kernels/triton_kernels/__init__.py index e69de29bb2d1..3a6a81512fa0 100644 --- a/python/triton_kernels/triton_kernels/__init__.py +++ b/python/triton_kernels/triton_kernels/__init__.py @@ -0,0 +1,3 @@ +from .batchnorm import batchnorm_forward # re-export for discoverability + + diff --git a/python/triton_kernels/triton_kernels/batchnorm.py b/python/triton_kernels/triton_kernels/batchnorm.py new file mode 100644 index 000000000000..38e712e598f9 --- /dev/null +++ b/python/triton_kernels/triton_kernels/batchnorm.py @@ -0,0 +1,179 @@ +from typing import Optional, Tuple +import torch +import triton +import triton.language as tl + + +def batchnorm_forward( + x: torch.Tensor, + gamma: torch.Tensor, + beta: torch.Tensor, + eps: float = 1e-5, + training: bool = True, + running_mean: Optional[torch.Tensor] = None, + running_var: Optional[torch.Tensor] = None, + momentum: float = 0.1, + layout: str = "NCHW", +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + BatchNorm forward (NCHW only) using Triton kernels. + + Args: + x: Input tensor of shape (N, C, H, W) or (N, C), CUDA device. + gamma: Scale parameters of shape (C,), fp32 recommended. + beta: Bias parameters of shape (C,), fp32 recommended. + eps: Small epsilon added to variance for numerical stability. + training: If True, compute batch statistics; else use running stats. + running_mean: If provided and training=False, used as mean (shape (C,)). + running_var: If provided and training=False, used as var (shape (C,)). + momentum: Placeholder for API symmetry; running stats are not updated + in this function. Callers can update externally if desired. + layout: Currently only "NCHW" supported. + + Returns: + (y, saved_mean, saved_var): + y: Output tensor with same shape/dtype as x + saved_mean: Per-channel mean (fp32) + saved_var: Per-channel variance (population, fp32) + + Notes: + - Statistics are accumulated in fp32; for half types accuracy is + generally sufficient and validated by tests. + - In eval mode, saved_mean/var mirror running_mean/var provided. + - This function does not update running statistics in-place. + """ + if not x.is_cuda: + raise ValueError("CUDA tensor required") + if x.ndim not in (2, 4): + raise ValueError(f"Unsupported ndim={x.ndim}; expected 2 or 4") + if layout not in ("NCHW",): + raise ValueError("Only NCHW layout supported in v1") + C = x.shape[1] + if gamma is None or beta is None: + raise ValueError("gamma and beta must be provided") + if gamma.numel() != C or beta.numel() != C: + raise ValueError("gamma/beta must have size equal to channel dimension") + if eps <= 0: + raise ValueError("eps must be positive") + if not training: + if running_mean is None or running_var is None: + raise ValueError("running_mean/var required for eval mode") + if running_mean.numel() != C or running_var.numel() != C: + raise ValueError("running stats must match channel dimension") + + # Prepare contiguous channel-first view: (C, R) + if x.ndim == 2: + N = x.shape[0] + R = N + x_cf = x.transpose(0, 1).contiguous() # (C, N) + inv_perm_to_orig = (1, 0) + out_shape = (N, C) + else: + N, C_, H, W = x.shape + assert C_ == C + R = N * H * W + x_cf = x.contiguous().permute(1, 0, 2, 3).contiguous().view(C, R) # (C, R) + inv_perm_to_orig = (1, 0, 2, 3) + out_shape = (N, C, H, W) + + device = x.device + x_dtype = x.dtype + + # Stats (mean, var) in fp32 + mean = torch.empty((C,), device=device, dtype=torch.float32) + var = torch.empty((C,), device=device, dtype=torch.float32) + + if training: + # Launch stats kernel: per-channel Welford across R + BLOCK_R = 1024 + grid = (C,) + _bn_stats_kernel[grid]( + x_cf, mean, var, + C, R, + BLOCK_R=BLOCK_R, + ) + saved_mean, saved_var = mean, var + else: + # Use provided running stats + saved_mean = running_mean.to(torch.float32, copy=True) + saved_var = running_var.to(torch.float32, copy=True) + + # Precompute inv_std + inv_std = (saved_var + eps).rsqrt() + + # Normalize + y_cf = torch.empty_like(x_cf, dtype=x_dtype) + BLOCK_R = 1024 + grid = (C,) + _bn_norm_kernel[grid]( + x_cf, y_cf, + saved_mean, inv_std, + gamma.to(torch.float32), beta.to(torch.float32), + C, R, + BLOCK_R=BLOCK_R, + ) + + # Restore original layout + if x.ndim == 2: + y = y_cf.transpose(0, 1).contiguous().view(out_shape) + else: + y = y_cf.view(C, N, H, W).permute(1, 0, 2, 3).contiguous() + + return y, saved_mean, saved_var + + +@triton.jit +def _bn_stats_kernel(x_cf, mean_out, var_out, C, R, BLOCK_R: tl.constexpr): + c = tl.program_id(0) + if c >= C: + return + # Pointers + x_ptr = x_cf + c * R + # Accumulators in fp32 + s = 0.0 + s2 = 0.0 + cnt = 0.0 + + offs = tl.arange(0, BLOCK_R) + r0 = 0 + while r0 < R: + idx = r0 + offs + mask = idx < R + vals = tl.load(x_ptr + idx, mask=mask, other=0).to(tl.float32) + s += tl.sum(vals, axis=0) + s2 += tl.sum(vals * vals, axis=0) + num = tl.sum(mask, axis=0) + cnt += num.to(tl.float32) + r0 += BLOCK_R + + # Finalize (population variance) + mean = tl.where(cnt > 0.0, s / cnt, 0.0) + var = tl.where(cnt > 0.0, s2 / cnt - mean * mean, 0.0) + tl.store(mean_out + c, mean) + tl.store(var_out + c, var) + + +@triton.jit +def _bn_norm_kernel(x_cf, y_cf, mean, inv_std, gamma, beta, C, R, BLOCK_R: tl.constexpr): + c = tl.program_id(0) + if c >= C: + return + x_ptr = x_cf + c * R + y_ptr = y_cf + c * R + m = tl.load(mean + c) + istd = tl.load(inv_std + c) + g = tl.load(gamma + c) + b = tl.load(beta + c) + + offs = tl.arange(0, BLOCK_R) + for r0 in range(0, R, BLOCK_R): + idx = r0 + offs + mask = idx < R + x = tl.load(x_ptr + idx, mask=mask, other=0) + x_f = x.to(tl.float32) + y = (x_f - m) * istd + y = y * g + b + tl.store(y_ptr + idx, y, mask=mask) + + +