Skip to content

Commit 4d8f7c5

Browse files
authored
feat: implement FLUX.1-Fill model with image generation capability. (#284)
1 parent 50a3a08 commit 4d8f7c5

File tree

16 files changed

+1159
-484
lines changed

16 files changed

+1159
-484
lines changed

xllm/core/framework/batch/dit_batch.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,11 @@ DiTForwardInput DiTBatch::prepare_forward_input() {
6060
std::vector<torch::Tensor> negative_prompt_embeds;
6161
std::vector<torch::Tensor> negative_pooled_prompt_embeds;
6262

63+
std::vector<torch::Tensor> images;
64+
std::vector<torch::Tensor> mask_images;
65+
6366
std::vector<torch::Tensor> latents;
67+
std::vector<torch::Tensor> masked_image_latents;
6468
for (const auto& request : request_vec_) {
6569
const auto& generation_params = request->state().generation_params();
6670
if (input.generation_params != generation_params) {
@@ -88,6 +92,10 @@ DiTForwardInput DiTBatch::prepare_forward_input() {
8892
input_params.negative_pooled_prompt_embed);
8993

9094
latents.emplace_back(input_params.latent);
95+
masked_image_latents.emplace_back(input_params.masked_image_latent);
96+
97+
images.emplace_back(input_params.image);
98+
mask_images.emplace_back(input_params.mask_image);
9199
}
92100

93101
if (input.prompts.size() != request_vec_.size()) {
@@ -106,6 +114,14 @@ DiTForwardInput DiTBatch::prepare_forward_input() {
106114
input.negative_prompts_2.clear();
107115
}
108116

117+
if (check_tensors_valid(images)) {
118+
input.images = torch::stack(images);
119+
}
120+
121+
if (check_tensors_valid(mask_images)) {
122+
input.mask_images = torch::stack(mask_images);
123+
}
124+
109125
if (check_tensors_valid(prompt_embeds)) {
110126
input.prompt_embeds = torch::stack(prompt_embeds);
111127
}
@@ -127,6 +143,9 @@ DiTForwardInput DiTBatch::prepare_forward_input() {
127143
input.latents = torch::stack(latents);
128144
}
129145

146+
if (check_tensors_valid(masked_image_latents)) {
147+
input.masked_image_latents = torch::stack(masked_image_latents);
148+
}
130149
return input;
131150
}
132151

xllm/core/framework/request/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ cc_library(
1111
incremental_decoder.h
1212
mm_data.h
1313
mm_input_helper.h
14+
mm_codec.h
1415
request_base.h
1516
request.h
1617
dit_request.h
@@ -31,6 +32,7 @@ cc_library(
3132
incremental_decoder.cpp
3233
mm_data.cpp
3334
mm_input_helper.cpp
35+
mm_codec.cpp
3436
request.cpp
3537
dit_request.cpp
3638
request_output.cpp

xllm/core/framework/request/dit_request.cpp

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -26,51 +26,7 @@ limitations under the License.
2626
#include <vector>
2727

2828
#include "api_service/call.h"
29-
30-
namespace {
31-
class OpenCVImageEncoder {
32-
public:
33-
// t float32, cpu, chw
34-
bool encode(const torch::Tensor& t, std::string& raw_data) {
35-
if (!valid(t)) {
36-
return false;
37-
}
38-
39-
auto img = t.permute({1, 2, 0}).contiguous();
40-
cv::Mat mat(img.size(0), img.size(1), CV_32FC3, img.data_ptr<float>());
41-
42-
cv::Mat mat_8u;
43-
mat.convertTo(mat_8u, CV_8UC3, 255.0);
44-
45-
// rgb -> bgr
46-
cv::cvtColor(mat_8u, mat_8u, cv::COLOR_RGB2BGR);
47-
48-
std::vector<uchar> data;
49-
if (!cv::imencode(".png", mat_8u, data)) {
50-
LOG(ERROR) << "image encode faild";
51-
return false;
52-
}
53-
54-
raw_data.assign(data.begin(), data.end());
55-
return true;
56-
}
57-
58-
private:
59-
bool valid(const torch::Tensor& t) {
60-
if (t.dim() != 3 || t.size(0) != 3) {
61-
LOG(ERROR) << "input tensor must be 3HW tensor";
62-
return false;
63-
}
64-
65-
if (t.scalar_type() != torch::kFloat32 || !t.device().is_cpu()) {
66-
LOG(ERROR) << "tensor must be cpu float32";
67-
return false;
68-
}
69-
70-
return true;
71-
}
72-
};
73-
} // namespace
29+
#include "mm_codec.h"
7430

7531
namespace xllm {
7632
DiTRequest::DiTRequest(const std::string& request_id,

xllm/core/framework/request/dit_request_params.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@ limitations under the License.
1616

1717
#include "dit_request_params.h"
1818

19+
#include "butil/base64.h"
1920
#include "core/common/instance_name.h"
2021
#include "core/common/macros.h"
2122
#include "core/util/uuid.h"
23+
#include "mm_codec.h"
2224
#include "request.h"
2325

2426
namespace xllm {
@@ -242,6 +244,31 @@ DiTRequestParams::DiTRequestParams(const proto::ImageGenerationRequest& request,
242244
if (input.has_latent()) {
243245
input_params.latent = proto_to_torch(input.latent());
244246
}
247+
if (input.has_masked_image_latent()) {
248+
input_params.masked_image_latent =
249+
proto_to_torch(input.masked_image_latent());
250+
}
251+
252+
OpenCVImageDecoder decoder;
253+
if (input.has_image()) {
254+
std::string raw_bytes;
255+
if (!butil::Base64Decode(input.image(), &raw_bytes)) {
256+
LOG(ERROR) << "Base64 image decode failed";
257+
}
258+
if (!decoder.decode(raw_bytes, input_params.image)) {
259+
LOG(ERROR) << "Image decode failed.";
260+
}
261+
}
262+
263+
if (input.has_mask_image()) {
264+
std::string raw_bytes;
265+
if (!butil::Base64Decode(input.mask_image(), &raw_bytes)) {
266+
LOG(ERROR) << "Base64 mask_image decode failed";
267+
}
268+
if (!decoder.decode(raw_bytes, input_params.mask_image)) {
269+
LOG(ERROR) << "Mask_image decode failed.";
270+
}
271+
}
245272

246273
// generation params
247274
const auto& params = request.parameters();

xllm/core/framework/request/dit_request_state.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ struct DiTGenerationParams {
4040
guidance_scale == other.guidance_scale &&
4141
num_images_per_prompt == other.num_images_per_prompt &&
4242
seed == other.seed &&
43-
max_sequence_length == other.max_sequence_length;
43+
max_sequence_length == other.max_sequence_length &&
44+
strength == other.strength;
4445
}
4546

4647
bool operator!=(const DiTGenerationParams& other) const {
@@ -62,6 +63,8 @@ struct DiTGenerationParams {
6263
int64_t seed = 0;
6364

6465
int32_t max_sequence_length = 512;
66+
67+
float strength = 1.0;
6568
};
6669

6770
struct DiTInputParams {
@@ -86,6 +89,12 @@ struct DiTInputParams {
8689
torch::Tensor negative_pooled_prompt_embed;
8790

8891
torch::Tensor latent;
92+
93+
torch::Tensor image;
94+
95+
torch::Tensor mask_image;
96+
97+
torch::Tensor masked_image_latent;
8998
};
9099

91100
struct DiTRequestState {
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
Copyright 2024 The ScaleLLM Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
==============================================================================*/
16+
17+
#include "mm_codec.h"
18+
19+
namespace xllm {
20+
21+
bool OpenCVImageDecoder::decode(const std::string& raw_data, torch::Tensor& t) {
22+
cv::Mat buffer(1, raw_data.size(), CV_8UC1, (void*)raw_data.data());
23+
cv::Mat image = cv::imdecode(buffer, cv::IMREAD_COLOR);
24+
if (image.empty()) {
25+
LOG(INFO) << " opencv image decode failed";
26+
return false;
27+
}
28+
29+
cv::cvtColor(image, image, cv::COLOR_BGR2RGB); // RGB
30+
31+
torch::Tensor tensor =
32+
torch::from_blob(image.data, {image.rows, image.cols, 3}, torch::kUInt8);
33+
34+
t = tensor.permute({2, 0, 1}).clone(); // [C, H, W]
35+
return true;
36+
}
37+
38+
bool OpenCVImageEncoder::encode(const torch::Tensor& t, std::string& raw_data) {
39+
if (!valid(t)) {
40+
return false;
41+
}
42+
43+
auto img = t.permute({1, 2, 0}).contiguous();
44+
cv::Mat mat(img.size(0), img.size(1), CV_32FC3, img.data_ptr<float>());
45+
46+
cv::Mat mat_8u;
47+
mat.convertTo(mat_8u, CV_8UC3, 255.0);
48+
49+
// rgb -> bgr
50+
cv::cvtColor(mat_8u, mat_8u, cv::COLOR_RGB2BGR);
51+
52+
std::vector<uchar> data;
53+
if (!cv::imencode(".png", mat_8u, data)) {
54+
LOG(ERROR) << "image encode faild";
55+
return false;
56+
}
57+
58+
raw_data.assign(data.begin(), data.end());
59+
return true;
60+
}
61+
62+
bool OpenCVImageEncoder::valid(const torch::Tensor& t) {
63+
if (t.dim() != 3 || t.size(0) != 3) {
64+
LOG(ERROR) << "input tensor must be 3HW tensor";
65+
return false;
66+
}
67+
68+
if (t.scalar_type() != torch::kFloat32 || !t.device().is_cpu()) {
69+
LOG(ERROR) << "tensor must be cpu float32";
70+
return false;
71+
}
72+
73+
return true;
74+
}
75+
76+
} // namespace xllm
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
Copyright 2024 The ScaleLLM Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
==============================================================================*/
16+
17+
#pragma once
18+
#include <torch/torch.h>
19+
20+
#include <opencv2/opencv.hpp>
21+
#include <string>
22+
23+
namespace xllm {
24+
25+
class OpenCVImageDecoder {
26+
public:
27+
OpenCVImageDecoder() = default;
28+
~OpenCVImageDecoder() = default;
29+
30+
bool decode(const std::string& raw_data, torch::Tensor& t);
31+
};
32+
33+
class OpenCVImageEncoder {
34+
public:
35+
OpenCVImageEncoder() = default;
36+
~OpenCVImageEncoder() = default;
37+
38+
bool encode(const torch::Tensor& t, std::string& raw_data);
39+
40+
private:
41+
bool valid(const torch::Tensor& t);
42+
};
43+
44+
} // namespace xllm

xllm/core/framework/request/mm_input_helper.cpp

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,29 +23,10 @@ limitations under the License.
2323
#include <opencv2/opencv.hpp>
2424

2525
#include "butil/base64.h"
26+
#include "mm_codec.h"
2627

2728
namespace xllm {
2829

29-
class OpenCVImageDecoder {
30-
public:
31-
bool decode(const std::string& raw_data, torch::Tensor& t) {
32-
cv::Mat buffer(1, raw_data.size(), CV_8UC1, (void*)raw_data.data());
33-
cv::Mat image = cv::imdecode(buffer, cv::IMREAD_COLOR);
34-
if (image.empty()) {
35-
LOG(INFO) << " opencv image decode failed";
36-
return false;
37-
}
38-
39-
cv::cvtColor(image, image, cv::COLOR_BGR2RGB); // RGB
40-
41-
torch::Tensor tensor = torch::from_blob(
42-
image.data, {image.rows, image.cols, 3}, torch::kUInt8);
43-
44-
t = tensor.permute({2, 0, 1}).clone(); // [C, H, W]
45-
return true;
46-
}
47-
};
48-
4930
class FileDownloadHelper {
5031
public:
5132
FileDownloadHelper() {}

xllm/core/runtime/dit_forward_params.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,18 @@ struct DiTForwardInput {
5151
if (latents.defined()) {
5252
input.latents = latents.to(device, dtype);
5353
}
54+
55+
if (masked_image_latents.defined()) {
56+
input.masked_image_latents = masked_image_latents.to(device, dtype);
57+
}
58+
59+
if (images.defined()) {
60+
input.images = images.to(device, dtype);
61+
}
62+
63+
if (mask_images.defined()) {
64+
input.mask_images = mask_images.to(device, dtype);
65+
}
5466
return input;
5567
}
5668

@@ -68,6 +80,12 @@ struct DiTForwardInput {
6880
// Secondary negative prompt to exclude additional unwanted features
6981
std::vector<std::string> negative_prompts_2;
7082

83+
torch::Tensor images;
84+
85+
torch::Tensor mask_images;
86+
87+
torch::Tensor masked_image_latents;
88+
7189
torch::Tensor prompt_embeds;
7290

7391
torch::Tensor pooled_prompt_embeds;

0 commit comments

Comments
 (0)