Skip to content

Commit b14b7bc

Browse files
authored
Merge pull request #593 from NexaAI/feat/mengsheng/image-gen-serve
feat: add image generation support for server
2 parents be88ece + 2360532 commit b14b7bc

File tree

6 files changed

+310
-2
lines changed

6 files changed

+310
-2
lines changed

runner/Makefile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ else
2323
UNZIP := unzip -q
2424
endif
2525

26-
.PHONY: build link xcopy download clean
26+
.PHONY: build link xcopy download clean serve
2727

2828
build:
2929
go build \
@@ -58,3 +58,6 @@ download: clean
5858

5959
clean:
6060
-$(RM) build
61+
62+
serve:
63+
./build/nexa serve

runner/nexa-sdk/image_gen.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,13 @@ func (ig *ImageGen) Destroy() error {
379379
return nil
380380
}
381381

382+
// Reset resets the ImageGen internal state (no-op for image generation)
383+
func (ig *ImageGen) Reset() error {
384+
slog.Debug("Reset called", "ptr", ig.ptr)
385+
// Image generation doesn't maintain state between generations, so this is a no-op
386+
return nil
387+
}
388+
382389
// Txt2Img generates an image from text prompt
383390
func (ig *ImageGen) Txt2Img(input ImageGenTxt2ImgInput) (ImageGenOutput, error) {
384391
slog.Debug("Txt2Img called", "input", input)

runner/server/docs/swagger.yaml

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,43 @@ paths:
6363
schema:
6464
$ref: '#/components/schemas/EmbeddingResponse'
6565

66+
/v1/images/generations:
67+
post:
68+
summary: Creates an image given a prompt
69+
description: Creates an image given a prompt. This endpoint follows OpenAI DALL-E 3 API specification for compatibility.
70+
operationId: PostV1ImagesGenerations
71+
requestBody:
72+
required: true
73+
content:
74+
application/json:
75+
schema:
76+
$ref: '#/components/schemas/ImageGenerationRequest'
77+
responses:
78+
'200':
79+
description: Successful image generation response
80+
content:
81+
application/json:
82+
schema:
83+
$ref: '#/components/schemas/ImageGenerationResponse'
84+
'400':
85+
description: Bad request - invalid parameters
86+
content:
87+
application/json:
88+
schema:
89+
$ref: '#/components/schemas/ErrorResponse'
90+
'404':
91+
description: Model not found
92+
content:
93+
application/json:
94+
schema:
95+
$ref: '#/components/schemas/ErrorResponse'
96+
'500':
97+
description: Internal server error
98+
content:
99+
application/json:
100+
schema:
101+
$ref: '#/components/schemas/ErrorResponse'
102+
66103
/v1/reranking:
67104
post:
68105
summary: Reranks the given documents for the given query
@@ -581,6 +618,75 @@ components:
581618
type: integer
582619
description: The index of the embedding in the list of embeddings
583620

621+
# ---------- Image Generation ----------
622+
ImageGenerationRequest:
623+
type: object
624+
required: [model, prompt]
625+
properties:
626+
model:
627+
type: string
628+
description: ID of the model to use
629+
default: "nexaml/sdxl-turbo-ryzen-ai"
630+
prompt:
631+
type: string
632+
description: A text description of the desired image(s). The maximum length is 1000 characters.
633+
default: "A white cat with red eyes"
634+
n:
635+
type: integer
636+
minimum: 1
637+
maximum: 10
638+
description: The number of images to generate. Must be between 1 and 10.
639+
default: 1
640+
size:
641+
type: string
642+
enum: ["512x512", "1024x1024", "1792x1024", "1024x1792"]
643+
description: The size of the generated images. Must be one of the supported sizes.
644+
default: "512x512"
645+
quality:
646+
type: string
647+
enum: ["standard", "hd"]
648+
description: The quality of the image that will be generated
649+
default: "standard"
650+
style:
651+
type: string
652+
enum: ["vivid", "natural"]
653+
description: The style of the generated images
654+
default: "vivid"
655+
response_format:
656+
type: string
657+
enum: ["url", "b64_json"]
658+
description: The format in which the generated images are returned
659+
default: "url"
660+
user:
661+
type: string
662+
description: A unique identifier representing your end-user
663+
664+
ImageGenerationResponse:
665+
type: object
666+
required: [created, data]
667+
properties:
668+
created:
669+
type: integer
670+
description: The Unix timestamp (in seconds) of when the image was created
671+
data:
672+
type: array
673+
items:
674+
$ref: '#/components/schemas/ImageGenerationData'
675+
description: The list of generated images
676+
677+
ImageGenerationData:
678+
type: object
679+
properties:
680+
url:
681+
type: string
682+
description: The URL of the generated image, if response_format is "url"
683+
b64_json:
684+
type: string
685+
description: The base64-encoded JSON of the generated image, if response_format is "b64_json"
686+
revised_prompt:
687+
type: string
688+
description: The prompt that was used to generate the image, if there was a revision to the prompt
689+
584690
# ---------- Reranking ----------
585691
RerankingRequest:
586692
type: object
@@ -643,6 +749,14 @@ components:
643749
description: Additional metadata about the document
644750

645751
# ---------- Common ----------
752+
ErrorResponse:
753+
type: object
754+
required: [error]
755+
properties:
756+
error:
757+
type: string
758+
description: Error message describing what went wrong
759+
646760
TokenUsage:
647761
type: object
648762
properties:

runner/server/handler/image.go

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
package handler
2+
3+
import (
4+
"encoding/base64"
5+
"errors"
6+
"fmt"
7+
"log/slog"
8+
"net/http"
9+
"os"
10+
"strconv"
11+
"strings"
12+
"time"
13+
14+
"github.com/gin-gonic/gin"
15+
"github.com/openai/openai-go"
16+
17+
"github.com/NexaAI/nexa-sdk/runner/internal/types"
18+
nexa_sdk "github.com/NexaAI/nexa-sdk/runner/nexa-sdk"
19+
"github.com/NexaAI/nexa-sdk/runner/server/service"
20+
)
21+
22+
// @Router /images/generations [post]
23+
// @Summary Creates an image given a prompt.
24+
// @Description Creates an image given a prompt. This endpoint follows OpenAI DALL-E 3 API specification for compatibility.
25+
// @Accept json
26+
// @Param request body openai.ImageGenerateParams true "Image generation request"
27+
// @Produce json
28+
// @Success 200 {object} openai.ImagesResponse "Successful image generation response"
29+
// @Failure 400 {object} map[string]any "Bad request - invalid parameters"
30+
// @Failure 404 {object} map[string]any "Model not found"
31+
// @Failure 500 {object} map[string]any "Internal server error"
32+
func ImageGenerations(c *gin.Context) {
33+
param := openai.ImageGenerateParams{}
34+
if err := c.ShouldBindJSON(&param); err != nil {
35+
slog.Error("Failed to bind JSON request", "error", err)
36+
c.JSON(http.StatusBadRequest, map[string]any{"error": err.Error()})
37+
return
38+
}
39+
40+
slog.Info("Image generation request received",
41+
"model", param.Model,
42+
"prompt_length", len(param.Prompt),
43+
"n", param.N,
44+
"size", param.Size)
45+
46+
if param.N.Value == 0 {
47+
param.N.Value = 1
48+
}
49+
if param.Size == "" {
50+
param.Size = openai.ImageGenerateParamsSize256x256
51+
}
52+
if param.ResponseFormat == "" {
53+
param.ResponseFormat = openai.ImageGenerateParamsResponseFormatURL
54+
}
55+
56+
imageGen, err := service.KeepAliveGet[nexa_sdk.ImageGen](
57+
param.Model,
58+
types.ModelParam{},
59+
c.GetHeader("Nexa-KeepCache") != "true",
60+
)
61+
if err != nil {
62+
c.JSON(http.StatusInternalServerError, map[string]any{"error": err.Error()})
63+
return
64+
}
65+
66+
width, height, err := parseImageSize(string(param.Size))
67+
if err != nil {
68+
c.JSON(http.StatusBadRequest, map[string]any{"error": err.Error()})
69+
return
70+
}
71+
72+
var imageData []openai.Image
73+
n := int(param.N.Value)
74+
slog.Info("Starting image generation", "count", n, "size", string(param.Size))
75+
for i := range n {
76+
outputPath := fmt.Sprintf("imagegen_output_%d.png", time.Now().UnixNano())
77+
slog.Debug("Generating image", "index", i, "output_path", outputPath)
78+
79+
config := &nexa_sdk.ImageGenerationConfig{
80+
Prompts: []string{param.Prompt},
81+
NegativePrompts: []string{"blurry, low quality, distorted, low resolution"},
82+
Height: height,
83+
Width: width,
84+
SamplerConfig: nexa_sdk.ImageSamplerConfig{
85+
Method: "ddim",
86+
Steps: 20,
87+
GuidanceScale: 7.5,
88+
Eta: 0.0,
89+
Seed: int32(time.Now().UnixNano() % 1000000),
90+
},
91+
SchedulerConfig: nexa_sdk.SchedulerConfig{
92+
Type: "ddim",
93+
NumTrainTimesteps: 1000,
94+
StepsOffset: 1,
95+
BetaStart: 0.00085,
96+
BetaEnd: 0.012,
97+
BetaSchedule: "scaled_linear",
98+
PredictionType: "epsilon",
99+
TimestepType: "discrete",
100+
TimestepSpacing: "leading",
101+
InterpolationType: "linear",
102+
ConfigPath: "",
103+
},
104+
Strength: 1.0,
105+
}
106+
107+
result, err := imageGen.Txt2Img(nexa_sdk.ImageGenTxt2ImgInput{
108+
PromptUTF8: param.Prompt,
109+
Config: config,
110+
OutputPath: outputPath,
111+
})
112+
if err != nil {
113+
c.JSON(http.StatusInternalServerError, map[string]any{"error": fmt.Sprintf("image generation failed: %v", err)})
114+
return
115+
}
116+
117+
data := openai.Image{
118+
RevisedPrompt: param.Prompt,
119+
}
120+
121+
if param.ResponseFormat == openai.ImageGenerateParamsResponseFormatB64JSON {
122+
b64Data, err := encodeImageToBase64(result.OutputImagePath)
123+
os.Remove(result.OutputImagePath)
124+
if err != nil {
125+
c.JSON(http.StatusInternalServerError, map[string]any{"error": fmt.Sprintf("failed to encode image: %v", err)})
126+
return
127+
}
128+
data.B64JSON = b64Data
129+
} else {
130+
data.URL = result.OutputImagePath
131+
}
132+
133+
imageData = append(imageData, data)
134+
slog.Info("Image generated successfully", "index", i, "output_path", result.OutputImagePath)
135+
}
136+
137+
response := openai.ImagesResponse{
138+
Created: time.Now().Unix(),
139+
Data: imageData,
140+
}
141+
142+
slog.Info("Image generation completed successfully", "total_images", len(imageData))
143+
c.JSON(http.StatusOK, response)
144+
}
145+
146+
func parseImageSize(size string) (int32, int32, error) {
147+
parts := strings.Split(size, "x")
148+
if len(parts) != 2 {
149+
return 0, 0, errors.New("invalid size format")
150+
}
151+
152+
width, err := strconv.Atoi(parts[0])
153+
if err != nil {
154+
return 0, 0, errors.New("invalid width")
155+
}
156+
157+
height, err := strconv.Atoi(parts[1])
158+
if err != nil {
159+
return 0, 0, errors.New("invalid height")
160+
}
161+
162+
return int32(width), int32(height), nil
163+
}
164+
165+
func encodeImageToBase64(imagePath string) (string, error) {
166+
imageData, err := os.ReadFile(imagePath)
167+
if err != nil {
168+
return "", fmt.Errorf("failed to read image file: %v", err)
169+
}
170+
mimeType := http.DetectContentType(imageData)
171+
base64String := base64.StdEncoding.EncodeToString(imageData)
172+
return fmt.Sprintf("data:%s;base64,%s", mimeType, base64String), nil
173+
}

runner/server/route.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ func RegisterAPIv1(r *gin.Engine) {
3333

3434
g.POST("/embeddings", handler.Embeddings)
3535

36+
g.POST("/images/generations", handler.ImageGenerations)
37+
3638
//g.POST("/reranking", handler.Reranking)
3739

3840
g.GET("/models/*model", handler.RetrieveModel)

runner/server/service/keepalive.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ func keepAliveGet[T any](name string, param types.ModelParam, reset bool) (any,
127127
break
128128
}
129129
}
130-
130+
131131
var t keepable
132132
var e error
133133
switch reflect.TypeFor[T]() {
@@ -168,6 +168,15 @@ func keepAliveGet[T any](name string, param types.ModelParam, reset bool) (any,
168168
PluginID: manifest.PluginId,
169169
DeviceID: manifest.DeviceId,
170170
})
171+
case reflect.TypeFor[nexa_sdk.ImageGen]():
172+
// For image generation models, use the model directory path instead of specific file
173+
modelDir := s.ModelfilePath(manifest.Name, "")
174+
t, e = nexa_sdk.NewImageGen(nexa_sdk.ImageGenCreateInput{
175+
ModelName: manifest.ModelName,
176+
ModelPath: modelDir,
177+
PluginID: manifest.PluginId,
178+
DeviceID: manifest.DeviceId,
179+
})
171180
//case reflect.TypeFor[nexa_sdk.Reranker]():
172181
// t, e = nexa_sdk.NewReranker(modelfile, nil, param.Device)
173182
//case reflect.TypeFor[nexa_sdk.TTS]():

0 commit comments

Comments
 (0)