11#include < cuda_fp16.h>
2-
32#include < cstdio>
4-
53#include " layernorm_kernels.h"
4+ #include < torch/nn/functional.h>
5+ #include < gtest/gtest.h>
66
77template <typename T>
88void printMatrix (T* a, int m, int n) {
@@ -15,6 +15,28 @@ void printMatrix(T* a, int m, int n) {
1515 puts (" " );
1616}
1717
18+ template <>
19+ void printMatrix<float >(float * a, int m, int n) {
20+ for (int i = 0 ; i < m; i++) {
21+ for (int j = 0 ; j < n; j++) {
22+ printf (" %f " , a[i * n + j]);
23+ }
24+ puts (" " );
25+ }
26+ puts (" " );
27+ }
28+
29+ template <>
30+ void printMatrix<half>(half* a, int m, int n) {
31+ for (int i = 0 ; i < m; i++) {
32+ for (int j = 0 ; j < n; j++) {
33+ printf (" %f " , __half2float (a[i * n + j]));
34+ }
35+ puts (" " );
36+ }
37+ puts (" " );
38+ }
39+
1840template <>
1941void printMatrix<half2>(half2* a, int m, int n) {
2042 for (int i = 0 ; i < m; i++) {
@@ -122,8 +144,44 @@ void layernorm_kernel_float_test() {
122144 printMatrix<float >(out, m, n);
123145}
124146
125- int main () {
126- layernorm_kernel_float_test ();
127- layernorm_kernel_half2_test ();
128- return 0 ;
147+ TEST (NormalizationKernelTest, LayernormFloatTest) {
148+ float epsilon = 1e-6 ;
149+ int m = 32 ;
150+ int n = 512 ;
151+
152+ auto input = torch::randn ({m, n});
153+ auto weight = torch::randn ({n});
154+ auto bias = torch::randn ({n});
155+ auto desired_out = torch::nn::functional::layer_norm (input, torch::nn::functional::LayerNormFuncOptions ({n}).weight (weight).bias (bias));
156+
157+ float * hout = (float *)malloc (m * n * sizeof (float ));
158+ float * hinput = input.data_ptr <float >();
159+ float * hweight = weight.data_ptr <float >();
160+ float * hbias = bias.data_ptr <float >();
161+
162+ float * dout;
163+ float * dinput;
164+ float * dweight;
165+ float * dbias;
166+ cudaMalloc ((void **)&dout, sizeof (float ) * m * n);
167+ cudaMalloc ((void **)&dinput, sizeof (float ) * m * n);
168+ cudaMalloc ((void **)&dweight, sizeof (float ) * n);
169+ cudaMalloc ((void **)&dbias, sizeof (float ) * n);
170+
171+ cudaMemcpy (dinput, hinput, sizeof (float ) * m * n, cudaMemcpyHostToDevice);
172+ cudaMemcpy (dweight, hweight, sizeof (float ) * n, cudaMemcpyHostToDevice);
173+ cudaMemcpy (dbias, hbias, sizeof (float ) * n, cudaMemcpyHostToDevice);
174+
175+ llm::kernel::invoke_layernorm_kernel<float >(
176+ dout, dinput, dweight, dbias, epsilon, m, n);
177+
178+ cudaMemcpy (hout, dout, sizeof (float ) * m * n, cudaMemcpyDeviceToHost);
179+
180+ auto out = torch::from_blob (hout, {m, n});
181+ EXPECT_TRUE (torch::allclose (out, desired_out, 1e-3 , 1e-5 ));
182+ free (hout);
183+ cudaFree (dout);
184+ cudaFree (dinput);
185+ cudaFree (dweight);
186+ cudaFree (dbias);
129187}
0 commit comments