Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 184de90

Browse files
authored
Merge pull request #447 from janhq/fix/missing-stop-token-eos
fix: missing [DONE] token EOS
2 parents 26b69b1 + d5e3a7a commit 184de90

File tree

1 file changed

+31
-31
lines changed

1 file changed

+31
-31
lines changed

controllers/llamaCPP.cc

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using json = nlohmann::json;
1010
/**
1111
* The state of the inference task
1212
*/
13-
enum InferenceStatus { PENDING, RUNNING, FINISHED };
13+
enum InferenceStatus { PENDING, RUNNING, EOS, FINISHED };
1414

1515
/**
1616
* There is a need to save state of current ongoing inference status of a
@@ -21,7 +21,7 @@ enum InferenceStatus { PENDING, RUNNING, FINISHED };
2121
*/
2222
struct inferenceState {
2323
int task_id;
24-
InferenceStatus inferenceStatus = PENDING;
24+
InferenceStatus inference_status = PENDING;
2525
llamaCPP *instance;
2626

2727
inferenceState(llamaCPP *inst) : instance(inst) {}
@@ -111,7 +111,6 @@ std::string create_full_return_json(const std::string &id,
111111
std::string create_return_json(const std::string &id, const std::string &model,
112112
const std::string &content,
113113
Json::Value finish_reason = Json::Value()) {
114-
115114
Json::Value root;
116115

117116
root["id"] = id;
@@ -167,7 +166,6 @@ void llamaCPP::warmupModel() {
167166
void llamaCPP::inference(
168167
const HttpRequestPtr &req,
169168
std::function<void(const HttpResponsePtr &)> &&callback) {
170-
171169
const auto &jsonBody = req->getJsonObject();
172170
// Check if model is loaded
173171
if (checkModelLoaded(callback)) {
@@ -180,7 +178,6 @@ void llamaCPP::inference(
180178
void llamaCPP::inferenceImpl(
181179
std::shared_ptr<Json::Value> jsonBody,
182180
std::function<void(const HttpResponsePtr &)> &callback) {
183-
184181
std::string formatted_output = pre_prompt;
185182

186183
json data;
@@ -218,7 +215,6 @@ void llamaCPP::inferenceImpl(
218215
};
219216

220217
if (!llama.multimodal) {
221-
222218
for (const auto &message : messages) {
223219
std::string input_role = message["role"].asString();
224220
std::string role;
@@ -243,7 +239,6 @@ void llamaCPP::inferenceImpl(
243239
}
244240
formatted_output += ai_prompt;
245241
} else {
246-
247242
data["image_data"] = json::array();
248243
for (const auto &message : messages) {
249244
std::string input_role = message["role"].asString();
@@ -327,18 +322,33 @@ void llamaCPP::inferenceImpl(
327322
auto state = create_inference_state(this);
328323
auto chunked_content_provider =
329324
[state, data](char *pBuffer, std::size_t nBuffSize) -> std::size_t {
330-
if (state->inferenceStatus == PENDING) {
331-
state->inferenceStatus = RUNNING;
332-
} else if (state->inferenceStatus == FINISHED) {
325+
if (state->inference_status == PENDING) {
326+
state->inference_status = RUNNING;
327+
} else if (state->inference_status == FINISHED) {
333328
return 0;
334329
}
335330

336331
if (!pBuffer) {
337332
LOG_INFO << "Connection closed or buffer is null. Reset context";
338-
state->inferenceStatus = FINISHED;
333+
state->inference_status = FINISHED;
339334
return 0;
340335
}
341336

337+
if (state->inference_status == EOS) {
338+
LOG_INFO << "End of result";
339+
const std::string str =
340+
"data: " +
341+
create_return_json(nitro_utils::generate_random_string(20), "_", "",
342+
"stop") +
343+
"\n\n" + "data: [DONE]" + "\n\n";
344+
345+
LOG_VERBOSE("data stream", {{"to_send", str}});
346+
std::size_t nRead = std::min(str.size(), nBuffSize);
347+
memcpy(pBuffer, str.data(), nRead);
348+
state->inference_status = FINISHED;
349+
return nRead;
350+
}
351+
342352
task_result result = state->instance->llama.next_result(state->task_id);
343353
if (!result.error) {
344354
const std::string to_send = result.result_json["content"];
@@ -352,28 +362,22 @@ void llamaCPP::inferenceImpl(
352362
memcpy(pBuffer, str.data(), nRead);
353363

354364
if (result.stop) {
355-
const std::string str =
356-
"data: " +
357-
create_return_json(nitro_utils::generate_random_string(20), "_",
358-
"", "stop") +
359-
"\n\n" + "data: [DONE]" + "\n\n";
360-
361-
LOG_VERBOSE("data stream", {{"to_send", str}});
362-
std::size_t nRead = std::min(str.size(), nBuffSize);
363-
memcpy(pBuffer, str.data(), nRead);
364365
LOG_INFO << "reached result stop";
365-
state->inferenceStatus = FINISHED;
366+
state->inference_status = EOS;
367+
return nRead;
366368
}
367369

368370
// Make sure nBufferSize is not zero
369371
// Otherwise it stop streaming
370372
if (!nRead) {
371-
state->inferenceStatus = FINISHED;
373+
state->inference_status = FINISHED;
372374
}
373375

374376
return nRead;
377+
} else {
378+
LOG_INFO << "Error during inference";
375379
}
376-
state->inferenceStatus = FINISHED;
380+
state->inference_status = FINISHED;
377381
return 0;
378382
};
379383
// Queued task
@@ -391,16 +395,17 @@ void llamaCPP::inferenceImpl(
391395

392396
// Since this is an async task, we will wait for the task to be
393397
// completed
394-
while (state->inferenceStatus != FINISHED && retries < 10) {
398+
while (state->inference_status != FINISHED && retries < 10) {
395399
// Should wait chunked_content_provider lambda to be called within
396400
// 3s
397-
if (state->inferenceStatus == PENDING) {
401+
if (state->inference_status == PENDING) {
398402
retries += 1;
399403
}
400-
if (state->inferenceStatus != RUNNING)
404+
if (state->inference_status != RUNNING)
401405
LOG_INFO << "Wait for task to be released:" << state->task_id;
402406
std::this_thread::sleep_for(std::chrono::milliseconds(100));
403407
}
408+
LOG_INFO << "Task completed, release it";
404409
// Request completed, release it
405410
state->instance->llama.request_cancel(state->task_id);
406411
});
@@ -445,7 +450,6 @@ void llamaCPP::embedding(
445450
void llamaCPP::embeddingImpl(
446451
std::shared_ptr<Json::Value> jsonBody,
447452
std::function<void(const HttpResponsePtr &)> &callback) {
448-
449453
// Queue embedding task
450454
auto state = create_inference_state(this);
451455

@@ -532,7 +536,6 @@ void llamaCPP::modelStatus(
532536
void llamaCPP::loadModel(
533537
const HttpRequestPtr &req,
534538
std::function<void(const HttpResponsePtr &)> &&callback) {
535-
536539
if (llama.model_loaded_external) {
537540
LOG_INFO << "model loaded";
538541
Json::Value jsonResp;
@@ -561,7 +564,6 @@ void llamaCPP::loadModel(
561564
}
562565

563566
bool llamaCPP::loadModelImpl(std::shared_ptr<Json::Value> jsonBody) {
564-
565567
gpt_params params;
566568
// By default will setting based on number of handlers
567569
if (jsonBody) {
@@ -570,11 +572,9 @@ bool llamaCPP::loadModelImpl(std::shared_ptr<Json::Value> jsonBody) {
570572
params.mmproj = jsonBody->operator[]("mmproj").asString();
571573
}
572574
if (!jsonBody->operator[]("grp_attn_n").isNull()) {
573-
574575
params.grp_attn_n = jsonBody->operator[]("grp_attn_n").asInt();
575576
}
576577
if (!jsonBody->operator[]("grp_attn_w").isNull()) {
577-
578578
params.grp_attn_w = jsonBody->operator[]("grp_attn_w").asInt();
579579
}
580580
if (!jsonBody->operator[]("mlock").isNull()) {

0 commit comments

Comments
 (0)