Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 92b5a5c

Browse files
authored
Merge pull request #372 from janhq/feat/batch_embedding
feat(embedding): Add input as vector for /embeddings
2 parents 1485be2 + 5f9a0a4 commit 92b5a5c

File tree

1 file changed

+36
-31
lines changed

1 file changed

+36
-31
lines changed

controllers/llamaCPP.cc

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,8 @@ std::shared_ptr<inferenceState> create_inference_state(llamaCPP *instance) {
2121

2222
// --------------------------------------------
2323

24-
std::string create_embedding_payload(const std::vector<float> &embedding,
24+
Json::Value create_embedding_payload(const std::vector<float> &embedding,
2525
int prompt_tokens) {
26-
Json::Value root;
27-
28-
root["object"] = "list";
29-
30-
Json::Value dataArray(Json::arrayValue);
3126
Json::Value dataItem;
3227

3328
dataItem["object"] = "embedding";
@@ -39,20 +34,7 @@ std::string create_embedding_payload(const std::vector<float> &embedding,
3934
dataItem["embedding"] = embeddingArray;
4035
dataItem["index"] = 0;
4136

42-
dataArray.append(dataItem);
43-
root["data"] = dataArray;
44-
45-
root["model"] = "_";
46-
47-
Json::Value usage;
48-
usage["prompt_tokens"] = prompt_tokens;
49-
usage["total_tokens"] = prompt_tokens; // Assuming total tokens equals prompt
50-
// tokens in this context
51-
root["usage"] = usage;
52-
53-
Json::StreamWriterBuilder writer;
54-
writer["indentation"] = ""; // Compact output
55-
return Json::writeString(writer, root);
37+
return dataItem;
5638
}
5739

5840
std::string create_full_return_json(const std::string &id,
@@ -406,19 +388,42 @@ void llamaCPP::embedding(
406388
std::function<void(const HttpResponsePtr &)> &&callback) {
407389
const auto &jsonBody = req->getJsonObject();
408390

409-
json prompt;
410-
if (jsonBody->isMember("input") != 0) {
411-
prompt = (*jsonBody)["input"].asString();
412-
} else {
413-
prompt = "";
391+
Json::Value responseData(Json::arrayValue);
392+
393+
if (jsonBody->isMember("input")) {
394+
const Json::Value &input = (*jsonBody)["input"];
395+
if (input.isString()) {
396+
// Process the single string input
397+
const int task_id = llama.request_completion(
398+
{{"prompt", input.asString()}, {"n_predict", 0}}, false, true, -1);
399+
task_result result = llama.next_result(task_id);
400+
std::vector<float> embedding_result = result.result_json["embedding"];
401+
responseData.append(create_embedding_payload(embedding_result, 0));
402+
} else if (input.isArray()) {
403+
// Process each element in the array input
404+
for (const auto &elem : input) {
405+
if (elem.isString()) {
406+
const int task_id = llama.request_completion(
407+
{{"prompt", elem.asString()}, {"n_predict", 0}}, false, true, -1);
408+
task_result result = llama.next_result(task_id);
409+
std::vector<float> embedding_result = result.result_json["embedding"];
410+
responseData.append(create_embedding_payload(embedding_result, 0));
411+
}
412+
}
413+
}
414414
}
415-
const int task_id = llama.request_completion(
416-
{{"prompt", prompt}, {"n_predict", 0}}, false, true, -1);
417-
task_result result = llama.next_result(task_id);
418-
std::vector<float> embedding_result = result.result_json["embedding"];
415+
419416
auto resp = nitro_utils::nitroHttpResponse();
420-
std::string embedding_resp = create_embedding_payload(embedding_result, 0);
421-
resp->setBody(embedding_resp);
417+
Json::Value root;
418+
root["data"] = responseData;
419+
root["model"] = "_";
420+
root["object"] = "list";
421+
Json::Value usage;
422+
usage["prompt_tokens"] = 0;
423+
usage["total_tokens"] = 0;
424+
root["usage"] = usage;
425+
426+
resp->setBody(Json::writeString(Json::StreamWriterBuilder(), root));
422427
resp->setContentTypeString("application/json");
423428
callback(resp);
424429
return;

0 commit comments

Comments
 (0)