From bd5b91f51a92f877e9ef0c14a0b781f5f67a54d4 Mon Sep 17 00:00:00 2001 From: Xianzhe Dong Date: Sat, 13 Apr 2024 08:30:29 -0400 Subject: [PATCH 1/5] [op] optimize layernorm kernel for half2 type --- src/kernels/layernorm_kernels.cu | 112 +++++++++++++++++++++++++++++++ src/kernels/layernorm_kernels.h | 24 +++++++ 2 files changed, 136 insertions(+) diff --git a/src/kernels/layernorm_kernels.cu b/src/kernels/layernorm_kernels.cu index 3e32bd8b..354c0ee9 100644 --- a/src/kernels/layernorm_kernels.cu +++ b/src/kernels/layernorm_kernels.cu @@ -3,6 +3,7 @@ #include "dispatch.h" #include "reduce_kernel_utils.cuh" +#include "layernorm_kernels.h" namespace llm::kernel { @@ -173,6 +174,67 @@ __global__ void layer_norm_kernel(T* __restrict__ out, } } +// equation: x -> (x - E[x]) / sqrt(Var[x] + eps) * w + b +// The mean and standard-deviation are calculated over the last dimension +template <> +__global__ void layer_norm_kernel(half2* __restrict__ out, + const half2* __restrict__ input, + const half2* __restrict__ weight, + const half2* __restrict__ bias, + const float epsilon, + int n) { + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + + __shared__ half s_mean; + __shared__ half s_variance; + half2 mean = make_half2(__float2half(0.0f), __float2half(0.0f)); + half2 variance = make_half2(__float2half(0.0f), __float2half(0.0f)); + + // calculate mean of the input. + for (int i = tidx; i < n; i += blockDim.x) { + const int idx = bidx * n + i; + mean = __hadd2(mean, __ldg(&input[idx])); + } + mean = block_reduce_sum(mean); + if (tidx == 0) { + s_mean = __hdiv(__hadd(mean.x, mean.y), __float2half((float)n * 2)); + } + __syncthreads(); + + // calculate variance of the input. + for (int i = tidx; i < n; i += blockDim.x) { + const half2 x = __hsub2(input[bidx * n + i], make_half2(s_mean, s_mean)); + variance = __hadd2(variance, __hmul2(x, x)); + } + variance = block_reduce_sum(variance); + if (tidx == 0) { + // const half2 e = make_half2(__float2half(epsilon), __float2half(epsilon)); + s_variance = __hadd(variance.x, variance.y); + s_variance = __hdiv(s_variance, __float2half((float)n * 2)); + s_variance = __hadd(s_variance, __float2half(epsilon)); + s_variance = hrsqrt(s_variance); + } + __syncthreads(); + + for (int i = tidx; i < n; i += blockDim.x) { + const int idx = bidx * n + i; + // float local_out = + // (__ldg(&input[idx]) - s_mean) * s_variance * __ldg(&weight[i]); + // if (bias != nullptr) { + // local_out += __ldg(&bias[i]); + // } + half2 local_out = __ldg(&input[idx]); + local_out = __hsub2(local_out, make_half2(s_mean, s_mean)); + local_out = __hmul2(local_out, make_half2(s_variance, s_variance)); + local_out = __hmul2(local_out, __ldg(&weight[i])); + if (bias != nullptr){ + local_out = __hadd2(local_out, __ldg(&bias[i])); + } + out[idx] = local_out; + } +} + void layer_norm(torch::Tensor& out, torch::Tensor input, torch::Tensor weight, @@ -197,4 +259,54 @@ void layer_norm(torch::Tensor& out, }); } +template +void invoke_layernorm_kernel(T* out, + const T* input, + const T* weight, + const T* bias, + const float epsilon, + int m, + int n) { + layer_norm_kernel<<>>(out, input, weight, bias, epsilon, n); +} + +template <> +void invoke_layernorm_kernel(half2* out, + const half2* input, + const half2* weight, + const half2* bias, + const float epsilon, + int m, + int n) { + layer_norm_kernel<<>>(out, input, weight, bias, epsilon, n); +} +template <> +void invoke_layernorm_kernel(float* out, + const float* input, + const float* weight, + const float* bias, + const float epsilon, + int m, + int n) { + layer_norm_kernel<<>>(out, input, weight, bias, epsilon, n); + } +// void invoke_float_layernorm_kernel(float* out, +// const float* input, +// const float* weight, +// const float* bias, +// const float epsilon, +// int m, +// int n){ +// layer_norm_kernel<<>>(out, input, weight, bias, epsilon, n); +// } + +// void invoke_half2_layernorm_kernel(half2* out, +// const half2* input, +// const half2* weight, +// const half2* bias, +// const float epsilon, +// int m, +// int n){ +// layer_norm_kernel<<>>(out, input, weight, bias, epsilon, n); +// } } // namespace llm::kernel diff --git a/src/kernels/layernorm_kernels.h b/src/kernels/layernorm_kernels.h index 496622bb..fe72e1b1 100644 --- a/src/kernels/layernorm_kernels.h +++ b/src/kernels/layernorm_kernels.h @@ -20,4 +20,28 @@ void layer_norm(torch::Tensor& out, torch::Tensor bias, float epsilon); +template +void invoke_layernorm_kernel(T* out, + const T* input, + const T* weight, + const T* bias, + const float epsilon, + int m, + int n); + +// void invoke_float_layernorm_kernel(float* out, +// const float* input, +// const float* weight, +// const float* bias, +// const float epsilon, +// int m, +// int n); + +// void invoke_half2_layernorm_kernel(half2* out, +// const half2* input, +// const half2* weight, +// const half2* bias, +// const float epsilon, +// int m, +// int n); } // namespace llm::kernel From 14a99f786ba51406b037a76a1aef1aa706cc4f43 Mon Sep 17 00:00:00 2001 From: Xianzhe Dong Date: Wed, 17 Apr 2024 02:32:46 -0400 Subject: [PATCH 2/5] [ut] add layernorm kernel unitest --- src/kernels/CMakeLists.txt | 21 ++++- src/kernels/layernorm_kernels.cu | 80 ++++++---------- src/kernels/layernorm_kernels.h | 16 ---- src/kernels/layernrom_kernels_test.cu | 129 ++++++++++++++++++++++++++ src/kernels/reduce_kernel_utils.cuh | 59 ++++++++++++ 5 files changed, 236 insertions(+), 69 deletions(-) create mode 100644 src/kernels/layernrom_kernels_test.cu diff --git a/src/kernels/CMakeLists.txt b/src/kernels/CMakeLists.txt index 397123cb..85d601c6 100644 --- a/src/kernels/CMakeLists.txt +++ b/src/kernels/CMakeLists.txt @@ -1,4 +1,5 @@ include(cc_library) +include(cc_binary) cc_library( NAME @@ -72,6 +73,24 @@ cc_library( torch ) +# cc_test( +# NAME +# layernorm_kernels_test +# SRCS +# layernrom_kernels_test.cu +# layernorm_kernels.cu +# DEPS +# DEFINES +# ) +cc_binary( + NAME + layernorm_kernels_test + SRCS + layernrom_kernels_test.cu + layernorm_kernels.cu + DEPS + torch +) + add_subdirectory(flash_attn) add_subdirectory(flash_infer) - diff --git a/src/kernels/layernorm_kernels.cu b/src/kernels/layernorm_kernels.cu index 354c0ee9..8787d458 100644 --- a/src/kernels/layernorm_kernels.cu +++ b/src/kernels/layernorm_kernels.cu @@ -2,8 +2,8 @@ #include #include "dispatch.h" -#include "reduce_kernel_utils.cuh" #include "layernorm_kernels.h" +#include "reduce_kernel_utils.cuh" namespace llm::kernel { @@ -178,11 +178,11 @@ __global__ void layer_norm_kernel(T* __restrict__ out, // The mean and standard-deviation are calculated over the last dimension template <> __global__ void layer_norm_kernel(half2* __restrict__ out, - const half2* __restrict__ input, - const half2* __restrict__ weight, - const half2* __restrict__ bias, - const float epsilon, - int n) { + const half2* __restrict__ input, + const half2* __restrict__ weight, + const half2* __restrict__ bias, + const float epsilon, + int n) { const int tidx = threadIdx.x; const int bidx = blockIdx.x; @@ -209,7 +209,6 @@ __global__ void layer_norm_kernel(half2* __restrict__ out, } variance = block_reduce_sum(variance); if (tidx == 0) { - // const half2 e = make_half2(__float2half(epsilon), __float2half(epsilon)); s_variance = __hadd(variance.x, variance.y); s_variance = __hdiv(s_variance, __float2half((float)n * 2)); s_variance = __hadd(s_variance, __float2half(epsilon)); @@ -219,16 +218,11 @@ __global__ void layer_norm_kernel(half2* __restrict__ out, for (int i = tidx; i < n; i += blockDim.x) { const int idx = bidx * n + i; - // float local_out = - // (__ldg(&input[idx]) - s_mean) * s_variance * __ldg(&weight[i]); - // if (bias != nullptr) { - // local_out += __ldg(&bias[i]); - // } half2 local_out = __ldg(&input[idx]); local_out = __hsub2(local_out, make_half2(s_mean, s_mean)); local_out = __hmul2(local_out, make_half2(s_variance, s_variance)); local_out = __hmul2(local_out, __ldg(&weight[i])); - if (bias != nullptr){ + if (bias != nullptr) { local_out = __hadd2(local_out, __ldg(&bias[i])); } out[idx] = local_out; @@ -261,52 +255,34 @@ void layer_norm(torch::Tensor& out, template void invoke_layernorm_kernel(T* out, - const T* input, - const T* weight, - const T* bias, - const float epsilon, - int m, - int n) { + const T* input, + const T* weight, + const T* bias, + const float epsilon, + int m, + int n) { layer_norm_kernel<<>>(out, input, weight, bias, epsilon, n); } template <> void invoke_layernorm_kernel(half2* out, - const half2* input, - const half2* weight, - const half2* bias, - const float epsilon, - int m, - int n) { + const half2* input, + const half2* weight, + const half2* bias, + const float epsilon, + int m, + int n) { layer_norm_kernel<<>>(out, input, weight, bias, epsilon, n); } template <> void invoke_layernorm_kernel(float* out, - const float* input, - const float* weight, - const float* bias, - const float epsilon, - int m, - int n) { + const float* input, + const float* weight, + const float* bias, + const float epsilon, + int m, + int n) { layer_norm_kernel<<>>(out, input, weight, bias, epsilon, n); - } -// void invoke_float_layernorm_kernel(float* out, -// const float* input, -// const float* weight, -// const float* bias, -// const float epsilon, -// int m, -// int n){ -// layer_norm_kernel<<>>(out, input, weight, bias, epsilon, n); -// } - -// void invoke_half2_layernorm_kernel(half2* out, -// const half2* input, -// const half2* weight, -// const half2* bias, -// const float epsilon, -// int m, -// int n){ -// layer_norm_kernel<<>>(out, input, weight, bias, epsilon, n); -// } -} // namespace llm::kernel +} + +} // namespace llm::kernel \ No newline at end of file diff --git a/src/kernels/layernorm_kernels.h b/src/kernels/layernorm_kernels.h index fe72e1b1..57e8cf0b 100644 --- a/src/kernels/layernorm_kernels.h +++ b/src/kernels/layernorm_kernels.h @@ -28,20 +28,4 @@ void invoke_layernorm_kernel(T* out, const float epsilon, int m, int n); - -// void invoke_float_layernorm_kernel(float* out, -// const float* input, -// const float* weight, -// const float* bias, -// const float epsilon, -// int m, -// int n); - -// void invoke_half2_layernorm_kernel(half2* out, -// const half2* input, -// const half2* weight, -// const half2* bias, -// const float epsilon, -// int m, -// int n); } // namespace llm::kernel diff --git a/src/kernels/layernrom_kernels_test.cu b/src/kernels/layernrom_kernels_test.cu new file mode 100644 index 00000000..119bc4a6 --- /dev/null +++ b/src/kernels/layernrom_kernels_test.cu @@ -0,0 +1,129 @@ +#include + +#include + +#include "layernorm_kernels.h" + +template +void printMatrix(T* a, int m, int n) { + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + printf("%f ", (float)a[i * n + j]); + } + puts(""); + } + puts(""); +} + +template <> +void printMatrix(half2* a, int m, int n) { + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + printf( + "%f %f ", __half2float(a[i * n + j].x), __half2float(a[i * n + j].y)); + } + puts(""); + } + puts(""); +} + +void layernorm_kernel_half2_test() { + float epsilon = 1e-6; + int m = 2; + int n = 2; + + half2* out = (half2*)malloc(m * n * sizeof(half2)); + half2* input = (half2*)malloc(m * n * sizeof(half2)); + half2* weight = (half2*)malloc(m * n * sizeof(half2)); + half2* bias = (half2*)malloc(m * n * sizeof(half2)); + + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + input[i * n + j] = half2(__float2half((float)(i * n + j * 2)), + __float2half((float)(i * n + j * 2 + 1))); + weight[i * n + j] = half2(__float2half(1.), __float2half(1.)); + bias[i * n + j] = half2(__float2half(0.), __float2half(0.)); + } + } + + half2* dout; + half2* dinput; + half2* dweight; + half2* dbias; + cudaMalloc((void**)&dout, sizeof(half2) * m * n); + cudaMalloc((void**)&dinput, sizeof(half2) * m * n); + cudaMalloc((void**)&dweight, sizeof(half2) * m * n); + cudaMalloc((void**)&dbias, sizeof(half2) * m * n); + + cudaMemcpy(dinput, input, sizeof(half2) * m * n, cudaMemcpyHostToDevice); + cudaMemcpy(dweight, weight, sizeof(half2) * m * n, cudaMemcpyHostToDevice); + cudaMemcpy(dbias, bias, sizeof(half2) * m * n, cudaMemcpyHostToDevice); + + llm::kernel::invoke_layernorm_kernel( + dout, dinput, dweight, dbias, epsilon, m, n); + + cudaMemcpy(out, dout, sizeof(half2) * m * n, cudaMemcpyDeviceToHost); + + printf("---------- test half2 layernorm kernel -----------\n"); + printf("input:\n"); + printMatrix(input, m, n); + printf("weights:\n"); + printMatrix(weight, m, n); + printf("bias:\n"); + printMatrix(bias, m, n); + printf("outputs:\n"); + printMatrix(out, m, n); +} + +void layernorm_kernel_float_test() { + float epsilon = 1e-6; + int m = 2; + int n = 4; + + float* out = (float*)malloc(m * n * sizeof(float)); + float* input = (float*)malloc(m * n * sizeof(float)); + float* weight = (float*)malloc(m * n * sizeof(float)); + float* bias = (float*)malloc(m * n * sizeof(float)); + + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + input[i * n + j] = (float)(i * n + j); + weight[i * n + j] = 1.; + bias[i * n + j] = 0.; + } + } + + float* dout; + float* dinput; + float* dweight; + float* dbias; + cudaMalloc((void**)&dout, sizeof(float) * m * n); + cudaMalloc((void**)&dinput, sizeof(float) * m * n); + cudaMalloc((void**)&dweight, sizeof(float) * m * n); + cudaMalloc((void**)&dbias, sizeof(float) * m * n); + + cudaMemcpy(dinput, input, sizeof(float) * m * n, cudaMemcpyHostToDevice); + cudaMemcpy(dweight, weight, sizeof(float) * m * n, cudaMemcpyHostToDevice); + cudaMemcpy(dbias, bias, sizeof(float) * m * n, cudaMemcpyHostToDevice); + + llm::kernel::invoke_layernorm_kernel( + dout, dinput, dweight, dbias, epsilon, m, n); + + cudaMemcpy(out, dout, sizeof(float) * m * n, cudaMemcpyDeviceToHost); + + printf("---------- test float layernorm kernel -----------\n"); + printf("input:\n"); + printMatrix(input, m, n); + printf("weights:\n"); + printMatrix(weight, m, n); + printf("bias:\n"); + printMatrix(bias, m, n); + printf("outputs:\n"); + printMatrix(out, m, n); +} + +int main() { + layernorm_kernel_float_test(); + layernorm_kernel_half2_test(); + return 0; +} \ No newline at end of file diff --git a/src/kernels/reduce_kernel_utils.cuh b/src/kernels/reduce_kernel_utils.cuh index 5e414dff..edb9d953 100644 --- a/src/kernels/reduce_kernel_utils.cuh +++ b/src/kernels/reduce_kernel_utils.cuh @@ -24,6 +24,36 @@ __inline__ __device__ T warp_reduce_sum(T val) { return val; } +// performs a parallel reduction operation across the threads within a single +// warp (32 threads). +// - val: The value to be reduced within a warp. +template <> +__inline__ __device__ half warp_reduce_sum(half val) { + // uses bitwise operations to perform a parallel reduction + // within a warp. The 'mask' is right-shifted by 1 in each iteration + // until it reaches zero, effectively summing all values within the warp. +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + val = __hadd(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); + } + return val; +} + +// performs a parallel reduction operation across the threads within a single +// warp (32 threads). +// - val: The value to be reduced within a warp. +template <> +__inline__ __device__ half2 warp_reduce_sum(half2 val) { + // uses bitwise operations to perform a parallel reduction + // within a warp. The 'mask' is right-shifted by 1 in each iteration + // until it reaches zero, effectively summing all values within the warp. +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + val = __hadd2(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); + } + return val; +} + // performs a parallel reduction operation across the threads within a single // warp (32 threads). // - val: The value to be reduced within a warp. @@ -63,6 +93,35 @@ __inline__ __device__ T block_reduce_sum(T val) { return val; } +/* Calculate the sum of all elements in a thread block */ +template <> +__inline__ __device__ half2 block_reduce_sum(half2 val) { + // up to 32 warps in a block + static __shared__ half2 shared[32]; + // lane id in a warp + int lane = threadIdx.x & 0x1f; + // wrap id: threadIdx.x / 32 + int wid = threadIdx.x >> 5; + + // perform a parallel reduction across the threads within each warp + val = warp_reduce_sum(val); + + if (lane == 0) { + // write the sum of each warp to shared memory + shared[wid] = val; + } + // wait for all warps to finish + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / 32.f)) + ? shared[lane] + : make_half2(__float2half(0.0f), __float2half(0.0f)); + val = warp_reduce_sum(val); + return val; +} + /* Calculate the max of all elements in a thread block */ template __inline__ __device__ T block_reduce_max(T val) { From 6ff173895b624466ac487cbcc2be8c3ab691a94c Mon Sep 17 00:00:00 2001 From: Xianzhe Dong Date: Wed, 17 Apr 2024 06:58:07 -0400 Subject: [PATCH 3/5] use gtest library rewrite layernorm kernel unitest --- src/kernels/CMakeLists.txt | 26 ++++------ src/kernels/layernrom_kernels_test.cu | 70 ++++++++++++++++++++++++--- 2 files changed, 73 insertions(+), 23 deletions(-) diff --git a/src/kernels/CMakeLists.txt b/src/kernels/CMakeLists.txt index 85d601c6..78ec3391 100644 --- a/src/kernels/CMakeLists.txt +++ b/src/kernels/CMakeLists.txt @@ -73,23 +73,15 @@ cc_library( torch ) -# cc_test( -# NAME -# layernorm_kernels_test -# SRCS -# layernrom_kernels_test.cu -# layernorm_kernels.cu -# DEPS -# DEFINES -# ) -cc_binary( - NAME - layernorm_kernels_test - SRCS - layernrom_kernels_test.cu - layernorm_kernels.cu - DEPS - torch +cc_test( + NAME + layernorm_kernels_test + SRCS + layernrom_kernels_test.cu + layernorm_kernels.cu + DEPS + torch + GTest::gtest_main ) add_subdirectory(flash_attn) diff --git a/src/kernels/layernrom_kernels_test.cu b/src/kernels/layernrom_kernels_test.cu index 119bc4a6..94c0b5b0 100644 --- a/src/kernels/layernrom_kernels_test.cu +++ b/src/kernels/layernrom_kernels_test.cu @@ -1,8 +1,8 @@ #include - #include - #include "layernorm_kernels.h" +#include +#include template void printMatrix(T* a, int m, int n) { @@ -15,6 +15,28 @@ void printMatrix(T* a, int m, int n) { puts(""); } +template <> +void printMatrix(float* a, int m, int n) { + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + printf("%f ", a[i * n + j]); + } + puts(""); + } + puts(""); +} + +template <> +void printMatrix(half* a, int m, int n) { + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + printf("%f ", __half2float(a[i * n + j])); + } + puts(""); + } + puts(""); +} + template <> void printMatrix(half2* a, int m, int n) { for (int i = 0; i < m; i++) { @@ -122,8 +144,44 @@ void layernorm_kernel_float_test() { printMatrix(out, m, n); } -int main() { - layernorm_kernel_float_test(); - layernorm_kernel_half2_test(); - return 0; +TEST(NormalizationKernelTest, LayernormFloatTest) { + float epsilon = 1e-6; + int m = 32; + int n = 512; + + auto input = torch::randn({m, n}); + auto weight = torch::randn({n}); + auto bias = torch::randn({n}); + auto desired_out = torch::nn::functional::layer_norm(input, torch::nn::functional::LayerNormFuncOptions({n}).weight(weight).bias(bias)); + + float* hout = (float*)malloc(m * n * sizeof(float)); + float* hinput = input.data_ptr(); + float* hweight = weight.data_ptr(); + float* hbias = bias.data_ptr(); + + float* dout; + float* dinput; + float* dweight; + float* dbias; + cudaMalloc((void**)&dout, sizeof(float) * m * n); + cudaMalloc((void**)&dinput, sizeof(float) * m * n); + cudaMalloc((void**)&dweight, sizeof(float) * n); + cudaMalloc((void**)&dbias, sizeof(float) * n); + + cudaMemcpy(dinput, hinput, sizeof(float) * m * n, cudaMemcpyHostToDevice); + cudaMemcpy(dweight, hweight, sizeof(float) * n, cudaMemcpyHostToDevice); + cudaMemcpy(dbias, hbias, sizeof(float) * n, cudaMemcpyHostToDevice); + + llm::kernel::invoke_layernorm_kernel( + dout, dinput, dweight, dbias, epsilon, m, n); + + cudaMemcpy(hout, dout, sizeof(float) * m * n, cudaMemcpyDeviceToHost); + + auto out = torch::from_blob(hout, {m, n}); + EXPECT_TRUE(torch::allclose(out, desired_out, 1e-3, 1e-5)); + free(hout); + cudaFree(dout); + cudaFree(dinput); + cudaFree(dweight); + cudaFree(dbias); } \ No newline at end of file From bc9f7e27f5279741e06ce65c719036fa79e26a65 Mon Sep 17 00:00:00 2001 From: Xianzhe Dong Date: Sat, 20 Apr 2024 01:17:03 -0400 Subject: [PATCH 4/5] added layernorm kernel half2 unit test using gtest library --- src/kernels/layernorm_kernels.cu | 18 +++ src/kernels/layernrom_kernels_test.cu | 212 +++++++++----------------- src/kernels/reduce_kernel_utils.cuh | 5 +- 3 files changed, 89 insertions(+), 146 deletions(-) diff --git a/src/kernels/layernorm_kernels.cu b/src/kernels/layernorm_kernels.cu index 8787d458..0a336657 100644 --- a/src/kernels/layernorm_kernels.cu +++ b/src/kernels/layernorm_kernels.cu @@ -285,4 +285,22 @@ void invoke_layernorm_kernel(float* out, layer_norm_kernel<<>>(out, input, weight, bias, epsilon, n); } +template <> +void invoke_layernorm_kernel(half* out, + const half* input, + const half* weight, + const half* bias, + const float epsilon, + int m, + int n) { + int half_n = n / 2; + half2* out_ptr = (half2*)out; + const half2* input_ptr = (const half2*)input; + const half2* weight_ptr = (const half2*)weight; + const half2* bias_ptr = (const half2*)bias; + + dim3 block(std::min(half_n, 1024)); + layer_norm_kernel + <<>>(out_ptr, input_ptr, weight_ptr, bias_ptr, epsilon, half_n); +} } // namespace llm::kernel \ No newline at end of file diff --git a/src/kernels/layernrom_kernels_test.cu b/src/kernels/layernrom_kernels_test.cu index 94c0b5b0..df472225 100644 --- a/src/kernels/layernrom_kernels_test.cu +++ b/src/kernels/layernrom_kernels_test.cu @@ -1,148 +1,10 @@ #include -#include -#include "layernorm_kernels.h" -#include #include +#include -template -void printMatrix(T* a, int m, int n) { - for (int i = 0; i < m; i++) { - for (int j = 0; j < n; j++) { - printf("%f ", (float)a[i * n + j]); - } - puts(""); - } - puts(""); -} - -template <> -void printMatrix(float* a, int m, int n) { - for (int i = 0; i < m; i++) { - for (int j = 0; j < n; j++) { - printf("%f ", a[i * n + j]); - } - puts(""); - } - puts(""); -} - -template <> -void printMatrix(half* a, int m, int n) { - for (int i = 0; i < m; i++) { - for (int j = 0; j < n; j++) { - printf("%f ", __half2float(a[i * n + j])); - } - puts(""); - } - puts(""); -} - -template <> -void printMatrix(half2* a, int m, int n) { - for (int i = 0; i < m; i++) { - for (int j = 0; j < n; j++) { - printf( - "%f %f ", __half2float(a[i * n + j].x), __half2float(a[i * n + j].y)); - } - puts(""); - } - puts(""); -} - -void layernorm_kernel_half2_test() { - float epsilon = 1e-6; - int m = 2; - int n = 2; - - half2* out = (half2*)malloc(m * n * sizeof(half2)); - half2* input = (half2*)malloc(m * n * sizeof(half2)); - half2* weight = (half2*)malloc(m * n * sizeof(half2)); - half2* bias = (half2*)malloc(m * n * sizeof(half2)); - - for (int i = 0; i < m; i++) { - for (int j = 0; j < n; j++) { - input[i * n + j] = half2(__float2half((float)(i * n + j * 2)), - __float2half((float)(i * n + j * 2 + 1))); - weight[i * n + j] = half2(__float2half(1.), __float2half(1.)); - bias[i * n + j] = half2(__float2half(0.), __float2half(0.)); - } - } - - half2* dout; - half2* dinput; - half2* dweight; - half2* dbias; - cudaMalloc((void**)&dout, sizeof(half2) * m * n); - cudaMalloc((void**)&dinput, sizeof(half2) * m * n); - cudaMalloc((void**)&dweight, sizeof(half2) * m * n); - cudaMalloc((void**)&dbias, sizeof(half2) * m * n); - - cudaMemcpy(dinput, input, sizeof(half2) * m * n, cudaMemcpyHostToDevice); - cudaMemcpy(dweight, weight, sizeof(half2) * m * n, cudaMemcpyHostToDevice); - cudaMemcpy(dbias, bias, sizeof(half2) * m * n, cudaMemcpyHostToDevice); - - llm::kernel::invoke_layernorm_kernel( - dout, dinput, dweight, dbias, epsilon, m, n); - - cudaMemcpy(out, dout, sizeof(half2) * m * n, cudaMemcpyDeviceToHost); - - printf("---------- test half2 layernorm kernel -----------\n"); - printf("input:\n"); - printMatrix(input, m, n); - printf("weights:\n"); - printMatrix(weight, m, n); - printf("bias:\n"); - printMatrix(bias, m, n); - printf("outputs:\n"); - printMatrix(out, m, n); -} - -void layernorm_kernel_float_test() { - float epsilon = 1e-6; - int m = 2; - int n = 4; - - float* out = (float*)malloc(m * n * sizeof(float)); - float* input = (float*)malloc(m * n * sizeof(float)); - float* weight = (float*)malloc(m * n * sizeof(float)); - float* bias = (float*)malloc(m * n * sizeof(float)); - - for (int i = 0; i < m; i++) { - for (int j = 0; j < n; j++) { - input[i * n + j] = (float)(i * n + j); - weight[i * n + j] = 1.; - bias[i * n + j] = 0.; - } - } - - float* dout; - float* dinput; - float* dweight; - float* dbias; - cudaMalloc((void**)&dout, sizeof(float) * m * n); - cudaMalloc((void**)&dinput, sizeof(float) * m * n); - cudaMalloc((void**)&dweight, sizeof(float) * m * n); - cudaMalloc((void**)&dbias, sizeof(float) * m * n); - - cudaMemcpy(dinput, input, sizeof(float) * m * n, cudaMemcpyHostToDevice); - cudaMemcpy(dweight, weight, sizeof(float) * m * n, cudaMemcpyHostToDevice); - cudaMemcpy(dbias, bias, sizeof(float) * m * n, cudaMemcpyHostToDevice); - - llm::kernel::invoke_layernorm_kernel( - dout, dinput, dweight, dbias, epsilon, m, n); +#include - cudaMemcpy(out, dout, sizeof(float) * m * n, cudaMemcpyDeviceToHost); - - printf("---------- test float layernorm kernel -----------\n"); - printf("input:\n"); - printMatrix(input, m, n); - printf("weights:\n"); - printMatrix(weight, m, n); - printf("bias:\n"); - printMatrix(bias, m, n); - printf("outputs:\n"); - printMatrix(out, m, n); -} +#include "layernorm_kernels.h" TEST(NormalizationKernelTest, LayernormFloatTest) { float epsilon = 1e-6; @@ -152,7 +14,10 @@ TEST(NormalizationKernelTest, LayernormFloatTest) { auto input = torch::randn({m, n}); auto weight = torch::randn({n}); auto bias = torch::randn({n}); - auto desired_out = torch::nn::functional::layer_norm(input, torch::nn::functional::LayerNormFuncOptions({n}).weight(weight).bias(bias)); + auto desired_out = torch::nn::functional::layer_norm( + input, + torch::nn::functional::LayerNormFuncOptions({n}).weight(weight).bias( + bias)); float* hout = (float*)malloc(m * n * sizeof(float)); float* hinput = input.data_ptr(); @@ -184,4 +49,65 @@ TEST(NormalizationKernelTest, LayernormFloatTest) { cudaFree(dinput); cudaFree(dweight); cudaFree(dbias); -} \ No newline at end of file +} + +TEST(NormalizationKernelTest, LayernormHalfTest) { + float epsilon = 1e-6; + int m = 4; + int n = 512; + auto input = torch::randn({m, n}); + auto weight = torch::randn({n}); + auto bias = torch::randn({n}); + auto desired_out = torch::nn::functional::layer_norm( + input, + torch::nn::functional::LayerNormFuncOptions({n}).weight(weight).bias( + bias)); + + half* hout = (half*)malloc(m * n * sizeof(half)); + half* hinput = (half*)malloc(m * n * sizeof(half)); + half* hweight = (half*)malloc(n * sizeof(half)); + half* hbias = (half*)malloc(n * sizeof(half)); + + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + hinput[i * n + j] = __float2half(input[i][j].item()); + } + } + for (int i = 0; i < weight.numel(); i++) + hweight[i] = __float2half(weight[i].item()); + for (int i = 0; i < bias.numel(); i++) + hbias[i] = __float2half(bias[i].item()); + + half* dout; + half* dinput; + half* dweight; + half* dbias; + cudaMalloc((void**)&dout, sizeof(half) * m * n); + cudaMalloc((void**)&dinput, sizeof(half) * m * n); + cudaMalloc((void**)&dweight, sizeof(half) * n); + cudaMalloc((void**)&dbias, sizeof(half) * n); + + cudaMemcpy(dinput, hinput, sizeof(half) * m * n, cudaMemcpyHostToDevice); + cudaMemcpy(dweight, hweight, sizeof(half) * n, cudaMemcpyHostToDevice); + cudaMemcpy(dbias, hbias, sizeof(half) * n, cudaMemcpyHostToDevice); + + llm::kernel::invoke_layernorm_kernel( + dout, dinput, dweight, dbias, epsilon, m, n); + + cudaMemcpy(hout, dout, sizeof(half) * m * n, cudaMemcpyDeviceToHost); + + float* float_hout = (float*)malloc(m * n * sizeof(float)); + for (int i = 0; i < m * n; i++) float_hout[i] = __half2float(hout[i]); + + auto out = torch::from_blob(float_hout, {m, n}); + EXPECT_TRUE(torch::allclose(out, desired_out, 0.05, 1e-3)); + free(hout); + free(hinput); + free(hweight); + free(hbias); + free(float_hout); + cudaFree(dout); + cudaFree(dinput); + cudaFree(dweight); + cudaFree(dbias); +} diff --git a/src/kernels/reduce_kernel_utils.cuh b/src/kernels/reduce_kernel_utils.cuh index edb9d953..16e077ad 100644 --- a/src/kernels/reduce_kernel_utils.cuh +++ b/src/kernels/reduce_kernel_utils.cuh @@ -198,9 +198,8 @@ struct TopK { // operator for cub::BlockReduce to get topk across a thread block template -__device__ __forceinline__ TopK reduce_topk_op( - const TopK& a, - const TopK& b) { +__device__ __forceinline__ TopK reduce_topk_op(const TopK& a, + const TopK& b) { TopK res = a; for (int i = 0; i < K; ++i) { res.insert(b.u[i], b.p[i]); From faaec0742d65eee9dae16ffddfb58e68ce10b8c6 Mon Sep 17 00:00:00 2001 From: Xianzhe Dong Date: Sun, 28 Apr 2024 01:51:49 -0400 Subject: [PATCH 5/5] [refactor] use torch::tensor to allocate memory in layernorm kernel unitest and just test llm::kernel::layer_norm --- src/kernels/layernorm_kernels.cu | 2 +- src/kernels/layernrom_kernels_test.cu | 101 ++++++-------------------- 2 files changed, 22 insertions(+), 81 deletions(-) diff --git a/src/kernels/layernorm_kernels.cu b/src/kernels/layernorm_kernels.cu index 0a336657..4ce7c38c 100644 --- a/src/kernels/layernorm_kernels.cu +++ b/src/kernels/layernorm_kernels.cu @@ -182,7 +182,7 @@ __global__ void layer_norm_kernel(half2* __restrict__ out, const half2* __restrict__ weight, const half2* __restrict__ bias, const float epsilon, - int n) { + int64_t n) { const int tidx = threadIdx.x; const int bidx = blockIdx.x; diff --git a/src/kernels/layernrom_kernels_test.cu b/src/kernels/layernrom_kernels_test.cu index df472225..c50c18cf 100644 --- a/src/kernels/layernrom_kernels_test.cu +++ b/src/kernels/layernrom_kernels_test.cu @@ -11,103 +11,44 @@ TEST(NormalizationKernelTest, LayernormFloatTest) { int m = 32; int n = 512; - auto input = torch::randn({m, n}); - auto weight = torch::randn({n}); - auto bias = torch::randn({n}); + auto out = torch::zeros({m, n}, torch::TensorOptions().device(torch::kCUDA)); + auto input = + torch::randn({m, n}, torch::TensorOptions().device(torch::kCUDA)); + auto weight = torch::randn({n}, torch::TensorOptions().device(torch::kCUDA)); + auto bias = torch::randn({n}, torch::TensorOptions().device(torch::kCUDA)); auto desired_out = torch::nn::functional::layer_norm( input, torch::nn::functional::LayerNormFuncOptions({n}).weight(weight).bias( bias)); - float* hout = (float*)malloc(m * n * sizeof(float)); - float* hinput = input.data_ptr(); - float* hweight = weight.data_ptr(); - float* hbias = bias.data_ptr(); + llm::kernel::layer_norm(out, input, weight, bias, epsilon); - float* dout; - float* dinput; - float* dweight; - float* dbias; - cudaMalloc((void**)&dout, sizeof(float) * m * n); - cudaMalloc((void**)&dinput, sizeof(float) * m * n); - cudaMalloc((void**)&dweight, sizeof(float) * n); - cudaMalloc((void**)&dbias, sizeof(float) * n); - - cudaMemcpy(dinput, hinput, sizeof(float) * m * n, cudaMemcpyHostToDevice); - cudaMemcpy(dweight, hweight, sizeof(float) * n, cudaMemcpyHostToDevice); - cudaMemcpy(dbias, hbias, sizeof(float) * n, cudaMemcpyHostToDevice); - - llm::kernel::invoke_layernorm_kernel( - dout, dinput, dweight, dbias, epsilon, m, n); - - cudaMemcpy(hout, dout, sizeof(float) * m * n, cudaMemcpyDeviceToHost); - - auto out = torch::from_blob(hout, {m, n}); EXPECT_TRUE(torch::allclose(out, desired_out, 1e-3, 1e-5)); - free(hout); - cudaFree(dout); - cudaFree(dinput); - cudaFree(dweight); - cudaFree(dbias); } TEST(NormalizationKernelTest, LayernormHalfTest) { float epsilon = 1e-6; int m = 4; int n = 512; - auto input = torch::randn({m, n}); - auto weight = torch::randn({n}); - auto bias = torch::randn({n}); + + auto out = torch::zeros( + {m, n}, + torch::TensorOptions().dtype(at::ScalarType::Half).device(torch::kCUDA)); + auto input = torch::randn( + {m, n}, + torch::TensorOptions().dtype(at::ScalarType::Half).device(torch::kCUDA)); + auto weight = torch::randn( + {n}, + torch::TensorOptions().dtype(at::ScalarType::Half).device(torch::kCUDA)); + auto bias = torch::randn( + {n}, + torch::TensorOptions().dtype(at::ScalarType::Half).device(torch::kCUDA)); auto desired_out = torch::nn::functional::layer_norm( input, torch::nn::functional::LayerNormFuncOptions({n}).weight(weight).bias( bias)); - half* hout = (half*)malloc(m * n * sizeof(half)); - half* hinput = (half*)malloc(m * n * sizeof(half)); - half* hweight = (half*)malloc(n * sizeof(half)); - half* hbias = (half*)malloc(n * sizeof(half)); - - for (int i = 0; i < m; i++) { - for (int j = 0; j < n; j++) { - hinput[i * n + j] = __float2half(input[i][j].item()); - } - } - for (int i = 0; i < weight.numel(); i++) - hweight[i] = __float2half(weight[i].item()); - for (int i = 0; i < bias.numel(); i++) - hbias[i] = __float2half(bias[i].item()); + llm::kernel::layer_norm(out, input, weight, bias, epsilon); - half* dout; - half* dinput; - half* dweight; - half* dbias; - cudaMalloc((void**)&dout, sizeof(half) * m * n); - cudaMalloc((void**)&dinput, sizeof(half) * m * n); - cudaMalloc((void**)&dweight, sizeof(half) * n); - cudaMalloc((void**)&dbias, sizeof(half) * n); - - cudaMemcpy(dinput, hinput, sizeof(half) * m * n, cudaMemcpyHostToDevice); - cudaMemcpy(dweight, hweight, sizeof(half) * n, cudaMemcpyHostToDevice); - cudaMemcpy(dbias, hbias, sizeof(half) * n, cudaMemcpyHostToDevice); - - llm::kernel::invoke_layernorm_kernel( - dout, dinput, dweight, dbias, epsilon, m, n); - - cudaMemcpy(hout, dout, sizeof(half) * m * n, cudaMemcpyDeviceToHost); - - float* float_hout = (float*)malloc(m * n * sizeof(float)); - for (int i = 0; i < m * n; i++) float_hout[i] = __half2float(hout[i]); - - auto out = torch::from_blob(float_hout, {m, n}); EXPECT_TRUE(torch::allclose(out, desired_out, 0.05, 1e-3)); - free(hout); - free(hinput); - free(hweight); - free(hbias); - free(float_hout); - cudaFree(dout); - cudaFree(dinput); - cudaFree(dweight); - cudaFree(dbias); -} +} \ No newline at end of file