Skip to content

Commit 7472bcd

Browse files
authored
[refactor] refactoring for sequence (#140)
this refaoctoring includes: 1> moved incremental decoding logic into a central place. 2> pass in capacity for sequence and move all GFlags into a central place. 3> added unittests for speculative decoding for sequence.
1 parent dff774e commit 7472bcd

24 files changed

+664
-431
lines changed

src/engine/batch.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ void Batch::process_sample_output(const SampleOutput& sample_output) {
285285
// add the next token to sequence
286286
const int32_t next_token_id =
287287
static_cast<int32_t>(next_tokens[output_idx++].item<int64_t>());
288-
seq->append_new_token_id(next_token_id);
288+
seq->append_token(next_token_id);
289289
}
290290
CHECK_EQ(output_idx, num_seqs);
291291
}
@@ -307,7 +307,7 @@ void Batch::process_validate_output(const torch::Tensor& accepted_ids) {
307307
ids.data_ptr<int64_t>(), static_cast<size_t>(ids.numel())};
308308

309309
// validate the draft tokens with accepted tokens
310-
seq->validate_token_ids(accepted_token_ids);
310+
seq->validate_tokens(accepted_token_ids);
311311
}
312312
CHECK_EQ(output_idx, num_seqs);
313313
}

src/engine/batch_test.cpp

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,40 +34,37 @@ TEST(BatchTest, Basic) {
3434
auto block_0 = allocator.allocate();
3535
EXPECT_EQ(block_0.id(), 0);
3636

37-
SamplingParameter sampling_param;
38-
sampling_param.frequency_penalty = 0.1;
39-
StoppingCriteria stopping_criteria;
40-
stopping_criteria.max_tokens = 20;
37+
Sequence::Options options;
38+
options.sampling_param.frequency_penalty = 0.1;
39+
options.stopping_criteria.max_tokens = 20;
40+
const size_t capacity = 100;
4141

4242
// prepare sequences
4343
// sequence in prefill phase
44-
Sequence seq1(/*token_ids=*/{1, 3, 5, 7, 5, 4, 3, 2, 1},
45-
sampling_param,
46-
stopping_criteria,
47-
/*echo=*/false,
48-
/*on_stream=*/nullptr);
44+
Sequence seq1(/*prompt=*/"",
45+
/*token_ids=*/{1, 3, 5, 7, 5, 4, 3, 2, 1},
46+
capacity,
47+
options);
4948
seq1.append_blocks(allocator.allocate(3)); // [1, 2, 3]
5049

5150
// seq in decode phase
52-
Sequence seq2(/*token_ids=*/{2, 4, 6, 8, 6, 4, 2},
53-
sampling_param,
54-
stopping_criteria,
55-
/*echo=*/false,
56-
/*on_stream=*/nullptr);
51+
Sequence seq2(/*prompt=*/"",
52+
/*token_ids=*/{2, 4, 6, 8, 6, 4, 2},
53+
capacity,
54+
options);
5755
seq2.append_blocks(allocator.allocate(4)); // [4, 5, 6, 7]
5856
seq2.commit_kv_cache(/*size=*/7);
59-
seq2.append_new_token_id(100);
57+
seq2.append_token(100);
6058

6159
// seq in decode phase
6260
Sequence seq3(
61+
/*prompt=*/"",
6362
/*token_ids=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19},
64-
sampling_param,
65-
stopping_criteria,
66-
/*echo=*/false,
67-
/*on_stream=*/nullptr);
63+
capacity,
64+
options);
6865
seq3.append_blocks(allocator.allocate(5)); // [8, 9, 10, 11, 12]
6966
seq3.commit_kv_cache(/*size=*/15);
70-
seq3.append_new_token_id(200);
67+
seq3.append_token(200);
7168

7269
// define outputs
7370
Batch batch({&seq1, &seq2, &seq3});

src/handlers/chat_handler.cpp

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "chat_handler.h"
22

3+
#include <absl/strings/escaping.h>
34
#include <glog/logging.h>
45
#include <grpcpp/grpcpp.h>
56
#include <torch/torch.h>
@@ -17,9 +18,9 @@
1718
#include "scheduler/scheduler.h"
1819
#include "utils.h"
1920

20-
DEFINE_bool(disable_default_chat_template,
21-
false,
22-
"Disable default chat template");
21+
DEFINE_bool(enable_jinja_chat_template, false, "Enable Jinja chat template");
22+
23+
DECLARE_int32(num_speculative_tokens);
2324

2425
namespace llm {
2526

@@ -87,10 +88,9 @@ bool send_delta_to_client(ChatCallData* call_data,
8788
Request* request,
8889
uint32_t index,
8990
bool first_message,
90-
const std::string& delta,
91-
FinishReason reason) {
91+
const SequenceDeltaOutput& output) {
9292
// send delta to client
93-
if (!delta.empty()) {
93+
if (!output.delta.empty()) {
9494
ChatResponse response;
9595
response.set_object("chat.completion.chunk");
9696
response.set_id(request->id);
@@ -104,22 +104,22 @@ bool send_delta_to_client(ChatCallData* call_data,
104104
if (first_message) {
105105
message->set_role("assistant");
106106
}
107-
message->set_content(delta);
107+
message->set_content(output.delta);
108108
if (!call_data->write(std::move(response))) {
109109
return false;
110110
}
111111
}
112112

113113
// send finish reason as a separate message
114-
if (reason != FinishReason::NONE) {
114+
if (output.finish_reason != FinishReason::NONE) {
115115
ChatResponse response;
116116
response.set_object("chat.completion");
117117
response.set_id(request->id);
118118
response.set_created(request->created_time);
119119
// response.set_model(request->model);
120120
auto* choice = response.add_choices();
121121
choice->set_index(index);
122-
choice->set_finish_reason(finish_reason_to_string(reason));
122+
choice->set_finish_reason(finish_reason_to_string(output.finish_reason));
123123
if (!call_data->write(std::move(response))) {
124124
return false;
125125
}
@@ -129,7 +129,7 @@ bool send_delta_to_client(ChatCallData* call_data,
129129

130130
bool send_result_to_client(ChatCallData* call_data,
131131
Request* request,
132-
const std::vector<SequenceResult>& seq_results,
132+
const std::vector<SequenceOutput>& seq_results,
133133
const Status& /*status*/,
134134
const Statistics& stats) {
135135
ChatResponse response;
@@ -145,7 +145,7 @@ bool send_result_to_client(ChatCallData* call_data,
145145
choice->set_index(i);
146146
auto* message = choice->mutable_message();
147147
message->set_role("assistant");
148-
message->set_content(seq_result.output_text);
148+
message->set_content(seq_result.text);
149149
if (seq_result.finish_reason != FinishReason::NONE) {
150150
choice->set_finish_reason(
151151
finish_reason_to_string(seq_result.finish_reason));
@@ -206,9 +206,21 @@ std::unique_ptr<Request> grpc_request_to_request(ChatCallData* call_data,
206206
return nullptr;
207207
}
208208

209+
uint32_t max_tokens = 0;
210+
if (grpc_request.has_max_tokens()) {
211+
max_tokens = grpc_request.max_tokens();
212+
} else {
213+
const uint32_t kDefaultMaxTokens = 16;
214+
max_tokens = kDefaultMaxTokens;
215+
}
216+
209217
const uint32_t num_seqs = grpc_request.has_n() ? grpc_request.n() : 1;
218+
// allocate enough capacity for prompt tokens, max tokens, and speculative
219+
// tokens
220+
const size_t capacity = prompt_tokens.size() + max_tokens +
221+
FLAGS_num_speculative_tokens + /*bouns_token*/ 1;
210222
auto request = std::make_unique<Request>(
211-
generate_request_id(), "", num_seqs, prompt_tokens);
223+
generate_request_id(), "", prompt_tokens, capacity, num_seqs);
212224

213225
// construct sampling parameters
214226
auto& sampling_param = request->sampling_param;
@@ -232,16 +244,9 @@ std::unique_ptr<Request> grpc_request_to_request(ChatCallData* call_data,
232244

233245
// construct stopping criteria
234246
auto& stopping_criteria = request->stopping_criteria;
235-
auto max_tokens =
236-
static_cast<uint32_t>(max_context_len - prompt_tokens.size());
237-
if (grpc_request.has_max_tokens()) {
238-
max_tokens = std::min(max_tokens, grpc_request.max_tokens());
239-
} else {
240-
const uint32_t kDefaultMaxTokens = 128;
241-
max_tokens = std::min(max_tokens, kDefaultMaxTokens);
242-
}
243247
stopping_criteria.max_tokens = max_tokens;
244-
stopping_criteria.max_context_length = model_args.max_position_embeddings();
248+
stopping_criteria.max_context_len =
249+
max_context_len - FLAGS_num_speculative_tokens;
245250
// stopping_criteria.ignore_eos_token = false;
246251
stopping_criteria.eos_token_id = model_args.eos_token_id();
247252

@@ -280,14 +285,14 @@ std::unique_ptr<Request> grpc_request_to_request(ChatCallData* call_data,
280285
// set callbacks
281286
if (request->stream) {
282287
// set callback for stream delta
283-
request->on_stream_delta = [call_data, request = request.get()](
284-
size_t index,
285-
bool first_message,
286-
const std::string& delta,
287-
FinishReason reason) -> bool {
288-
return send_delta_to_client(
289-
call_data, request, index, first_message, delta, reason);
290-
};
288+
request->on_stream_delta =
289+
[call_data, request = request.get(), first_message = true](
290+
size_t index, const SequenceDeltaOutput& output) mutable {
291+
const auto ret = send_delta_to_client(
292+
call_data, request, index, first_message, output);
293+
first_message = false;
294+
return ret;
295+
};
291296

292297
// set callback for stream request
293298
request->on_stream_finish = [call_data](const Status& /*status*/) -> bool {
@@ -296,7 +301,7 @@ std::unique_ptr<Request> grpc_request_to_request(ChatCallData* call_data,
296301
} else {
297302
// set callback for non-stream request
298303
request->on_finish = [call_data, request = request.get()](
299-
const std::vector<SequenceResult>& seq_results,
304+
const std::vector<SequenceOutput>& seq_results,
300305
const Status& status,
301306
const Statistics& stats) -> bool {
302307
return send_result_to_client(
@@ -323,15 +328,15 @@ ChatHandler::ChatHandler(Scheduler* scheduler, const Engine* engine)
323328
// construct chat template
324329
auto factory = ModelRegistry::get_default_chat_template_factory(
325330
model_args_.model_type());
326-
if (!FLAGS_disable_default_chat_template && factory) {
327-
LOG(INFO) << "Use default chat template for model type: "
331+
if (!FLAGS_enable_jinja_chat_template && factory) {
332+
LOG(INFO) << "Using default chat template for model type: "
328333
<< model_args_.model_type();
329334
chat_template_ = factory();
330335
} else {
331336
const auto& tokenizer_args = engine->tokenizer_args();
332337
if (!tokenizer_args.chat_template().empty()) {
333-
LOG(INFO) << "Use chat template from tokenizer args for model type: "
334-
<< model_args_.model_type();
338+
LOG(INFO) << "Using jinja chat template: "
339+
<< absl::CEscape(tokenizer_args.chat_template());
335340
chat_template_ = std::make_unique<JinjaChatTemplate>(
336341
tokenizer_args.chat_template(), /*add_generation_prompt=*/true);
337342
}

src/handlers/completion_handler.cpp

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include "scheduler/scheduler.h"
1515
#include "utils.h"
1616

17+
DECLARE_int32(num_speculative_tokens);
18+
1719
namespace llm {
1820

1921
namespace {
@@ -102,30 +104,29 @@ bool verify_request_arguments(CompletionCallData* call_data) {
102104
bool send_delta_to_client(CompletionCallData* call_data,
103105
Request* request,
104106
uint32_t index,
105-
const std::string& delta,
106-
FinishReason reason) {
107-
if (!delta.empty()) {
107+
const SequenceDeltaOutput& output) {
108+
if (!output.delta.empty()) {
108109
CompletionResponse response;
109110
response.set_object("text_completion");
110111
response.set_id(request->id);
111112
response.set_created(request->created_time);
112113
// response.set_model(request->model);
113114
auto* choice = response.add_choices();
114115
choice->set_index(index);
115-
choice->set_text(delta);
116+
choice->set_text(output.delta);
116117
if (!call_data->write(std::move(response))) {
117118
return false;
118119
}
119120
}
120121

121-
if (reason != FinishReason::NONE) {
122+
if (output.finish_reason != FinishReason::NONE) {
122123
CompletionResponse response;
123124
response.set_object("text_completion");
124125
response.set_id(request->id);
125126
response.set_created(request->created_time);
126127
// response.set_model(request->model);
127128
auto* choice = response.add_choices();
128-
choice->set_finish_reason(finish_reason_to_string(reason));
129+
choice->set_finish_reason(finish_reason_to_string(output.finish_reason));
129130
if (!call_data->write(std::move(response))) {
130131
return false;
131132
}
@@ -135,7 +136,7 @@ bool send_delta_to_client(CompletionCallData* call_data,
135136

136137
bool send_result_to_client(CompletionCallData* call_data,
137138
Request* request,
138-
const std::vector<SequenceResult>& seq_results,
139+
const std::vector<SequenceOutput>& outputs,
139140
const Status& /*status*/,
140141
const Statistics& stats) {
141142
CompletionResponse response;
@@ -145,15 +146,14 @@ bool send_result_to_client(CompletionCallData* call_data,
145146
// response.set_model(request->model);
146147

147148
// add choices into response
148-
for (uint32_t i = 0; i < seq_results.size(); ++i) {
149-
const auto& seq_result = seq_results[i];
149+
for (uint32_t i = 0; i < outputs.size(); ++i) {
150+
const auto& output = outputs[i];
150151
auto* choice = response.add_choices();
151152
choice->set_index(i);
152-
choice->set_text(seq_result.output_text);
153+
choice->set_text(output.text);
153154
// choice->set_logprobs(0);
154-
if (seq_result.finish_reason != FinishReason::NONE) {
155-
choice->set_finish_reason(
156-
finish_reason_to_string(seq_result.finish_reason));
155+
if (output.finish_reason != FinishReason::NONE) {
156+
choice->set_finish_reason(finish_reason_to_string(output.finish_reason));
157157
}
158158
}
159159

@@ -191,9 +191,24 @@ std::unique_ptr<Request> grpc_request_to_request(CompletionCallData* call_data,
191191
return nullptr;
192192
}
193193

194+
uint32_t max_tokens = 0;
195+
if (grpc_request.has_max_tokens()) {
196+
max_tokens = grpc_request.max_tokens();
197+
} else {
198+
const uint32_t kDefaultMaxTokens = 16;
199+
max_tokens = kDefaultMaxTokens;
200+
}
201+
// allocate enough capacity for prompt tokens, max tokens, and speculative
202+
// tokens
203+
const size_t capacity = prompt_tokens.size() + max_tokens +
204+
FLAGS_num_speculative_tokens + /*bouns_token*/ 1;
205+
194206
const uint32_t num_seqs = grpc_request.has_n() ? grpc_request.n() : 1;
195-
auto request = std::make_unique<Request>(
196-
generate_request_id(), grpc_request.prompt(), num_seqs, prompt_tokens);
207+
auto request = std::make_unique<Request>(generate_request_id(),
208+
grpc_request.prompt(),
209+
prompt_tokens,
210+
capacity,
211+
num_seqs);
197212

198213
// construct sampling parameters
199214
auto& sampling_param = request->sampling_param;
@@ -217,16 +232,9 @@ std::unique_ptr<Request> grpc_request_to_request(CompletionCallData* call_data,
217232

218233
// construct stopping criteria
219234
auto& stopping_criteria = request->stopping_criteria;
220-
auto max_tokens =
221-
static_cast<uint32_t>(max_context_len - prompt_tokens.size());
222-
if (grpc_request.has_max_tokens()) {
223-
max_tokens = std::min(max_tokens, grpc_request.max_tokens());
224-
} else {
225-
const uint32_t kDefaultMaxTokens = 128;
226-
max_tokens = std::min(max_tokens, kDefaultMaxTokens);
227-
}
228235
stopping_criteria.max_tokens = max_tokens;
229-
stopping_criteria.max_context_length = model_args.max_position_embeddings();
236+
stopping_criteria.max_context_len =
237+
max_context_len - FLAGS_num_speculative_tokens;
230238
// stopping_criteria.ignore_eos_token = false;
231239
stopping_criteria.eos_token_id = model_args.eos_token_id();
232240

@@ -263,10 +271,8 @@ std::unique_ptr<Request> grpc_request_to_request(CompletionCallData* call_data,
263271
if (request->stream) {
264272
request->on_stream_delta = [call_data, request = request.get()](
265273
size_t index,
266-
bool /*first_message*/,
267-
const std::string& delta,
268-
FinishReason reason) -> bool {
269-
return send_delta_to_client(call_data, request, index, delta, reason);
274+
const SequenceDeltaOutput& output) -> bool {
275+
return send_delta_to_client(call_data, request, index, output);
270276
};
271277

272278
// add on_stream_finish callback
@@ -276,7 +282,7 @@ std::unique_ptr<Request> grpc_request_to_request(CompletionCallData* call_data,
276282
} else {
277283
// add on_finish callback
278284
request->on_finish = [call_data, request = request.get()](
279-
const std::vector<SequenceResult>& seq_results,
285+
const std::vector<SequenceOutput>& seq_results,
280286
const Status& status,
281287
const Statistics& stats) -> bool {
282288
return send_result_to_client(

0 commit comments

Comments
 (0)