From 68f4c5d18c23d7da52bf9b73425c2488e34b54f7 Mon Sep 17 00:00:00 2001 From: Ying Hu Date: Wed, 30 Jul 2025 02:12:51 +0000 Subject: [PATCH 1/4] feat: bitnet gemm kernel --- gpu/bitnet_kernels/bitgemm.cu | 236 ++++++++++++++++++++++++++++++++++ gpu/model.py | 53 +++++++- gpu/test_gemm.py | 59 +++++++++ 3 files changed, 344 insertions(+), 4 deletions(-) create mode 100644 gpu/bitnet_kernels/bitgemm.cu create mode 100644 gpu/test_gemm.py diff --git a/gpu/bitnet_kernels/bitgemm.cu b/gpu/bitnet_kernels/bitgemm.cu new file mode 100644 index 00000000..423c15c0 --- /dev/null +++ b/gpu/bitnet_kernels/bitgemm.cu @@ -0,0 +1,236 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +template +__device__ void decode_i2s_to_i8s(T1 *_i2s, T2 *_i8s, const int N = 16) { + // convert 8 int2b_t to 8 int8b_t -> 2 int32 + uint *i8s = reinterpret_cast(_i8s); + + // i2s = {e7,e6,e5,e4,e3,e2,e1,e0} + // also require interleave {e7,e3,e6,e2,e5,e1,e4,e0} + uint const i2s = *_i2s; + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 + static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3 + static constexpr uint I4s_TO_I8s_MAGIC_NUM = 0x00000000; // 1024 + +#pragma unroll + for (int i = 0; i < (N / 4); i++) { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(i8s[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), + "n"(I4s_TO_I8s_MAGIC_NUM), "n"(immLut)); + i8s[i] = __vsubss4(i8s[i], 0x02020202); + } +} + + +template +__global__ void int8_int2_gemm_tensor_core( + const int8_t *__restrict__ A, // M x K matrix, row-major + const int32_t *__restrict__ B_compressed, // Compressed int2 data for N x K matrix, column-major + int32_t *__restrict__ C, // M x N output matrix, row-major + int M) +{ + // Define WMMA dimensions - all constant + constexpr int WMMA_M = 16; + constexpr int WMMA_N = 16; + constexpr int WMMA_K = 16; + + // Define block tile dimensions - all constant + constexpr int BLOCK_SIZE_M = 64; // Multiple of WMMA_M + constexpr int BLOCK_SIZE_N = 64; // Multiple of WMMA_N + constexpr int BLOCK_SIZE_K = 32; // K dimension as requested + + // Calculate thread block position + const int blockM = blockIdx.y * BLOCK_SIZE_M; + const int blockN = blockIdx.x * BLOCK_SIZE_N; + + // Calculate thread ID and warp IDs + const int warpM = threadIdx.y; // 0-1 (2 warps in M dimension) + const int warpN = threadIdx.z; // 0-1 (2 warps in N dimension) + const int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y; + + // Add padding to shared memory to avoid bank conflicts + constexpr int PAD_A = 16; // Padding for A matrix + constexpr int PAD_B = 16; // Padding for B matrix + + // Allocate shared memory for A and B matrices with padding + __shared__ int8_t shared_A[BLOCK_SIZE_M][BLOCK_SIZE_K + PAD_A]; + __shared__ int8_t shared_B[BLOCK_SIZE_N][BLOCK_SIZE_K + PAD_B]; + + // Define fragments for all tiles this warp will handle - static allocation + nvcuda::wmma::fragment c_frags[2][2]; + nvcuda::wmma::fragment a_frag; + nvcuda::wmma::fragment b_frag; + + // Initialize all accumulator fragments to zero (unrolled) + #pragma unroll + for (int m_iter = 0; m_iter < 2; m_iter++) { + #pragma unroll + for (int n_iter = 0; n_iter < 2; n_iter++) { + nvcuda::wmma::fill_fragment(c_frags[m_iter][n_iter], 0); + } + } + + // Only check M bounds at the beginning + const bool m_valid = blockM < M; + + // Loop over K dimension in chunks of BLOCK_SIZE_K + #pragma unroll 4 // Partial unroll of K-dimension loop + for (int k_block = 0; k_block < K; k_block += BLOCK_SIZE_K) { + // Clear shared memory first + __syncthreads(); + + // Load A matrix tiles into shared memory using vectorized loads + // Each thread handles multiple elements based on its ID + for (int load_idx = tid; load_idx < (BLOCK_SIZE_M * BLOCK_SIZE_K / 16); load_idx += blockDim.x * blockDim.y * blockDim.z) { + int local_m = (load_idx * 16) / BLOCK_SIZE_K; + int local_k = (load_idx * 16) % BLOCK_SIZE_K; + + int global_m = blockM + local_m; + int global_k = k_block + local_k; + + // Use vector loads for A - 16 bytes at a time (int4 = 4 integers = 16 bytes) + if (m_valid && global_m < M) { + // Vector load from A to shared memory + *((int4*)&shared_A[local_m][local_k]) = *((int4*)&A[global_m * K + global_k]); + } else { + // Zero out if M is out of bounds + *((int4*)&shared_A[local_m][local_k]) = {0}; + } + } + + // Load B matrix tiles into shared memory (always in bounds for N and K) + // Calculate which 16-element chunk this thread is responsible for + int chunk_n = (tid * 16 / BLOCK_SIZE_K); + int chunk_k = (tid * 16) % BLOCK_SIZE_K; + + if (chunk_n < BLOCK_SIZE_N) { + int global_n = blockN + chunk_n; + int global_k = k_block + chunk_k; + + // Calculate which compressed block this belongs to + int n_block = global_n / 16; + int k_block_32 = global_k / 32; + int k_offset_in_block = chunk_k % 32; + + // Get the specific compressed tile within the 16x32 block + int in_block_n = chunk_n % 16; + int compressed_block_idx = n_block * (K / 32) + k_block_32; + + // Calculate which tile within the compressed block + int tile_idx; + tile_idx = in_block_n / 8 * 16 + in_block_n % 8 + (k_offset_in_block / 16) * 8; + + // Extract and decompress the int2 values + int32_t compressed = B_compressed[compressed_block_idx * 32 + tile_idx]; + int8_t decompressed[16]; + decode_i2s_to_i8s(&compressed, decompressed); + + // Vector store to shared memory + *((int4*)&shared_B[chunk_n][chunk_k]) = *((int4*)decompressed); + } + + // Make sure all threads have finished loading into shared memory + __syncthreads(); + + // Process the 2x2 WMMA tiles for this K block + #pragma unroll + for (int m_iter = 0; m_iter < 2; m_iter++) { + #pragma unroll + for (int n_iter = 0; n_iter < 2; n_iter++) { + // Calculate the starting positions for this WMMA tile + #pragma unroll + for (int wmma_k = 0; wmma_k < BLOCK_SIZE_K; wmma_k += WMMA_K) { + // Fully unroll the m and n iterations + const int tile_m = (warpM * 2 + m_iter) * WMMA_M; + const int tile_n = (warpN * 2 + n_iter) * WMMA_N; + + // Load matrix A fragment from shared memory with padding + nvcuda::wmma::load_matrix_sync( + a_frag, &shared_A[tile_m][wmma_k], BLOCK_SIZE_K + PAD_A); + + // Load matrix B fragment from shared memory with padding + nvcuda::wmma::load_matrix_sync( + b_frag, &shared_B[tile_n][wmma_k], BLOCK_SIZE_K + PAD_B); + + // Perform matrix multiplication + nvcuda::wmma::mma_sync(c_frags[m_iter][n_iter], a_frag, b_frag, c_frags[m_iter][n_iter]); + } + } + } + } + + // Store results back to global memory - only check M bounds + #pragma unroll + for (int m_iter = 0; m_iter < 2; m_iter++) { + const int tile_m = (warpM * 2 + m_iter) * WMMA_M; + const int global_tile_m = blockM + tile_m; + + if (m_valid && global_tile_m < M) { + #pragma unroll + for (int n_iter = 0; n_iter < 2; n_iter++) { + const int tile_n = (warpN * 2 + n_iter) * WMMA_N; + const int global_tile_n = blockN + tile_n; + + // No need to check N bounds as it's always aligned + nvcuda::wmma::store_matrix_sync( + &C[global_tile_m * N + global_tile_n], + c_frags[m_iter][n_iter], N, nvcuda::wmma::mem_row_major); + } + } + } +} + +extern "C" void bitlinear_int8xint2(int8_t *input0, int8_t *input1, + int32_t *output0, int M, int N, int K, + cudaStream_t stream = 0) { + if (N == 3840 && K == 2560) { + int8_int2_gemm_tensor_core<3840, 2560> + <<>>( + input0, (int32_t *)input1, (int32_t *)output0, M); + } else if (N == 2560 && K == 2560) { + int8_int2_gemm_tensor_core<2560, 2560> + <<>>( + input0, (int32_t *)input1, (int32_t *)output0, M); + } else if (N == 13824 && K == 2560) { + int8_int2_gemm_tensor_core<13824, 2560> + <<>>( + input0, (int32_t *)input1, (int32_t *)output0, M); + } else if (N == 2560 && K == 6912) { + int8_int2_gemm_tensor_core<2560, 6912> + <<>>( + input0, (int32_t *)input1, (int32_t *)output0, M); + } else { + std::cerr << "Error: Unsupported matrix dimensions for bitlinear_int8xint2. " + << "Required kernel: M=" << M << ", N=" << N << ", K=" << K << std::endl; + std::cerr << "Supported configurations:" << std::endl; + std::cerr << " - N=3840, K=2560" << std::endl; + std::cerr << " - N=2560, K=2560" << std::endl; + std::cerr << " - N=13824, K=2560" << std::endl; + std::cerr << " - N=2560, K=6912" << std::endl; + throw std::runtime_error("Unsupported matrix dimensions for bitlinear_int8xint2"); + } + + // Check for CUDA launch errors + cudaError_t launch_error = cudaGetLastError(); + if (launch_error != cudaSuccess) { + std::cerr << "CUDA kernel launch failed: " << cudaGetErrorString(launch_error) << std::endl; + throw std::runtime_error("CUDA kernel launch failed"); + } + + // Synchronize and check for execution errors + cudaError_t sync_error = cudaStreamSynchronize(stream); + if (sync_error != cudaSuccess) { + std::cerr << "CUDA kernel execution failed: " << cudaGetErrorString(sync_error) << std::endl; + throw std::runtime_error("CUDA kernel execution failed"); + } +} \ No newline at end of file diff --git a/gpu/model.py b/gpu/model.py index cd5abec0..9a37062d 100755 --- a/gpu/model.py +++ b/gpu/model.py @@ -17,8 +17,11 @@ import ctypes bitnet_lib = ctypes.CDLL('bitnet_kernels/libbitnet.so') +gemm_lib = ctypes.CDLL('bitnet_kernels/libgemm.so') -def bitnet_int8xint2_linear(input0, input1, s, ws): +import numpy as np + +def bitnet_int8xint2_linear_gemv(input0, input1, s, ws): out_shape = list(input0.shape) out_shape[-1] = input1.shape[0] @@ -36,6 +39,42 @@ def bitnet_int8xint2_linear(input0, input1, s, ws): return ret +def bitnet_int8xint2_linear_gemm(input0, input1, s, ws): + out_shape = list(input0.shape) + out_shape[-1] = input1.shape[0] + + stream = torch.cuda.current_stream() + + M = input0.shape[0] + if len(out_shape) == 3: + M *= input0.shape[1] + N = input1.shape[0] + K = input1.shape[1] * 4 + + ret = torch.zeros(*out_shape, dtype=torch.int32, device=input0.device) + + gemm_lib.bitlinear_int8xint2(*[ctypes.c_void_p(input0.data_ptr()), ctypes.c_void_p(input1.data_ptr()), ctypes.c_void_p(ret.data_ptr()), ctypes.c_int(M), ctypes.c_int(N), ctypes.c_int(K), ctypes.c_void_p(stream.cuda_stream)]) + ret = ret.to(torch.bfloat16) + ret = ret / s + if N == 3840 and K == 2560: + #split last dim to 6 parts evenly + ret = ret.reshape(*ret.shape[:-1], 6, -1) + # devide each part by first 6 coresponding weight scale + ret = ret * ws[:6].reshape(1, 6, 1) + elif (N == 2560 and K == 2560): + # 1 part + ret = ret* ws[:1].reshape(1, 1, 1, 1) + elif (N == 13824 and K == 2560): + # 2 parts + ret = ret.reshape(*ret.shape[:-1], 2, -1) + # devide each part by first 2 coresponding weight scale + ret = ret * ws[:2].reshape(1, 1, 2, 1) + elif (N == 2560 and K == 6912): + # 1 part + ret = ret * ws[:1].reshape(1, 1, 1, 1) + + return ret.reshape(*out_shape) + @dataclass class ModelArgs: dim: int = 2560 @@ -63,16 +102,22 @@ def __init__(self, in_features: int, out_features: int, bias: bool = False): self.out_features = out_features self.weight = torch.nn.Parameter(torch.zeros(out_features, in_features//4, dtype=torch.int8), requires_grad=False) - self.weight_scale = torch.nn.Parameter(torch.zeros(4, dtype=torch.bfloat16), requires_grad=False) + self.weight_scale = torch.nn.Parameter(torch.zeros(6, dtype=torch.bfloat16), requires_grad=False) @torch.compile def quant_input(self, input): s = 127 / input.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5) return (input * s).round().clamp(-128, 127).to(torch.int8), s - def forward(self, input): + def forward(self, input, weight_int8=None): input, s = self.quant_input(input) - return bitnet_int8xint2_linear(input, self.weight, s, self.weight_scale) + weight_np = weight_int8.cpu().to(torch.int32).T.numpy() + input_np = input.cpu().to(torch.int32).numpy() + out_np = np.matmul(input_np, weight_np) + if input.shape[0] == 1: + return bitnet_int8xint2_linear_gemv(input, self.weight, s, self.weight_scale) + else: + return bitnet_int8xint2_linear_gemm(input, self.weight, s, self.weight_scale) class BitLinear(nn.Linear): @torch.compile diff --git a/gpu/test_gemm.py b/gpu/test_gemm.py new file mode 100644 index 00000000..903e15c3 --- /dev/null +++ b/gpu/test_gemm.py @@ -0,0 +1,59 @@ +import torch +from model import BitLinear, BitLinearKernel +from pack_weight import convert_weight_int8_to_int2 + +from torch import nn + + +def quant_weight_int8(weight): + s = 1.0 / weight.abs().mean().clamp_(min=1e-5) + new_weight = (weight * s).round().clamp(-1, 1).to(torch.int8) + new_scale = (1.0 / s).to(torch.bfloat16) + return new_weight, new_scale.reshape(1).repeat(6) + +def quant_weight(weight): + s = 1.0 / weight.abs().mean().clamp_(min=1e-5) + new_weight = (weight * s).round().clamp(-1, 1) / s + + return new_weight + +def convert_int8_to_int2(weight): + return convert_weight_int8_to_int2(weight) + + +def test_BitLinear(): + in_dim = 2560 # 64 + out_dim = 3840 # 32 + default_dtype = torch.bfloat16 + x = torch.randn(128, in_dim, dtype=default_dtype).cuda() # (batch, in_features) + + layer0 = BitLinear(in_dim, out_dim).cuda() + layer0 = layer0.to(default_dtype) + nn.init.kaiming_uniform_(layer0.weight, nonlinearity='relu') + with torch.no_grad(): + layer0.weight.copy_( quant_weight(layer0.weight)) + nn.init.zeros_(layer0.bias) + out0 = layer0(x) + + assert not torch.isnan(out0).any() + assert layer0.weight.dtype == default_dtype + # print(layer0.weight.dtype, layer0.weight.shape) + + layer1 = BitLinearKernel(in_dim, out_dim).cuda() + weight_int8, scale = quant_weight_int8(layer0.weight) + weight = convert_int8_to_int2(weight_int8) + + + with torch.no_grad(): + layer1.weight.copy_(weight) + layer1.weight_scale.copy_(scale) + print(layer1.weight, layer1.weight_scale) + out1 = layer1(x, weight_int8) + assert out1.dtype == default_dtype + + print(f"Non-kernel output: {out0}, Kernel output: {out1}") + assert torch.equal(out0, out1), "Outputs from non-kernel and kernel paths should match" + + +if __name__ == "__main__": + test_BitLinear() From aff3d8189c4dad8bb8640114ff965b43ce6bde16 Mon Sep 17 00:00:00 2001 From: Ying Hu Date: Wed, 30 Jul 2025 02:43:04 +0000 Subject: [PATCH 2/4] feat: in2 prefill --- gpu/bitnet_kernels/bitgemm.cu | 7 ------- gpu/bitnet_kernels/bitnet_kernels.cu | 2 +- gpu/bitnet_kernels/compile.sh | 1 + gpu/convert_checkpoint.py | 6 +++--- gpu/generate.py | 6 ++---- gpu/model.py | 5 +---- 6 files changed, 8 insertions(+), 19 deletions(-) diff --git a/gpu/bitnet_kernels/bitgemm.cu b/gpu/bitnet_kernels/bitgemm.cu index 423c15c0..4d12e34a 100644 --- a/gpu/bitnet_kernels/bitgemm.cu +++ b/gpu/bitnet_kernels/bitgemm.cu @@ -226,11 +226,4 @@ extern "C" void bitlinear_int8xint2(int8_t *input0, int8_t *input1, std::cerr << "CUDA kernel launch failed: " << cudaGetErrorString(launch_error) << std::endl; throw std::runtime_error("CUDA kernel launch failed"); } - - // Synchronize and check for execution errors - cudaError_t sync_error = cudaStreamSynchronize(stream); - if (sync_error != cudaSuccess) { - std::cerr << "CUDA kernel execution failed: " << cudaGetErrorString(sync_error) << std::endl; - throw std::runtime_error("CUDA kernel execution failed"); - } } \ No newline at end of file diff --git a/gpu/bitnet_kernels/bitnet_kernels.cu b/gpu/bitnet_kernels/bitnet_kernels.cu index 6e615809..16650005 100644 --- a/gpu/bitnet_kernels/bitnet_kernels.cu +++ b/gpu/bitnet_kernels/bitnet_kernels.cu @@ -2,7 +2,7 @@ extern "C" void bitlinear_int8xint2(int8_t* input0, int8_t* input1, __nv_bfloat16* output0, __nv_bfloat16* s, __nv_bfloat16* ws, int M, int N, int K, cudaStream_t stream){ if (M == 1 && N == 3840 && K == 2560){ - ladder_int8xint2_kernel<1, 3840, 2560, 3, 8, 16><<>>(input0, input1, output0, s, ws); + ladder_int8xint2_kernel<1, 3840, 2560, 6, 8, 16><<>>(input0, input1, output0, s, ws); } else if (M == 1 && N == 2560 && K == 2560){ ladder_int8xint2_kernel<1, 2560, 2560, 1, 8, 16><<>>(input0, input1, output0, s, ws); diff --git a/gpu/bitnet_kernels/compile.sh b/gpu/bitnet_kernels/compile.sh index 1e22741d..6d6e40ca 100644 --- a/gpu/bitnet_kernels/compile.sh +++ b/gpu/bitnet_kernels/compile.sh @@ -1,3 +1,4 @@ nvcc -std=c++17 -Xcudafe --diag_suppress=177 --compiler-options -fPIC -lineinfo --shared bitnet_kernels.cu -lcuda -gencode=arch=compute_80,code=compute_80 -o libbitnet.so +nvcc -std=c++17 -Xcudafe --diag_suppress=177 --compiler-options -fPIC -lineinfo --shared bitgemm.cu -lcuda -gencode=arch=compute_80,code=compute_80 -o libgemm.so diff --git a/gpu/convert_checkpoint.py b/gpu/convert_checkpoint.py index 797ad1db..0b0ba90f 100755 --- a/gpu/convert_checkpoint.py +++ b/gpu/convert_checkpoint.py @@ -47,7 +47,7 @@ def convert_int8_to_int2(weight): wk_weight, wb_scale = quant_weight_int8(wk) wv_weight, wc_scale = quant_weight_int8(wv) wqkv_weight = torch.cat([wq_weight, wk_weight, wv_weight], dim=0) - wqkv_scale = torch.cat([wa_scale, wb_scale, wc_scale, zero], dim=0) + wqkv_scale = torch.cat([wa_scale, wa_scale, wa_scale, wa_scale, wb_scale, wc_scale], dim=0) int2_result[key] = convert_int8_to_int2(wqkv_weight) int2_result[key.replace('weight', 'weight_scale')] = wqkv_scale @@ -62,7 +62,7 @@ def convert_int8_to_int2(weight): w1_weight, w1_scale = quant_weight_int8(w1) w3_weight, w3_scale = quant_weight_int8(w3) w13_weight = torch.cat([w1_weight, w3_weight], dim=0) - w13_scale = torch.cat([w1_scale, w3_scale, zero, zero], dim=0) + w13_scale = torch.cat([w1_scale, w3_scale, zero, zero, zero, zero], dim=0) int2_result[key] = convert_int8_to_int2(w13_weight) int2_result[key.replace('weight', 'weight_scale')] = w13_scale @@ -72,7 +72,7 @@ def convert_int8_to_int2(weight): fp16_result[key] = w13_weight elif 'w2' in key or 'wo' in key: weight, scale = quant_weight_int8(value) - scale = torch.cat([scale, zero, zero, zero], dim=0) + scale = torch.cat([scale, zero, zero, zero, zero, zero], dim=0) int2_result[key] = convert_int8_to_int2(weight) int2_result[key.replace('weight', 'weight_scale')] = scale diff --git a/gpu/generate.py b/gpu/generate.py index 638ed7b3..63415a07 100755 --- a/gpu/generate.py +++ b/gpu/generate.py @@ -53,7 +53,7 @@ def build( """ start_time = time.time() - model_args_prefill = fast.ModelArgs(use_kernel=False) + model_args_prefill = fast.ModelArgs(use_kernel=True) model_args_decode = fast.ModelArgs(use_kernel=True) tokenizer = Tokenizer("./tokenizer.model") @@ -63,11 +63,9 @@ def build( prefill_model = fast.Transformer(model_args_prefill) decode_model = fast.Transformer(model_args_decode) - fp16_ckpt_path = str(Path(ckpt_dir) / "model_state_fp16.pt") - fp16_checkpoint = torch.load(fp16_ckpt_path, map_location="cpu") int2_ckpt_path = str(Path(ckpt_dir) / "model_state_int2.pt") int2_checkpoint = torch.load(int2_ckpt_path, map_location="cpu") - prefill_model.load_state_dict(fp16_checkpoint, strict=True) + prefill_model.load_state_dict(int2_checkpoint, strict=True) decode_model.load_state_dict(int2_checkpoint, strict=True) torch.cuda.synchronize() diff --git a/gpu/model.py b/gpu/model.py index 9a37062d..1df8bd68 100755 --- a/gpu/model.py +++ b/gpu/model.py @@ -109,11 +109,8 @@ def quant_input(self, input): s = 127 / input.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5) return (input * s).round().clamp(-128, 127).to(torch.int8), s - def forward(self, input, weight_int8=None): + def forward(self, input): input, s = self.quant_input(input) - weight_np = weight_int8.cpu().to(torch.int32).T.numpy() - input_np = input.cpu().to(torch.int32).numpy() - out_np = np.matmul(input_np, weight_np) if input.shape[0] == 1: return bitnet_int8xint2_linear_gemv(input, self.weight, s, self.weight_scale) else: From 6dbe26574285e4866132df49fc6bc898edb7b21d Mon Sep 17 00:00:00 2001 From: Ying Hu Date: Wed, 30 Jul 2025 03:17:17 +0000 Subject: [PATCH 3/4] Add tests for gemm kernel --- gpu/test_gemm.py | 118 ++++++++++++++++++++++++----------------------- 1 file changed, 61 insertions(+), 57 deletions(-) diff --git a/gpu/test_gemm.py b/gpu/test_gemm.py index 903e15c3..bf9aa401 100644 --- a/gpu/test_gemm.py +++ b/gpu/test_gemm.py @@ -1,59 +1,63 @@ import torch -from model import BitLinear, BitLinearKernel from pack_weight import convert_weight_int8_to_int2 - -from torch import nn - - -def quant_weight_int8(weight): - s = 1.0 / weight.abs().mean().clamp_(min=1e-5) - new_weight = (weight * s).round().clamp(-1, 1).to(torch.int8) - new_scale = (1.0 / s).to(torch.bfloat16) - return new_weight, new_scale.reshape(1).repeat(6) - -def quant_weight(weight): - s = 1.0 / weight.abs().mean().clamp_(min=1e-5) - new_weight = (weight * s).round().clamp(-1, 1) / s - - return new_weight - -def convert_int8_to_int2(weight): - return convert_weight_int8_to_int2(weight) - - -def test_BitLinear(): - in_dim = 2560 # 64 - out_dim = 3840 # 32 - default_dtype = torch.bfloat16 - x = torch.randn(128, in_dim, dtype=default_dtype).cuda() # (batch, in_features) - - layer0 = BitLinear(in_dim, out_dim).cuda() - layer0 = layer0.to(default_dtype) - nn.init.kaiming_uniform_(layer0.weight, nonlinearity='relu') - with torch.no_grad(): - layer0.weight.copy_( quant_weight(layer0.weight)) - nn.init.zeros_(layer0.bias) - out0 = layer0(x) - - assert not torch.isnan(out0).any() - assert layer0.weight.dtype == default_dtype - # print(layer0.weight.dtype, layer0.weight.shape) - - layer1 = BitLinearKernel(in_dim, out_dim).cuda() - weight_int8, scale = quant_weight_int8(layer0.weight) - weight = convert_int8_to_int2(weight_int8) - - - with torch.no_grad(): - layer1.weight.copy_(weight) - layer1.weight_scale.copy_(scale) - print(layer1.weight, layer1.weight_scale) - out1 = layer1(x, weight_int8) - assert out1.dtype == default_dtype - - print(f"Non-kernel output: {out0}, Kernel output: {out1}") - assert torch.equal(out0, out1), "Outputs from non-kernel and kernel paths should match" - - -if __name__ == "__main__": - test_BitLinear() +from torch.profiler import profile, record_function, ProfilerActivity +import ctypes +import numpy as np +from torch.utils import benchmark + +gemm_lib = ctypes.CDLL('bitnet_kernels/libgemm.so') +# set all seed +torch.manual_seed(42) +np.random.seed(42) + +def bit_linear_int8xint2(input0, weight, out, M, N, K): + stream = torch.cuda.current_stream() + gemm_lib.bitlinear_int8xint2(*[ + ctypes.c_void_p(input0.data_ptr()), + ctypes.c_void_p(weight.data_ptr()), + ctypes.c_void_p(out.data_ptr()), + ctypes.c_int(M), + ctypes.c_int(N), + ctypes.c_int(K), + ctypes.c_void_p(stream.cuda_stream),]) + +M = 512 +test_list = [ + (2560, 2560), + (3840, 2560), + (13824, 2560), + (2560, 6912), +] +for N,K in test_list: + weight = torch.randint(-1, 2, (N, K), dtype=torch.int8, device='cuda') + weight_scale = torch.ones(1, dtype=torch.bfloat16, device='cuda') + weight_compressed = convert_weight_int8_to_int2(weight).to('cuda') + weight_np = weight.cpu().to(torch.int32).T.numpy() + stream = torch.cuda.current_stream() + input0 = torch.randint(-128,127,(M, K),dtype=torch.int8, device='cuda') + input0_np = input0.cpu().to(torch.int32).numpy() + out_np = np.matmul(input0_np, weight_np) + weight_bf16 = weight.to(torch.bfloat16).T + input0_bf16 = input0.to(torch.bfloat16) + s = torch.ones(1, dtype=torch.bfloat16, device='cuda') + ws = torch.ones(6, dtype=torch.bfloat16, device='cuda') + out = torch.empty(M, N, dtype=torch.int32, device='cuda') + t0 = benchmark.Timer( + stmt="bit_linear_int8xint2(input0, weight_compressed, out, M, N, K)", + setup="from __main__ import input0, weight_compressed, s, ws, out, bit_linear_int8xint2, M, N, K", + num_threads=1, + ) + + t1 = benchmark.Timer( + stmt="out_bf16 = torch.matmul(input0_bf16, weight_bf16)", + setup="from __main__ import input0_bf16, weight_bf16", + num_threads=1, + ) + + time0 = t0.timeit(50) + time1 = t1.timeit(50) + + print(f'Shape{M,N,K}, W2A8: {time0.mean * 1e6:.2f}us, torch BF16: {time1.mean * 1e6:.2f}us') + out_np = torch.tensor(out_np).cuda() + + print(f'custom == np {torch.all(out==out_np)}') From d1caa81e162722a145aa8b183fe2d2d74abac8c7 Mon Sep 17 00:00:00 2001 From: Ying Hu Date: Wed, 6 Aug 2025 07:00:28 +0000 Subject: [PATCH 4/4] Fuse convert to bf16 and scaling into kernel; make BLOCK_SIZE_M and BLOCK_SIZE_N adjustable --- gpu/bitnet_kernels/bitgemm.cu | 209 +++++++++++++++++----------------- gpu/model.py | 32 ++---- gpu/test_gemm.py | 15 +-- 3 files changed, 121 insertions(+), 135 deletions(-) diff --git a/gpu/bitnet_kernels/bitgemm.cu b/gpu/bitnet_kernels/bitgemm.cu index 4d12e34a..eda5e267 100644 --- a/gpu/bitnet_kernels/bitgemm.cu +++ b/gpu/bitnet_kernels/bitgemm.cu @@ -31,184 +31,180 @@ __device__ void decode_i2s_to_i8s(T1 *_i2s, T2 *_i8s, const int N = 16) { } } - -template -__global__ void int8_int2_gemm_tensor_core( - const int8_t *__restrict__ A, // M x K matrix, row-major - const int32_t *__restrict__ B_compressed, // Compressed int2 data for N x K matrix, column-major - int32_t *__restrict__ C, // M x N output matrix, row-major - int M) +template +__global__ void int8_int2_gemm_fused_kernel( + const int8_t *__restrict__ A, + const int32_t *__restrict__ B_compressed, + __nv_bfloat16 *__restrict__ C, + int M, + const __nv_bfloat16 *__restrict__ s, // MODIFICATION: s is now bfloat16 + const __nv_bfloat16 *__restrict__ ws) // MODIFICATION: ws is now bfloat16 { - // Define WMMA dimensions - all constant + // --- GEMM Calculation Stage (largely unchanged) --- constexpr int WMMA_M = 16; constexpr int WMMA_N = 16; constexpr int WMMA_K = 16; - - // Define block tile dimensions - all constant - constexpr int BLOCK_SIZE_M = 64; // Multiple of WMMA_M - constexpr int BLOCK_SIZE_N = 64; // Multiple of WMMA_N - constexpr int BLOCK_SIZE_K = 32; // K dimension as requested + constexpr int BLOCK_SIZE_K = 32; + constexpr int WARPS_M = 2; + constexpr int WARPS_N = 2; + constexpr int M_ITER = BLOCK_SIZE_M / WMMA_M / WARPS_M; + constexpr int N_ITER = BLOCK_SIZE_N / WMMA_N / WARPS_N; - // Calculate thread block position const int blockM = blockIdx.y * BLOCK_SIZE_M; const int blockN = blockIdx.x * BLOCK_SIZE_N; - - // Calculate thread ID and warp IDs - const int warpM = threadIdx.y; // 0-1 (2 warps in M dimension) - const int warpN = threadIdx.z; // 0-1 (2 warps in N dimension) + const int warpM = threadIdx.y; + const int warpN = threadIdx.z; const int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y; - // Add padding to shared memory to avoid bank conflicts - constexpr int PAD_A = 16; // Padding for A matrix - constexpr int PAD_B = 16; // Padding for B matrix + constexpr int PAD_A = 16; + constexpr int PAD_B = 16; - // Allocate shared memory for A and B matrices with padding __shared__ int8_t shared_A[BLOCK_SIZE_M][BLOCK_SIZE_K + PAD_A]; __shared__ int8_t shared_B[BLOCK_SIZE_N][BLOCK_SIZE_K + PAD_B]; - // Define fragments for all tiles this warp will handle - static allocation - nvcuda::wmma::fragment c_frags[2][2]; + nvcuda::wmma::fragment c_frags[M_ITER][N_ITER]; nvcuda::wmma::fragment a_frag; nvcuda::wmma::fragment b_frag; - // Initialize all accumulator fragments to zero (unrolled) #pragma unroll - for (int m_iter = 0; m_iter < 2; m_iter++) { + for (int m_iter = 0; m_iter < M_ITER; m_iter++) { #pragma unroll - for (int n_iter = 0; n_iter < 2; n_iter++) { + for (int n_iter = 0; n_iter < N_ITER; n_iter++) { nvcuda::wmma::fill_fragment(c_frags[m_iter][n_iter], 0); } } - // Only check M bounds at the beginning const bool m_valid = blockM < M; - // Loop over K dimension in chunks of BLOCK_SIZE_K - #pragma unroll 4 // Partial unroll of K-dimension loop for (int k_block = 0; k_block < K; k_block += BLOCK_SIZE_K) { - // Clear shared memory first __syncthreads(); - - // Load A matrix tiles into shared memory using vectorized loads - // Each thread handles multiple elements based on its ID + // Load A tile for (int load_idx = tid; load_idx < (BLOCK_SIZE_M * BLOCK_SIZE_K / 16); load_idx += blockDim.x * blockDim.y * blockDim.z) { int local_m = (load_idx * 16) / BLOCK_SIZE_K; int local_k = (load_idx * 16) % BLOCK_SIZE_K; - int global_m = blockM + local_m; int global_k = k_block + local_k; - - // Use vector loads for A - 16 bytes at a time (int4 = 4 integers = 16 bytes) if (m_valid && global_m < M) { - // Vector load from A to shared memory *((int4*)&shared_A[local_m][local_k]) = *((int4*)&A[global_m * K + global_k]); } else { - // Zero out if M is out of bounds *((int4*)&shared_A[local_m][local_k]) = {0}; } } - - // Load B matrix tiles into shared memory (always in bounds for N and K) - // Calculate which 16-element chunk this thread is responsible for + // Load B tile int chunk_n = (tid * 16 / BLOCK_SIZE_K); int chunk_k = (tid * 16) % BLOCK_SIZE_K; - if (chunk_n < BLOCK_SIZE_N) { int global_n = blockN + chunk_n; int global_k = k_block + chunk_k; - - // Calculate which compressed block this belongs to int n_block = global_n / 16; int k_block_32 = global_k / 32; int k_offset_in_block = chunk_k % 32; - - // Get the specific compressed tile within the 16x32 block int in_block_n = chunk_n % 16; int compressed_block_idx = n_block * (K / 32) + k_block_32; - - // Calculate which tile within the compressed block - int tile_idx; - tile_idx = in_block_n / 8 * 16 + in_block_n % 8 + (k_offset_in_block / 16) * 8; - - // Extract and decompress the int2 values + int tile_idx = in_block_n / 8 * 16 + in_block_n % 8 + (k_offset_in_block / 16) * 8; int32_t compressed = B_compressed[compressed_block_idx * 32 + tile_idx]; int8_t decompressed[16]; decode_i2s_to_i8s(&compressed, decompressed); - - // Vector store to shared memory *((int4*)&shared_B[chunk_n][chunk_k]) = *((int4*)decompressed); } - - // Make sure all threads have finished loading into shared memory __syncthreads(); - - // Process the 2x2 WMMA tiles for this K block + // Perform MMA #pragma unroll - for (int m_iter = 0; m_iter < 2; m_iter++) { + for (int m_iter = 0; m_iter < M_ITER; m_iter++) { #pragma unroll - for (int n_iter = 0; n_iter < 2; n_iter++) { - // Calculate the starting positions for this WMMA tile + for (int n_iter = 0; n_iter < N_ITER; n_iter++) { #pragma unroll for (int wmma_k = 0; wmma_k < BLOCK_SIZE_K; wmma_k += WMMA_K) { - // Fully unroll the m and n iterations - const int tile_m = (warpM * 2 + m_iter) * WMMA_M; - const int tile_n = (warpN * 2 + n_iter) * WMMA_N; - - // Load matrix A fragment from shared memory with padding - nvcuda::wmma::load_matrix_sync( - a_frag, &shared_A[tile_m][wmma_k], BLOCK_SIZE_K + PAD_A); - - // Load matrix B fragment from shared memory with padding - nvcuda::wmma::load_matrix_sync( - b_frag, &shared_B[tile_n][wmma_k], BLOCK_SIZE_K + PAD_B); - - // Perform matrix multiplication + const int tile_m = (warpM + m_iter * WARPS_M) * WMMA_M; + const int tile_n = (warpN + n_iter * WARPS_N) * WMMA_N; + nvcuda::wmma::load_matrix_sync(a_frag, &shared_A[tile_m][wmma_k], BLOCK_SIZE_K + PAD_A); + nvcuda::wmma::load_matrix_sync(b_frag, &shared_B[tile_n][wmma_k], BLOCK_SIZE_K + PAD_B); nvcuda::wmma::mma_sync(c_frags[m_iter][n_iter], a_frag, b_frag, c_frags[m_iter][n_iter]); } } } } - // Store results back to global memory - only check M bounds + // --- Fused Post-Processing and Store Stage --- + __shared__ int32_t shared_C[BLOCK_SIZE_M][BLOCK_SIZE_N]; + #pragma unroll - for (int m_iter = 0; m_iter < 2; m_iter++) { - const int tile_m = (warpM * 2 + m_iter) * WMMA_M; - const int global_tile_m = blockM + tile_m; - - if (m_valid && global_tile_m < M) { - #pragma unroll - for (int n_iter = 0; n_iter < 2; n_iter++) { - const int tile_n = (warpN * 2 + n_iter) * WMMA_N; - const int global_tile_n = blockN + tile_n; - - // No need to check N bounds as it's always aligned - nvcuda::wmma::store_matrix_sync( - &C[global_tile_m * N + global_tile_n], - c_frags[m_iter][n_iter], N, nvcuda::wmma::mem_row_major); + for (int m_iter = 0; m_iter < M_ITER; m_iter++) { + #pragma unroll + for (int n_iter = 0; n_iter < N_ITER; n_iter++) { + const int tile_m = (warpM + m_iter * WARPS_M) * WMMA_M; + const int tile_n = (warpN + n_iter * WARPS_N) * WMMA_N; + nvcuda::wmma::store_matrix_sync( + &shared_C[tile_m][tile_n], + c_frags[m_iter][n_iter], + BLOCK_SIZE_N, + nvcuda::wmma::mem_row_major); + } + } + __syncthreads(); + + for (int i = tid; i < BLOCK_SIZE_M * BLOCK_SIZE_N; i += blockDim.x * blockDim.y * blockDim.z) { + const int m = i / BLOCK_SIZE_N; + const int n = i % BLOCK_SIZE_N; + + const int global_m = blockM + m; + const int global_n = blockN + n; + + if (global_m < M) { + int32_t val = shared_C[m][n]; + float float_val = static_cast(val); + + // MODIFICATION: Load bfloat16 scales and convert to float for calculation + float s_val = __bfloat162float(s[global_m]); + float_val /= s_val; + + int ws_idx = 0; + if (N == 3840) { + ws_idx = global_n / (3840 / 6); + } else if (N == 13824) { + ws_idx = global_n / (13824 / 2); } + + float ws_val = __bfloat162float(ws[ws_idx]); + float_val *= ws_val; + + __nv_bfloat16 bf16_val = __float2bfloat16(float_val); + C[global_m * N + global_n] = bf16_val; } } } -extern "C" void bitlinear_int8xint2(int8_t *input0, int8_t *input1, - int32_t *output0, int M, int N, int K, - cudaStream_t stream = 0) { +extern "C" void bitlinear_fused_int8xint2( + int8_t *input0, int8_t *input1, + __nv_bfloat16 *output0, + int M, int N, int K, + __nv_bfloat16 *s, // MODIFICATION: s is now bfloat16 + __nv_bfloat16 *ws, // MODIFICATION: ws is now bfloat16 + cudaStream_t stream = 0) { + + constexpr int BLOCK_SIZE_M = 64; + constexpr int BLOCK_SIZE_N = 64; + + const dim3 gridDim(N / BLOCK_SIZE_N, (M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M, 1); + const dim3 blockDim(32, 2, 2); + + // Kernel launch now passes the bfloat16 pointers if (N == 3840 && K == 2560) { - int8_int2_gemm_tensor_core<3840, 2560> - <<>>( - input0, (int32_t *)input1, (int32_t *)output0, M); + int8_int2_gemm_fused_kernel<3840, 2560, BLOCK_SIZE_M, BLOCK_SIZE_N> + <<>>( + input0, (int32_t *)input1, output0, M, s, ws); } else if (N == 2560 && K == 2560) { - int8_int2_gemm_tensor_core<2560, 2560> - <<>>( - input0, (int32_t *)input1, (int32_t *)output0, M); + int8_int2_gemm_fused_kernel<2560, 2560, BLOCK_SIZE_M, BLOCK_SIZE_N> + <<>>( + input0, (int32_t *)input1, output0, M, s, ws); } else if (N == 13824 && K == 2560) { - int8_int2_gemm_tensor_core<13824, 2560> - <<>>( - input0, (int32_t *)input1, (int32_t *)output0, M); + int8_int2_gemm_fused_kernel<13824, 2560, BLOCK_SIZE_M, BLOCK_SIZE_N> + <<>>( + input0, (int32_t *)input1, output0, M, s, ws); } else if (N == 2560 && K == 6912) { - int8_int2_gemm_tensor_core<2560, 6912> - <<>>( - input0, (int32_t *)input1, (int32_t *)output0, M); + int8_int2_gemm_fused_kernel<2560, 6912, BLOCK_SIZE_M, BLOCK_SIZE_N> + <<>>( + input0, (int32_t *)input1, output0, M, s, ws); } else { std::cerr << "Error: Unsupported matrix dimensions for bitlinear_int8xint2. " << "Required kernel: M=" << M << ", N=" << N << ", K=" << K << std::endl; @@ -220,10 +216,9 @@ extern "C" void bitlinear_int8xint2(int8_t *input0, int8_t *input1, throw std::runtime_error("Unsupported matrix dimensions for bitlinear_int8xint2"); } - // Check for CUDA launch errors cudaError_t launch_error = cudaGetLastError(); if (launch_error != cudaSuccess) { std::cerr << "CUDA kernel launch failed: " << cudaGetErrorString(launch_error) << std::endl; throw std::runtime_error("CUDA kernel launch failed"); } -} \ No newline at end of file +} diff --git a/gpu/model.py b/gpu/model.py index 1df8bd68..fab4ad86 100755 --- a/gpu/model.py +++ b/gpu/model.py @@ -51,27 +51,17 @@ def bitnet_int8xint2_linear_gemm(input0, input1, s, ws): N = input1.shape[0] K = input1.shape[1] * 4 - ret = torch.zeros(*out_shape, dtype=torch.int32, device=input0.device) - - gemm_lib.bitlinear_int8xint2(*[ctypes.c_void_p(input0.data_ptr()), ctypes.c_void_p(input1.data_ptr()), ctypes.c_void_p(ret.data_ptr()), ctypes.c_int(M), ctypes.c_int(N), ctypes.c_int(K), ctypes.c_void_p(stream.cuda_stream)]) - ret = ret.to(torch.bfloat16) - ret = ret / s - if N == 3840 and K == 2560: - #split last dim to 6 parts evenly - ret = ret.reshape(*ret.shape[:-1], 6, -1) - # devide each part by first 6 coresponding weight scale - ret = ret * ws[:6].reshape(1, 6, 1) - elif (N == 2560 and K == 2560): - # 1 part - ret = ret* ws[:1].reshape(1, 1, 1, 1) - elif (N == 13824 and K == 2560): - # 2 parts - ret = ret.reshape(*ret.shape[:-1], 2, -1) - # devide each part by first 2 coresponding weight scale - ret = ret * ws[:2].reshape(1, 1, 2, 1) - elif (N == 2560 and K == 6912): - # 1 part - ret = ret * ws[:1].reshape(1, 1, 1, 1) + ret = torch.zeros(*out_shape, dtype=torch.bfloat16, device=input0.device) + + gemm_lib.bitlinear_fused_int8xint2(*[ctypes.c_void_p(input0.data_ptr()), + ctypes.c_void_p(input1.data_ptr()), + ctypes.c_void_p(ret.data_ptr()), + ctypes.c_int(M), + ctypes.c_int(N), + ctypes.c_int(K), + ctypes.c_void_p(s.data_ptr()), + ctypes.c_void_p(ws.data_ptr()), + ctypes.c_void_p(stream.cuda_stream)]) return ret.reshape(*out_shape) diff --git a/gpu/test_gemm.py b/gpu/test_gemm.py index bf9aa401..a3acf3c4 100644 --- a/gpu/test_gemm.py +++ b/gpu/test_gemm.py @@ -4,6 +4,7 @@ import ctypes import numpy as np from torch.utils import benchmark +from model import bitnet_int8xint2_linear_gemm gemm_lib = ctypes.CDLL('bitnet_kernels/libgemm.so') # set all seed @@ -39,12 +40,12 @@ def bit_linear_int8xint2(input0, weight, out, M, N, K): out_np = np.matmul(input0_np, weight_np) weight_bf16 = weight.to(torch.bfloat16).T input0_bf16 = input0.to(torch.bfloat16) - s = torch.ones(1, dtype=torch.bfloat16, device='cuda') + s = torch.ones(M, dtype=torch.bfloat16, device='cuda') ws = torch.ones(6, dtype=torch.bfloat16, device='cuda') - out = torch.empty(M, N, dtype=torch.int32, device='cuda') + out = bitnet_int8xint2_linear_gemm(input0, weight_compressed, s,ws) t0 = benchmark.Timer( - stmt="bit_linear_int8xint2(input0, weight_compressed, out, M, N, K)", - setup="from __main__ import input0, weight_compressed, s, ws, out, bit_linear_int8xint2, M, N, K", + stmt="out_kernel = bitnet_int8xint2_linear_gemm(input0, weight_compressed, s,ws)", + setup="from __main__ import input0, weight_compressed, s, ws, bitnet_int8xint2_linear_gemm", num_threads=1, ) @@ -54,10 +55,10 @@ def bit_linear_int8xint2(input0, weight, out, M, N, K): num_threads=1, ) - time0 = t0.timeit(50) - time1 = t1.timeit(50) + time0 = t0.timeit(10) + time1 = t1.timeit(10) print(f'Shape{M,N,K}, W2A8: {time0.mean * 1e6:.2f}us, torch BF16: {time1.mean * 1e6:.2f}us') - out_np = torch.tensor(out_np).cuda() + out_np = torch.tensor(out_np).cuda().to(torch.bfloat16) print(f'custom == np {torch.all(out==out_np)}')