From e9cc36b1a1d0f42b351bf8c428b7bba406f7da1e Mon Sep 17 00:00:00 2001 From: 0xrushi <6279035+0xrushi@users.noreply.github.com> Date: Sun, 21 Sep 2025 18:31:24 -0400 Subject: [PATCH 1/5] add bn --- benchmarking/benchmark_batchnorm.py | 172 ++++++++++++++++++++++++++++ kernels/__init__.py | 3 + kernels/batchnorm.py | 170 +++++++++++++++++++++++++++ test/test_batchnorm.py | 86 ++++++++++++++ 4 files changed, 431 insertions(+) create mode 100644 benchmarking/benchmark_batchnorm.py create mode 100644 kernels/batchnorm.py create mode 100644 test/test_batchnorm.py diff --git a/benchmarking/benchmark_batchnorm.py b/benchmarking/benchmark_batchnorm.py new file mode 100644 index 0000000..ba930f4 --- /dev/null +++ b/benchmarking/benchmark_batchnorm.py @@ -0,0 +1,172 @@ +import time +import torch +import triton +import kernels + +def create_test_data(batch_size: int, channels: int, device: str = 'cuda'): + input_tensor = torch.randn(batch_size, channels, device=device, dtype=torch.float32) + gamma = torch.ones(channels, device=device, dtype=torch.float32) + beta = torch.zeros(channels, device=device, dtype=torch.float32) + return input_tensor, gamma, beta + +def check_correctness(batch_size: int, channels: int, eps: float = 1e-5): + print(f"Testing correctness for batch_size={batch_size}, channels={channels}...") + input_tensor, gamma, beta = create_test_data(batch_size, channels) + + output_custom = kernels.batchnorm(input_tensor, gamma, beta, eps) + + bn = torch.nn.BatchNorm1d(channels, eps=eps, affine=True, track_running_stats=False, device=input_tensor.device) + bn.weight.data = gamma + bn.bias.data = beta + bn.train() + with torch.no_grad(): + output_pytorch = bn(input_tensor) + + max_diff = torch.max(torch.abs(output_custom - output_pytorch)).item() + rel_error = max_diff / (torch.max(torch.abs(output_pytorch)).item() + 1e-8) + is_correct = max_diff < 1e-3 or rel_error < 1e-3 + + print(f" Max difference: {max_diff:.2e}") + print(f" Relative error: {rel_error:.2e}") + print(f" Result: {'✓ PASS' if is_correct else '✗ FAIL'}") + return is_correct + +def benchmark_batch_norm(batch_size: int, channels: int, eps: float = 1e-5, iters: int = 100): + print(f"Benchmarking batch_size={batch_size}, channels={channels}...") + input_tensor, gamma, beta = create_test_data(batch_size, channels) + + for _ in range(30): + _ = kernels.batchnorm(input_tensor, gamma, beta, eps) + torch.cuda.synchronize() + + triton_times = [] + for _ in range(iters): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + _ = kernels.batchnorm(input_tensor, gamma, beta, eps) + end.record() + torch.cuda.synchronize() + triton_times.append(start.elapsed_time(end)) + triton_time = sum(triton_times) / len(triton_times) + + bn = torch.nn.BatchNorm1d(channels, eps=eps, affine=True, track_running_stats=False, device=input_tensor.device) + bn.weight.data = gamma + bn.bias.data = beta + bn.train() + for _ in range(30): + with torch.no_grad(): + _ = bn(input_tensor) + torch.cuda.synchronize() + + pytorch_times = [] + for _ in range(iters): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + with torch.no_grad(): + _ = bn(input_tensor) + end.record() + torch.cuda.synchronize() + pytorch_times.append(start.elapsed_time(end)) + pytorch_time = sum(pytorch_times) / len(pytorch_times) + + speedup = pytorch_time / triton_time if triton_time > 0 else 0 + bytes_accessed = 2 * batch_size * channels * 4 + bandwidth_gb = bytes_accessed / (triton_time * 1e6) + + print(f" Triton time: {triton_time:.3f} ms") + print(f" PyTorch time: {pytorch_time:.3f} ms") + print(f" Speedup: {speedup:.2f}x {'🚀' if speedup > 1.0 else ''}") + print(f" Bandwidth: {bandwidth_gb:.1f} GB/s") + + return {'triton_ms': triton_time, 'pytorch_ms': pytorch_time, + 'speedup': speedup, 'bandwidth_gb': bandwidth_gb} + +def run_comprehensive_batch_norm_tests(): + print("=" * 60) + print("OPTIMIZED BATCH NORMALIZATION (N x C)") + print("=" * 60) + + test_configs = [ + (32, 128), + (64, 256), + (128, 512), + (256, 1024), + (512, 2048), + (1024, 4096), + (2048, 8192), + (4096, 16384), + (8192, 32768), + (16384, 65536), + ] + + print("\n1. CORRECTNESS TESTS") + print("-" * 30) + all_correct = True + for batch_size, channels in test_configs[:5]: + is_correct = check_correctness(batch_size, channels) + all_correct = all_correct and is_correct + print() + print(f"Overall Correctness: {'✓ ALL PASS' if all_correct else '✗ SOME FAILED'}") + + print("\n2. PERFORMANCE BENCHMARKS") + print("-" * 30) + results = [] + for batch_size, channels in test_configs: + result = benchmark_batch_norm(batch_size, channels, iters=100) + results.append(((batch_size, channels), result)) + print() + + print("\n3. PERFORMANCE SUMMARY") + print("-" * 30) + print(f"{'Config':<15} {'Triton (ms)':<12} {'PyTorch (ms)':<12} {'Speedup':<10} {'BW (GB/s)':<10}") + print("-" * 70) + + winning_configs = 0 + for (batch_size, channels), result in results: + config_str = f"{batch_size}x{channels}" + speedup_str = f"{result['speedup']:.2f}x" + if result['speedup'] > 1.0: + speedup_str += " 🚀" + winning_configs += 1 + print(f"{config_str:<15} {result['triton_ms']:<12.3f} {result['pytorch_ms']:<12.3f} {speedup_str:<10} {result['bandwidth_gb']:<10.1f}") + + avg_speedup = sum(r[1]['speedup'] for r in results) / len(results) + geometric_mean = (torch.prod(torch.tensor([r[1]['speedup'] for r in results])) ** (1/len(results))).item() + + print(f"\nAverage Speedup: {avg_speedup:.2f}x") + print(f"Geometric Mean Speedup: {geometric_mean:.2f}x") + print(f"Winning configurations: {winning_configs}/{len(results)}") + + best = max(results, key=lambda x: x[1]['speedup']) + worst = min(results, key=lambda x: x[1]['speedup']) + print(f"\nBest speedup: {best[0][0]}x{best[0][1]} = {best[1]['speedup']:.2f}x") + print(f"Worst speedup: {worst[0][0]}x{worst[0][1]} = {worst[1]['speedup']:.2f}x") + + print("\n4. ANALYSIS") + print("-" * 30) + print("For large batches, PyTorch likely uses:") + print("- Optimized cuDNN kernels with tensor core utilization") + print("- Multi-stage reduction algorithms") + print("- Better memory access patterns for high bandwidth") + print("\nPossible improvements:") + print("- Tune fused kernel configs (bigger BLOCK_C on high-end GPUs)") + print("- Use mixed precision IO (fp16/bf16) with fp32 math to cut bandwidth") + print("- Consider Welford variance for extreme N if needed") + + return all_correct, results + +if __name__ == "__main__": + if not torch.cuda.is_available(): + print("CUDA is not available. This benchmark requires GPU.") + else: + print(f"Running on {torch.cuda.get_device_name()}") + print(f"PyTorch version: {torch.__version__}") + print(f"Triton version: {triton.__version__}\n") + + print("Quick Correctness Check:") + check_correctness(128, 256) + print() + + run_comprehensive_batch_norm_tests() \ No newline at end of file diff --git a/kernels/__init__.py b/kernels/__init__.py index dd492d3..32febe1 100644 --- a/kernels/__init__.py +++ b/kernels/__init__.py @@ -1,11 +1,14 @@ # from .conv import _conv, conv from . import blocksparse +from .batchnorm import _batchnorm, batchnorm from .cross_entropy import _cross_entropy, cross_entropy from .flash_attention import attention from .matmul import _matmul, get_higher_dtype, matmul __all__ = [ "blocksparse", + "_batchnorm", + "batchnorm", "_cross_entropy", "cross_entropy", "_matmul", diff --git a/kernels/batchnorm.py b/kernels/batchnorm.py new file mode 100644 index 0000000..a0ee89b --- /dev/null +++ b/kernels/batchnorm.py @@ -0,0 +1,170 @@ +import torch +import triton +import triton.language as tl + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_N': 256, 'BLOCK_C': 128}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_N': 512, 'BLOCK_C': 128}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_N': 256, 'BLOCK_C': 256}, num_warps=8, num_stages=2), + triton.Config({'BLOCK_N': 128, 'BLOCK_C': 256}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_N': 512, 'BLOCK_C': 256}, num_warps=8, num_stages=3), + ], + key=['N', 'C', 'stride_in_c'], +) +@triton.jit +def bn_reduce_kernel(x_ptr, sum_ptr, sumsq_ptr, N, C, stride_in_n, stride_in_c, + BLOCK_N: tl.constexpr, BLOCK_C: tl.constexpr): + pid_n = tl.program_id(0) + pid_c = tl.program_id(1) + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_c = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + + mask_n = offs_n < N + mask_c = offs_c < C + + x_ptrs = x_ptr + offs_n[:, None] * stride_in_n + offs_c[None, :] * stride_in_c + mask_2d = mask_n[:, None] & mask_c[None, :] + + x = tl.load(x_ptrs, mask=mask_2d, other=0).to(tl.float32) + psum = tl.sum(x, axis=0) + psumsq = tl.sum(x * x, axis=0) + + tl.atomic_add(sum_ptr + offs_c, psum, mask=mask_c) + tl.atomic_add(sumsq_ptr + offs_c, psumsq, mask=mask_c) + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_N': 256, 'BLOCK_C': 128}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_N': 512, 'BLOCK_C': 128}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_N': 256, 'BLOCK_C': 256}, num_warps=8, num_stages=2), + triton.Config({'BLOCK_N': 128, 'BLOCK_C': 256}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_N': 512, 'BLOCK_C': 256}, num_warps=8, num_stages=3), + ], + key=['N', 'C', 'stride_in_c'], +) +@triton.jit +def bn_norm_kernel(x_ptr, gamma_ptr, beta_ptr, y_ptr, sum_ptr, sumsq_ptr, + N, C, eps, stride_in_n, stride_in_c, stride_out_n, stride_out_c, + BLOCK_N: tl.constexpr, BLOCK_C: tl.constexpr): + pid_n = tl.program_id(0) + pid_c = tl.program_id(1) + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_c = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + mask_n = offs_n < N + mask_c = offs_c < C + + s = tl.load(sum_ptr + offs_c, mask=mask_c, other=0.).to(tl.float32) + ss = tl.load(sumsq_ptr + offs_c, mask=mask_c, other=0.).to(tl.float32) + Nf = tl.full((), N, tl.float32) + + mean = s / Nf + var = ss / Nf - mean * mean + inv_std = tl.rsqrt(var + eps) + + gamma = tl.load(gamma_ptr + offs_c, mask=mask_c, other=1.).to(tl.float32) + beta = tl.load(beta_ptr + offs_c, mask=mask_c, other=0.).to(tl.float32) + + scale = gamma * inv_std + shift = beta - mean * scale + + x_ptrs = x_ptr + offs_n[:, None] * stride_in_n + offs_c[None, :] * stride_in_c + y_ptrs = y_ptr + offs_n[:, None] * stride_out_n + offs_c[None, :] * stride_out_c + mask_2d = mask_n[:, None] & mask_c[None, :] + + x = tl.load(x_ptrs, mask=mask_2d, other=0.).to(tl.float32) + y = x * scale[None, :] + shift[None, :] + tl.store(y_ptrs, y, mask=mask_2d) + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_C': 64, 'BLOCK_N': 128}, num_warps=2, num_stages=3), + triton.Config({'BLOCK_C': 128, 'BLOCK_N': 128}, num_warps=4, num_stages=3), + triton.Config({'BLOCK_C': 128, 'BLOCK_N': 256}, num_warps=4, num_stages=4), + triton.Config({'BLOCK_C': 256, 'BLOCK_N': 128}, num_warps=8, num_stages=4), + ], + key=['N', 'C', 'stride_in_c'], +) +@triton.jit +def bn_fused_kernel(x_ptr, y_ptr, gamma_ptr, beta_ptr, N, C, eps, + stride_in_n, stride_in_c, stride_out_n, stride_out_c, + BLOCK_C: tl.constexpr, BLOCK_N: tl.constexpr): + pid_c = tl.program_id(0) + offs_c = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + mask_c = offs_c < C + + gamma = tl.load(gamma_ptr + offs_c, mask=mask_c, other=1.).to(tl.float32) + beta = tl.load(beta_ptr + offs_c, mask=mask_c, other=0.).to(tl.float32) + + s = tl.zeros([BLOCK_C], dtype=tl.float32) + ss = tl.zeros([BLOCK_C], dtype=tl.float32) + + tl.max_contiguous(offs_c, BLOCK_C) + + for n0 in range(0, N, BLOCK_N): + offs_n = n0 + tl.arange(0, BLOCK_N) + mask_n = offs_n < N + x_ptrs = x_ptr + offs_n[:, None] * stride_in_n + offs_c[None, :] * stride_in_c + m = mask_n[:, None] & mask_c[None, :] + x = tl.load(x_ptrs, mask=m, other=0.).to(tl.float32) + s += tl.sum(x, axis=0) + ss += tl.sum(x * x, axis=0) + + Nf = tl.full((), N, tl.float32) + mean = s / Nf + var = ss / Nf - mean * mean + inv_std = tl.rsqrt(var + eps) + scale = gamma * inv_std + shift = beta - mean * scale + + for n0 in range(0, N, BLOCK_N): + offs_n = n0 + tl.arange(0, BLOCK_N) + mask_n = offs_n < N + x_ptrs = x_ptr + offs_n[:, None] * stride_in_n + offs_c[None, :] * stride_in_c + y_ptrs = y_ptr + offs_n[:, None] * stride_out_n + offs_c[None, :] * stride_out_c + m = mask_n[:, None] & mask_c[None, :] + x = tl.load(x_ptrs, mask=m, other=0.).to(tl.float32) + y = x * scale[None, :] + shift[None, :] + tl.store(y_ptrs, y, mask=m) + +def _batchnorm(input: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor, + output: torch.Tensor, N: int, C: int, eps: float): + assert input.is_cuda and output.is_cuda and gamma.is_cuda and beta.is_cuda + assert input.dtype in (torch.float16, torch.bfloat16, torch.float32) + + problem_size = N * C + use_fused = problem_size < (1 << 24) + + if use_fused: + def grid(meta): + return (triton.cdiv(C, meta['BLOCK_C']),) + bn_fused_kernel[grid]( + input, output, gamma, beta, + N, C, eps, + input.stride(0), input.stride(1), + output.stride(0), output.stride(1), + ) + return + + sum_buf = torch.zeros(C, device=input.device, dtype=torch.float32) + sumsq_buf = torch.zeros(C, device=input.device, dtype=torch.float32) + + def grid(meta): + return (triton.cdiv(N, meta['BLOCK_N']), triton.cdiv(C, meta['BLOCK_C'])) + + bn_reduce_kernel[grid](input, sum_buf, sumsq_buf, + N, C, input.stride(0), input.stride(1)) + + bn_norm_kernel[grid](input, gamma, beta, output, + sum_buf, sumsq_buf, + N, C, eps, + input.stride(0), input.stride(1), + output.stride(0), output.stride(1)) + +def batchnorm(input: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor, eps: float = 1e-5): + N, C = input.shape + output = torch.empty_like(input) + _batchnorm(input, gamma, beta, output, N, C, eps) + return output \ No newline at end of file diff --git a/test/test_batchnorm.py b/test/test_batchnorm.py new file mode 100644 index 0000000..ed9cc66 --- /dev/null +++ b/test/test_batchnorm.py @@ -0,0 +1,86 @@ +import pytest +import torch +import kernels + +def create_test_data(batch_size: int, channels: int, device: str = 'cuda'): + input_tensor = torch.randn(batch_size, channels, device=device, dtype=torch.float32) + gamma = torch.ones(channels, device=device, dtype=torch.float32) + beta = torch.zeros(channels, device=device, dtype=torch.float32) + return input_tensor, gamma, beta + +@pytest.mark.parametrize("batch_size,channels", [ + (32, 128), + (64, 256), + (128, 512), + (256, 1024), + (512, 2048), +]) +def test_batchnorm_correctness(batch_size, channels): + eps = 1e-5 + input_tensor, gamma, beta = create_test_data(batch_size, channels) + + output_custom = kernels.batchnorm(input_tensor, gamma, beta, eps) + + bn = torch.nn.BatchNorm1d(channels, eps=eps, affine=True, track_running_stats=False, device=input_tensor.device) + bn.weight.data = gamma + bn.bias.data = beta + bn.train() + with torch.no_grad(): + output_pytorch = bn(input_tensor) + + max_diff = torch.max(torch.abs(output_custom - output_pytorch)).item() + rel_error = max_diff / (torch.max(torch.abs(output_pytorch)).item() + 1e-8) + + assert max_diff < 1e-3 or rel_error < 1e-3, f"Max diff: {max_diff:.2e}, Rel error: {rel_error:.2e}" + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_batchnorm_dtypes(dtype): + batch_size, channels = 128, 256 + eps = 1e-5 + + input_tensor = torch.randn(batch_size, channels, device='cuda', dtype=dtype) + gamma = torch.ones(channels, device='cuda', dtype=dtype) + beta = torch.zeros(channels, device='cuda', dtype=dtype) + + output = kernels.batchnorm(input_tensor, gamma, beta, eps) + assert output.dtype == dtype + assert output.shape == input_tensor.shape + +def test_batchnorm_backward_compatibility(): + batch_size, channels = 64, 128 + eps = 1e-5 + input_tensor, gamma, beta = create_test_data(batch_size, channels) + + N, C = input_tensor.shape + output = torch.empty_like(input_tensor) + kernels.batchnorm._batchnorm(input_tensor, gamma, beta, output, N, C, eps) + + bn = torch.nn.BatchNorm1d(C, eps=eps, affine=True, track_running_stats=False, device=input_tensor.device) + bn.weight.data = gamma + bn.bias.data = beta + bn.train() + with torch.no_grad(): + expected = bn(input_tensor) + + torch.testing.assert_close(output, expected, atol=1e-3, rtol=1e-3) + +@pytest.mark.parametrize("use_small_problem", [True, False]) +def test_fused_vs_two_pass(use_small_problem): + if use_small_problem: + batch_size, channels = 32, 64 + else: + batch_size, channels = 8192, 1024 + + eps = 1e-5 + input_tensor, gamma, beta = create_test_data(batch_size, channels) + + output = kernels.batchnorm(input_tensor, gamma, beta, eps) + + bn = torch.nn.BatchNorm1d(channels, eps=eps, affine=True, track_running_stats=False, device=input_tensor.device) + bn.weight.data = gamma + bn.bias.data = beta + bn.train() + with torch.no_grad(): + expected = bn(input_tensor) + + torch.testing.assert_close(output, expected, atol=1e-3, rtol=1e-3) \ No newline at end of file From d9c950d60fdf3b8bfb7df2587816e2f090ab672a Mon Sep 17 00:00:00 2001 From: 0xrushi <6279035+0xrushi@users.noreply.github.com> Date: Sun, 21 Sep 2025 19:47:04 -0400 Subject: [PATCH 2/5] bm --- benchmarking/benchmark_batchnorm.py | 172 ---------------------------- kernels/batchnorm.py | 2 + main.py | 32 ++++++ pyproject.toml | 26 +++++ test/test_batchnorm.py | 7 +- 5 files changed, 64 insertions(+), 175 deletions(-) delete mode 100644 benchmarking/benchmark_batchnorm.py diff --git a/benchmarking/benchmark_batchnorm.py b/benchmarking/benchmark_batchnorm.py deleted file mode 100644 index ba930f4..0000000 --- a/benchmarking/benchmark_batchnorm.py +++ /dev/null @@ -1,172 +0,0 @@ -import time -import torch -import triton -import kernels - -def create_test_data(batch_size: int, channels: int, device: str = 'cuda'): - input_tensor = torch.randn(batch_size, channels, device=device, dtype=torch.float32) - gamma = torch.ones(channels, device=device, dtype=torch.float32) - beta = torch.zeros(channels, device=device, dtype=torch.float32) - return input_tensor, gamma, beta - -def check_correctness(batch_size: int, channels: int, eps: float = 1e-5): - print(f"Testing correctness for batch_size={batch_size}, channels={channels}...") - input_tensor, gamma, beta = create_test_data(batch_size, channels) - - output_custom = kernels.batchnorm(input_tensor, gamma, beta, eps) - - bn = torch.nn.BatchNorm1d(channels, eps=eps, affine=True, track_running_stats=False, device=input_tensor.device) - bn.weight.data = gamma - bn.bias.data = beta - bn.train() - with torch.no_grad(): - output_pytorch = bn(input_tensor) - - max_diff = torch.max(torch.abs(output_custom - output_pytorch)).item() - rel_error = max_diff / (torch.max(torch.abs(output_pytorch)).item() + 1e-8) - is_correct = max_diff < 1e-3 or rel_error < 1e-3 - - print(f" Max difference: {max_diff:.2e}") - print(f" Relative error: {rel_error:.2e}") - print(f" Result: {'✓ PASS' if is_correct else '✗ FAIL'}") - return is_correct - -def benchmark_batch_norm(batch_size: int, channels: int, eps: float = 1e-5, iters: int = 100): - print(f"Benchmarking batch_size={batch_size}, channels={channels}...") - input_tensor, gamma, beta = create_test_data(batch_size, channels) - - for _ in range(30): - _ = kernels.batchnorm(input_tensor, gamma, beta, eps) - torch.cuda.synchronize() - - triton_times = [] - for _ in range(iters): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - _ = kernels.batchnorm(input_tensor, gamma, beta, eps) - end.record() - torch.cuda.synchronize() - triton_times.append(start.elapsed_time(end)) - triton_time = sum(triton_times) / len(triton_times) - - bn = torch.nn.BatchNorm1d(channels, eps=eps, affine=True, track_running_stats=False, device=input_tensor.device) - bn.weight.data = gamma - bn.bias.data = beta - bn.train() - for _ in range(30): - with torch.no_grad(): - _ = bn(input_tensor) - torch.cuda.synchronize() - - pytorch_times = [] - for _ in range(iters): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - with torch.no_grad(): - _ = bn(input_tensor) - end.record() - torch.cuda.synchronize() - pytorch_times.append(start.elapsed_time(end)) - pytorch_time = sum(pytorch_times) / len(pytorch_times) - - speedup = pytorch_time / triton_time if triton_time > 0 else 0 - bytes_accessed = 2 * batch_size * channels * 4 - bandwidth_gb = bytes_accessed / (triton_time * 1e6) - - print(f" Triton time: {triton_time:.3f} ms") - print(f" PyTorch time: {pytorch_time:.3f} ms") - print(f" Speedup: {speedup:.2f}x {'🚀' if speedup > 1.0 else ''}") - print(f" Bandwidth: {bandwidth_gb:.1f} GB/s") - - return {'triton_ms': triton_time, 'pytorch_ms': pytorch_time, - 'speedup': speedup, 'bandwidth_gb': bandwidth_gb} - -def run_comprehensive_batch_norm_tests(): - print("=" * 60) - print("OPTIMIZED BATCH NORMALIZATION (N x C)") - print("=" * 60) - - test_configs = [ - (32, 128), - (64, 256), - (128, 512), - (256, 1024), - (512, 2048), - (1024, 4096), - (2048, 8192), - (4096, 16384), - (8192, 32768), - (16384, 65536), - ] - - print("\n1. CORRECTNESS TESTS") - print("-" * 30) - all_correct = True - for batch_size, channels in test_configs[:5]: - is_correct = check_correctness(batch_size, channels) - all_correct = all_correct and is_correct - print() - print(f"Overall Correctness: {'✓ ALL PASS' if all_correct else '✗ SOME FAILED'}") - - print("\n2. PERFORMANCE BENCHMARKS") - print("-" * 30) - results = [] - for batch_size, channels in test_configs: - result = benchmark_batch_norm(batch_size, channels, iters=100) - results.append(((batch_size, channels), result)) - print() - - print("\n3. PERFORMANCE SUMMARY") - print("-" * 30) - print(f"{'Config':<15} {'Triton (ms)':<12} {'PyTorch (ms)':<12} {'Speedup':<10} {'BW (GB/s)':<10}") - print("-" * 70) - - winning_configs = 0 - for (batch_size, channels), result in results: - config_str = f"{batch_size}x{channels}" - speedup_str = f"{result['speedup']:.2f}x" - if result['speedup'] > 1.0: - speedup_str += " 🚀" - winning_configs += 1 - print(f"{config_str:<15} {result['triton_ms']:<12.3f} {result['pytorch_ms']:<12.3f} {speedup_str:<10} {result['bandwidth_gb']:<10.1f}") - - avg_speedup = sum(r[1]['speedup'] for r in results) / len(results) - geometric_mean = (torch.prod(torch.tensor([r[1]['speedup'] for r in results])) ** (1/len(results))).item() - - print(f"\nAverage Speedup: {avg_speedup:.2f}x") - print(f"Geometric Mean Speedup: {geometric_mean:.2f}x") - print(f"Winning configurations: {winning_configs}/{len(results)}") - - best = max(results, key=lambda x: x[1]['speedup']) - worst = min(results, key=lambda x: x[1]['speedup']) - print(f"\nBest speedup: {best[0][0]}x{best[0][1]} = {best[1]['speedup']:.2f}x") - print(f"Worst speedup: {worst[0][0]}x{worst[0][1]} = {worst[1]['speedup']:.2f}x") - - print("\n4. ANALYSIS") - print("-" * 30) - print("For large batches, PyTorch likely uses:") - print("- Optimized cuDNN kernels with tensor core utilization") - print("- Multi-stage reduction algorithms") - print("- Better memory access patterns for high bandwidth") - print("\nPossible improvements:") - print("- Tune fused kernel configs (bigger BLOCK_C on high-end GPUs)") - print("- Use mixed precision IO (fp16/bf16) with fp32 math to cut bandwidth") - print("- Consider Welford variance for extreme N if needed") - - return all_correct, results - -if __name__ == "__main__": - if not torch.cuda.is_available(): - print("CUDA is not available. This benchmark requires GPU.") - else: - print(f"Running on {torch.cuda.get_device_name()}") - print(f"PyTorch version: {torch.__version__}") - print(f"Triton version: {triton.__version__}\n") - - print("Quick Correctness Check:") - check_correctness(128, 256) - print() - - run_comprehensive_batch_norm_tests() \ No newline at end of file diff --git a/kernels/batchnorm.py b/kernels/batchnorm.py index a0ee89b..29fa6fc 100644 --- a/kernels/batchnorm.py +++ b/kernels/batchnorm.py @@ -1,6 +1,7 @@ import torch import triton import triton.language as tl +from benchmarking import Profiler @triton.autotune( configs=[ @@ -163,6 +164,7 @@ def grid(meta): input.stride(0), input.stride(1), output.stride(0), output.stride(1)) +@Profiler.profiling_decorator("batchnorm") def batchnorm(input: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor, eps: float = 1e-5): N, C = input.shape output = torch.empty_like(input) diff --git a/main.py b/main.py index eac1c68..b69f33b 100644 --- a/main.py +++ b/main.py @@ -6,6 +6,7 @@ from models.llama import llama_example_chat_completion, llama_example_text_completion from benchmarking import Profiler, compare_benchmarks +import kernels import pprint @@ -66,11 +67,42 @@ def main(operation: str, profile=False, benchmark=False, **kwargs): print("\n==================================\n") +def batchnorm_benchmark(batch_size=512, channels=2048, use_triton=True, suppress_prints=False): + """Benchmark batchnorm implementation""" + if not suppress_prints: + print(f"Running BatchNorm benchmark: {batch_size}x{channels}, use_triton={use_triton}") + + device = 'cuda' + eps = 1e-5 + + # Create test data + input_tensor = torch.randn(batch_size, channels, device=device, dtype=torch.float32) + gamma = torch.ones(channels, device=device, dtype=torch.float32) + beta = torch.zeros(channels, device=device, dtype=torch.float32) + + if use_triton: + # Use custom Triton implementation + output = kernels.batchnorm(input_tensor, gamma, beta, eps) + else: + # Use PyTorch baseline + bn = torch.nn.BatchNorm1d(channels, eps=eps, affine=True, track_running_stats=False, device=device) + bn.weight.data = gamma + bn.bias.data = beta + bn.train() + with torch.no_grad(): + output = bn(input_tensor) + + if not suppress_prints: + print(f"Output shape: {output.shape}, dtype: {output.dtype}") + + def runner(operation: str, kwargs): if operation == "llama_chat_completion": llama_example_chat_completion(**kwargs) elif operation == "llama_text_completion": llama_example_text_completion(**kwargs) + elif operation == "batchnorm_benchmark": + batchnorm_benchmark(**kwargs) else: raise ValueError(f"Unknown operation: {operation}") diff --git a/pyproject.toml b/pyproject.toml index e69de29..3d2128f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -0,0 +1,26 @@ +[build-system] +requires = ["setuptools>=64", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "triton-kernels" +version = "0.1.0" +description = "High-performance Triton kernels for machine learning" +readme = "README.md" +requires-python = ">=3.8" +dependencies = [ + "torch", + "triton", + "numpy", +] + +[project.optional-dependencies] +test = ["pytest", "pytest-xdist"] +dev = ["pytest", "pytest-xdist", "black", "ruff"] + +[tool.setuptools.packages.find] +include = ["kernels*"] +exclude = ["test*", "benchmarking*", "models*"] + +[tool.setuptools.package-dir] +"" = "." \ No newline at end of file diff --git a/test/test_batchnorm.py b/test/test_batchnorm.py index ed9cc66..f26620b 100644 --- a/test/test_batchnorm.py +++ b/test/test_batchnorm.py @@ -1,6 +1,7 @@ import pytest import torch import kernels +from kernels import batchnorm def create_test_data(batch_size: int, channels: int, device: str = 'cuda'): input_tensor = torch.randn(batch_size, channels, device=device, dtype=torch.float32) @@ -19,7 +20,7 @@ def test_batchnorm_correctness(batch_size, channels): eps = 1e-5 input_tensor, gamma, beta = create_test_data(batch_size, channels) - output_custom = kernels.batchnorm(input_tensor, gamma, beta, eps) + output_custom = batchnorm(input_tensor, gamma, beta, eps) bn = torch.nn.BatchNorm1d(channels, eps=eps, affine=True, track_running_stats=False, device=input_tensor.device) bn.weight.data = gamma @@ -42,7 +43,7 @@ def test_batchnorm_dtypes(dtype): gamma = torch.ones(channels, device='cuda', dtype=dtype) beta = torch.zeros(channels, device='cuda', dtype=dtype) - output = kernels.batchnorm(input_tensor, gamma, beta, eps) + output = batchnorm(input_tensor, gamma, beta, eps) assert output.dtype == dtype assert output.shape == input_tensor.shape @@ -74,7 +75,7 @@ def test_fused_vs_two_pass(use_small_problem): eps = 1e-5 input_tensor, gamma, beta = create_test_data(batch_size, channels) - output = kernels.batchnorm(input_tensor, gamma, beta, eps) + output = batchnorm(input_tensor, gamma, beta, eps) bn = torch.nn.BatchNorm1d(channels, eps=eps, affine=True, track_running_stats=False, device=input_tensor.device) bn.weight.data = gamma From 0060d5ba37a6ef7b2b3367bf94735578fea187a5 Mon Sep 17 00:00:00 2001 From: 0xrushi <6279035+0xrushi@users.noreply.github.com> Date: Sun, 21 Sep 2025 19:54:33 -0400 Subject: [PATCH 3/5] benchmark --- benchmarking/profiler.py | 2 +- main.py | 11 +++++++---- test/test_batchnorm.py | 4 ++-- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/benchmarking/profiler.py b/benchmarking/profiler.py index a01a981..af814d0 100644 --- a/benchmarking/profiler.py +++ b/benchmarking/profiler.py @@ -77,5 +77,5 @@ def get_benchmark_vals(cls): @classmethod def get_profiling_data(cls): if cls._instance and cls._instance.profiler: - return self.profiler.key_averages() + return cls._instance.profiler.key_averages() return None diff --git a/main.py b/main.py index b69f33b..0e76aae 100644 --- a/main.py +++ b/main.py @@ -32,8 +32,8 @@ def main(operation: str, profile=False, benchmark=False, **kwargs): p = Profiler(profile, benchmark) torch.cuda.empty_cache() runner(operation, kwargs) - benchmarks["triton"] = Profiler.get_benchmark_vals() - profiles["triton"] = Profiler.get_profiling_data() + benchmarks["non_triton"] = Profiler.get_benchmark_vals() + profiles["non_triton"] = Profiler.get_profiling_data() Profiler.reset() p = Profiler(profile, benchmark) @@ -43,8 +43,8 @@ def main(operation: str, profile=False, benchmark=False, **kwargs): p = Profiler(profile, benchmark) torch.cuda.empty_cache() runner(operation, kwargs) - benchmarks["non_triton"] = Profiler.get_benchmark_vals() - profiles["non_triton"] = Profiler.get_profiling_data() + benchmarks["triton"] = Profiler.get_benchmark_vals() + profiles["triton"] = Profiler.get_profiling_data() elif profile: runner(operation, kwargs) data = Profiler.get_profiling_data() @@ -67,6 +67,7 @@ def main(operation: str, profile=False, benchmark=False, **kwargs): print("\n==================================\n") +@Profiler.profiling_decorator("batchnorm_benchmark") def batchnorm_benchmark(batch_size=512, channels=2048, use_triton=True, suppress_prints=False): """Benchmark batchnorm implementation""" if not suppress_prints: @@ -95,6 +96,8 @@ def batchnorm_benchmark(batch_size=512, channels=2048, use_triton=True, suppress if not suppress_prints: print(f"Output shape: {output.shape}, dtype: {output.dtype}") + return output + def runner(operation: str, kwargs): if operation == "llama_chat_completion": diff --git a/test/test_batchnorm.py b/test/test_batchnorm.py index f26620b..1bf9aec 100644 --- a/test/test_batchnorm.py +++ b/test/test_batchnorm.py @@ -1,7 +1,7 @@ import pytest import torch import kernels -from kernels import batchnorm +from kernels import batchnorm, _batchnorm def create_test_data(batch_size: int, channels: int, device: str = 'cuda'): input_tensor = torch.randn(batch_size, channels, device=device, dtype=torch.float32) @@ -54,7 +54,7 @@ def test_batchnorm_backward_compatibility(): N, C = input_tensor.shape output = torch.empty_like(input_tensor) - kernels.batchnorm._batchnorm(input_tensor, gamma, beta, output, N, C, eps) + _batchnorm(input_tensor, gamma, beta, output, N, C, eps) bn = torch.nn.BatchNorm1d(C, eps=eps, affine=True, track_running_stats=False, device=input_tensor.device) bn.weight.data = gamma From 564b1be9ed8985c8153d0e2468d34d624bada891 Mon Sep 17 00:00:00 2001 From: 0xrushi <6279035+0xrushi@users.noreply.github.com> Date: Sun, 21 Sep 2025 19:54:56 -0400 Subject: [PATCH 4/5] gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index d300e33..e223f1c 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ __pycache__/ .pytest_cache **/.cache **/meta-llama/**/* +*.egg-info From 87c9bc6abd0c031da16444df7217e51fe70093f5 Mon Sep 17 00:00:00 2001 From: 0xrushi <6279035+0xrushi@users.noreply.github.com> Date: Sun, 21 Sep 2025 20:00:59 -0400 Subject: [PATCH 5/5] deps --- pyproject.toml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3d2128f..d6c2ba6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,10 @@ dependencies = [ "torch", "triton", "numpy", + "pandas>=2.0.3", + "fire>=0.7.1", + "fairscale>=0.4.13", + "tiktoken>=0.7.0", ] [project.optional-dependencies] @@ -23,4 +27,4 @@ include = ["kernels*"] exclude = ["test*", "benchmarking*", "models*"] [tool.setuptools.package-dir] -"" = "." \ No newline at end of file +"" = "."