@@ -10,7 +10,7 @@ using json = nlohmann::json;
1010/* *
1111 * The state of the inference task
1212 */
13- enum InferenceStatus { PENDING, RUNNING, FINISHED };
13+ enum InferenceStatus { PENDING, RUNNING, EOS, FINISHED };
1414
1515/* *
1616 * There is a need to save state of current ongoing inference status of a
@@ -21,7 +21,7 @@ enum InferenceStatus { PENDING, RUNNING, FINISHED };
2121 */
2222struct inferenceState {
2323 int task_id;
24- InferenceStatus inferenceStatus = PENDING;
24+ InferenceStatus inference_status = PENDING;
2525 llamaCPP *instance;
2626
2727 inferenceState (llamaCPP *inst) : instance(inst) {}
@@ -111,7 +111,6 @@ std::string create_full_return_json(const std::string &id,
111111std::string create_return_json (const std::string &id, const std::string &model,
112112 const std::string &content,
113113 Json::Value finish_reason = Json::Value()) {
114-
115114 Json::Value root;
116115
117116 root[" id" ] = id;
@@ -167,7 +166,6 @@ void llamaCPP::warmupModel() {
167166void llamaCPP::inference (
168167 const HttpRequestPtr &req,
169168 std::function<void (const HttpResponsePtr &)> &&callback) {
170-
171169 const auto &jsonBody = req->getJsonObject ();
172170 // Check if model is loaded
173171 if (checkModelLoaded (callback)) {
@@ -180,7 +178,6 @@ void llamaCPP::inference(
180178void llamaCPP::inferenceImpl (
181179 std::shared_ptr<Json::Value> jsonBody,
182180 std::function<void (const HttpResponsePtr &)> &callback) {
183-
184181 std::string formatted_output = pre_prompt;
185182
186183 json data;
@@ -218,7 +215,6 @@ void llamaCPP::inferenceImpl(
218215 };
219216
220217 if (!llama.multimodal ) {
221-
222218 for (const auto &message : messages) {
223219 std::string input_role = message[" role" ].asString ();
224220 std::string role;
@@ -243,7 +239,6 @@ void llamaCPP::inferenceImpl(
243239 }
244240 formatted_output += ai_prompt;
245241 } else {
246-
247242 data[" image_data" ] = json::array ();
248243 for (const auto &message : messages) {
249244 std::string input_role = message[" role" ].asString ();
@@ -327,18 +322,33 @@ void llamaCPP::inferenceImpl(
327322 auto state = create_inference_state (this );
328323 auto chunked_content_provider =
329324 [state, data](char *pBuffer, std::size_t nBuffSize) -> std::size_t {
330- if (state->inferenceStatus == PENDING) {
331- state->inferenceStatus = RUNNING;
332- } else if (state->inferenceStatus == FINISHED) {
325+ if (state->inference_status == PENDING) {
326+ state->inference_status = RUNNING;
327+ } else if (state->inference_status == FINISHED) {
333328 return 0 ;
334329 }
335330
336331 if (!pBuffer) {
337332 LOG_INFO << " Connection closed or buffer is null. Reset context" ;
338- state->inferenceStatus = FINISHED;
333+ state->inference_status = FINISHED;
339334 return 0 ;
340335 }
341336
337+ if (state->inference_status == EOS) {
338+ LOG_INFO << " End of result" ;
339+ const std::string str =
340+ " data: " +
341+ create_return_json (nitro_utils::generate_random_string (20 ), " _" , " " ,
342+ " stop" ) +
343+ " \n\n " + " data: [DONE]" + " \n\n " ;
344+
345+ LOG_VERBOSE (" data stream" , {{" to_send" , str}});
346+ std::size_t nRead = std::min (str.size (), nBuffSize);
347+ memcpy (pBuffer, str.data (), nRead);
348+ state->inference_status = FINISHED;
349+ return nRead;
350+ }
351+
342352 task_result result = state->instance ->llama .next_result (state->task_id );
343353 if (!result.error ) {
344354 const std::string to_send = result.result_json [" content" ];
@@ -352,28 +362,22 @@ void llamaCPP::inferenceImpl(
352362 memcpy (pBuffer, str.data (), nRead);
353363
354364 if (result.stop ) {
355- const std::string str =
356- " data: " +
357- create_return_json (nitro_utils::generate_random_string (20 ), " _" ,
358- " " , " stop" ) +
359- " \n\n " + " data: [DONE]" + " \n\n " ;
360-
361- LOG_VERBOSE (" data stream" , {{" to_send" , str}});
362- std::size_t nRead = std::min (str.size (), nBuffSize);
363- memcpy (pBuffer, str.data (), nRead);
364365 LOG_INFO << " reached result stop" ;
365- state->inferenceStatus = FINISHED;
366+ state->inference_status = EOS;
367+ return nRead;
366368 }
367369
368370 // Make sure nBufferSize is not zero
369371 // Otherwise it stop streaming
370372 if (!nRead) {
371- state->inferenceStatus = FINISHED;
373+ state->inference_status = FINISHED;
372374 }
373375
374376 return nRead;
377+ } else {
378+ LOG_INFO << " Error during inference" ;
375379 }
376- state->inferenceStatus = FINISHED;
380+ state->inference_status = FINISHED;
377381 return 0 ;
378382 };
379383 // Queued task
@@ -391,16 +395,17 @@ void llamaCPP::inferenceImpl(
391395
392396 // Since this is an async task, we will wait for the task to be
393397 // completed
394- while (state->inferenceStatus != FINISHED && retries < 10 ) {
398+ while (state->inference_status != FINISHED && retries < 10 ) {
395399 // Should wait chunked_content_provider lambda to be called within
396400 // 3s
397- if (state->inferenceStatus == PENDING) {
401+ if (state->inference_status == PENDING) {
398402 retries += 1 ;
399403 }
400- if (state->inferenceStatus != RUNNING)
404+ if (state->inference_status != RUNNING)
401405 LOG_INFO << " Wait for task to be released:" << state->task_id ;
402406 std::this_thread::sleep_for (std::chrono::milliseconds (100 ));
403407 }
408+ LOG_INFO << " Task completed, release it" ;
404409 // Request completed, release it
405410 state->instance ->llama .request_cancel (state->task_id );
406411 });
@@ -445,7 +450,6 @@ void llamaCPP::embedding(
445450void llamaCPP::embeddingImpl (
446451 std::shared_ptr<Json::Value> jsonBody,
447452 std::function<void (const HttpResponsePtr &)> &callback) {
448-
449453 // Queue embedding task
450454 auto state = create_inference_state (this );
451455
@@ -532,7 +536,6 @@ void llamaCPP::modelStatus(
532536void llamaCPP::loadModel (
533537 const HttpRequestPtr &req,
534538 std::function<void (const HttpResponsePtr &)> &&callback) {
535-
536539 if (llama.model_loaded_external ) {
537540 LOG_INFO << " model loaded" ;
538541 Json::Value jsonResp;
@@ -561,7 +564,6 @@ void llamaCPP::loadModel(
561564}
562565
563566bool llamaCPP::loadModelImpl (std::shared_ptr<Json::Value> jsonBody) {
564-
565567 gpt_params params;
566568 // By default will setting based on number of handlers
567569 if (jsonBody) {
@@ -570,11 +572,9 @@ bool llamaCPP::loadModelImpl(std::shared_ptr<Json::Value> jsonBody) {
570572 params.mmproj = jsonBody->operator [](" mmproj" ).asString ();
571573 }
572574 if (!jsonBody->operator [](" grp_attn_n" ).isNull ()) {
573-
574575 params.grp_attn_n = jsonBody->operator [](" grp_attn_n" ).asInt ();
575576 }
576577 if (!jsonBody->operator [](" grp_attn_w" ).isNull ()) {
577-
578578 params.grp_attn_w = jsonBody->operator [](" grp_attn_w" ).asInt ();
579579 }
580580 if (!jsonBody->operator [](" mlock" ).isNull ()) {
0 commit comments