Skip to content

Commit a7925c9

Browse files
authored
[unittest] added more unittests for pos_embedding, sampler and rejection_sampler. (#142)
1 parent acd6ae0 commit a7925c9

File tree

9 files changed

+424
-201
lines changed

9 files changed

+424
-201
lines changed

src/handlers/completion_handler.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ bool send_delta_to_client(CompletionCallData* call_data,
126126
response.set_created(request->created_time);
127127
// response.set_model(request->model);
128128
auto* choice = response.add_choices();
129+
choice->set_index(index);
129130
choice->set_finish_reason(finish_reason_to_string(output.finish_reason));
130131
if (!call_data->write(std::move(response))) {
131132
return false;

src/layers/attention/attention.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
#include "memory/kv_cache.h"
88
#include "models/parameters.h"
99

10-
DECLARE_bool(disable_custom_kernels);
11-
1210
namespace llm {
1311

1412
class AttentionImpl : public torch::nn::Module {

src/layers/attention/ref_handler.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ namespace llm {
1010
using ISlice = torch::indexing::Slice;
1111

1212
namespace {
13-
constexpr float negative_infinity = -std::numeric_limits<float>::infinity();
14-
1513
torch::Tensor masked_self_attention(
1614
const torch::Tensor& query, // [q_seq_len, n_heads, head_dim]
1715
const torch::Tensor& key, // [k_seq_len, n_heads, head_dim]
@@ -31,7 +29,7 @@ torch::Tensor masked_self_attention(
3129
}
3230
// apply causal mask
3331
if (mask.defined()) {
34-
scores = scores.masked_fill(mask == 0, negative_infinity);
32+
scores = scores.masked_fill(mask == 0, -INFINITY);
3533
}
3634

3735
scores = torch::softmax(scores, /*dim=*/-1);

src/layers/pos_embedding_test.cpp

Lines changed: 158 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,22 @@
11
#include "pos_embedding.h"
22

3-
#include <c10/core/Device.h>
4-
#include <c10/core/ScalarType.h>
5-
#include <c10/core/TensorImpl.h>
63
#include <gtest/gtest.h>
4+
#include <torch/torch.h>
5+
6+
#include <tuple>
77

88
namespace llm {
99
namespace {
10-
using torch::indexing::None;
11-
using ISlice = torch::indexing::Slice;
12-
1310
// Rotary code ported from llama repo, which is used as disired output
1411
torch::Tensor precompute_freqs_cis(int64_t dim,
1512
int64_t max_position_embeddings,
1613
float theta) {
17-
auto range = torch::arange(0, dim, 2);
18-
auto slice =
19-
range.slice(/*dim=*/0, /*start=*/0, /*end=*/dim / 2).to(torch::kFloat32);
14+
auto range =
15+
torch::arange(/*start=*/0, /*end=*/dim, /*step=*/2, torch::kFloat32);
16+
auto slice = range.slice(/*dim=*/0, /*start=*/0, /*end=*/dim / 2);
2017
auto freqs = 1.0 / torch::pow(theta, slice / dim);
21-
auto t = torch::arange(0, max_position_embeddings, 1).to(torch::kFloat32);
22-
freqs = torch::outer(t, freqs).to(torch::kFloat32);
18+
auto t = torch::arange(/*end=*/max_position_embeddings, torch::kFloat32);
19+
freqs = torch::outer(t, freqs);
2320
return torch::polar(torch::ones_like(freqs), freqs);
2421
}
2522

@@ -55,8 +52,10 @@ std::tuple<torch::Tensor, torch::Tensor> apply_rotary_emb(
5552

5653
// [1, 2, 3, 4, 5, 6] => [1, 3, 5, 2, 4, 6]
5754
inline torch::Tensor interleaved_to_half(const torch::Tensor& x) {
58-
auto x1 = x.index({ISlice(), ISlice(), ISlice(0, None, 2)});
59-
auto x2 = x.index({ISlice(), ISlice(), ISlice(1, None, 2)});
55+
using torch::indexing::None;
56+
using ISlice = torch::indexing::Slice;
57+
auto x1 = x.index({"...", ISlice(0, None, 2)});
58+
auto x2 = x.index({"...", ISlice(1, None, 2)});
6059
return torch::cat({x1, x2}, /*dim=*/-1);
6160
}
6261

@@ -67,98 +66,181 @@ inline torch::Tensor half_to_interleaved(const torch::Tensor& x) {
6766
.flatten(/*start_dim=*/-2);
6867
}
6968

69+
std::tuple<torch::Tensor, torch::Tensor> apply_rotary_emb_ref(
70+
const torch::Tensor& query,
71+
const torch::Tensor& key,
72+
const torch::Tensor& positions,
73+
int64_t head_dim,
74+
int64_t max_position_embeddings,
75+
float theta,
76+
bool interleaved) {
77+
auto freqs_cis =
78+
precompute_freqs_cis(head_dim, max_position_embeddings, theta);
79+
namespace F = torch::nn::functional;
80+
auto selected_freqs_cis = F::embedding(positions, freqs_cis);
81+
82+
if (interleaved) {
83+
return apply_rotary_emb(query, key, selected_freqs_cis);
84+
}
85+
86+
auto interleaved_query = half_to_interleaved(query);
87+
auto interleaved_key = half_to_interleaved(key);
88+
auto [query_ref, key_ref] =
89+
apply_rotary_emb(interleaved_query, interleaved_key, selected_freqs_cis);
90+
query_ref = interleaved_to_half(query_ref);
91+
key_ref = interleaved_to_half(key_ref);
92+
return std::make_tuple(query_ref, key_ref);
93+
}
94+
7095
} // namespace
7196

72-
TEST(RotaryEmbeddingTest, Interleaved) {
73-
const int64_t num_tokens = 16;
74-
const int64_t n_heads = 4;
75-
const int64_t head_dim = 4;
76-
const int64_t max_position_embeddings = 128;
77-
torch::ScalarType dtype(torch::kFloat);
78-
torch::Device device(torch::kCPU);
97+
class PosEmbeddingTest : public ::testing::TestWithParam<
98+
std::tuple<torch::Device,
99+
torch::ScalarType,
100+
int64_t /*num_tokens*/,
101+
int64_t /*n_heads*/,
102+
int64_t /*n_kv_heads*/,
103+
int64_t /*head_dim*/,
104+
float /*theta*/,
105+
bool /*interleaved*/,
106+
int64_t /*max_position_embeddings*/>> {
107+
};
108+
109+
TEST_P(PosEmbeddingTest, Rotary) {
110+
const auto [device,
111+
dtype,
112+
num_tokens,
113+
n_heads,
114+
n_kv_heads,
115+
head_dim,
116+
theta,
117+
interleaved,
118+
max_position_embeddings] = GetParam();
79119
const auto options = torch::dtype(dtype).device(device);
80120

121+
// prepare inputs
122+
torch::Tensor query = torch::rand({num_tokens, n_heads, head_dim}, options);
123+
torch::Tensor key = torch::rand({num_tokens, n_kv_heads, head_dim}, options);
124+
const torch::Tensor positions = torch::randint(
125+
0, max_position_embeddings, {num_tokens}, options.dtype(torch::kInt));
126+
81127
RotaryEmbeddingGeneric rotary_embedding(head_dim,
82128
max_position_embeddings,
83129
/*scaling_factor*/ 0.0f,
84-
/*theta=*/10000.0f,
85-
/*interleaved=*/true,
130+
theta,
131+
interleaved,
86132
options);
87-
88-
torch::Tensor query = torch::rand({num_tokens, n_heads, head_dim});
89-
torch::Tensor key = torch::rand({num_tokens, n_heads, head_dim});
90-
const torch::Tensor positions =
91-
torch::randint(0, max_position_embeddings, {num_tokens});
92-
93-
// make a copy for inplace operation
94133
const auto [query_output, key_output] =
95134
rotary_embedding.forward(query, key, positions);
96135

97136
// compute the desired output
98-
auto freqs_cis =
99-
precompute_freqs_cis(head_dim, max_position_embeddings, 10000.0f);
100-
namespace F = torch::nn::functional;
101-
auto selected_freqs_cis = F::embedding(positions, freqs_cis);
102-
const auto [desired_query, desired_key] =
103-
apply_rotary_emb(query, key, selected_freqs_cis);
104-
105-
// check the output
106-
ASSERT_TRUE(torch::allclose(desired_query,
137+
auto [query_ref, key_ref] = apply_rotary_emb_ref(query,
138+
key,
139+
positions,
140+
head_dim,
141+
max_position_embeddings,
142+
theta,
143+
interleaved);
144+
145+
ASSERT_TRUE(torch::allclose(query_ref,
107146
query_output,
108147
/*rtol=*/1e-03,
109148
/*atol=*/1e-05));
110-
ASSERT_TRUE(torch::allclose(desired_key,
149+
ASSERT_TRUE(torch::allclose(key_ref,
111150
key_output,
112151
/*rtol=*/1e-03,
113152
/*atol=*/1e-05));
114153
}
115154

116-
TEST(RotaryEmbeddingTest, HalfRotated) {
117-
const int64_t num_tokens = 16;
118-
const int64_t n_heads = 4;
119-
const int64_t head_dim = 4;
120-
const int64_t max_position_embeddings = 128;
121-
torch::ScalarType dtype(torch::kFloat);
122-
torch::Device device(torch::kCPU);
155+
INSTANTIATE_TEST_SUITE_P(
156+
RotaryCorrectness,
157+
PosEmbeddingTest,
158+
::testing::Combine(
159+
::testing::Values(torch::kCPU),
160+
::testing::Values(torch::kFloat),
161+
::testing::Values(1, 2, 8, 16), // num_tokens
162+
::testing::Values(32), // n_heads
163+
::testing::Values(32 /*mha*/, 8 /*gqa*/, 1 /*mqa*/), // n_kv_heads
164+
::testing::Values(128), // head_dim
165+
::testing::Values(100000.0f, 500000.0f), // theta
166+
::testing::Values(false, true), // interleaved
167+
::testing::Values(4096, 8192) // max_position_embeddings
168+
));
169+
170+
class PosEmbeddingKernelTest
171+
: public ::testing::TestWithParam<
172+
std::tuple<torch::Device,
173+
torch::ScalarType,
174+
int64_t /*num_tokens*/,
175+
int64_t /*n_heads*/,
176+
int64_t /*n_kv_heads*/,
177+
int64_t /*head_dim*/,
178+
int64_t /*rotary_dim*/,
179+
float /*scaling_factor*/,
180+
float /*theta*/,
181+
bool /*interleaved*/,
182+
int64_t /*max_position_embeddings*/>> {};
183+
184+
TEST_P(PosEmbeddingKernelTest, Rotary) {
185+
const auto [device,
186+
dtype,
187+
num_tokens,
188+
n_heads,
189+
n_kv_heads,
190+
head_dim,
191+
rotary_dim,
192+
scaling_factor,
193+
theta,
194+
interleaved,
195+
max_position_embeddings] = GetParam();
196+
123197
const auto options = torch::dtype(dtype).device(device);
124-
RotaryEmbeddingGeneric rotary_embedding(head_dim,
198+
// prepare inputs
199+
torch::Tensor query = torch::rand({num_tokens, n_heads, head_dim}, options);
200+
torch::Tensor key = torch::rand({num_tokens, n_kv_heads, head_dim}, options);
201+
const torch::Tensor positions = torch::randint(
202+
0, max_position_embeddings, {num_tokens}, options.dtype(torch::kInt));
203+
204+
RotaryEmbeddingGeneric rotary_embedding(rotary_dim,
125205
max_position_embeddings,
126-
/*scaling_factor*/ 0.0f,
127-
/*theta=*/10000.0f,
128-
/*interleaved=*/false,
206+
scaling_factor,
207+
10000.0f,
208+
interleaved,
129209
options);
130210

131-
torch::Tensor query = torch::rand({num_tokens, n_heads, head_dim});
132-
torch::Tensor key = torch::rand({num_tokens, n_heads, head_dim});
133-
const torch::Tensor positions =
134-
torch::randint(0, max_position_embeddings, {num_tokens});
211+
RotaryEmbeddingKernel rotary_embedding_kernel(rotary_dim,
212+
max_position_embeddings,
213+
scaling_factor,
214+
10000.0f,
215+
interleaved,
216+
options);
135217

136-
// make a copy for inplace operation
137-
const auto [query_output, key_output] =
218+
auto [query_output, key_output] =
138219
rotary_embedding.forward(query, key, positions);
139220

140-
// compute the desired output
141-
auto freqs_cis =
142-
precompute_freqs_cis(head_dim, max_position_embeddings, 10000.0f);
143-
namespace F = torch::nn::functional;
144-
auto selected_freqs_cis = F::embedding(positions, freqs_cis);
145-
auto [desired_query, desired_key] = apply_rotary_emb(
146-
half_to_interleaved(query), half_to_interleaved(key), selected_freqs_cis);
221+
// apply rotary embedding using the kernel in place
222+
auto [query_output_kernel, key_output_kernel] =
223+
rotary_embedding_kernel.forward(query.clone(), key.clone(), positions);
147224

148-
desired_query = interleaved_to_half(desired_query);
149-
desired_key = interleaved_to_half(desired_key);
150-
151-
// check the output
152-
ASSERT_TRUE(torch::allclose(desired_query,
153-
query_output,
154-
/*rtol=*/1e-03,
155-
/*atol=*/1e-05));
156-
ASSERT_TRUE(torch::allclose(desired_key,
157-
key_output,
158-
/*rtol=*/1e-03,
159-
/*atol=*/1e-05));
225+
ASSERT_TRUE(torch::allclose(query_output, query_output_kernel));
226+
ASSERT_TRUE(torch::allclose(key_output, key_output_kernel));
160227
}
161228

162-
// TODO: test kernel version
229+
INSTANTIATE_TEST_SUITE_P(
230+
Rotary,
231+
PosEmbeddingKernelTest,
232+
::testing::Combine(
233+
::testing::Values(torch::kCUDA),
234+
::testing::Values(torch::kHalf, torch::kBFloat16),
235+
::testing::Values(1, 2, 8, 16), // num_tokens
236+
::testing::Values(32), // n_heads
237+
::testing::Values(32 /*mha*/, 8 /*gqa*/, 1 /*mqa*/), // n_kv_heads
238+
::testing::Values(128), // head_dim
239+
::testing::Values(128, 64), // rotary_dim
240+
::testing::Values(0.0f, 0.5f), // scaling_factor
241+
::testing::Values(100000.0f, 500000.0f), // theta
242+
::testing::Values(false, true), // interleaved
243+
::testing::Values(4096, 8192) // max_position_embeddings
244+
));
163245

164246
} // namespace llm

src/sampling/sampler_test.cpp

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,39 +15,41 @@ TEST(SamplerTest, Greedy) {
1515
torch::ScalarType dtype(torch::kFloat32);
1616
torch::Device device(torch::kCPU);
1717
const auto options = torch::dtype(dtype).device(device);
18-
const auto do_sample = torch::tensor({false, false}, device);
19-
Sampler sampler(do_sample);
2018

2119
int64_t batch_size = 2;
2220
int64_t vocab_size = 32000;
23-
const auto logits = torch::randn({batch_size, vocab_size}, options);
24-
auto output = sampler(logits);
25-
2621
const auto probs =
27-
torch::softmax(logits, /*dim=*/-1, /*dtype=*/torch::kFloat32);
28-
const auto next_tokens = probs.argmax(/*dim=*/-1);
29-
EXPECT_TRUE(torch::allclose(output.next_tokens, next_tokens));
22+
torch::randn({batch_size, vocab_size}, options).softmax(/*dim=*/-1);
23+
auto output = Sampler::greedy_sample(probs);
24+
const auto desired_output = probs.argmax(/*dim=*/-1);
25+
EXPECT_TRUE(torch::allclose(output, desired_output));
3026
}
3127

3228
TEST(SamplerTest, Random) {
3329
// Test GreedySampler
3430
torch::ScalarType dtype(torch::kFloat32);
3531
torch::Device device(torch::kCPU);
3632
const auto options = torch::dtype(dtype).device(device);
37-
const auto do_sample = torch::tensor({true, false}, device);
38-
Sampler sampler(do_sample);
3933

40-
int64_t batch_size = 2;
41-
int64_t vocab_size = 32000;
42-
const auto logits = torch::randn({batch_size, vocab_size}, options);
43-
auto output = sampler(logits);
34+
// set random seed
35+
torch::manual_seed(100);
4436

45-
const auto probs =
46-
torch::softmax(logits, /*dim=*/-1, /*dtype=*/torch::kFloat32);
47-
const auto next_tokens_greedy = probs.argmax(/*dim=*/-1);
48-
EXPECT_TRUE(torch::allclose(output.next_tokens[1], next_tokens_greedy[1]));
37+
int64_t vocab_size = 50;
38+
int64_t num_samples = 500000;
39+
40+
auto target_prob = torch::randn({vocab_size}, options).softmax(/*dim=*/-1);
41+
42+
auto probs = target_prob.reshape({1, -1}).repeat({num_samples, 1});
43+
auto output = Sampler::random_sample(probs);
44+
45+
auto token_ids = output.flatten();
46+
// calculate the probability of each sampled token
47+
auto bincount =
48+
token_ids.bincount(/*weights=*/torch::nullopt, /*minlength=*/vocab_size);
49+
auto sample_prob = bincount.to(torch::kFloat) / num_samples;
4950

50-
// TODO: add unittests for Random
51+
EXPECT_TRUE(
52+
torch::allclose(target_prob, sample_prob, /*rtol=*/1e-2, /*atol=*/1e-3));
5153
}
5254

5355
} // namespace llm

0 commit comments

Comments
 (0)