Skip to content

Commit acd6ae0

Browse files
authored
[unittest] added more unittests for speculative decoding (#141)
1 parent 7472bcd commit acd6ae0

File tree

7 files changed

+325
-774
lines changed

7 files changed

+325
-774
lines changed

src/common/slice.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,22 @@ class Slice final {
6464
template <typename T>
6565
inline bool operator==(const Slice<T>& lhs, const std::vector<T>& rhs) {
6666
return lhs.size() == rhs.size() &&
67-
std::equal(lhs.begin(), lhs.end(), rhs.begin());
67+
(lhs.data() == rhs.data() ||
68+
std::equal(lhs.begin(), lhs.end(), rhs.begin()));
6869
}
6970

7071
template <typename T>
7172
inline bool operator==(const std::vector<T>& lhs, const Slice<T>& rhs) {
7273
return lhs.size() == rhs.size() &&
73-
std::equal(lhs.begin(), lhs.end(), rhs.begin());
74+
(lhs.data() == rhs.data() ||
75+
std::equal(lhs.begin(), lhs.end(), rhs.begin()));
76+
}
77+
78+
template <typename T>
79+
inline bool operator==(const Slice<T>& lhs, const Slice<T>& rhs) {
80+
return lhs.size() == rhs.size() &&
81+
(lhs.data() == rhs.data() ||
82+
std::equal(lhs.begin(), lhs.end(), rhs.begin()));
7483
}
7584

7685
} // namespace llm

src/request/sequence.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,24 +54,29 @@ void Sequence::append_token(int32_t token_id) {
5454

5555
size_t Sequence::validate_tokens(const Slice<int64_t>& accpeted_token_ids) {
5656
const size_t len = accpeted_token_ids.size();
57+
CHECK_GT(len, 0) << "empty accepted token ids";
5758
CHECK_GT(num_tokens_, len) << "accepted tokens exceed the sequence length";
59+
const auto bonus_token_id = accpeted_token_ids.back();
60+
CHECK(bonus_token_id == -1 || bonus_token_id == token_ids().back())
61+
<< "bonus token mismatch with the last token";
5862

5963
// validate the accepted tokens with draft tokens, stop at the first mismatch
6064
const size_t start_idx = num_tokens_ - len;
61-
size_t accpeted_len = 0;
65+
bool mismatch = false;
66+
size_t num_accpeted = 0;
6267
for (size_t i = 0; i < len; ++i) {
6368
const size_t cur_idx = start_idx + i;
6469
const int32_t draft_token_id = token_ids_[cur_idx];
6570
const int32_t target_token_id = static_cast<int32_t>(accpeted_token_ids[i]);
6671

67-
// stop at first rejected token id
68-
if (target_token_id == -1) {
72+
// stop at first mismatch or rejected token
73+
if (mismatch || target_token_id == -1) {
6974
num_tokens_ = cur_idx;
7075
break;
7176
}
72-
73-
++accpeted_len;
74-
if (target_token_id != draft_token_id) {
77+
++num_accpeted;
78+
mismatch = target_token_id != draft_token_id;
79+
if (mismatch) {
7580
// overwrite the token id with the accepted token id
7681
token_ids_[cur_idx] = target_token_id;
7782
// update the token count
@@ -93,9 +98,8 @@ size_t Sequence::validate_tokens(const Slice<int64_t>& accpeted_token_ids) {
9398
}
9499

95100
// adjust the token count for remaining discarded tokens
96-
for (size_t i = accpeted_len; i < len; ++i) {
97-
const auto token_id = token_ids_[start_idx + i];
98-
--token_to_count_map_[token_id];
101+
for (size_t i = num_accpeted; i < len; ++i) {
102+
--token_to_count_map_[token_ids_[start_idx + i]];
99103
}
100104

101105
// adjust kv cache position
@@ -104,9 +108,11 @@ size_t Sequence::validate_tokens(const Slice<int64_t>& accpeted_token_ids) {
104108
num_kv_cache_tokens = std::min(num_kv_cache_tokens, num_tokens_ - 1);
105109
}
106110

111+
CHECK_GT(num_accpeted, 0) << "no token accepted";
112+
107113
// the finish status is valid after the validation
108114
finish_status_invalidated_ = false;
109-
return accpeted_len;
115+
return num_accpeted;
110116
}
111117

112118
// decode the sequence to get delta text using the tokenizer

src/request/sequence.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ class Sequence final {
133133

134134
// validate draft tokens with accepted tokens for speculative decoding
135135
// N.B. take int64_t as input to be compatible with torch::Tensor
136-
// returns the number of accepted tokens
136+
// returns the number of accepted tokens, including the resampled token
137137
size_t validate_tokens(const Slice<int64_t>& accpeted_token_ids);
138138

139139
// add new cache blocks

0 commit comments

Comments
 (0)