33
44#include " dispatch.h"
55#include " reduce_kernel_utils.cuh"
6+ #include " layernorm_kernels.h"
67
78namespace llm ::kernel {
89
@@ -173,6 +174,67 @@ __global__ void layer_norm_kernel(T* __restrict__ out,
173174 }
174175}
175176
177+ // equation: x -> (x - E[x]) / sqrt(Var[x] + eps) * w + b
178+ // The mean and standard-deviation are calculated over the last dimension
179+ template <>
180+ __global__ void layer_norm_kernel<half2>(half2* __restrict__ out,
181+ const half2* __restrict__ input,
182+ const half2* __restrict__ weight,
183+ const half2* __restrict__ bias,
184+ const float epsilon,
185+ int n) {
186+ const int tidx = threadIdx .x ;
187+ const int bidx = blockIdx .x ;
188+
189+ __shared__ half s_mean;
190+ __shared__ half s_variance;
191+ half2 mean = make_half2 (__float2half (0 .0f ), __float2half (0 .0f ));
192+ half2 variance = make_half2 (__float2half (0 .0f ), __float2half (0 .0f ));
193+
194+ // calculate mean of the input.
195+ for (int i = tidx; i < n; i += blockDim .x ) {
196+ const int idx = bidx * n + i;
197+ mean = __hadd2 (mean, __ldg (&input[idx]));
198+ }
199+ mean = block_reduce_sum<half2>(mean);
200+ if (tidx == 0 ) {
201+ s_mean = __hdiv (__hadd (mean.x , mean.y ), __float2half ((float )n * 2 ));
202+ }
203+ __syncthreads ();
204+
205+ // calculate variance of the input.
206+ for (int i = tidx; i < n; i += blockDim .x ) {
207+ const half2 x = __hsub2 (input[bidx * n + i], make_half2 (s_mean, s_mean));
208+ variance = __hadd2 (variance, __hmul2 (x, x));
209+ }
210+ variance = block_reduce_sum<half2>(variance);
211+ if (tidx == 0 ) {
212+ // const half2 e = make_half2(__float2half(epsilon), __float2half(epsilon));
213+ s_variance = __hadd (variance.x , variance.y );
214+ s_variance = __hdiv (s_variance, __float2half ((float )n * 2 ));
215+ s_variance = __hadd (s_variance, __float2half (epsilon));
216+ s_variance = hrsqrt (s_variance);
217+ }
218+ __syncthreads ();
219+
220+ for (int i = tidx; i < n; i += blockDim .x ) {
221+ const int idx = bidx * n + i;
222+ // float local_out =
223+ // (__ldg(&input[idx]) - s_mean) * s_variance * __ldg(&weight[i]);
224+ // if (bias != nullptr) {
225+ // local_out += __ldg(&bias[i]);
226+ // }
227+ half2 local_out = __ldg (&input[idx]);
228+ local_out = __hsub2 (local_out, make_half2 (s_mean, s_mean));
229+ local_out = __hmul2 (local_out, make_half2 (s_variance, s_variance));
230+ local_out = __hmul2 (local_out, __ldg (&weight[i]));
231+ if (bias != nullptr ){
232+ local_out = __hadd2 (local_out, __ldg (&bias[i]));
233+ }
234+ out[idx] = local_out;
235+ }
236+ }
237+
176238void layer_norm (torch::Tensor& out,
177239 torch::Tensor input,
178240 torch::Tensor weight,
@@ -197,4 +259,54 @@ void layer_norm(torch::Tensor& out,
197259 });
198260}
199261
262+ template <typename T>
263+ void invoke_layernorm_kernel (T* out,
264+ const T* input,
265+ const T* weight,
266+ const T* bias,
267+ const float epsilon,
268+ int m,
269+ int n) {
270+ layer_norm_kernel<T><<<m, n>>> (out, input, weight, bias, epsilon, n);
271+ }
272+
273+ template <>
274+ void invoke_layernorm_kernel<half2>(half2* out,
275+ const half2* input,
276+ const half2* weight,
277+ const half2* bias,
278+ const float epsilon,
279+ int m,
280+ int n) {
281+ layer_norm_kernel<half2><<<m, n>>> (out, input, weight, bias, epsilon, n);
282+ }
283+ template <>
284+ void invoke_layernorm_kernel<float >(float * out,
285+ const float * input,
286+ const float * weight,
287+ const float * bias,
288+ const float epsilon,
289+ int m,
290+ int n) {
291+ layer_norm_kernel<float ><<<m, n>>> (out, input, weight, bias, epsilon, n);
292+ }
293+ // void invoke_float_layernorm_kernel(float* out,
294+ // const float* input,
295+ // const float* weight,
296+ // const float* bias,
297+ // const float epsilon,
298+ // int m,
299+ // int n){
300+ // layer_norm_kernel<float><<<m, n>>>(out, input, weight, bias, epsilon, n);
301+ // }
302+
303+ // void invoke_half2_layernorm_kernel(half2* out,
304+ // const half2* input,
305+ // const half2* weight,
306+ // const half2* bias,
307+ // const float epsilon,
308+ // int m,
309+ // int n){
310+ // layer_norm_kernel<half2><<<m, n>>>(out, input, weight, bias, epsilon, n);
311+ // }
200312} // namespace llm::kernel
0 commit comments