@@ -21,13 +21,8 @@ std::shared_ptr<inferenceState> create_inference_state(llamaCPP *instance) {
2121
2222// --------------------------------------------
2323
24- std::string create_embedding_payload (const std::vector<float > &embedding,
24+ Json::Value create_embedding_payload (const std::vector<float > &embedding,
2525 int prompt_tokens) {
26- Json::Value root;
27-
28- root[" object" ] = " list" ;
29-
30- Json::Value dataArray (Json::arrayValue);
3126 Json::Value dataItem;
3227
3328 dataItem[" object" ] = " embedding" ;
@@ -39,20 +34,7 @@ std::string create_embedding_payload(const std::vector<float> &embedding,
3934 dataItem[" embedding" ] = embeddingArray;
4035 dataItem[" index" ] = 0 ;
4136
42- dataArray.append (dataItem);
43- root[" data" ] = dataArray;
44-
45- root[" model" ] = " _" ;
46-
47- Json::Value usage;
48- usage[" prompt_tokens" ] = prompt_tokens;
49- usage[" total_tokens" ] = prompt_tokens; // Assuming total tokens equals prompt
50- // tokens in this context
51- root[" usage" ] = usage;
52-
53- Json::StreamWriterBuilder writer;
54- writer[" indentation" ] = " " ; // Compact output
55- return Json::writeString (writer, root);
37+ return dataItem;
5638}
5739
5840std::string create_full_return_json (const std::string &id,
@@ -406,19 +388,42 @@ void llamaCPP::embedding(
406388 std::function<void (const HttpResponsePtr &)> &&callback) {
407389 const auto &jsonBody = req->getJsonObject ();
408390
409- json prompt;
410- if (jsonBody->isMember (" input" ) != 0 ) {
411- prompt = (*jsonBody)[" input" ].asString ();
412- } else {
413- prompt = " " ;
391+ Json::Value responseData (Json::arrayValue);
392+
393+ if (jsonBody->isMember (" input" )) {
394+ const Json::Value &input = (*jsonBody)[" input" ];
395+ if (input.isString ()) {
396+ // Process the single string input
397+ const int task_id = llama.request_completion (
398+ {{" prompt" , input.asString ()}, {" n_predict" , 0 }}, false , true , -1 );
399+ task_result result = llama.next_result (task_id);
400+ std::vector<float > embedding_result = result.result_json [" embedding" ];
401+ responseData.append (create_embedding_payload (embedding_result, 0 ));
402+ } else if (input.isArray ()) {
403+ // Process each element in the array input
404+ for (const auto &elem : input) {
405+ if (elem.isString ()) {
406+ const int task_id = llama.request_completion (
407+ {{" prompt" , elem.asString ()}, {" n_predict" , 0 }}, false , true , -1 );
408+ task_result result = llama.next_result (task_id);
409+ std::vector<float > embedding_result = result.result_json [" embedding" ];
410+ responseData.append (create_embedding_payload (embedding_result, 0 ));
411+ }
412+ }
413+ }
414414 }
415- const int task_id = llama.request_completion (
416- {{" prompt" , prompt}, {" n_predict" , 0 }}, false , true , -1 );
417- task_result result = llama.next_result (task_id);
418- std::vector<float > embedding_result = result.result_json [" embedding" ];
415+
419416 auto resp = nitro_utils::nitroHttpResponse ();
420- std::string embedding_resp = create_embedding_payload (embedding_result, 0 );
421- resp->setBody (embedding_resp);
417+ Json::Value root;
418+ root[" data" ] = responseData;
419+ root[" model" ] = " _" ;
420+ root[" object" ] = " list" ;
421+ Json::Value usage;
422+ usage[" prompt_tokens" ] = 0 ;
423+ usage[" total_tokens" ] = 0 ;
424+ root[" usage" ] = usage;
425+
426+ resp->setBody (Json::writeString (Json::StreamWriterBuilder (), root));
422427 resp->setContentTypeString (" application/json" );
423428 callback (resp);
424429 return ;
0 commit comments