Skip to content

Commit 741baaf

Browse files
committed
server : minor refactor
1 parent b951161 commit 741baaf

File tree

1 file changed

+29
-31
lines changed

1 file changed

+29
-31
lines changed

tools/server/server.cpp

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1686,14 +1686,13 @@ struct server_slot {
16861686
llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0);
16871687
}
16881688

1689-
void prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
1689+
bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
16901690
bool res = prompt_cache.load(prompt, tokens, ctx, id);
16911691
if (!res) {
16921692
SLT_WRN(*this, "%s", "failed to load prompt from cache\n");
1693-
1694-
llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1);
1695-
prompt.tokens.clear();
16961693
}
1694+
1695+
return res;
16971696
}
16981697

16991698
std::vector<common_adapter_lora_info> lora;
@@ -2339,7 +2338,6 @@ struct server_context {
23392338

23402339
llama_batch batch {};
23412340

2342-
bool clean_kv_cache = true;
23432341
bool add_bos_token = true;
23442342

23452343
int32_t n_ctx; // total context for all clients / slots
@@ -2701,7 +2699,10 @@ struct server_context {
27012699
const int64_t t_start = ggml_time_us();
27022700

27032701
ret->prompt_save(*prompt_cache);
2704-
ret->prompt_load(*prompt_cache, task.tokens);
2702+
2703+
if (!ret->prompt_load(*prompt_cache, task.tokens)) {
2704+
clear_slot(*ret);
2705+
}
27052706

27062707
prompt_cache->update();
27072708

@@ -2712,12 +2713,21 @@ struct server_context {
27122713
return ret;
27132714
}
27142715

2715-
// return true if at least one slot has been purged
2716+
void clear_slot(server_slot & slot) const {
2717+
GGML_ASSERT(!slot.is_processing());
2718+
2719+
SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size());
2720+
2721+
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
2722+
slot.prompt.tokens.clear();
2723+
}
2724+
2725+
// return true if at least one slot has been cleared
27162726
// TODO: improve logic
2717-
// - smarter decision which slot to purge (LRU or longest prompt?)
2727+
// - smarter decision which slot to clear (LRU or longest prompt?)
27182728
// - move slot to level 2 cache instead of removing?
27192729
// - instead of purging, try to store and resume later?
2720-
bool try_purge_idle_slots() {
2730+
bool try_clear_idle_slots() {
27212731
bool res = false;
27222732

27232733
if (!params_base.kv_unified) {
@@ -2732,12 +2742,11 @@ struct server_context {
27322742
if (slot.prompt.n_tokens() > 0) {
27332743
SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
27342744

2735-
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
2736-
slot.prompt.tokens.clear();
2745+
clear_slot(slot);
27372746

27382747
res = true;
27392748

2740-
// purge slots one by one
2749+
// clear slots one by one
27412750
break;
27422751
}
27432752
}
@@ -2847,14 +2856,6 @@ struct server_context {
28472856
return true;
28482857
}
28492858

2850-
void kv_cache_clear() {
2851-
SRV_DBG("%s", "clearing KV cache\n");
2852-
2853-
// clear the entire KV cache
2854-
llama_memory_clear(llama_get_memory(ctx), true);
2855-
clean_kv_cache = false;
2856-
}
2857-
28582859
bool process_token(completion_token_output & result, server_slot & slot) {
28592860
// remember which tokens were sampled - used for repetition penalties during sampling
28602861
const std::string token_str = result.text_to_send;
@@ -3442,8 +3443,8 @@ struct server_context {
34423443

34433444
// Erase token cache
34443445
const size_t n_erased = slot->prompt.tokens.size();
3445-
llama_memory_seq_rm(llama_get_memory(ctx), slot->id, -1, -1);
3446-
slot->prompt.tokens.clear();
3446+
3447+
clear_slot(*slot);
34473448

34483449
auto res = std::make_unique<server_task_result_slot_erase>();
34493450
res->id = task.id;
@@ -3476,9 +3477,6 @@ struct server_context {
34763477

34773478
if (all_idle) {
34783479
SRV_INF("%s", "all slots are idle\n");
3479-
if (clean_kv_cache) {
3480-
kv_cache_clear();
3481-
}
34823480

34833481
return;
34843482
}
@@ -3872,12 +3870,11 @@ struct server_context {
38723870

38733871
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
38743872
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
3875-
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
3873+
3874+
clear_slot(slot);
38763875

38773876
// there is no common part left
38783877
slot.n_prompt_tokens_cache = 0;
3879-
3880-
slot.prompt.tokens.clear();
38813878
}
38823879

38833880
// check if we should process the image
@@ -4108,8 +4105,9 @@ struct server_context {
41084105
send_error(slot, err);
41094106
slot.release();
41104107

4111-
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
4112-
slot.prompt.tokens.clear();
4108+
// note: it's complicated to keep track of how much of the current batch has been
4109+
// processed before the error occurred, so we simply clear the entire context
4110+
clear_slot(slot);
41134111
}
41144112
}
41154113

@@ -4118,7 +4116,7 @@ struct server_context {
41184116
}
41194117

41204118
// retry with half the batch size to try to find a free slot in the KV cache
4121-
if (!try_purge_idle_slots()) {
4119+
if (!try_clear_idle_slots()) {
41224120
n_batch /= 2;
41234121
}
41244122

0 commit comments

Comments
 (0)