11#include " chat_handler.h"
22
3+ #include < absl/strings/escaping.h>
34#include < glog/logging.h>
45#include < grpcpp/grpcpp.h>
56#include < torch/torch.h>
1718#include " scheduler/scheduler.h"
1819#include " utils.h"
1920
20- DEFINE_bool (disable_default_chat_template,
21- false ,
22- " Disable default chat template " );
21+ DEFINE_bool (enable_jinja_chat_template, false , " Enable Jinja chat template " );
22+
23+ DECLARE_int32 (num_speculative_tokens );
2324
2425namespace llm {
2526
@@ -87,10 +88,9 @@ bool send_delta_to_client(ChatCallData* call_data,
8788 Request* request,
8889 uint32_t index,
8990 bool first_message,
90- const std::string& delta,
91- FinishReason reason) {
91+ const SequenceDeltaOutput& output) {
9292 // send delta to client
93- if (!delta.empty ()) {
93+ if (!output. delta .empty ()) {
9494 ChatResponse response;
9595 response.set_object (" chat.completion.chunk" );
9696 response.set_id (request->id );
@@ -104,22 +104,22 @@ bool send_delta_to_client(ChatCallData* call_data,
104104 if (first_message) {
105105 message->set_role (" assistant" );
106106 }
107- message->set_content (delta);
107+ message->set_content (output. delta );
108108 if (!call_data->write (std::move (response))) {
109109 return false ;
110110 }
111111 }
112112
113113 // send finish reason as a separate message
114- if (reason != FinishReason::NONE) {
114+ if (output. finish_reason != FinishReason::NONE) {
115115 ChatResponse response;
116116 response.set_object (" chat.completion" );
117117 response.set_id (request->id );
118118 response.set_created (request->created_time );
119119 // response.set_model(request->model);
120120 auto * choice = response.add_choices ();
121121 choice->set_index (index);
122- choice->set_finish_reason (finish_reason_to_string (reason ));
122+ choice->set_finish_reason (finish_reason_to_string (output. finish_reason ));
123123 if (!call_data->write (std::move (response))) {
124124 return false ;
125125 }
@@ -129,7 +129,7 @@ bool send_delta_to_client(ChatCallData* call_data,
129129
130130bool send_result_to_client (ChatCallData* call_data,
131131 Request* request,
132- const std::vector<SequenceResult >& seq_results,
132+ const std::vector<SequenceOutput >& seq_results,
133133 const Status& /* status*/ ,
134134 const Statistics& stats) {
135135 ChatResponse response;
@@ -145,7 +145,7 @@ bool send_result_to_client(ChatCallData* call_data,
145145 choice->set_index (i);
146146 auto * message = choice->mutable_message ();
147147 message->set_role (" assistant" );
148- message->set_content (seq_result.output_text );
148+ message->set_content (seq_result.text );
149149 if (seq_result.finish_reason != FinishReason::NONE) {
150150 choice->set_finish_reason (
151151 finish_reason_to_string (seq_result.finish_reason ));
@@ -206,9 +206,21 @@ std::unique_ptr<Request> grpc_request_to_request(ChatCallData* call_data,
206206 return nullptr ;
207207 }
208208
209+ uint32_t max_tokens = 0 ;
210+ if (grpc_request.has_max_tokens ()) {
211+ max_tokens = grpc_request.max_tokens ();
212+ } else {
213+ const uint32_t kDefaultMaxTokens = 16 ;
214+ max_tokens = kDefaultMaxTokens ;
215+ }
216+
209217 const uint32_t num_seqs = grpc_request.has_n () ? grpc_request.n () : 1 ;
218+ // allocate enough capacity for prompt tokens, max tokens, and speculative
219+ // tokens
220+ const size_t capacity = prompt_tokens.size () + max_tokens +
221+ FLAGS_num_speculative_tokens + /* bouns_token*/ 1 ;
210222 auto request = std::make_unique<Request>(
211- generate_request_id (), " " , num_seqs, prompt_tokens );
223+ generate_request_id (), " " , prompt_tokens, capacity, num_seqs );
212224
213225 // construct sampling parameters
214226 auto & sampling_param = request->sampling_param ;
@@ -232,16 +244,9 @@ std::unique_ptr<Request> grpc_request_to_request(ChatCallData* call_data,
232244
233245 // construct stopping criteria
234246 auto & stopping_criteria = request->stopping_criteria ;
235- auto max_tokens =
236- static_cast <uint32_t >(max_context_len - prompt_tokens.size ());
237- if (grpc_request.has_max_tokens ()) {
238- max_tokens = std::min (max_tokens, grpc_request.max_tokens ());
239- } else {
240- const uint32_t kDefaultMaxTokens = 128 ;
241- max_tokens = std::min (max_tokens, kDefaultMaxTokens );
242- }
243247 stopping_criteria.max_tokens = max_tokens;
244- stopping_criteria.max_context_length = model_args.max_position_embeddings ();
248+ stopping_criteria.max_context_len =
249+ max_context_len - FLAGS_num_speculative_tokens;
245250 // stopping_criteria.ignore_eos_token = false;
246251 stopping_criteria.eos_token_id = model_args.eos_token_id ();
247252
@@ -280,14 +285,14 @@ std::unique_ptr<Request> grpc_request_to_request(ChatCallData* call_data,
280285 // set callbacks
281286 if (request->stream ) {
282287 // set callback for stream delta
283- request->on_stream_delta = [call_data, request = request. get ()](
284- size_t index,
285- bool first_message,
286- const std::string& delta,
287- FinishReason reason) -> bool {
288- return send_delta_to_client (
289- call_data, request, index, first_message, delta, reason) ;
290- };
288+ request->on_stream_delta =
289+ [call_data, request = request. get (), first_message = true ](
290+ size_t index, const SequenceDeltaOutput& output) mutable {
291+ const auto ret = send_delta_to_client (
292+ call_data, request, index, first_message, output);
293+ first_message = false ;
294+ return ret ;
295+ };
291296
292297 // set callback for stream request
293298 request->on_stream_finish = [call_data](const Status& /* status*/ ) -> bool {
@@ -296,7 +301,7 @@ std::unique_ptr<Request> grpc_request_to_request(ChatCallData* call_data,
296301 } else {
297302 // set callback for non-stream request
298303 request->on_finish = [call_data, request = request.get ()](
299- const std::vector<SequenceResult >& seq_results,
304+ const std::vector<SequenceOutput >& seq_results,
300305 const Status& status,
301306 const Statistics& stats) -> bool {
302307 return send_result_to_client (
@@ -323,15 +328,15 @@ ChatHandler::ChatHandler(Scheduler* scheduler, const Engine* engine)
323328 // construct chat template
324329 auto factory = ModelRegistry::get_default_chat_template_factory (
325330 model_args_.model_type ());
326- if (!FLAGS_disable_default_chat_template && factory) {
327- LOG (INFO) << " Use default chat template for model type: "
331+ if (!FLAGS_enable_jinja_chat_template && factory) {
332+ LOG (INFO) << " Using default chat template for model type: "
328333 << model_args_.model_type ();
329334 chat_template_ = factory ();
330335 } else {
331336 const auto & tokenizer_args = engine->tokenizer_args ();
332337 if (!tokenizer_args.chat_template ().empty ()) {
333- LOG (INFO) << " Use chat template from tokenizer args for model type : "
334- << model_args_. model_type ( );
338+ LOG (INFO) << " Using jinja chat template: "
339+ << absl::CEscape (tokenizer_args. chat_template () );
335340 chat_template_ = std::make_unique<JinjaChatTemplate>(
336341 tokenizer_args.chat_template (), /* add_generation_prompt=*/ true );
337342 }
0 commit comments