Skip to content

Commit 6ff1738

Browse files
author
Xianzhe Dong
committed
use gtest library rewrite layernorm kernel unitest
1 parent 14a99f7 commit 6ff1738

File tree

2 files changed

+73
-23
lines changed

2 files changed

+73
-23
lines changed

src/kernels/CMakeLists.txt

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -73,23 +73,15 @@ cc_library(
7373
torch
7474
)
7575

76-
# cc_test(
77-
# NAME
78-
# layernorm_kernels_test
79-
# SRCS
80-
# layernrom_kernels_test.cu
81-
# layernorm_kernels.cu
82-
# DEPS
83-
# DEFINES
84-
# )
85-
cc_binary(
86-
NAME
87-
layernorm_kernels_test
88-
SRCS
89-
layernrom_kernels_test.cu
90-
layernorm_kernels.cu
91-
DEPS
92-
torch
76+
cc_test(
77+
NAME
78+
layernorm_kernels_test
79+
SRCS
80+
layernrom_kernels_test.cu
81+
layernorm_kernels.cu
82+
DEPS
83+
torch
84+
GTest::gtest_main
9385
)
9486

9587
add_subdirectory(flash_attn)

src/kernels/layernrom_kernels_test.cu

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#include <cuda_fp16.h>
2-
32
#include <cstdio>
4-
53
#include "layernorm_kernels.h"
4+
#include <torch/nn/functional.h>
5+
#include <gtest/gtest.h>
66

77
template <typename T>
88
void printMatrix(T* a, int m, int n) {
@@ -15,6 +15,28 @@ void printMatrix(T* a, int m, int n) {
1515
puts("");
1616
}
1717

18+
template <>
19+
void printMatrix<float>(float* a, int m, int n) {
20+
for (int i = 0; i < m; i++) {
21+
for (int j = 0; j < n; j++) {
22+
printf("%f ", a[i * n + j]);
23+
}
24+
puts("");
25+
}
26+
puts("");
27+
}
28+
29+
template <>
30+
void printMatrix<half>(half* a, int m, int n) {
31+
for (int i = 0; i < m; i++) {
32+
for (int j = 0; j < n; j++) {
33+
printf("%f ", __half2float(a[i * n + j]));
34+
}
35+
puts("");
36+
}
37+
puts("");
38+
}
39+
1840
template <>
1941
void printMatrix<half2>(half2* a, int m, int n) {
2042
for (int i = 0; i < m; i++) {
@@ -122,8 +144,44 @@ void layernorm_kernel_float_test() {
122144
printMatrix<float>(out, m, n);
123145
}
124146

125-
int main() {
126-
layernorm_kernel_float_test();
127-
layernorm_kernel_half2_test();
128-
return 0;
147+
TEST(NormalizationKernelTest, LayernormFloatTest) {
148+
float epsilon = 1e-6;
149+
int m = 32;
150+
int n = 512;
151+
152+
auto input = torch::randn({m, n});
153+
auto weight = torch::randn({n});
154+
auto bias = torch::randn({n});
155+
auto desired_out = torch::nn::functional::layer_norm(input, torch::nn::functional::LayerNormFuncOptions({n}).weight(weight).bias(bias));
156+
157+
float* hout = (float*)malloc(m * n * sizeof(float));
158+
float* hinput = input.data_ptr<float>();
159+
float* hweight = weight.data_ptr<float>();
160+
float* hbias = bias.data_ptr<float>();
161+
162+
float* dout;
163+
float* dinput;
164+
float* dweight;
165+
float* dbias;
166+
cudaMalloc((void**)&dout, sizeof(float) * m * n);
167+
cudaMalloc((void**)&dinput, sizeof(float) * m * n);
168+
cudaMalloc((void**)&dweight, sizeof(float) * n);
169+
cudaMalloc((void**)&dbias, sizeof(float) * n);
170+
171+
cudaMemcpy(dinput, hinput, sizeof(float) * m * n, cudaMemcpyHostToDevice);
172+
cudaMemcpy(dweight, hweight, sizeof(float) * n, cudaMemcpyHostToDevice);
173+
cudaMemcpy(dbias, hbias, sizeof(float) * n, cudaMemcpyHostToDevice);
174+
175+
llm::kernel::invoke_layernorm_kernel<float>(
176+
dout, dinput, dweight, dbias, epsilon, m, n);
177+
178+
cudaMemcpy(hout, dout, sizeof(float) * m * n, cudaMemcpyDeviceToHost);
179+
180+
auto out = torch::from_blob(hout, {m, n});
181+
EXPECT_TRUE(torch::allclose(out, desired_out, 1e-3, 1e-5));
182+
free(hout);
183+
cudaFree(dout);
184+
cudaFree(dinput);
185+
cudaFree(dweight);
186+
cudaFree(dbias);
129187
}

0 commit comments

Comments
 (0)