11#include < cuda_fp16.h>
2- #include < cstdio>
3- #include " layernorm_kernels.h"
4- #include < torch/nn/functional.h>
52#include < gtest/gtest.h>
3+ #include < torch/nn/functional.h>
64
7- template <typename T>
8- void printMatrix (T* a, int m, int n) {
9- for (int i = 0 ; i < m; i++) {
10- for (int j = 0 ; j < n; j++) {
11- printf (" %f " , (float )a[i * n + j]);
12- }
13- puts (" " );
14- }
15- puts (" " );
16- }
17-
18- template <>
19- void printMatrix<float >(float * a, int m, int n) {
20- for (int i = 0 ; i < m; i++) {
21- for (int j = 0 ; j < n; j++) {
22- printf (" %f " , a[i * n + j]);
23- }
24- puts (" " );
25- }
26- puts (" " );
27- }
28-
29- template <>
30- void printMatrix<half>(half* a, int m, int n) {
31- for (int i = 0 ; i < m; i++) {
32- for (int j = 0 ; j < n; j++) {
33- printf (" %f " , __half2float (a[i * n + j]));
34- }
35- puts (" " );
36- }
37- puts (" " );
38- }
39-
40- template <>
41- void printMatrix<half2>(half2* a, int m, int n) {
42- for (int i = 0 ; i < m; i++) {
43- for (int j = 0 ; j < n; j++) {
44- printf (
45- " %f %f " , __half2float (a[i * n + j].x ), __half2float (a[i * n + j].y ));
46- }
47- puts (" " );
48- }
49- puts (" " );
50- }
51-
52- void layernorm_kernel_half2_test () {
53- float epsilon = 1e-6 ;
54- int m = 2 ;
55- int n = 2 ;
56-
57- half2* out = (half2*)malloc (m * n * sizeof (half2));
58- half2* input = (half2*)malloc (m * n * sizeof (half2));
59- half2* weight = (half2*)malloc (m * n * sizeof (half2));
60- half2* bias = (half2*)malloc (m * n * sizeof (half2));
61-
62- for (int i = 0 ; i < m; i++) {
63- for (int j = 0 ; j < n; j++) {
64- input[i * n + j] = half2 (__float2half ((float )(i * n + j * 2 )),
65- __float2half ((float )(i * n + j * 2 + 1 )));
66- weight[i * n + j] = half2 (__float2half (1 .), __float2half (1 .));
67- bias[i * n + j] = half2 (__float2half (0 .), __float2half (0 .));
68- }
69- }
70-
71- half2* dout;
72- half2* dinput;
73- half2* dweight;
74- half2* dbias;
75- cudaMalloc ((void **)&dout, sizeof (half2) * m * n);
76- cudaMalloc ((void **)&dinput, sizeof (half2) * m * n);
77- cudaMalloc ((void **)&dweight, sizeof (half2) * m * n);
78- cudaMalloc ((void **)&dbias, sizeof (half2) * m * n);
79-
80- cudaMemcpy (dinput, input, sizeof (half2) * m * n, cudaMemcpyHostToDevice);
81- cudaMemcpy (dweight, weight, sizeof (half2) * m * n, cudaMemcpyHostToDevice);
82- cudaMemcpy (dbias, bias, sizeof (half2) * m * n, cudaMemcpyHostToDevice);
83-
84- llm::kernel::invoke_layernorm_kernel<half2>(
85- dout, dinput, dweight, dbias, epsilon, m, n);
86-
87- cudaMemcpy (out, dout, sizeof (half2) * m * n, cudaMemcpyDeviceToHost);
88-
89- printf (" ---------- test half2 layernorm kernel -----------\n " );
90- printf (" input:\n " );
91- printMatrix<half2>(input, m, n);
92- printf (" weights:\n " );
93- printMatrix<half2>(weight, m, n);
94- printf (" bias:\n " );
95- printMatrix<half2>(bias, m, n);
96- printf (" outputs:\n " );
97- printMatrix<half2>(out, m, n);
98- }
99-
100- void layernorm_kernel_float_test () {
101- float epsilon = 1e-6 ;
102- int m = 2 ;
103- int n = 4 ;
104-
105- float * out = (float *)malloc (m * n * sizeof (float ));
106- float * input = (float *)malloc (m * n * sizeof (float ));
107- float * weight = (float *)malloc (m * n * sizeof (float ));
108- float * bias = (float *)malloc (m * n * sizeof (float ));
109-
110- for (int i = 0 ; i < m; i++) {
111- for (int j = 0 ; j < n; j++) {
112- input[i * n + j] = (float )(i * n + j);
113- weight[i * n + j] = 1 .;
114- bias[i * n + j] = 0 .;
115- }
116- }
117-
118- float * dout;
119- float * dinput;
120- float * dweight;
121- float * dbias;
122- cudaMalloc ((void **)&dout, sizeof (float ) * m * n);
123- cudaMalloc ((void **)&dinput, sizeof (float ) * m * n);
124- cudaMalloc ((void **)&dweight, sizeof (float ) * m * n);
125- cudaMalloc ((void **)&dbias, sizeof (float ) * m * n);
126-
127- cudaMemcpy (dinput, input, sizeof (float ) * m * n, cudaMemcpyHostToDevice);
128- cudaMemcpy (dweight, weight, sizeof (float ) * m * n, cudaMemcpyHostToDevice);
129- cudaMemcpy (dbias, bias, sizeof (float ) * m * n, cudaMemcpyHostToDevice);
130-
131- llm::kernel::invoke_layernorm_kernel<float >(
132- dout, dinput, dweight, dbias, epsilon, m, n);
5+ #include < cstdio>
1336
134- cudaMemcpy (out, dout, sizeof (float ) * m * n, cudaMemcpyDeviceToHost);
135-
136- printf (" ---------- test float layernorm kernel -----------\n " );
137- printf (" input:\n " );
138- printMatrix<float >(input, m, n);
139- printf (" weights:\n " );
140- printMatrix<float >(weight, m, n);
141- printf (" bias:\n " );
142- printMatrix<float >(bias, m, n);
143- printf (" outputs:\n " );
144- printMatrix<float >(out, m, n);
145- }
7+ #include " layernorm_kernels.h"
1468
1479TEST (NormalizationKernelTest, LayernormFloatTest) {
14810 float epsilon = 1e-6 ;
@@ -152,7 +14,10 @@ TEST(NormalizationKernelTest, LayernormFloatTest) {
15214 auto input = torch::randn ({m, n});
15315 auto weight = torch::randn ({n});
15416 auto bias = torch::randn ({n});
155- auto desired_out = torch::nn::functional::layer_norm (input, torch::nn::functional::LayerNormFuncOptions ({n}).weight (weight).bias (bias));
17+ auto desired_out = torch::nn::functional::layer_norm (
18+ input,
19+ torch::nn::functional::LayerNormFuncOptions ({n}).weight (weight).bias (
20+ bias));
15621
15722 float * hout = (float *)malloc (m * n * sizeof (float ));
15823 float * hinput = input.data_ptr <float >();
@@ -184,4 +49,65 @@ TEST(NormalizationKernelTest, LayernormFloatTest) {
18449 cudaFree (dinput);
18550 cudaFree (dweight);
18651 cudaFree (dbias);
187- }
52+ }
53+
54+ TEST (NormalizationKernelTest, LayernormHalfTest) {
55+ float epsilon = 1e-6 ;
56+ int m = 4 ;
57+ int n = 512 ;
58+ auto input = torch::randn ({m, n});
59+ auto weight = torch::randn ({n});
60+ auto bias = torch::randn ({n});
61+ auto desired_out = torch::nn::functional::layer_norm (
62+ input,
63+ torch::nn::functional::LayerNormFuncOptions ({n}).weight (weight).bias (
64+ bias));
65+
66+ half* hout = (half*)malloc (m * n * sizeof (half));
67+ half* hinput = (half*)malloc (m * n * sizeof (half));
68+ half* hweight = (half*)malloc (n * sizeof (half));
69+ half* hbias = (half*)malloc (n * sizeof (half));
70+
71+ for (int i = 0 ; i < m; i++) {
72+ for (int j = 0 ; j < n; j++) {
73+ hinput[i * n + j] = __float2half (input[i][j].item <float >());
74+ }
75+ }
76+ for (int i = 0 ; i < weight.numel (); i++)
77+ hweight[i] = __float2half (weight[i].item <float >());
78+ for (int i = 0 ; i < bias.numel (); i++)
79+ hbias[i] = __float2half (bias[i].item <float >());
80+
81+ half* dout;
82+ half* dinput;
83+ half* dweight;
84+ half* dbias;
85+ cudaMalloc ((void **)&dout, sizeof (half) * m * n);
86+ cudaMalloc ((void **)&dinput, sizeof (half) * m * n);
87+ cudaMalloc ((void **)&dweight, sizeof (half) * n);
88+ cudaMalloc ((void **)&dbias, sizeof (half) * n);
89+
90+ cudaMemcpy (dinput, hinput, sizeof (half) * m * n, cudaMemcpyHostToDevice);
91+ cudaMemcpy (dweight, hweight, sizeof (half) * n, cudaMemcpyHostToDevice);
92+ cudaMemcpy (dbias, hbias, sizeof (half) * n, cudaMemcpyHostToDevice);
93+
94+ llm::kernel::invoke_layernorm_kernel<half>(
95+ dout, dinput, dweight, dbias, epsilon, m, n);
96+
97+ cudaMemcpy (hout, dout, sizeof (half) * m * n, cudaMemcpyDeviceToHost);
98+
99+ float * float_hout = (float *)malloc (m * n * sizeof (float ));
100+ for (int i = 0 ; i < m * n; i++) float_hout[i] = __half2float (hout[i]);
101+
102+ auto out = torch::from_blob (float_hout, {m, n});
103+ EXPECT_TRUE (torch::allclose (out, desired_out, 0.05 , 1e-3 ));
104+ free (hout);
105+ free (hinput);
106+ free (hweight);
107+ free (hbias);
108+ free (float_hout);
109+ cudaFree (dout);
110+ cudaFree (dinput);
111+ cudaFree (dweight);
112+ cudaFree (dbias);
113+ }
0 commit comments