Skip to content

Commit 14a99f7

Browse files
author
Xianzhe Dong
committed
[ut] add layernorm kernel unitest
1 parent bd5b91f commit 14a99f7

File tree

5 files changed

+236
-69
lines changed

5 files changed

+236
-69
lines changed

src/kernels/CMakeLists.txt

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
include(cc_library)
2+
include(cc_binary)
23

34
cc_library(
45
NAME
@@ -72,6 +73,24 @@ cc_library(
7273
torch
7374
)
7475

76+
# cc_test(
77+
# NAME
78+
# layernorm_kernels_test
79+
# SRCS
80+
# layernrom_kernels_test.cu
81+
# layernorm_kernels.cu
82+
# DEPS
83+
# DEFINES
84+
# )
85+
cc_binary(
86+
NAME
87+
layernorm_kernels_test
88+
SRCS
89+
layernrom_kernels_test.cu
90+
layernorm_kernels.cu
91+
DEPS
92+
torch
93+
)
94+
7595
add_subdirectory(flash_attn)
7696
add_subdirectory(flash_infer)
77-

src/kernels/layernorm_kernels.cu

Lines changed: 28 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
#include <torch/torch.h>
33

44
#include "dispatch.h"
5-
#include "reduce_kernel_utils.cuh"
65
#include "layernorm_kernels.h"
6+
#include "reduce_kernel_utils.cuh"
77

88
namespace llm::kernel {
99

@@ -178,11 +178,11 @@ __global__ void layer_norm_kernel(T* __restrict__ out,
178178
// The mean and standard-deviation are calculated over the last dimension
179179
template <>
180180
__global__ void layer_norm_kernel<half2>(half2* __restrict__ out,
181-
const half2* __restrict__ input,
182-
const half2* __restrict__ weight,
183-
const half2* __restrict__ bias,
184-
const float epsilon,
185-
int n) {
181+
const half2* __restrict__ input,
182+
const half2* __restrict__ weight,
183+
const half2* __restrict__ bias,
184+
const float epsilon,
185+
int n) {
186186
const int tidx = threadIdx.x;
187187
const int bidx = blockIdx.x;
188188

@@ -209,7 +209,6 @@ __global__ void layer_norm_kernel<half2>(half2* __restrict__ out,
209209
}
210210
variance = block_reduce_sum<half2>(variance);
211211
if (tidx == 0) {
212-
// const half2 e = make_half2(__float2half(epsilon), __float2half(epsilon));
213212
s_variance = __hadd(variance.x, variance.y);
214213
s_variance = __hdiv(s_variance, __float2half((float)n * 2));
215214
s_variance = __hadd(s_variance, __float2half(epsilon));
@@ -219,16 +218,11 @@ __global__ void layer_norm_kernel<half2>(half2* __restrict__ out,
219218

220219
for (int i = tidx; i < n; i += blockDim.x) {
221220
const int idx = bidx * n + i;
222-
// float local_out =
223-
// (__ldg(&input[idx]) - s_mean) * s_variance * __ldg(&weight[i]);
224-
// if (bias != nullptr) {
225-
// local_out += __ldg(&bias[i]);
226-
// }
227221
half2 local_out = __ldg(&input[idx]);
228222
local_out = __hsub2(local_out, make_half2(s_mean, s_mean));
229223
local_out = __hmul2(local_out, make_half2(s_variance, s_variance));
230224
local_out = __hmul2(local_out, __ldg(&weight[i]));
231-
if (bias != nullptr){
225+
if (bias != nullptr) {
232226
local_out = __hadd2(local_out, __ldg(&bias[i]));
233227
}
234228
out[idx] = local_out;
@@ -261,52 +255,34 @@ void layer_norm(torch::Tensor& out,
261255

262256
template <typename T>
263257
void invoke_layernorm_kernel(T* out,
264-
const T* input,
265-
const T* weight,
266-
const T* bias,
267-
const float epsilon,
268-
int m,
269-
int n) {
258+
const T* input,
259+
const T* weight,
260+
const T* bias,
261+
const float epsilon,
262+
int m,
263+
int n) {
270264
layer_norm_kernel<T><<<m, n>>>(out, input, weight, bias, epsilon, n);
271265
}
272266

273267
template <>
274268
void invoke_layernorm_kernel<half2>(half2* out,
275-
const half2* input,
276-
const half2* weight,
277-
const half2* bias,
278-
const float epsilon,
279-
int m,
280-
int n) {
269+
const half2* input,
270+
const half2* weight,
271+
const half2* bias,
272+
const float epsilon,
273+
int m,
274+
int n) {
281275
layer_norm_kernel<half2><<<m, n>>>(out, input, weight, bias, epsilon, n);
282276
}
283277
template <>
284278
void invoke_layernorm_kernel<float>(float* out,
285-
const float* input,
286-
const float* weight,
287-
const float* bias,
288-
const float epsilon,
289-
int m,
290-
int n) {
279+
const float* input,
280+
const float* weight,
281+
const float* bias,
282+
const float epsilon,
283+
int m,
284+
int n) {
291285
layer_norm_kernel<float><<<m, n>>>(out, input, weight, bias, epsilon, n);
292-
}
293-
// void invoke_float_layernorm_kernel(float* out,
294-
// const float* input,
295-
// const float* weight,
296-
// const float* bias,
297-
// const float epsilon,
298-
// int m,
299-
// int n){
300-
// layer_norm_kernel<float><<<m, n>>>(out, input, weight, bias, epsilon, n);
301-
// }
302-
303-
// void invoke_half2_layernorm_kernel(half2* out,
304-
// const half2* input,
305-
// const half2* weight,
306-
// const half2* bias,
307-
// const float epsilon,
308-
// int m,
309-
// int n){
310-
// layer_norm_kernel<half2><<<m, n>>>(out, input, weight, bias, epsilon, n);
311-
// }
312-
} // namespace llm::kernel
286+
}
287+
288+
} // namespace llm::kernel

src/kernels/layernorm_kernels.h

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,4 @@ void invoke_layernorm_kernel(T* out,
2828
const float epsilon,
2929
int m,
3030
int n);
31-
32-
// void invoke_float_layernorm_kernel(float* out,
33-
// const float* input,
34-
// const float* weight,
35-
// const float* bias,
36-
// const float epsilon,
37-
// int m,
38-
// int n);
39-
40-
// void invoke_half2_layernorm_kernel(half2* out,
41-
// const half2* input,
42-
// const half2* weight,
43-
// const half2* bias,
44-
// const float epsilon,
45-
// int m,
46-
// int n);
4731
} // namespace llm::kernel
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
#include <cuda_fp16.h>
2+
3+
#include <cstdio>
4+
5+
#include "layernorm_kernels.h"
6+
7+
template <typename T>
8+
void printMatrix(T* a, int m, int n) {
9+
for (int i = 0; i < m; i++) {
10+
for (int j = 0; j < n; j++) {
11+
printf("%f ", (float)a[i * n + j]);
12+
}
13+
puts("");
14+
}
15+
puts("");
16+
}
17+
18+
template <>
19+
void printMatrix<half2>(half2* a, int m, int n) {
20+
for (int i = 0; i < m; i++) {
21+
for (int j = 0; j < n; j++) {
22+
printf(
23+
"%f %f ", __half2float(a[i * n + j].x), __half2float(a[i * n + j].y));
24+
}
25+
puts("");
26+
}
27+
puts("");
28+
}
29+
30+
void layernorm_kernel_half2_test() {
31+
float epsilon = 1e-6;
32+
int m = 2;
33+
int n = 2;
34+
35+
half2* out = (half2*)malloc(m * n * sizeof(half2));
36+
half2* input = (half2*)malloc(m * n * sizeof(half2));
37+
half2* weight = (half2*)malloc(m * n * sizeof(half2));
38+
half2* bias = (half2*)malloc(m * n * sizeof(half2));
39+
40+
for (int i = 0; i < m; i++) {
41+
for (int j = 0; j < n; j++) {
42+
input[i * n + j] = half2(__float2half((float)(i * n + j * 2)),
43+
__float2half((float)(i * n + j * 2 + 1)));
44+
weight[i * n + j] = half2(__float2half(1.), __float2half(1.));
45+
bias[i * n + j] = half2(__float2half(0.), __float2half(0.));
46+
}
47+
}
48+
49+
half2* dout;
50+
half2* dinput;
51+
half2* dweight;
52+
half2* dbias;
53+
cudaMalloc((void**)&dout, sizeof(half2) * m * n);
54+
cudaMalloc((void**)&dinput, sizeof(half2) * m * n);
55+
cudaMalloc((void**)&dweight, sizeof(half2) * m * n);
56+
cudaMalloc((void**)&dbias, sizeof(half2) * m * n);
57+
58+
cudaMemcpy(dinput, input, sizeof(half2) * m * n, cudaMemcpyHostToDevice);
59+
cudaMemcpy(dweight, weight, sizeof(half2) * m * n, cudaMemcpyHostToDevice);
60+
cudaMemcpy(dbias, bias, sizeof(half2) * m * n, cudaMemcpyHostToDevice);
61+
62+
llm::kernel::invoke_layernorm_kernel<half2>(
63+
dout, dinput, dweight, dbias, epsilon, m, n);
64+
65+
cudaMemcpy(out, dout, sizeof(half2) * m * n, cudaMemcpyDeviceToHost);
66+
67+
printf("---------- test half2 layernorm kernel -----------\n");
68+
printf("input:\n");
69+
printMatrix<half2>(input, m, n);
70+
printf("weights:\n");
71+
printMatrix<half2>(weight, m, n);
72+
printf("bias:\n");
73+
printMatrix<half2>(bias, m, n);
74+
printf("outputs:\n");
75+
printMatrix<half2>(out, m, n);
76+
}
77+
78+
void layernorm_kernel_float_test() {
79+
float epsilon = 1e-6;
80+
int m = 2;
81+
int n = 4;
82+
83+
float* out = (float*)malloc(m * n * sizeof(float));
84+
float* input = (float*)malloc(m * n * sizeof(float));
85+
float* weight = (float*)malloc(m * n * sizeof(float));
86+
float* bias = (float*)malloc(m * n * sizeof(float));
87+
88+
for (int i = 0; i < m; i++) {
89+
for (int j = 0; j < n; j++) {
90+
input[i * n + j] = (float)(i * n + j);
91+
weight[i * n + j] = 1.;
92+
bias[i * n + j] = 0.;
93+
}
94+
}
95+
96+
float* dout;
97+
float* dinput;
98+
float* dweight;
99+
float* dbias;
100+
cudaMalloc((void**)&dout, sizeof(float) * m * n);
101+
cudaMalloc((void**)&dinput, sizeof(float) * m * n);
102+
cudaMalloc((void**)&dweight, sizeof(float) * m * n);
103+
cudaMalloc((void**)&dbias, sizeof(float) * m * n);
104+
105+
cudaMemcpy(dinput, input, sizeof(float) * m * n, cudaMemcpyHostToDevice);
106+
cudaMemcpy(dweight, weight, sizeof(float) * m * n, cudaMemcpyHostToDevice);
107+
cudaMemcpy(dbias, bias, sizeof(float) * m * n, cudaMemcpyHostToDevice);
108+
109+
llm::kernel::invoke_layernorm_kernel<float>(
110+
dout, dinput, dweight, dbias, epsilon, m, n);
111+
112+
cudaMemcpy(out, dout, sizeof(float) * m * n, cudaMemcpyDeviceToHost);
113+
114+
printf("---------- test float layernorm kernel -----------\n");
115+
printf("input:\n");
116+
printMatrix<float>(input, m, n);
117+
printf("weights:\n");
118+
printMatrix<float>(weight, m, n);
119+
printf("bias:\n");
120+
printMatrix<float>(bias, m, n);
121+
printf("outputs:\n");
122+
printMatrix<float>(out, m, n);
123+
}
124+
125+
int main() {
126+
layernorm_kernel_float_test();
127+
layernorm_kernel_half2_test();
128+
return 0;
129+
}

src/kernels/reduce_kernel_utils.cuh

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,36 @@ __inline__ __device__ T warp_reduce_sum(T val) {
2424
return val;
2525
}
2626

27+
// performs a parallel reduction operation across the threads within a single
28+
// warp (32 threads).
29+
// - val: The value to be reduced within a warp.
30+
template <>
31+
__inline__ __device__ half warp_reduce_sum<half>(half val) {
32+
// uses bitwise operations to perform a parallel reduction
33+
// within a warp. The 'mask' is right-shifted by 1 in each iteration
34+
// until it reaches zero, effectively summing all values within the warp.
35+
#pragma unroll
36+
for (int mask = 16; mask > 0; mask >>= 1) {
37+
val = __hadd(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32));
38+
}
39+
return val;
40+
}
41+
42+
// performs a parallel reduction operation across the threads within a single
43+
// warp (32 threads).
44+
// - val: The value to be reduced within a warp.
45+
template <>
46+
__inline__ __device__ half2 warp_reduce_sum<half2>(half2 val) {
47+
// uses bitwise operations to perform a parallel reduction
48+
// within a warp. The 'mask' is right-shifted by 1 in each iteration
49+
// until it reaches zero, effectively summing all values within the warp.
50+
#pragma unroll
51+
for (int mask = 16; mask > 0; mask >>= 1) {
52+
val = __hadd2(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32));
53+
}
54+
return val;
55+
}
56+
2757
// performs a parallel reduction operation across the threads within a single
2858
// warp (32 threads).
2959
// - val: The value to be reduced within a warp.
@@ -63,6 +93,35 @@ __inline__ __device__ T block_reduce_sum(T val) {
6393
return val;
6494
}
6595

96+
/* Calculate the sum of all elements in a thread block */
97+
template <>
98+
__inline__ __device__ half2 block_reduce_sum<half2>(half2 val) {
99+
// up to 32 warps in a block
100+
static __shared__ half2 shared[32];
101+
// lane id in a warp
102+
int lane = threadIdx.x & 0x1f;
103+
// wrap id: threadIdx.x / 32
104+
int wid = threadIdx.x >> 5;
105+
106+
// perform a parallel reduction across the threads within each warp
107+
val = warp_reduce_sum<half2>(val);
108+
109+
if (lane == 0) {
110+
// write the sum of each warp to shared memory
111+
shared[wid] = val;
112+
}
113+
// wait for all warps to finish
114+
__syncthreads();
115+
116+
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
117+
// blockDim.x is not divided by 32
118+
val = (threadIdx.x < (blockDim.x / 32.f))
119+
? shared[lane]
120+
: make_half2(__float2half(0.0f), __float2half(0.0f));
121+
val = warp_reduce_sum<half2>(val);
122+
return val;
123+
}
124+
66125
/* Calculate the max of all elements in a thread block */
67126
template <typename T>
68127
__inline__ __device__ T block_reduce_max(T val) {

0 commit comments

Comments
 (0)