Skip to content

Commit 1473ea6

Browse files
committed
feat: add KeepAlive parameter to chat, embedder, and image generation requests
1 parent 311d64b commit 1473ea6

File tree

4 files changed

+40
-8
lines changed

4 files changed

+40
-8
lines changed

runner/server/handler/completion.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/openai/openai-go"
1414
"github.com/openai/openai-go/shared/constant"
1515

16+
"github.com/NexaAI/nexa-sdk/runner/internal/config"
1617
"github.com/NexaAI/nexa-sdk/runner/internal/store"
1718
"github.com/NexaAI/nexa-sdk/runner/internal/types"
1819
nexa_sdk "github.com/NexaAI/nexa-sdk/runner/nexa-sdk"
@@ -29,7 +30,8 @@ type ChatCompletionNewParams openai.ChatCompletionNewParams
2930
// ChatCompletionRequest defines the request body for the chat completions API.
3031
// example: { "model": "nexaml/nexaml-models", "messages": [ { "role": "user", "content": "why is the sky blue?" } ] }
3132
type ChatCompletionRequest struct {
32-
Stream bool `json:"stream"`
33+
Stream bool `json:"stream"`
34+
KeepAlive *int64 `json:"keep_alive"`
3335

3436
EnableThink bool `json:"enable_think"`
3537
TopK int32 `json:"top_k"`
@@ -124,11 +126,16 @@ func chatCompletionsLLM(c *gin.Context, param ChatCompletionRequest) {
124126

125127
samplerConfig := parseSamplerConfig(param)
126128

129+
keepAlive := config.Get().KeepAlive
130+
if param.KeepAlive != nil {
131+
keepAlive = *param.KeepAlive
132+
}
127133
// Get LLM instance
128134
p, err := service.KeepAliveGet[nexa_sdk.LLM](
129135
string(param.Model),
130136
types.ModelParam{NCtx: 4096, NGpuLayers: 999, SystemPrompt: systemPrompt},
131137
c.GetHeader("Nexa-KeepCache") != "true",
138+
keepAlive,
132139
)
133140
if errors.Is(err, os.ErrNotExist) {
134141
c.JSON(http.StatusNotFound, map[string]any{"error": "model not found"})
@@ -353,11 +360,16 @@ func chatCompletionsVLM(c *gin.Context, param ChatCompletionRequest) {
353360

354361
samplerConfig := parseSamplerConfig(param)
355362

363+
keepAlive := config.Get().KeepAlive
364+
if param.KeepAlive != nil {
365+
keepAlive = *param.KeepAlive
366+
}
356367
// Get VLM instance
357368
p, err := service.KeepAliveGet[nexa_sdk.VLM](
358369
string(param.Model),
359370
types.ModelParam{NCtx: 4096, NGpuLayers: 999, SystemPrompt: systemPrompt},
360371
c.GetHeader("Nexa-KeepCache") != "true",
372+
keepAlive,
361373
)
362374
if errors.Is(err, os.ErrNotExist) {
363375
c.JSON(http.StatusNotFound, map[string]any{"error": "model not found"})

runner/server/handler/embedder.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"github.com/gin-gonic/gin"
77
"github.com/openai/openai-go"
88

9+
"github.com/NexaAI/nexa-sdk/runner/internal/config"
910
"github.com/NexaAI/nexa-sdk/runner/internal/types"
1011
nexa_sdk "github.com/NexaAI/nexa-sdk/runner/nexa-sdk"
1112
"github.com/NexaAI/nexa-sdk/runner/server/service"
@@ -17,16 +18,25 @@ import (
1718
// @Accept json
1819
// @Param request body openai.EmbeddingNewParams true "Embedding request"
1920
func Embeddings(c *gin.Context) {
20-
param := openai.EmbeddingNewParams{}
21+
param := struct {
22+
openai.EmbeddingNewParams
23+
KeepAlive *int64 `json:"keep_alive"`
24+
}{}
25+
2126
if err := c.ShouldBindJSON(&param); err != nil {
2227
c.JSON(http.StatusBadRequest, map[string]any{"error": err.Error()})
2328
return
2429
}
2530

31+
keepAlive := config.Get().KeepAlive
32+
if param.KeepAlive != nil {
33+
keepAlive = *param.KeepAlive
34+
}
2635
p, err := service.KeepAliveGet[nexa_sdk.Embedder](
2736
string(param.Model),
2837
types.ModelParam{},
2938
false,
39+
keepAlive,
3040
)
3141
if err != nil {
3242
c.JSON(http.StatusInternalServerError, map[string]any{"error": err.Error()})

runner/server/handler/image.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"github.com/gin-gonic/gin"
1515
"github.com/openai/openai-go"
1616

17+
"github.com/NexaAI/nexa-sdk/runner/internal/config"
1718
"github.com/NexaAI/nexa-sdk/runner/internal/types"
1819
nexa_sdk "github.com/NexaAI/nexa-sdk/runner/nexa-sdk"
1920
"github.com/NexaAI/nexa-sdk/runner/server/service"
@@ -30,7 +31,10 @@ import (
3031
// @Failure 404 {object} map[string]any "Model not found"
3132
// @Failure 500 {object} map[string]any "Internal server error"
3233
func ImageGenerations(c *gin.Context) {
33-
param := openai.ImageGenerateParams{}
34+
param := struct {
35+
openai.ImageGenerateParams
36+
KeepAlive *int64 `json:"keep_alive"`
37+
}{}
3438
if err := c.ShouldBindJSON(&param); err != nil {
3539
slog.Error("Failed to bind JSON request", "error", err)
3640
c.JSON(http.StatusBadRequest, map[string]any{"error": err.Error()})
@@ -53,10 +57,15 @@ func ImageGenerations(c *gin.Context) {
5357
param.ResponseFormat = openai.ImageGenerateParamsResponseFormatURL
5458
}
5559

60+
keepAlive := config.Get().KeepAlive
61+
if param.KeepAlive != nil {
62+
keepAlive = *param.KeepAlive
63+
}
5664
imageGen, err := service.KeepAliveGet[nexa_sdk.ImageGen](
5765
param.Model,
5866
types.ModelParam{},
5967
c.GetHeader("Nexa-KeepCache") != "true",
68+
keepAlive,
6069
)
6170
if err != nil {
6271
c.JSON(http.StatusInternalServerError, map[string]any{"error": err.Error()})

runner/server/service/keepalive.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,15 @@ import (
66
"sync"
77
"time"
88

9-
"github.com/NexaAI/nexa-sdk/runner/internal/config"
109
"github.com/NexaAI/nexa-sdk/runner/internal/store"
1110
"github.com/NexaAI/nexa-sdk/runner/internal/types"
1211
nexa_sdk "github.com/NexaAI/nexa-sdk/runner/nexa-sdk"
1312
)
1413

1514
// KeepAliveGet retrieves a model from the keepalive cache or creates it if not found
1615
// This avoids the overhead of repeatedly loading/unloading models from disk
17-
func KeepAliveGet[T any](name string, param types.ModelParam, reset bool) (*T, error) {
18-
t, err := keepAliveGet[T](name, param, reset)
16+
func KeepAliveGet[T any](name string, param types.ModelParam, reset bool, timeout int64) (*T, error) {
17+
t, err := keepAliveGet[T](name, param, reset, timeout)
1918
if err != nil {
2019
return nil, err
2120
}
@@ -37,6 +36,7 @@ type modelKeepInfo struct {
3736
model keepable
3837
param types.ModelParam
3938
lastTime time.Time
39+
timeout int64 // timeout in seconds for this specific model
4040
}
4141

4242
// keepable interface defines objects that can be managed by the keepalive service
@@ -70,7 +70,7 @@ func (keepAlive *keepAliveService) start() {
7070
case <-t.C:
7171
keepAlive.Lock()
7272
for name, model := range keepAlive.models {
73-
if time.Since(model.lastTime).Milliseconds()/1000 > config.Get().KeepAlive {
73+
if int64(time.Since(model.lastTime).Seconds()) > model.timeout {
7474
model.model.Destroy()
7575
delete(keepAlive.models, name)
7676
}
@@ -83,7 +83,7 @@ func (keepAlive *keepAliveService) start() {
8383

8484
// keepAliveGet retrieves a cached model or creates a new one if not found
8585
// Ensures only one model is kept in memory at a time by clearing others
86-
func keepAliveGet[T any](name string, param types.ModelParam, reset bool) (any, error) {
86+
func keepAliveGet[T any](name string, param types.ModelParam, reset bool, timeout int64) (any, error) {
8787
keepAlive.Lock()
8888
defer keepAlive.Unlock()
8989

@@ -195,6 +195,7 @@ func keepAliveGet[T any](name string, param types.ModelParam, reset bool) (any,
195195
model: t,
196196
param: param,
197197
lastTime: time.Now(),
198+
timeout: timeout,
198199
}
199200
keepAlive.models[name] = model
200201

0 commit comments

Comments
 (0)