11#include " pos_embedding.h"
22
3- #include < c10/core/Device.h>
4- #include < c10/core/ScalarType.h>
5- #include < c10/core/TensorImpl.h>
63#include < gtest/gtest.h>
4+ #include < torch/torch.h>
5+
6+ #include < tuple>
77
88namespace llm {
99namespace {
10- using torch::indexing::None;
11- using ISlice = torch::indexing::Slice;
12-
1310// Rotary code ported from llama repo, which is used as disired output
1411torch::Tensor precompute_freqs_cis (int64_t dim,
1512 int64_t max_position_embeddings,
1613 float theta) {
17- auto range = torch::arange ( 0 , dim, 2 );
18- auto slice =
19- range.slice (/* dim=*/ 0 , /* start=*/ 0 , /* end=*/ dim / 2 ). to (torch:: kFloat32 );
14+ auto range =
15+ torch::arange ( /* start= */ 0 , /* end= */ dim, /* step= */ 2 , torch:: kFloat32 );
16+ auto slice = range.slice (/* dim=*/ 0 , /* start=*/ 0 , /* end=*/ dim / 2 );
2017 auto freqs = 1.0 / torch::pow (theta, slice / dim);
21- auto t = torch::arange (0 , max_position_embeddings, 1 ). to ( torch::kFloat32 );
22- freqs = torch::outer (t, freqs). to (torch:: kFloat32 ) ;
18+ auto t = torch::arange (/* end= */ max_position_embeddings, torch::kFloat32 );
19+ freqs = torch::outer (t, freqs);
2320 return torch::polar (torch::ones_like (freqs), freqs);
2421}
2522
@@ -55,8 +52,10 @@ std::tuple<torch::Tensor, torch::Tensor> apply_rotary_emb(
5552
5653// [1, 2, 3, 4, 5, 6] => [1, 3, 5, 2, 4, 6]
5754inline torch::Tensor interleaved_to_half (const torch::Tensor& x) {
58- auto x1 = x.index ({ISlice (), ISlice (), ISlice (0 , None, 2 )});
59- auto x2 = x.index ({ISlice (), ISlice (), ISlice (1 , None, 2 )});
55+ using torch::indexing::None;
56+ using ISlice = torch::indexing::Slice;
57+ auto x1 = x.index ({" ..." , ISlice (0 , None, 2 )});
58+ auto x2 = x.index ({" ..." , ISlice (1 , None, 2 )});
6059 return torch::cat ({x1, x2}, /* dim=*/ -1 );
6160}
6261
@@ -67,98 +66,181 @@ inline torch::Tensor half_to_interleaved(const torch::Tensor& x) {
6766 .flatten (/* start_dim=*/ -2 );
6867}
6968
69+ std::tuple<torch::Tensor, torch::Tensor> apply_rotary_emb_ref (
70+ const torch::Tensor& query,
71+ const torch::Tensor& key,
72+ const torch::Tensor& positions,
73+ int64_t head_dim,
74+ int64_t max_position_embeddings,
75+ float theta,
76+ bool interleaved) {
77+ auto freqs_cis =
78+ precompute_freqs_cis (head_dim, max_position_embeddings, theta);
79+ namespace F = torch::nn::functional;
80+ auto selected_freqs_cis = F::embedding (positions, freqs_cis);
81+
82+ if (interleaved) {
83+ return apply_rotary_emb (query, key, selected_freqs_cis);
84+ }
85+
86+ auto interleaved_query = half_to_interleaved (query);
87+ auto interleaved_key = half_to_interleaved (key);
88+ auto [query_ref, key_ref] =
89+ apply_rotary_emb (interleaved_query, interleaved_key, selected_freqs_cis);
90+ query_ref = interleaved_to_half (query_ref);
91+ key_ref = interleaved_to_half (key_ref);
92+ return std::make_tuple (query_ref, key_ref);
93+ }
94+
7095} // namespace
7196
72- TEST (RotaryEmbeddingTest, Interleaved) {
73- const int64_t num_tokens = 16 ;
74- const int64_t n_heads = 4 ;
75- const int64_t head_dim = 4 ;
76- const int64_t max_position_embeddings = 128 ;
77- torch::ScalarType dtype (torch::kFloat );
78- torch::Device device (torch::kCPU );
97+ class PosEmbeddingTest : public ::testing::TestWithParam<
98+ std::tuple<torch::Device,
99+ torch::ScalarType,
100+ int64_t /* num_tokens*/ ,
101+ int64_t /* n_heads*/ ,
102+ int64_t /* n_kv_heads*/ ,
103+ int64_t /* head_dim*/ ,
104+ float /* theta*/ ,
105+ bool /* interleaved*/ ,
106+ int64_t /* max_position_embeddings*/ >> {
107+ };
108+
109+ TEST_P (PosEmbeddingTest, Rotary) {
110+ const auto [device,
111+ dtype,
112+ num_tokens,
113+ n_heads,
114+ n_kv_heads,
115+ head_dim,
116+ theta,
117+ interleaved,
118+ max_position_embeddings] = GetParam ();
79119 const auto options = torch::dtype (dtype).device (device);
80120
121+ // prepare inputs
122+ torch::Tensor query = torch::rand ({num_tokens, n_heads, head_dim}, options);
123+ torch::Tensor key = torch::rand ({num_tokens, n_kv_heads, head_dim}, options);
124+ const torch::Tensor positions = torch::randint (
125+ 0 , max_position_embeddings, {num_tokens}, options.dtype (torch::kInt ));
126+
81127 RotaryEmbeddingGeneric rotary_embedding (head_dim,
82128 max_position_embeddings,
83129 /* scaling_factor*/ 0 .0f ,
84- /* theta= */ 10000 . 0f ,
85- /* interleaved= */ true ,
130+ theta,
131+ interleaved,
86132 options);
87-
88- torch::Tensor query = torch::rand ({num_tokens, n_heads, head_dim});
89- torch::Tensor key = torch::rand ({num_tokens, n_heads, head_dim});
90- const torch::Tensor positions =
91- torch::randint (0 , max_position_embeddings, {num_tokens});
92-
93- // make a copy for inplace operation
94133 const auto [query_output, key_output] =
95134 rotary_embedding.forward (query, key, positions);
96135
97136 // compute the desired output
98- auto freqs_cis =
99- precompute_freqs_cis (head_dim, max_position_embeddings, 10000 . 0f );
100- namespace F = torch::nn::functional;
101- auto selected_freqs_cis = F::embedding (positions, freqs_cis);
102- const auto [desired_query, desired_key] =
103- apply_rotary_emb (query, key, selected_freqs_cis);
104-
105- // check the output
106- ASSERT_TRUE (torch::allclose (desired_query ,
137+ auto [query_ref, key_ref] = apply_rotary_emb_ref (query,
138+ key,
139+ positions,
140+ head_dim,
141+ max_position_embeddings,
142+ theta,
143+ interleaved);
144+
145+ ASSERT_TRUE (torch::allclose (query_ref ,
107146 query_output,
108147 /* rtol=*/ 1e-03 ,
109148 /* atol=*/ 1e-05 ));
110- ASSERT_TRUE (torch::allclose (desired_key ,
149+ ASSERT_TRUE (torch::allclose (key_ref ,
111150 key_output,
112151 /* rtol=*/ 1e-03 ,
113152 /* atol=*/ 1e-05 ));
114153}
115154
116- TEST (RotaryEmbeddingTest, HalfRotated) {
117- const int64_t num_tokens = 16 ;
118- const int64_t n_heads = 4 ;
119- const int64_t head_dim = 4 ;
120- const int64_t max_position_embeddings = 128 ;
121- torch::ScalarType dtype (torch::kFloat );
122- torch::Device device (torch::kCPU );
155+ INSTANTIATE_TEST_SUITE_P (
156+ RotaryCorrectness,
157+ PosEmbeddingTest,
158+ ::testing::Combine (
159+ ::testing::Values (torch::kCPU ),
160+ ::testing::Values(torch::kFloat ),
161+ ::testing::Values(1 , 2 , 8 , 16 ), // num_tokens
162+ ::testing::Values(32 ), // n_heads
163+ ::testing::Values(32 /* mha*/ , 8 /* gqa*/ , 1 /* mqa*/ ), // n_kv_heads
164+ ::testing::Values(128 ), // head_dim
165+ ::testing::Values(100000 .0f , 500000 .0f ), // theta
166+ ::testing::Values(false , true ), // interleaved
167+ ::testing::Values(4096 , 8192 ) // max_position_embeddings
168+ ));
169+
170+ class PosEmbeddingKernelTest
171+ : public ::testing::TestWithParam<
172+ std::tuple<torch::Device,
173+ torch::ScalarType,
174+ int64_t /* num_tokens*/ ,
175+ int64_t /* n_heads*/ ,
176+ int64_t /* n_kv_heads*/ ,
177+ int64_t /* head_dim*/ ,
178+ int64_t /* rotary_dim*/ ,
179+ float /* scaling_factor*/ ,
180+ float /* theta*/ ,
181+ bool /* interleaved*/ ,
182+ int64_t /* max_position_embeddings*/ >> {};
183+
184+ TEST_P (PosEmbeddingKernelTest, Rotary) {
185+ const auto [device,
186+ dtype,
187+ num_tokens,
188+ n_heads,
189+ n_kv_heads,
190+ head_dim,
191+ rotary_dim,
192+ scaling_factor,
193+ theta,
194+ interleaved,
195+ max_position_embeddings] = GetParam ();
196+
123197 const auto options = torch::dtype (dtype).device (device);
124- RotaryEmbeddingGeneric rotary_embedding (head_dim,
198+ // prepare inputs
199+ torch::Tensor query = torch::rand ({num_tokens, n_heads, head_dim}, options);
200+ torch::Tensor key = torch::rand ({num_tokens, n_kv_heads, head_dim}, options);
201+ const torch::Tensor positions = torch::randint (
202+ 0 , max_position_embeddings, {num_tokens}, options.dtype (torch::kInt ));
203+
204+ RotaryEmbeddingGeneric rotary_embedding (rotary_dim,
125205 max_position_embeddings,
126- /* scaling_factor*/ 0 . 0f ,
127- /* theta= */ 10000 .0f ,
128- /* interleaved= */ false ,
206+ scaling_factor,
207+ 10000 .0f ,
208+ interleaved,
129209 options);
130210
131- torch::Tensor query = torch::rand ({num_tokens, n_heads, head_dim});
132- torch::Tensor key = torch::rand ({num_tokens, n_heads, head_dim});
133- const torch::Tensor positions =
134- torch::randint (0 , max_position_embeddings, {num_tokens});
211+ RotaryEmbeddingKernel rotary_embedding_kernel (rotary_dim,
212+ max_position_embeddings,
213+ scaling_factor,
214+ 10000 .0f ,
215+ interleaved,
216+ options);
135217
136- // make a copy for inplace operation
137- const auto [query_output, key_output] =
218+ auto [query_output, key_output] =
138219 rotary_embedding.forward (query, key, positions);
139220
140- // compute the desired output
141- auto freqs_cis =
142- precompute_freqs_cis (head_dim, max_position_embeddings, 10000 .0f );
143- namespace F = torch::nn::functional;
144- auto selected_freqs_cis = F::embedding (positions, freqs_cis);
145- auto [desired_query, desired_key] = apply_rotary_emb (
146- half_to_interleaved (query), half_to_interleaved (key), selected_freqs_cis);
221+ // apply rotary embedding using the kernel in place
222+ auto [query_output_kernel, key_output_kernel] =
223+ rotary_embedding_kernel.forward (query.clone (), key.clone (), positions);
147224
148- desired_query = interleaved_to_half (desired_query);
149- desired_key = interleaved_to_half (desired_key);
150-
151- // check the output
152- ASSERT_TRUE (torch::allclose (desired_query,
153- query_output,
154- /* rtol=*/ 1e-03 ,
155- /* atol=*/ 1e-05 ));
156- ASSERT_TRUE (torch::allclose (desired_key,
157- key_output,
158- /* rtol=*/ 1e-03 ,
159- /* atol=*/ 1e-05 ));
225+ ASSERT_TRUE (torch::allclose (query_output, query_output_kernel));
226+ ASSERT_TRUE (torch::allclose (key_output, key_output_kernel));
160227}
161228
162- // TODO: test kernel version
229+ INSTANTIATE_TEST_SUITE_P (
230+ Rotary,
231+ PosEmbeddingKernelTest,
232+ ::testing::Combine (
233+ ::testing::Values (torch::kCUDA ),
234+ ::testing::Values(torch::kHalf , torch::kBFloat16 ),
235+ ::testing::Values(1 , 2 , 8 , 16 ), // num_tokens
236+ ::testing::Values(32 ), // n_heads
237+ ::testing::Values(32 /* mha*/ , 8 /* gqa*/ , 1 /* mqa*/ ), // n_kv_heads
238+ ::testing::Values(128 ), // head_dim
239+ ::testing::Values(128 , 64 ), // rotary_dim
240+ ::testing::Values(0 .0f , 0 .5f ), // scaling_factor
241+ ::testing::Values(100000 .0f , 500000 .0f ), // theta
242+ ::testing::Values(false , true ), // interleaved
243+ ::testing::Values(4096 , 8192 ) // max_position_embeddings
244+ ));
163245
164246} // namespace llm
0 commit comments