Skip to content

Commit 6e806d7

Browse files
committed
feat(completion): enhance completion and chat endpoints with keep-alive functionality
1 parent b14b7bc commit 6e806d7

File tree

4 files changed

+54
-17
lines changed

4 files changed

+54
-17
lines changed

runner/server/handler/completion.go

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,33 @@ import (
1313
"github.com/openai/openai-go"
1414
"github.com/openai/openai-go/shared/constant"
1515

16+
"github.com/NexaAI/nexa-sdk/runner/internal/config"
1617
"github.com/NexaAI/nexa-sdk/runner/internal/store"
1718
"github.com/NexaAI/nexa-sdk/runner/internal/types"
1819
nexa_sdk "github.com/NexaAI/nexa-sdk/runner/nexa-sdk"
1920
"github.com/NexaAI/nexa-sdk/runner/server/service"
2021
)
2122

23+
type BaseParams struct {
24+
// stream: if false the response will be returned as a single response object, rather than a stream of objects
25+
Stream bool `json:"stream" default:"false"`
26+
// keep_alive: controls how long the model will stay loaded into memory following the request (default: 5m)
27+
KeepAlive *int64 `json:"keep_alive" default:"300"`
28+
}
29+
30+
// getKeepAliveValue extracts the keepAlive value from BaseParams, using default if not set
31+
func getKeepAliveValue(param BaseParams) int64 {
32+
if param.KeepAlive != nil {
33+
return *param.KeepAlive
34+
}
35+
return config.Get().KeepAlive
36+
}
37+
38+
type CompletionRequest struct {
39+
BaseParams
40+
openai.CompletionNewParams
41+
}
42+
2243
// @Router /completions [post]
2344
// @Summary completion
2445
// @Description Legacy completion endpoint for text generation. It is recommended to use the Chat Completions endpoint for new applications.
@@ -27,16 +48,18 @@ import (
2748
// @Produce json
2849
// @Success 200 {object} openai.Completion
2950
func Completions(c *gin.Context) {
30-
param := openai.CompletionNewParams{}
51+
param := CompletionRequest{}
3152
if err := c.ShouldBindJSON(&param); err != nil {
3253
c.JSON(http.StatusBadRequest, map[string]any{"error": err.Error()})
3354
return
3455
}
56+
slog.Debug("param", "param", param)
3557

3658
p, err := service.KeepAliveGet[nexa_sdk.LLM](
3759
string(param.Model),
3860
types.ModelParam{NCtx: 4096},
3961
c.GetHeader("Nexa-KeepCache") != "true",
62+
getKeepAliveValue(param.BaseParams),
4063
)
4164
if err != nil {
4265
c.JSON(http.StatusInternalServerError, map[string]any{"error": err.Error()})
@@ -58,19 +81,18 @@ func Completions(c *gin.Context) {
5881
}
5982
}
6083

61-
type ChatCompletionNewParams openai.ChatCompletionNewParams
62-
6384
// ChatCompletionRequest defines the request body for the chat completions API.
6485
// example: { "model": "nexaml/nexaml-models", "messages": [ { "role": "user", "content": "why is the sky blue?" } ] }
6586
type ChatCompletionRequest struct {
66-
Stream bool `json:"stream" default:"false"`
6787
EnableThink bool `json:"enable_think" default:"true"`
68-
69-
ChatCompletionNewParams
88+
BaseParams
89+
openai.ChatCompletionNewParams
7090
}
7191

7292
var toolCallRegex = regexp.MustCompile(`<tool_call>([\s\S]+)<\/tool_call>` + "|" + "```json([\\s\\S]+)```")
7393

94+
95+
7496
// @Router /chat/completions [post]
7597
// @Summary Creates a model response for the given chat conversation.
7698
// @Description This endpoint generates a model response for a given conversation, which can include text and images. It supports both single-turn and multi-turn conversations and can be used for various tasks like question answering, code generation, and function calling.
@@ -85,6 +107,8 @@ func ChatCompletions(c *gin.Context) {
85107
return
86108
}
87109

110+
slog.Debug("param", "param", param)
111+
88112
s := store.Get()
89113
manifest, err := s.GetManifest(param.Model)
90114
if err != nil {
@@ -109,6 +133,7 @@ func chatCompletionsLLM(c *gin.Context, param ChatCompletionRequest) {
109133
string(param.Model),
110134
types.ModelParam{NCtx: 4096},
111135
c.GetHeader("Nexa-KeepCache") != "true",
136+
getKeepAliveValue(param.BaseParams),
112137
)
113138
if errors.Is(err, os.ErrNotExist) {
114139
c.JSON(http.StatusNotFound, map[string]any{"error": "model not found"})
@@ -276,6 +301,7 @@ func chatCompletionsVLM(c *gin.Context, param ChatCompletionRequest) {
276301
string(param.Model),
277302
types.ModelParam{NCtx: 4096},
278303
c.GetHeader("Nexa-KeepCache") != "true",
304+
getKeepAliveValue(param.BaseParams),
279305
)
280306
if errors.Is(err, os.ErrNotExist) {
281307
c.JSON(http.StatusNotFound, map[string]any{"error": "model not found"})

runner/server/handler/embedder.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ func Embeddings(c *gin.Context) {
2727
string(param.Model),
2828
types.ModelParam{},
2929
false,
30+
300, // default 5 minutes for embedder
3031
)
3132
if err != nil {
3233
c.JSON(http.StatusInternalServerError, map[string]any{"error": err.Error()})

runner/server/handler/image.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ func ImageGenerations(c *gin.Context) {
5757
param.Model,
5858
types.ModelParam{},
5959
c.GetHeader("Nexa-KeepCache") != "true",
60+
300, // default 5 minutes for image generation
6061
)
6162
if err != nil {
6263
c.JSON(http.StatusInternalServerError, map[string]any{"error": err.Error()})

runner/server/service/keepalive.go

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@ import (
1414

1515
// KeepAliveGet retrieves a model from the keepalive cache or creates it if not found
1616
// This avoids the overhead of repeatedly loading/unloading models from disk
17-
func KeepAliveGet[T any](name string, param types.ModelParam, reset bool) (*T, error) {
18-
t, err := keepAliveGet[T](name, param, reset)
17+
// keepAlive specifies the timeout in seconds for this specific model instance
18+
func KeepAliveGet[T any](name string, param types.ModelParam, reset bool, keepAlive int64) (*T, error) {
19+
t, err := keepAliveGet[T](name, param, reset, keepAlive)
1920
if err != nil {
2021
return nil, err
2122
}
@@ -34,9 +35,10 @@ type keepAliveService struct {
3435

3536
// modelKeepInfo holds metadata for a cached model instance
3637
type modelKeepInfo struct {
37-
model keepable
38-
param types.ModelParam
39-
lastTime time.Time
38+
model keepable
39+
param types.ModelParam
40+
lastTime time.Time
41+
keepAliveTimeout int64
4042
}
4143

4244
// keepable interface defines objects that can be managed by the keepalive service
@@ -70,7 +72,12 @@ func (keepAlive *keepAliveService) start() {
7072
case <-t.C:
7173
keepAlive.Lock()
7274
for name, model := range keepAlive.models {
73-
if time.Since(model.lastTime).Milliseconds()/1000 > config.Get().KeepAlive {
75+
// Use the model-specific keepAlive timeout, fallback to global config if not set
76+
timeout := model.keepAliveTimeout
77+
if timeout <= 0 {
78+
timeout = config.Get().KeepAlive
79+
}
80+
if time.Since(model.lastTime).Milliseconds()/1000 > timeout {
7481
model.model.Destroy()
7582
delete(keepAlive.models, name)
7683
}
@@ -83,7 +90,7 @@ func (keepAlive *keepAliveService) start() {
8390

8491
// keepAliveGet retrieves a cached model or creates a new one if not found
8592
// Ensures only one model is kept in memory at a time by clearing others
86-
func keepAliveGet[T any](name string, param types.ModelParam, reset bool) (any, error) {
93+
func keepAliveGet[T any](name string, param types.ModelParam, reset bool, keepAliveTimeout int64) (any, error) {
8794
keepAlive.Lock()
8895
defer keepAlive.Unlock()
8996

@@ -102,6 +109,7 @@ func keepAliveGet[T any](name string, param types.ModelParam, reset bool) (any,
102109
model.model.Reset()
103110
}
104111
model.lastTime = time.Now()
112+
model.keepAliveTimeout = keepAliveTimeout
105113
return model.model, nil
106114
}
107115

@@ -127,7 +135,7 @@ func keepAliveGet[T any](name string, param types.ModelParam, reset bool) (any,
127135
break
128136
}
129137
}
130-
138+
131139
var t keepable
132140
var e error
133141
switch reflect.TypeFor[T]() {
@@ -188,9 +196,10 @@ func keepAliveGet[T any](name string, param types.ModelParam, reset bool) (any,
188196
return nil, e
189197
}
190198
model = &modelKeepInfo{
191-
model: t,
192-
param: param,
193-
lastTime: time.Now(),
199+
model: t,
200+
param: param,
201+
lastTime: time.Now(),
202+
keepAliveTimeout: keepAliveTimeout,
194203
}
195204
keepAlive.models[name] = model
196205

0 commit comments

Comments
 (0)