Skip to content

Commit bd5b91f

Browse files
author
Xianzhe Dong
committed
[op] optimize layernorm kernel for half2 type
1 parent 0408bc1 commit bd5b91f

File tree

2 files changed

+136
-0
lines changed

2 files changed

+136
-0
lines changed

src/kernels/layernorm_kernels.cu

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "dispatch.h"
55
#include "reduce_kernel_utils.cuh"
6+
#include "layernorm_kernels.h"
67

78
namespace llm::kernel {
89

@@ -173,6 +174,67 @@ __global__ void layer_norm_kernel(T* __restrict__ out,
173174
}
174175
}
175176

177+
// equation: x -> (x - E[x]) / sqrt(Var[x] + eps) * w + b
178+
// The mean and standard-deviation are calculated over the last dimension
179+
template <>
180+
__global__ void layer_norm_kernel<half2>(half2* __restrict__ out,
181+
const half2* __restrict__ input,
182+
const half2* __restrict__ weight,
183+
const half2* __restrict__ bias,
184+
const float epsilon,
185+
int n) {
186+
const int tidx = threadIdx.x;
187+
const int bidx = blockIdx.x;
188+
189+
__shared__ half s_mean;
190+
__shared__ half s_variance;
191+
half2 mean = make_half2(__float2half(0.0f), __float2half(0.0f));
192+
half2 variance = make_half2(__float2half(0.0f), __float2half(0.0f));
193+
194+
// calculate mean of the input.
195+
for (int i = tidx; i < n; i += blockDim.x) {
196+
const int idx = bidx * n + i;
197+
mean = __hadd2(mean, __ldg(&input[idx]));
198+
}
199+
mean = block_reduce_sum<half2>(mean);
200+
if (tidx == 0) {
201+
s_mean = __hdiv(__hadd(mean.x, mean.y), __float2half((float)n * 2));
202+
}
203+
__syncthreads();
204+
205+
// calculate variance of the input.
206+
for (int i = tidx; i < n; i += blockDim.x) {
207+
const half2 x = __hsub2(input[bidx * n + i], make_half2(s_mean, s_mean));
208+
variance = __hadd2(variance, __hmul2(x, x));
209+
}
210+
variance = block_reduce_sum<half2>(variance);
211+
if (tidx == 0) {
212+
// const half2 e = make_half2(__float2half(epsilon), __float2half(epsilon));
213+
s_variance = __hadd(variance.x, variance.y);
214+
s_variance = __hdiv(s_variance, __float2half((float)n * 2));
215+
s_variance = __hadd(s_variance, __float2half(epsilon));
216+
s_variance = hrsqrt(s_variance);
217+
}
218+
__syncthreads();
219+
220+
for (int i = tidx; i < n; i += blockDim.x) {
221+
const int idx = bidx * n + i;
222+
// float local_out =
223+
// (__ldg(&input[idx]) - s_mean) * s_variance * __ldg(&weight[i]);
224+
// if (bias != nullptr) {
225+
// local_out += __ldg(&bias[i]);
226+
// }
227+
half2 local_out = __ldg(&input[idx]);
228+
local_out = __hsub2(local_out, make_half2(s_mean, s_mean));
229+
local_out = __hmul2(local_out, make_half2(s_variance, s_variance));
230+
local_out = __hmul2(local_out, __ldg(&weight[i]));
231+
if (bias != nullptr){
232+
local_out = __hadd2(local_out, __ldg(&bias[i]));
233+
}
234+
out[idx] = local_out;
235+
}
236+
}
237+
176238
void layer_norm(torch::Tensor& out,
177239
torch::Tensor input,
178240
torch::Tensor weight,
@@ -197,4 +259,54 @@ void layer_norm(torch::Tensor& out,
197259
});
198260
}
199261

262+
template <typename T>
263+
void invoke_layernorm_kernel(T* out,
264+
const T* input,
265+
const T* weight,
266+
const T* bias,
267+
const float epsilon,
268+
int m,
269+
int n) {
270+
layer_norm_kernel<T><<<m, n>>>(out, input, weight, bias, epsilon, n);
271+
}
272+
273+
template <>
274+
void invoke_layernorm_kernel<half2>(half2* out,
275+
const half2* input,
276+
const half2* weight,
277+
const half2* bias,
278+
const float epsilon,
279+
int m,
280+
int n) {
281+
layer_norm_kernel<half2><<<m, n>>>(out, input, weight, bias, epsilon, n);
282+
}
283+
template <>
284+
void invoke_layernorm_kernel<float>(float* out,
285+
const float* input,
286+
const float* weight,
287+
const float* bias,
288+
const float epsilon,
289+
int m,
290+
int n) {
291+
layer_norm_kernel<float><<<m, n>>>(out, input, weight, bias, epsilon, n);
292+
}
293+
// void invoke_float_layernorm_kernel(float* out,
294+
// const float* input,
295+
// const float* weight,
296+
// const float* bias,
297+
// const float epsilon,
298+
// int m,
299+
// int n){
300+
// layer_norm_kernel<float><<<m, n>>>(out, input, weight, bias, epsilon, n);
301+
// }
302+
303+
// void invoke_half2_layernorm_kernel(half2* out,
304+
// const half2* input,
305+
// const half2* weight,
306+
// const half2* bias,
307+
// const float epsilon,
308+
// int m,
309+
// int n){
310+
// layer_norm_kernel<half2><<<m, n>>>(out, input, weight, bias, epsilon, n);
311+
// }
200312
} // namespace llm::kernel

src/kernels/layernorm_kernels.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,28 @@ void layer_norm(torch::Tensor& out,
2020
torch::Tensor bias,
2121
float epsilon);
2222

23+
template <typename T>
24+
void invoke_layernorm_kernel(T* out,
25+
const T* input,
26+
const T* weight,
27+
const T* bias,
28+
const float epsilon,
29+
int m,
30+
int n);
31+
32+
// void invoke_float_layernorm_kernel(float* out,
33+
// const float* input,
34+
// const float* weight,
35+
// const float* bias,
36+
// const float epsilon,
37+
// int m,
38+
// int n);
39+
40+
// void invoke_half2_layernorm_kernel(half2* out,
41+
// const half2* input,
42+
// const half2* weight,
43+
// const half2* bias,
44+
// const float epsilon,
45+
// int m,
46+
// int n);
2347
} // namespace llm::kernel

0 commit comments

Comments
 (0)