diff --git a/.gitignore b/.gitignore index d300e33..e223f1c 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ __pycache__/ .pytest_cache **/.cache **/meta-llama/**/* +*.egg-info diff --git a/benchmarking/profiler.py b/benchmarking/profiler.py index a01a981..af814d0 100644 --- a/benchmarking/profiler.py +++ b/benchmarking/profiler.py @@ -77,5 +77,5 @@ def get_benchmark_vals(cls): @classmethod def get_profiling_data(cls): if cls._instance and cls._instance.profiler: - return self.profiler.key_averages() + return cls._instance.profiler.key_averages() return None diff --git a/kernels/__init__.py b/kernels/__init__.py index dd492d3..32febe1 100644 --- a/kernels/__init__.py +++ b/kernels/__init__.py @@ -1,11 +1,14 @@ # from .conv import _conv, conv from . import blocksparse +from .batchnorm import _batchnorm, batchnorm from .cross_entropy import _cross_entropy, cross_entropy from .flash_attention import attention from .matmul import _matmul, get_higher_dtype, matmul __all__ = [ "blocksparse", + "_batchnorm", + "batchnorm", "_cross_entropy", "cross_entropy", "_matmul", diff --git a/kernels/batchnorm.py b/kernels/batchnorm.py new file mode 100644 index 0000000..29fa6fc --- /dev/null +++ b/kernels/batchnorm.py @@ -0,0 +1,172 @@ +import torch +import triton +import triton.language as tl +from benchmarking import Profiler + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_N': 256, 'BLOCK_C': 128}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_N': 512, 'BLOCK_C': 128}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_N': 256, 'BLOCK_C': 256}, num_warps=8, num_stages=2), + triton.Config({'BLOCK_N': 128, 'BLOCK_C': 256}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_N': 512, 'BLOCK_C': 256}, num_warps=8, num_stages=3), + ], + key=['N', 'C', 'stride_in_c'], +) +@triton.jit +def bn_reduce_kernel(x_ptr, sum_ptr, sumsq_ptr, N, C, stride_in_n, stride_in_c, + BLOCK_N: tl.constexpr, BLOCK_C: tl.constexpr): + pid_n = tl.program_id(0) + pid_c = tl.program_id(1) + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_c = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + + mask_n = offs_n < N + mask_c = offs_c < C + + x_ptrs = x_ptr + offs_n[:, None] * stride_in_n + offs_c[None, :] * stride_in_c + mask_2d = mask_n[:, None] & mask_c[None, :] + + x = tl.load(x_ptrs, mask=mask_2d, other=0).to(tl.float32) + psum = tl.sum(x, axis=0) + psumsq = tl.sum(x * x, axis=0) + + tl.atomic_add(sum_ptr + offs_c, psum, mask=mask_c) + tl.atomic_add(sumsq_ptr + offs_c, psumsq, mask=mask_c) + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_N': 256, 'BLOCK_C': 128}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_N': 512, 'BLOCK_C': 128}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_N': 256, 'BLOCK_C': 256}, num_warps=8, num_stages=2), + triton.Config({'BLOCK_N': 128, 'BLOCK_C': 256}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_N': 512, 'BLOCK_C': 256}, num_warps=8, num_stages=3), + ], + key=['N', 'C', 'stride_in_c'], +) +@triton.jit +def bn_norm_kernel(x_ptr, gamma_ptr, beta_ptr, y_ptr, sum_ptr, sumsq_ptr, + N, C, eps, stride_in_n, stride_in_c, stride_out_n, stride_out_c, + BLOCK_N: tl.constexpr, BLOCK_C: tl.constexpr): + pid_n = tl.program_id(0) + pid_c = tl.program_id(1) + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_c = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + mask_n = offs_n < N + mask_c = offs_c < C + + s = tl.load(sum_ptr + offs_c, mask=mask_c, other=0.).to(tl.float32) + ss = tl.load(sumsq_ptr + offs_c, mask=mask_c, other=0.).to(tl.float32) + Nf = tl.full((), N, tl.float32) + + mean = s / Nf + var = ss / Nf - mean * mean + inv_std = tl.rsqrt(var + eps) + + gamma = tl.load(gamma_ptr + offs_c, mask=mask_c, other=1.).to(tl.float32) + beta = tl.load(beta_ptr + offs_c, mask=mask_c, other=0.).to(tl.float32) + + scale = gamma * inv_std + shift = beta - mean * scale + + x_ptrs = x_ptr + offs_n[:, None] * stride_in_n + offs_c[None, :] * stride_in_c + y_ptrs = y_ptr + offs_n[:, None] * stride_out_n + offs_c[None, :] * stride_out_c + mask_2d = mask_n[:, None] & mask_c[None, :] + + x = tl.load(x_ptrs, mask=mask_2d, other=0.).to(tl.float32) + y = x * scale[None, :] + shift[None, :] + tl.store(y_ptrs, y, mask=mask_2d) + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_C': 64, 'BLOCK_N': 128}, num_warps=2, num_stages=3), + triton.Config({'BLOCK_C': 128, 'BLOCK_N': 128}, num_warps=4, num_stages=3), + triton.Config({'BLOCK_C': 128, 'BLOCK_N': 256}, num_warps=4, num_stages=4), + triton.Config({'BLOCK_C': 256, 'BLOCK_N': 128}, num_warps=8, num_stages=4), + ], + key=['N', 'C', 'stride_in_c'], +) +@triton.jit +def bn_fused_kernel(x_ptr, y_ptr, gamma_ptr, beta_ptr, N, C, eps, + stride_in_n, stride_in_c, stride_out_n, stride_out_c, + BLOCK_C: tl.constexpr, BLOCK_N: tl.constexpr): + pid_c = tl.program_id(0) + offs_c = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + mask_c = offs_c < C + + gamma = tl.load(gamma_ptr + offs_c, mask=mask_c, other=1.).to(tl.float32) + beta = tl.load(beta_ptr + offs_c, mask=mask_c, other=0.).to(tl.float32) + + s = tl.zeros([BLOCK_C], dtype=tl.float32) + ss = tl.zeros([BLOCK_C], dtype=tl.float32) + + tl.max_contiguous(offs_c, BLOCK_C) + + for n0 in range(0, N, BLOCK_N): + offs_n = n0 + tl.arange(0, BLOCK_N) + mask_n = offs_n < N + x_ptrs = x_ptr + offs_n[:, None] * stride_in_n + offs_c[None, :] * stride_in_c + m = mask_n[:, None] & mask_c[None, :] + x = tl.load(x_ptrs, mask=m, other=0.).to(tl.float32) + s += tl.sum(x, axis=0) + ss += tl.sum(x * x, axis=0) + + Nf = tl.full((), N, tl.float32) + mean = s / Nf + var = ss / Nf - mean * mean + inv_std = tl.rsqrt(var + eps) + scale = gamma * inv_std + shift = beta - mean * scale + + for n0 in range(0, N, BLOCK_N): + offs_n = n0 + tl.arange(0, BLOCK_N) + mask_n = offs_n < N + x_ptrs = x_ptr + offs_n[:, None] * stride_in_n + offs_c[None, :] * stride_in_c + y_ptrs = y_ptr + offs_n[:, None] * stride_out_n + offs_c[None, :] * stride_out_c + m = mask_n[:, None] & mask_c[None, :] + x = tl.load(x_ptrs, mask=m, other=0.).to(tl.float32) + y = x * scale[None, :] + shift[None, :] + tl.store(y_ptrs, y, mask=m) + +def _batchnorm(input: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor, + output: torch.Tensor, N: int, C: int, eps: float): + assert input.is_cuda and output.is_cuda and gamma.is_cuda and beta.is_cuda + assert input.dtype in (torch.float16, torch.bfloat16, torch.float32) + + problem_size = N * C + use_fused = problem_size < (1 << 24) + + if use_fused: + def grid(meta): + return (triton.cdiv(C, meta['BLOCK_C']),) + bn_fused_kernel[grid]( + input, output, gamma, beta, + N, C, eps, + input.stride(0), input.stride(1), + output.stride(0), output.stride(1), + ) + return + + sum_buf = torch.zeros(C, device=input.device, dtype=torch.float32) + sumsq_buf = torch.zeros(C, device=input.device, dtype=torch.float32) + + def grid(meta): + return (triton.cdiv(N, meta['BLOCK_N']), triton.cdiv(C, meta['BLOCK_C'])) + + bn_reduce_kernel[grid](input, sum_buf, sumsq_buf, + N, C, input.stride(0), input.stride(1)) + + bn_norm_kernel[grid](input, gamma, beta, output, + sum_buf, sumsq_buf, + N, C, eps, + input.stride(0), input.stride(1), + output.stride(0), output.stride(1)) + +@Profiler.profiling_decorator("batchnorm") +def batchnorm(input: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor, eps: float = 1e-5): + N, C = input.shape + output = torch.empty_like(input) + _batchnorm(input, gamma, beta, output, N, C, eps) + return output \ No newline at end of file diff --git a/main.py b/main.py index eac1c68..0e76aae 100644 --- a/main.py +++ b/main.py @@ -6,6 +6,7 @@ from models.llama import llama_example_chat_completion, llama_example_text_completion from benchmarking import Profiler, compare_benchmarks +import kernels import pprint @@ -31,8 +32,8 @@ def main(operation: str, profile=False, benchmark=False, **kwargs): p = Profiler(profile, benchmark) torch.cuda.empty_cache() runner(operation, kwargs) - benchmarks["triton"] = Profiler.get_benchmark_vals() - profiles["triton"] = Profiler.get_profiling_data() + benchmarks["non_triton"] = Profiler.get_benchmark_vals() + profiles["non_triton"] = Profiler.get_profiling_data() Profiler.reset() p = Profiler(profile, benchmark) @@ -42,8 +43,8 @@ def main(operation: str, profile=False, benchmark=False, **kwargs): p = Profiler(profile, benchmark) torch.cuda.empty_cache() runner(operation, kwargs) - benchmarks["non_triton"] = Profiler.get_benchmark_vals() - profiles["non_triton"] = Profiler.get_profiling_data() + benchmarks["triton"] = Profiler.get_benchmark_vals() + profiles["triton"] = Profiler.get_profiling_data() elif profile: runner(operation, kwargs) data = Profiler.get_profiling_data() @@ -66,11 +67,45 @@ def main(operation: str, profile=False, benchmark=False, **kwargs): print("\n==================================\n") +@Profiler.profiling_decorator("batchnorm_benchmark") +def batchnorm_benchmark(batch_size=512, channels=2048, use_triton=True, suppress_prints=False): + """Benchmark batchnorm implementation""" + if not suppress_prints: + print(f"Running BatchNorm benchmark: {batch_size}x{channels}, use_triton={use_triton}") + + device = 'cuda' + eps = 1e-5 + + # Create test data + input_tensor = torch.randn(batch_size, channels, device=device, dtype=torch.float32) + gamma = torch.ones(channels, device=device, dtype=torch.float32) + beta = torch.zeros(channels, device=device, dtype=torch.float32) + + if use_triton: + # Use custom Triton implementation + output = kernels.batchnorm(input_tensor, gamma, beta, eps) + else: + # Use PyTorch baseline + bn = torch.nn.BatchNorm1d(channels, eps=eps, affine=True, track_running_stats=False, device=device) + bn.weight.data = gamma + bn.bias.data = beta + bn.train() + with torch.no_grad(): + output = bn(input_tensor) + + if not suppress_prints: + print(f"Output shape: {output.shape}, dtype: {output.dtype}") + + return output + + def runner(operation: str, kwargs): if operation == "llama_chat_completion": llama_example_chat_completion(**kwargs) elif operation == "llama_text_completion": llama_example_text_completion(**kwargs) + elif operation == "batchnorm_benchmark": + batchnorm_benchmark(**kwargs) else: raise ValueError(f"Unknown operation: {operation}") diff --git a/pyproject.toml b/pyproject.toml index e69de29..d6c2ba6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -0,0 +1,30 @@ +[build-system] +requires = ["setuptools>=64", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "triton-kernels" +version = "0.1.0" +description = "High-performance Triton kernels for machine learning" +readme = "README.md" +requires-python = ">=3.8" +dependencies = [ + "torch", + "triton", + "numpy", + "pandas>=2.0.3", + "fire>=0.7.1", + "fairscale>=0.4.13", + "tiktoken>=0.7.0", +] + +[project.optional-dependencies] +test = ["pytest", "pytest-xdist"] +dev = ["pytest", "pytest-xdist", "black", "ruff"] + +[tool.setuptools.packages.find] +include = ["kernels*"] +exclude = ["test*", "benchmarking*", "models*"] + +[tool.setuptools.package-dir] +"" = "." diff --git a/test/test_batchnorm.py b/test/test_batchnorm.py new file mode 100644 index 0000000..1bf9aec --- /dev/null +++ b/test/test_batchnorm.py @@ -0,0 +1,87 @@ +import pytest +import torch +import kernels +from kernels import batchnorm, _batchnorm + +def create_test_data(batch_size: int, channels: int, device: str = 'cuda'): + input_tensor = torch.randn(batch_size, channels, device=device, dtype=torch.float32) + gamma = torch.ones(channels, device=device, dtype=torch.float32) + beta = torch.zeros(channels, device=device, dtype=torch.float32) + return input_tensor, gamma, beta + +@pytest.mark.parametrize("batch_size,channels", [ + (32, 128), + (64, 256), + (128, 512), + (256, 1024), + (512, 2048), +]) +def test_batchnorm_correctness(batch_size, channels): + eps = 1e-5 + input_tensor, gamma, beta = create_test_data(batch_size, channels) + + output_custom = batchnorm(input_tensor, gamma, beta, eps) + + bn = torch.nn.BatchNorm1d(channels, eps=eps, affine=True, track_running_stats=False, device=input_tensor.device) + bn.weight.data = gamma + bn.bias.data = beta + bn.train() + with torch.no_grad(): + output_pytorch = bn(input_tensor) + + max_diff = torch.max(torch.abs(output_custom - output_pytorch)).item() + rel_error = max_diff / (torch.max(torch.abs(output_pytorch)).item() + 1e-8) + + assert max_diff < 1e-3 or rel_error < 1e-3, f"Max diff: {max_diff:.2e}, Rel error: {rel_error:.2e}" + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_batchnorm_dtypes(dtype): + batch_size, channels = 128, 256 + eps = 1e-5 + + input_tensor = torch.randn(batch_size, channels, device='cuda', dtype=dtype) + gamma = torch.ones(channels, device='cuda', dtype=dtype) + beta = torch.zeros(channels, device='cuda', dtype=dtype) + + output = batchnorm(input_tensor, gamma, beta, eps) + assert output.dtype == dtype + assert output.shape == input_tensor.shape + +def test_batchnorm_backward_compatibility(): + batch_size, channels = 64, 128 + eps = 1e-5 + input_tensor, gamma, beta = create_test_data(batch_size, channels) + + N, C = input_tensor.shape + output = torch.empty_like(input_tensor) + _batchnorm(input_tensor, gamma, beta, output, N, C, eps) + + bn = torch.nn.BatchNorm1d(C, eps=eps, affine=True, track_running_stats=False, device=input_tensor.device) + bn.weight.data = gamma + bn.bias.data = beta + bn.train() + with torch.no_grad(): + expected = bn(input_tensor) + + torch.testing.assert_close(output, expected, atol=1e-3, rtol=1e-3) + +@pytest.mark.parametrize("use_small_problem", [True, False]) +def test_fused_vs_two_pass(use_small_problem): + if use_small_problem: + batch_size, channels = 32, 64 + else: + batch_size, channels = 8192, 1024 + + eps = 1e-5 + input_tensor, gamma, beta = create_test_data(batch_size, channels) + + output = batchnorm(input_tensor, gamma, beta, eps) + + bn = torch.nn.BatchNorm1d(channels, eps=eps, affine=True, track_running_stats=False, device=input_tensor.device) + bn.weight.data = gamma + bn.bias.data = beta + bn.train() + with torch.no_grad(): + expected = bn(input_tensor) + + torch.testing.assert_close(output, expected, atol=1e-3, rtol=1e-3) \ No newline at end of file