Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions xllm/core/framework/batch/dit_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ DiTForwardInput DiTBatch::prepare_forward_input() {
std::vector<torch::Tensor> negative_prompt_embeds;
std::vector<torch::Tensor> negative_pooled_prompt_embeds;

std::vector<torch::Tensor> images;
std::vector<torch::Tensor> mask_images;

std::vector<torch::Tensor> latents;
std::vector<torch::Tensor> masked_image_latents;
for (const auto& request : request_vec_) {
const auto& generation_params = request->state().generation_params();
if (input.generation_params != generation_params) {
Expand Down Expand Up @@ -88,6 +92,10 @@ DiTForwardInput DiTBatch::prepare_forward_input() {
input_params.negative_pooled_prompt_embed);

latents.emplace_back(input_params.latent);
masked_image_latents.emplace_back(input_params.masked_image_latent);

images.emplace_back(input_params.image);
mask_images.emplace_back(input_params.mask_image);
}

if (input.prompts.size() != request_vec_.size()) {
Expand All @@ -106,6 +114,14 @@ DiTForwardInput DiTBatch::prepare_forward_input() {
input.negative_prompts_2.clear();
}

if (check_tensors_valid(images)) {
input.images = torch::stack(images);
}

if (check_tensors_valid(mask_images)) {
input.mask_images = torch::stack(mask_images);
}

if (check_tensors_valid(prompt_embeds)) {
input.prompt_embeds = torch::stack(prompt_embeds);
}
Expand All @@ -127,6 +143,9 @@ DiTForwardInput DiTBatch::prepare_forward_input() {
input.latents = torch::stack(latents);
}

if (check_tensors_valid(masked_image_latents)) {
input.masked_image_latents = torch::stack(masked_image_latents);
}
return input;
}

Expand Down
2 changes: 2 additions & 0 deletions xllm/core/framework/request/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ cc_library(
incremental_decoder.h
mm_data.h
mm_input_helper.h
mm_codec.h
request_base.h
request.h
dit_request.h
Expand All @@ -31,6 +32,7 @@ cc_library(
incremental_decoder.cpp
mm_data.cpp
mm_input_helper.cpp
mm_codec.cpp
request.cpp
dit_request.cpp
request_output.cpp
Expand Down
46 changes: 1 addition & 45 deletions xllm/core/framework/request/dit_request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,51 +26,7 @@ limitations under the License.
#include <vector>

#include "api_service/call.h"

namespace {
class OpenCVImageEncoder {
public:
// t float32, cpu, chw
bool encode(const torch::Tensor& t, std::string& raw_data) {
if (!valid(t)) {
return false;
}

auto img = t.permute({1, 2, 0}).contiguous();
cv::Mat mat(img.size(0), img.size(1), CV_32FC3, img.data_ptr<float>());

cv::Mat mat_8u;
mat.convertTo(mat_8u, CV_8UC3, 255.0);

// rgb -> bgr
cv::cvtColor(mat_8u, mat_8u, cv::COLOR_RGB2BGR);

std::vector<uchar> data;
if (!cv::imencode(".png", mat_8u, data)) {
LOG(ERROR) << "image encode faild";
return false;
}

raw_data.assign(data.begin(), data.end());
return true;
}

private:
bool valid(const torch::Tensor& t) {
if (t.dim() != 3 || t.size(0) != 3) {
LOG(ERROR) << "input tensor must be 3HW tensor";
return false;
}

if (t.scalar_type() != torch::kFloat32 || !t.device().is_cpu()) {
LOG(ERROR) << "tensor must be cpu float32";
return false;
}

return true;
}
};
} // namespace
#include "mm_codec.h"

namespace xllm {
DiTRequest::DiTRequest(const std::string& request_id,
Expand Down
27 changes: 27 additions & 0 deletions xllm/core/framework/request/dit_request_params.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ limitations under the License.

#include "dit_request_params.h"

#include "butil/base64.h"
#include "core/common/instance_name.h"
#include "core/common/macros.h"
#include "core/util/uuid.h"
#include "mm_codec.h"
#include "request.h"

namespace xllm {
Expand Down Expand Up @@ -242,6 +244,31 @@ DiTRequestParams::DiTRequestParams(const proto::ImageGenerationRequest& request,
if (input.has_latent()) {
input_params.latent = proto_to_torch(input.latent());
}
if (input.has_masked_image_latent()) {
input_params.masked_image_latent =
proto_to_torch(input.masked_image_latent());
}

OpenCVImageDecoder decoder;
if (input.has_image()) {
std::string raw_bytes;
if (!butil::Base64Decode(input.image(), &raw_bytes)) {
LOG(ERROR) << "Base64 image decode failed";
}
if (!decoder.decode(raw_bytes, input_params.image)) {
LOG(ERROR) << "Image decode failed.";
}
}

if (input.has_mask_image()) {
std::string raw_bytes;
if (!butil::Base64Decode(input.mask_image(), &raw_bytes)) {
LOG(ERROR) << "Base64 mask_image decode failed";
}
if (!decoder.decode(raw_bytes, input_params.mask_image)) {
LOG(ERROR) << "Mask_image decode failed.";
}
}

// generation params
const auto& params = request.parameters();
Expand Down
11 changes: 10 additions & 1 deletion xllm/core/framework/request/dit_request_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ struct DiTGenerationParams {
guidance_scale == other.guidance_scale &&
num_images_per_prompt == other.num_images_per_prompt &&
seed == other.seed &&
max_sequence_length == other.max_sequence_length;
max_sequence_length == other.max_sequence_length &&
strength == other.strength;
}

bool operator!=(const DiTGenerationParams& other) const {
Expand All @@ -62,6 +63,8 @@ struct DiTGenerationParams {
int64_t seed = 0;

int32_t max_sequence_length = 512;

float strength = 1.0;
};

struct DiTInputParams {
Expand All @@ -86,6 +89,12 @@ struct DiTInputParams {
torch::Tensor negative_pooled_prompt_embed;

torch::Tensor latent;

torch::Tensor image;

torch::Tensor mask_image;

torch::Tensor masked_image_latent;
};

struct DiTRequestState {
Expand Down
76 changes: 76 additions & 0 deletions xllm/core/framework/request/mm_codec.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
Copyright 2024 The ScaleLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "mm_codec.h"

namespace xllm {

bool OpenCVImageDecoder::decode(const std::string& raw_data, torch::Tensor& t) {
cv::Mat buffer(1, raw_data.size(), CV_8UC1, (void*)raw_data.data());
cv::Mat image = cv::imdecode(buffer, cv::IMREAD_COLOR);
if (image.empty()) {
LOG(INFO) << " opencv image decode failed";
return false;
}

cv::cvtColor(image, image, cv::COLOR_BGR2RGB); // RGB

torch::Tensor tensor =
torch::from_blob(image.data, {image.rows, image.cols, 3}, torch::kUInt8);

t = tensor.permute({2, 0, 1}).clone(); // [C, H, W]
return true;
}

bool OpenCVImageEncoder::encode(const torch::Tensor& t, std::string& raw_data) {
if (!valid(t)) {
return false;
}

auto img = t.permute({1, 2, 0}).contiguous();
cv::Mat mat(img.size(0), img.size(1), CV_32FC3, img.data_ptr<float>());

cv::Mat mat_8u;
mat.convertTo(mat_8u, CV_8UC3, 255.0);

// rgb -> bgr
cv::cvtColor(mat_8u, mat_8u, cv::COLOR_RGB2BGR);

std::vector<uchar> data;
if (!cv::imencode(".png", mat_8u, data)) {
LOG(ERROR) << "image encode faild";
return false;
}

raw_data.assign(data.begin(), data.end());
return true;
}

bool OpenCVImageEncoder::valid(const torch::Tensor& t) {
if (t.dim() != 3 || t.size(0) != 3) {
LOG(ERROR) << "input tensor must be 3HW tensor";
return false;
}

if (t.scalar_type() != torch::kFloat32 || !t.device().is_cpu()) {
LOG(ERROR) << "tensor must be cpu float32";
return false;
}

return true;
}

} // namespace xllm
44 changes: 44 additions & 0 deletions xllm/core/framework/request/mm_codec.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
Copyright 2024 The ScaleLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#pragma once
#include <torch/torch.h>

#include <opencv2/opencv.hpp>
#include <string>

namespace xllm {

class OpenCVImageDecoder {
public:
OpenCVImageDecoder() = default;
~OpenCVImageDecoder() = default;

bool decode(const std::string& raw_data, torch::Tensor& t);
};

class OpenCVImageEncoder {
public:
OpenCVImageEncoder() = default;
~OpenCVImageEncoder() = default;

bool encode(const torch::Tensor& t, std::string& raw_data);

private:
bool valid(const torch::Tensor& t);
};

} // namespace xllm
21 changes: 1 addition & 20 deletions xllm/core/framework/request/mm_input_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,10 @@ limitations under the License.
#include <opencv2/opencv.hpp>

#include "butil/base64.h"
#include "mm_codec.h"

namespace xllm {

class OpenCVImageDecoder {
public:
bool decode(const std::string& raw_data, torch::Tensor& t) {
cv::Mat buffer(1, raw_data.size(), CV_8UC1, (void*)raw_data.data());
cv::Mat image = cv::imdecode(buffer, cv::IMREAD_COLOR);
if (image.empty()) {
LOG(INFO) << " opencv image decode failed";
return false;
}

cv::cvtColor(image, image, cv::COLOR_BGR2RGB); // RGB

torch::Tensor tensor = torch::from_blob(
image.data, {image.rows, image.cols, 3}, torch::kUInt8);

t = tensor.permute({2, 0, 1}).clone(); // [C, H, W]
return true;
}
};

class FileDownloadHelper {
public:
FileDownloadHelper() {}
Expand Down
18 changes: 18 additions & 0 deletions xllm/core/runtime/dit_forward_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ struct DiTForwardInput {
if (latents.defined()) {
input.latents = latents.to(device, dtype);
}

if (masked_image_latents.defined()) {
input.masked_image_latents = masked_image_latents.to(device, dtype);
}

if (images.defined()) {
input.images = images.to(device, dtype);
}

if (mask_images.defined()) {
input.mask_images = mask_images.to(device, dtype);
}
return input;
}

Expand All @@ -68,6 +80,12 @@ struct DiTForwardInput {
// Secondary negative prompt to exclude additional unwanted features
std::vector<std::string> negative_prompts_2;

torch::Tensor images;

torch::Tensor mask_images;

torch::Tensor masked_image_latents;

torch::Tensor prompt_embeds;

torch::Tensor pooled_prompt_embeds;
Expand Down
Loading