22#include " llama.h"
33#include " log.h"
44#include " utils/nitro_utils.h"
5+ #include < chrono>
6+ #include < thread>
7+ #include < trantor/utils/Logger.h>
58
69using namespace inferences ;
710using json = nlohmann::json;
811
912struct inferenceState {
10- bool isStopped = false ;
13+ bool is_stopped = false ;
14+ bool is_streaming = false ;
1115 int task_id;
1216 llamaCPP *instance;
1317
14- inferenceState (int tid, llamaCPP *inst) : task_id(tid), instance(inst) {}
18+ inferenceState (llamaCPP *inst) : instance(inst) {}
1519};
1620
17- std::shared_ptr<inferenceState> create_inference_state (int task_id,
18- llamaCPP *instance) {
19- return std::make_shared<inferenceState>(task_id, instance);
21+ std::shared_ptr<inferenceState> create_inference_state (llamaCPP *instance) {
22+ return std::make_shared<inferenceState>(instance);
2023}
2124
2225// --------------------------------------------
@@ -296,26 +299,35 @@ void llamaCPP::chatCompletion(
296299#endif
297300 int task_id;
298301
299- task_id = llama.request_completion (data, false , false , -1 );
300-
301302 LOG_INFO << " Resolved request for task_id:" << task_id;
302303
303304 if (is_streamed) {
304- auto state = create_inference_state (task_id, this );
305-
305+ auto state = create_inference_state (this );
306+ state-> task_id = task_id;
306307 auto chunked_content_provider =
307- [this , state](char *pBuffer, std::size_t nBuffSize) -> std::size_t {
308+ [state, data](char *pBuffer, std::size_t nBuffSize) -> std::size_t {
309+ if (!state->is_streaming ) {
310+ state->task_id =
311+ state->instance ->llama .request_completion (data, false , false , -1 );
312+ state->instance ->single_queue_is_busy = true ;
313+ }
308314 if (!pBuffer) {
309315 LOG_INFO << " Connection closed or buffer is null. Reset context" ;
310316 state->instance ->llama .request_cancel (state->task_id );
317+ state->is_streaming = false ;
318+ state->instance ->single_queue_is_busy = false ;
311319 return 0 ;
312320 }
313- if (state->isStopped ) {
321+ if (state->is_stopped ) {
322+ state->is_streaming = false ;
323+ state->instance ->single_queue_is_busy = false ;
314324 return 0 ;
315325 }
316326
317327 task_result result = state->instance ->llama .next_result (state->task_id );
318328 if (!result.error ) {
329+ // Update streaming state to being streamed
330+ state->is_streaming = true ;
319331 const std::string to_send = result.result_json [" content" ];
320332 const std::string str =
321333 " data: " +
@@ -337,14 +349,30 @@ void llamaCPP::chatCompletion(
337349 std::size_t nRead = std::min (str.size (), nBuffSize);
338350 memcpy (pBuffer, str.data (), nRead);
339351 LOG_INFO << " reached result stop" ;
340- state->isStopped = true ;
352+ state->is_stopped = true ;
341353 state->instance ->llama .request_cancel (state->task_id );
354+ state->is_streaming = false ;
355+ state->instance ->single_queue_is_busy = false ;
356+
342357 return nRead;
343358 }
344359 return nRead;
345360 } else {
346- return 0 ;
361+ if (state->instance ->llama .params .n_parallel == 1 ) {
362+ while (state->instance ->single_queue_is_busy ) {
363+ LOG_INFO << " Waiting for task to be released status:"
364+ << state->instance ->single_queue_is_busy ;
365+ std::this_thread::sleep_for (std::chrono::milliseconds (500 )); // Waiting in 500 miliseconds step
366+ }
367+ }
368+ std::string str = " \n\n " ;
369+ std::size_t nRead = str.size ();
370+ memcpy (pBuffer, str.data (), nRead);
371+ LOG_INFO << " Failing retrying now" ;
372+ return nRead;
347373 }
374+ state->is_streaming = false ;
375+ state->instance ->single_queue_is_busy = false ;
348376 return 0 ;
349377 };
350378 auto resp = nitro_utils::nitroStreamResponse (chunked_content_provider,
0 commit comments