1+ #include < cuda_fp16.h>
2+
3+ #include < cstdio>
4+
5+ #include " layernorm_kernels.h"
6+
7+ template <typename T>
8+ void printMatrix (T* a, int m, int n) {
9+ for (int i = 0 ; i < m; i++) {
10+ for (int j = 0 ; j < n; j++) {
11+ printf (" %f " , (float )a[i * n + j]);
12+ }
13+ puts (" " );
14+ }
15+ puts (" " );
16+ }
17+
18+ template <>
19+ void printMatrix<half2>(half2* a, int m, int n) {
20+ for (int i = 0 ; i < m; i++) {
21+ for (int j = 0 ; j < n; j++) {
22+ printf (
23+ " %f %f " , __half2float (a[i * n + j].x ), __half2float (a[i * n + j].y ));
24+ }
25+ puts (" " );
26+ }
27+ puts (" " );
28+ }
29+
30+ void layernorm_kernel_half2_test () {
31+ float epsilon = 1e-6 ;
32+ int m = 2 ;
33+ int n = 2 ;
34+
35+ half2* out = (half2*)malloc (m * n * sizeof (half2));
36+ half2* input = (half2*)malloc (m * n * sizeof (half2));
37+ half2* weight = (half2*)malloc (m * n * sizeof (half2));
38+ half2* bias = (half2*)malloc (m * n * sizeof (half2));
39+
40+ for (int i = 0 ; i < m; i++) {
41+ for (int j = 0 ; j < n; j++) {
42+ input[i * n + j] = half2 (__float2half ((float )(i * n + j * 2 )),
43+ __float2half ((float )(i * n + j * 2 + 1 )));
44+ weight[i * n + j] = half2 (__float2half (1 .), __float2half (1 .));
45+ bias[i * n + j] = half2 (__float2half (0 .), __float2half (0 .));
46+ }
47+ }
48+
49+ half2* dout;
50+ half2* dinput;
51+ half2* dweight;
52+ half2* dbias;
53+ cudaMalloc ((void **)&dout, sizeof (half2) * m * n);
54+ cudaMalloc ((void **)&dinput, sizeof (half2) * m * n);
55+ cudaMalloc ((void **)&dweight, sizeof (half2) * m * n);
56+ cudaMalloc ((void **)&dbias, sizeof (half2) * m * n);
57+
58+ cudaMemcpy (dinput, input, sizeof (half2) * m * n, cudaMemcpyHostToDevice);
59+ cudaMemcpy (dweight, weight, sizeof (half2) * m * n, cudaMemcpyHostToDevice);
60+ cudaMemcpy (dbias, bias, sizeof (half2) * m * n, cudaMemcpyHostToDevice);
61+
62+ llm::kernel::invoke_layernorm_kernel<half2>(
63+ dout, dinput, dweight, dbias, epsilon, m, n);
64+
65+ cudaMemcpy (out, dout, sizeof (half2) * m * n, cudaMemcpyDeviceToHost);
66+
67+ printf (" ---------- test half2 layernorm kernel -----------\n " );
68+ printf (" input:\n " );
69+ printMatrix<half2>(input, m, n);
70+ printf (" weights:\n " );
71+ printMatrix<half2>(weight, m, n);
72+ printf (" bias:\n " );
73+ printMatrix<half2>(bias, m, n);
74+ printf (" outputs:\n " );
75+ printMatrix<half2>(out, m, n);
76+ }
77+
78+ void layernorm_kernel_float_test () {
79+ float epsilon = 1e-6 ;
80+ int m = 2 ;
81+ int n = 4 ;
82+
83+ float * out = (float *)malloc (m * n * sizeof (float ));
84+ float * input = (float *)malloc (m * n * sizeof (float ));
85+ float * weight = (float *)malloc (m * n * sizeof (float ));
86+ float * bias = (float *)malloc (m * n * sizeof (float ));
87+
88+ for (int i = 0 ; i < m; i++) {
89+ for (int j = 0 ; j < n; j++) {
90+ input[i * n + j] = (float )(i * n + j);
91+ weight[i * n + j] = 1 .;
92+ bias[i * n + j] = 0 .;
93+ }
94+ }
95+
96+ float * dout;
97+ float * dinput;
98+ float * dweight;
99+ float * dbias;
100+ cudaMalloc ((void **)&dout, sizeof (float ) * m * n);
101+ cudaMalloc ((void **)&dinput, sizeof (float ) * m * n);
102+ cudaMalloc ((void **)&dweight, sizeof (float ) * m * n);
103+ cudaMalloc ((void **)&dbias, sizeof (float ) * m * n);
104+
105+ cudaMemcpy (dinput, input, sizeof (float ) * m * n, cudaMemcpyHostToDevice);
106+ cudaMemcpy (dweight, weight, sizeof (float ) * m * n, cudaMemcpyHostToDevice);
107+ cudaMemcpy (dbias, bias, sizeof (float ) * m * n, cudaMemcpyHostToDevice);
108+
109+ llm::kernel::invoke_layernorm_kernel<float >(
110+ dout, dinput, dweight, dbias, epsilon, m, n);
111+
112+ cudaMemcpy (out, dout, sizeof (float ) * m * n, cudaMemcpyDeviceToHost);
113+
114+ printf (" ---------- test float layernorm kernel -----------\n " );
115+ printf (" input:\n " );
116+ printMatrix<float >(input, m, n);
117+ printf (" weights:\n " );
118+ printMatrix<float >(weight, m, n);
119+ printf (" bias:\n " );
120+ printMatrix<float >(bias, m, n);
121+ printf (" outputs:\n " );
122+ printMatrix<float >(out, m, n);
123+ }
124+
125+ int main () {
126+ layernorm_kernel_float_test ();
127+ layernorm_kernel_half2_test ();
128+ return 0 ;
129+ }
0 commit comments