@@ -54,24 +54,29 @@ void Sequence::append_token(int32_t token_id) {
5454
5555size_t Sequence::validate_tokens (const Slice<int64_t >& accpeted_token_ids) {
5656 const size_t len = accpeted_token_ids.size ();
57+ CHECK_GT (len, 0 ) << " empty accepted token ids" ;
5758 CHECK_GT (num_tokens_, len) << " accepted tokens exceed the sequence length" ;
59+ const auto bonus_token_id = accpeted_token_ids.back ();
60+ CHECK (bonus_token_id == -1 || bonus_token_id == token_ids ().back ())
61+ << " bonus token mismatch with the last token" ;
5862
5963 // validate the accepted tokens with draft tokens, stop at the first mismatch
6064 const size_t start_idx = num_tokens_ - len;
61- size_t accpeted_len = 0 ;
65+ bool mismatch = false ;
66+ size_t num_accpeted = 0 ;
6267 for (size_t i = 0 ; i < len; ++i) {
6368 const size_t cur_idx = start_idx + i;
6469 const int32_t draft_token_id = token_ids_[cur_idx];
6570 const int32_t target_token_id = static_cast <int32_t >(accpeted_token_ids[i]);
6671
67- // stop at first rejected token id
68- if (target_token_id == -1 ) {
72+ // stop at first mismatch or rejected token
73+ if (mismatch || target_token_id == -1 ) {
6974 num_tokens_ = cur_idx;
7075 break ;
7176 }
72-
73- ++accpeted_len ;
74- if (target_token_id != draft_token_id ) {
77+ ++num_accpeted;
78+ mismatch = target_token_id != draft_token_id ;
79+ if (mismatch ) {
7580 // overwrite the token id with the accepted token id
7681 token_ids_[cur_idx] = target_token_id;
7782 // update the token count
@@ -93,9 +98,8 @@ size_t Sequence::validate_tokens(const Slice<int64_t>& accpeted_token_ids) {
9398 }
9499
95100 // adjust the token count for remaining discarded tokens
96- for (size_t i = accpeted_len; i < len; ++i) {
97- const auto token_id = token_ids_[start_idx + i];
98- --token_to_count_map_[token_id];
101+ for (size_t i = num_accpeted; i < len; ++i) {
102+ --token_to_count_map_[token_ids_[start_idx + i]];
99103 }
100104
101105 // adjust kv cache position
@@ -104,9 +108,11 @@ size_t Sequence::validate_tokens(const Slice<int64_t>& accpeted_token_ids) {
104108 num_kv_cache_tokens = std::min (num_kv_cache_tokens, num_tokens_ - 1 );
105109 }
106110
111+ CHECK_GT (num_accpeted, 0 ) << " no token accepted" ;
112+
107113 // the finish status is valid after the validation
108114 finish_status_invalidated_ = false ;
109- return accpeted_len ;
115+ return num_accpeted ;
110116}
111117
112118// decode the sequence to get delta text using the tokenizer
0 commit comments