From 696f693d26dc401bcc6b6a4357e020cff5f61cba Mon Sep 17 00:00:00 2001 From: "zhangjun.937" Date: Mon, 27 Oct 2025 14:43:38 +0800 Subject: [PATCH] feat: implement FLUX.1-Fill model with image generation capability. --- xllm/core/framework/batch/dit_batch.cpp | 19 + xllm/core/framework/request/CMakeLists.txt | 2 + xllm/core/framework/request/dit_request.cpp | 46 +- .../framework/request/dit_request_params.cpp | 27 ++ .../framework/request/dit_request_state.h | 11 +- xllm/core/framework/request/mm_codec.cpp | 76 +++ xllm/core/framework/request/mm_codec.h | 44 ++ .../framework/request/mm_input_helper.cpp | 21 +- xllm/core/runtime/dit_forward_params.h | 18 + xllm/models/dit/autoencoder_kl.h | 131 +++++- xllm/models/dit/dit.h | 7 +- xllm/models/dit/pipeline_flux.h | 409 ++-------------- xllm/models/dit/pipeline_flux_base.h | 351 ++++++++++++++ xllm/models/dit/pipeline_flux_fill.h | 436 ++++++++++++++++++ xllm/models/models.h | 29 +- xllm/proto/image_generation.proto | 16 +- 16 files changed, 1159 insertions(+), 484 deletions(-) create mode 100644 xllm/core/framework/request/mm_codec.cpp create mode 100644 xllm/core/framework/request/mm_codec.h create mode 100644 xllm/models/dit/pipeline_flux_base.h create mode 100644 xllm/models/dit/pipeline_flux_fill.h diff --git a/xllm/core/framework/batch/dit_batch.cpp b/xllm/core/framework/batch/dit_batch.cpp index dbf4e72f..1c09327a 100644 --- a/xllm/core/framework/batch/dit_batch.cpp +++ b/xllm/core/framework/batch/dit_batch.cpp @@ -60,7 +60,11 @@ DiTForwardInput DiTBatch::prepare_forward_input() { std::vector negative_prompt_embeds; std::vector negative_pooled_prompt_embeds; + std::vector images; + std::vector mask_images; + std::vector latents; + std::vector masked_image_latents; for (const auto& request : request_vec_) { const auto& generation_params = request->state().generation_params(); if (input.generation_params != generation_params) { @@ -88,6 +92,10 @@ DiTForwardInput DiTBatch::prepare_forward_input() { input_params.negative_pooled_prompt_embed); latents.emplace_back(input_params.latent); + masked_image_latents.emplace_back(input_params.masked_image_latent); + + images.emplace_back(input_params.image); + mask_images.emplace_back(input_params.mask_image); } if (input.prompts.size() != request_vec_.size()) { @@ -106,6 +114,14 @@ DiTForwardInput DiTBatch::prepare_forward_input() { input.negative_prompts_2.clear(); } + if (check_tensors_valid(images)) { + input.images = torch::stack(images); + } + + if (check_tensors_valid(mask_images)) { + input.mask_images = torch::stack(mask_images); + } + if (check_tensors_valid(prompt_embeds)) { input.prompt_embeds = torch::stack(prompt_embeds); } @@ -127,6 +143,9 @@ DiTForwardInput DiTBatch::prepare_forward_input() { input.latents = torch::stack(latents); } + if (check_tensors_valid(masked_image_latents)) { + input.masked_image_latents = torch::stack(masked_image_latents); + } return input; } diff --git a/xllm/core/framework/request/CMakeLists.txt b/xllm/core/framework/request/CMakeLists.txt index c1de17ab..c3f1fbdb 100644 --- a/xllm/core/framework/request/CMakeLists.txt +++ b/xllm/core/framework/request/CMakeLists.txt @@ -11,6 +11,7 @@ cc_library( incremental_decoder.h mm_data.h mm_input_helper.h + mm_codec.h request_base.h request.h dit_request.h @@ -31,6 +32,7 @@ cc_library( incremental_decoder.cpp mm_data.cpp mm_input_helper.cpp + mm_codec.cpp request.cpp dit_request.cpp request_output.cpp diff --git a/xllm/core/framework/request/dit_request.cpp b/xllm/core/framework/request/dit_request.cpp index 91867ba0..c0009e1f 100644 --- a/xllm/core/framework/request/dit_request.cpp +++ b/xllm/core/framework/request/dit_request.cpp @@ -26,51 +26,7 @@ limitations under the License. #include #include "api_service/call.h" - -namespace { -class OpenCVImageEncoder { - public: - // t float32, cpu, chw - bool encode(const torch::Tensor& t, std::string& raw_data) { - if (!valid(t)) { - return false; - } - - auto img = t.permute({1, 2, 0}).contiguous(); - cv::Mat mat(img.size(0), img.size(1), CV_32FC3, img.data_ptr()); - - cv::Mat mat_8u; - mat.convertTo(mat_8u, CV_8UC3, 255.0); - - // rgb -> bgr - cv::cvtColor(mat_8u, mat_8u, cv::COLOR_RGB2BGR); - - std::vector data; - if (!cv::imencode(".png", mat_8u, data)) { - LOG(ERROR) << "image encode faild"; - return false; - } - - raw_data.assign(data.begin(), data.end()); - return true; - } - - private: - bool valid(const torch::Tensor& t) { - if (t.dim() != 3 || t.size(0) != 3) { - LOG(ERROR) << "input tensor must be 3HW tensor"; - return false; - } - - if (t.scalar_type() != torch::kFloat32 || !t.device().is_cpu()) { - LOG(ERROR) << "tensor must be cpu float32"; - return false; - } - - return true; - } -}; -} // namespace +#include "mm_codec.h" namespace xllm { DiTRequest::DiTRequest(const std::string& request_id, diff --git a/xllm/core/framework/request/dit_request_params.cpp b/xllm/core/framework/request/dit_request_params.cpp index 130e69f3..2d01537f 100644 --- a/xllm/core/framework/request/dit_request_params.cpp +++ b/xllm/core/framework/request/dit_request_params.cpp @@ -16,9 +16,11 @@ limitations under the License. #include "dit_request_params.h" +#include "butil/base64.h" #include "core/common/instance_name.h" #include "core/common/macros.h" #include "core/util/uuid.h" +#include "mm_codec.h" #include "request.h" namespace xllm { @@ -242,6 +244,31 @@ DiTRequestParams::DiTRequestParams(const proto::ImageGenerationRequest& request, if (input.has_latent()) { input_params.latent = proto_to_torch(input.latent()); } + if (input.has_masked_image_latent()) { + input_params.masked_image_latent = + proto_to_torch(input.masked_image_latent()); + } + + OpenCVImageDecoder decoder; + if (input.has_image()) { + std::string raw_bytes; + if (!butil::Base64Decode(input.image(), &raw_bytes)) { + LOG(ERROR) << "Base64 image decode failed"; + } + if (!decoder.decode(raw_bytes, input_params.image)) { + LOG(ERROR) << "Image decode failed."; + } + } + + if (input.has_mask_image()) { + std::string raw_bytes; + if (!butil::Base64Decode(input.mask_image(), &raw_bytes)) { + LOG(ERROR) << "Base64 mask_image decode failed"; + } + if (!decoder.decode(raw_bytes, input_params.mask_image)) { + LOG(ERROR) << "Mask_image decode failed."; + } + } // generation params const auto& params = request.parameters(); diff --git a/xllm/core/framework/request/dit_request_state.h b/xllm/core/framework/request/dit_request_state.h index caef3fcb..fab43cb1 100644 --- a/xllm/core/framework/request/dit_request_state.h +++ b/xllm/core/framework/request/dit_request_state.h @@ -40,7 +40,8 @@ struct DiTGenerationParams { guidance_scale == other.guidance_scale && num_images_per_prompt == other.num_images_per_prompt && seed == other.seed && - max_sequence_length == other.max_sequence_length; + max_sequence_length == other.max_sequence_length && + strength == other.strength; } bool operator!=(const DiTGenerationParams& other) const { @@ -62,6 +63,8 @@ struct DiTGenerationParams { int64_t seed = 0; int32_t max_sequence_length = 512; + + float strength = 1.0; }; struct DiTInputParams { @@ -86,6 +89,12 @@ struct DiTInputParams { torch::Tensor negative_pooled_prompt_embed; torch::Tensor latent; + + torch::Tensor image; + + torch::Tensor mask_image; + + torch::Tensor masked_image_latent; }; struct DiTRequestState { diff --git a/xllm/core/framework/request/mm_codec.cpp b/xllm/core/framework/request/mm_codec.cpp new file mode 100644 index 00000000..cdb1abc1 --- /dev/null +++ b/xllm/core/framework/request/mm_codec.cpp @@ -0,0 +1,76 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. +Copyright 2024 The ScaleLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mm_codec.h" + +namespace xllm { + +bool OpenCVImageDecoder::decode(const std::string& raw_data, torch::Tensor& t) { + cv::Mat buffer(1, raw_data.size(), CV_8UC1, (void*)raw_data.data()); + cv::Mat image = cv::imdecode(buffer, cv::IMREAD_COLOR); + if (image.empty()) { + LOG(INFO) << " opencv image decode failed"; + return false; + } + + cv::cvtColor(image, image, cv::COLOR_BGR2RGB); // RGB + + torch::Tensor tensor = + torch::from_blob(image.data, {image.rows, image.cols, 3}, torch::kUInt8); + + t = tensor.permute({2, 0, 1}).clone(); // [C, H, W] + return true; +} + +bool OpenCVImageEncoder::encode(const torch::Tensor& t, std::string& raw_data) { + if (!valid(t)) { + return false; + } + + auto img = t.permute({1, 2, 0}).contiguous(); + cv::Mat mat(img.size(0), img.size(1), CV_32FC3, img.data_ptr()); + + cv::Mat mat_8u; + mat.convertTo(mat_8u, CV_8UC3, 255.0); + + // rgb -> bgr + cv::cvtColor(mat_8u, mat_8u, cv::COLOR_RGB2BGR); + + std::vector data; + if (!cv::imencode(".png", mat_8u, data)) { + LOG(ERROR) << "image encode faild"; + return false; + } + + raw_data.assign(data.begin(), data.end()); + return true; +} + +bool OpenCVImageEncoder::valid(const torch::Tensor& t) { + if (t.dim() != 3 || t.size(0) != 3) { + LOG(ERROR) << "input tensor must be 3HW tensor"; + return false; + } + + if (t.scalar_type() != torch::kFloat32 || !t.device().is_cpu()) { + LOG(ERROR) << "tensor must be cpu float32"; + return false; + } + + return true; +} + +} // namespace xllm diff --git a/xllm/core/framework/request/mm_codec.h b/xllm/core/framework/request/mm_codec.h new file mode 100644 index 00000000..eea7d9d3 --- /dev/null +++ b/xllm/core/framework/request/mm_codec.h @@ -0,0 +1,44 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. +Copyright 2024 The ScaleLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once +#include + +#include +#include + +namespace xllm { + +class OpenCVImageDecoder { + public: + OpenCVImageDecoder() = default; + ~OpenCVImageDecoder() = default; + + bool decode(const std::string& raw_data, torch::Tensor& t); +}; + +class OpenCVImageEncoder { + public: + OpenCVImageEncoder() = default; + ~OpenCVImageEncoder() = default; + + bool encode(const torch::Tensor& t, std::string& raw_data); + + private: + bool valid(const torch::Tensor& t); +}; + +} // namespace xllm diff --git a/xllm/core/framework/request/mm_input_helper.cpp b/xllm/core/framework/request/mm_input_helper.cpp index d11604f0..6cc61a7d 100644 --- a/xllm/core/framework/request/mm_input_helper.cpp +++ b/xllm/core/framework/request/mm_input_helper.cpp @@ -23,29 +23,10 @@ limitations under the License. #include #include "butil/base64.h" +#include "mm_codec.h" namespace xllm { -class OpenCVImageDecoder { - public: - bool decode(const std::string& raw_data, torch::Tensor& t) { - cv::Mat buffer(1, raw_data.size(), CV_8UC1, (void*)raw_data.data()); - cv::Mat image = cv::imdecode(buffer, cv::IMREAD_COLOR); - if (image.empty()) { - LOG(INFO) << " opencv image decode failed"; - return false; - } - - cv::cvtColor(image, image, cv::COLOR_BGR2RGB); // RGB - - torch::Tensor tensor = torch::from_blob( - image.data, {image.rows, image.cols, 3}, torch::kUInt8); - - t = tensor.permute({2, 0, 1}).clone(); // [C, H, W] - return true; - } -}; - class FileDownloadHelper { public: FileDownloadHelper() {} diff --git a/xllm/core/runtime/dit_forward_params.h b/xllm/core/runtime/dit_forward_params.h index 26cd9af6..e52d4bb2 100644 --- a/xllm/core/runtime/dit_forward_params.h +++ b/xllm/core/runtime/dit_forward_params.h @@ -51,6 +51,18 @@ struct DiTForwardInput { if (latents.defined()) { input.latents = latents.to(device, dtype); } + + if (masked_image_latents.defined()) { + input.masked_image_latents = masked_image_latents.to(device, dtype); + } + + if (images.defined()) { + input.images = images.to(device, dtype); + } + + if (mask_images.defined()) { + input.mask_images = mask_images.to(device, dtype); + } return input; } @@ -68,6 +80,12 @@ struct DiTForwardInput { // Secondary negative prompt to exclude additional unwanted features std::vector negative_prompts_2; + torch::Tensor images; + + torch::Tensor mask_images; + + torch::Tensor masked_image_latents; + torch::Tensor prompt_embeds; torch::Tensor pooled_prompt_embeds; diff --git a/xllm/models/dit/autoencoder_kl.h b/xllm/models/dit/autoencoder_kl.h index ec2a15f2..f57acc5b 100644 --- a/xllm/models/dit/autoencoder_kl.h +++ b/xllm/models/dit/autoencoder_kl.h @@ -34,15 +34,43 @@ limitations under the License. #include "framework/model_context.h" #include "models/model_registry.h" // VAE model compatible with huggingface weights -// ref to: -// https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl.py +// ref to: +// https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl.py namespace xllm { + +torch::Tensor randn_tensor(const std::vector& shape, + int64_t seed, + torch::TensorOptions& options) { + if (shape.empty()) { + LOG(FATAL) << "Shape must not be empty."; + } + at::Generator gen = at::detail::createCPUGenerator(); + gen = gen.clone(); + gen.set_current_seed(seed); + torch::Tensor latents; + latents = torch::randn( + shape, gen, options.device(torch::kCPU).dtype(torch::kFloat32)); + latents = latents.to(options); + return latents; +} + class VAEImageProcessorImpl : public torch::nn::Module { public: - explicit VAEImageProcessorImpl(ModelContext context) { + explicit VAEImageProcessorImpl(ModelContext context, + bool do_resize = true, + bool do_normalize = true, + bool do_binarize = false, + bool do_convert_rgb = false, + bool do_convert_grayscale = false) { const auto& model_args = context.get_model_args(); - scale_factor_ = 1 << (model_args.block_out_channels().size() - 1); + scale_factor_ = 1 << model_args.block_out_channels().size(); + latent_channels_ = 4; + do_resize_ = do_resize; + do_normalize_ = do_normalize; + do_binarize_ = do_binarize; + do_convert_rgb_ = do_convert_rgb; + do_convert_grayscale_ = do_convert_grayscale; } std::pair adjust_dimensions(int64_t height, @@ -84,16 +112,9 @@ class VAEImageProcessorImpl : public torch::nn::Module { torch::indexing::Slice(x1, x2)}); } } - if (do_convert_grayscale_ && processed.size(1) == 3) { - std::vector weights = {0.299f, 0.587f, 0.114f}; - torch::Tensor weight_tensor = - torch::tensor(weights, torch::kFloat32).view({1, 3, 1, 1}); - if (processed.dim() == 3) { - weight_tensor = weight_tensor.squeeze(0); - } - processed = torch::sum(processed * weight_tensor, 1, true); - } else if (do_convert_rgb_ && processed.size(1) == 1) { - processed = torch::cat({processed, processed, processed}, 1); + int channel = processed.size(1); + if (channel == latent_channels_) { + return image; } auto [target_h, target_w] = @@ -108,7 +129,7 @@ class VAEImageProcessorImpl : public torch::nn::Module { if (do_binarize_) { processed = (processed >= 0.5f).to(torch::kFloat32); } - + processed = processed.to(image.dtype()); return processed; } @@ -170,11 +191,12 @@ class VAEImageProcessorImpl : public torch::nn::Module { image, torch::nn::functional::InterpolateFuncOptions() .size(std::vector{target_height, target_width}) - .align_corners(false)); + .mode(torch::kNearest)); } private: int scale_factor_ = 8; + int latent_channels_ = 4; bool do_resize_ = true; bool do_normalize_ = true; bool do_binarize_ = false; @@ -836,6 +858,78 @@ class UpDecoderBlock2DImpl : public torch::nn::Module { }; TORCH_MODULE(UpDecoderBlock2D); +class DiagonalGaussianDistribution { + public: + DiagonalGaussianDistribution(torch::Tensor parameters, + bool deterministic = false) + : parameters_(std::move(parameters)), deterministic_(deterministic) { + auto chunks = parameters_.chunk(2, 1); + mean_ = chunks[0]; + logvar_ = chunks[1]; + + logvar_ = torch::clamp(logvar_, -30.0f, 20.0f); + + std_ = torch::exp(0.5f * logvar_); + var_ = torch::exp(logvar_); + + if (deterministic_) { + std_.fill_(0.0f); + var_.fill_(0.0f); + } + } + + torch::Tensor sample(int64_t seed) const { + torch::TensorOptions options = mean_.options(); + std::vector shape(mean_.sizes().begin(), mean_.sizes().end()); + return mean_ + std_ * randn_tensor(shape, seed, options); + } + + torch::Tensor kl(const std::optional& other = + std::nullopt) const { + if (deterministic_) { + return torch::tensor(0.0f, mean_.options()); + } + + if (!other.has_value()) { + return 0.5f * torch::sum(torch::pow(mean_, 2) + var_ - 1.0f - logvar_, + {1, 2, 3}); + } else { + const auto& other_dist = other.value(); + return 0.5f * torch::sum(torch::pow(mean_ - other_dist.mean_, 2) / + other_dist.var_ + + var_ / other_dist.var_ - 1.0f - logvar_ + + other_dist.logvar_, + {1, 2, 3}); + } + } + + torch::Tensor nll(const torch::Tensor& sample, + const std::vector& dims = {1, 2, 3}) const { + if (deterministic_) { + return torch::tensor(0.0f, mean_.options()); + } + const float logtwopi = std::log(2.0f * M_PI); + return 0.5f * + torch::sum(logtwopi + logvar_ + torch::pow(sample - mean_, 2) / var_, + dims); + } + + torch::Tensor mode() const { return mean_; } + + const torch::Tensor& mean() const { return mean_; } + const torch::Tensor& std() const { return std_; } + const torch::Tensor& var() const { return var_; } + const torch::Tensor& logvar() const { return logvar_; } + + private: + torch::Tensor parameters_; + torch::Tensor mean_; + torch::Tensor logvar_; + torch::Tensor std_; + torch::Tensor var_; + bool deterministic_; +}; + // VAE standard encoder implementation // This class is used to encode images into latent representations. class VAEEncoderImpl : public torch::nn::Module { @@ -1129,12 +1223,13 @@ class VAEImpl : public torch::nn::Module { } } - torch::Tensor encode(const torch::Tensor& images) { + torch::Tensor encode(const torch::Tensor& images, int64_t seed) { auto enc = encoder_(images); if (args_.use_quant_conv()) { enc = quant_conv_(enc); } - return enc; + auto posterior = DiagonalGaussianDistribution(enc); + return posterior.sample(seed); } torch::Tensor decode(const torch::Tensor& latents) { diff --git a/xllm/models/dit/dit.h b/xllm/models/dit/dit.h index 5e1a6359..e9d9302a 100644 --- a/xllm/models/dit/dit.h +++ b/xllm/models/dit/dit.h @@ -1216,7 +1216,8 @@ class FluxTransformer2DModelImpl : public torch::nn::Module { auto num_layers = model_args.num_layers(); auto num_single_layers = model_args.num_single_layers(); auto patch_size = model_args.mm_patch_size(); - out_channels_ = model_args.in_channels(); + in_channels_ = model_args.in_channels(); + out_channels_ = model_args.out_channels(); guidance_embeds_ = model_args.guidance_embeds(); // Initialize the transformer model components here @@ -1235,7 +1236,7 @@ class FluxTransformer2DModelImpl : public torch::nn::Module { context_embedder_ = register_module( "context_embedder", DiTLinear(joint_attention_dim, inner_dim)); x_embedder_ = - register_module("x_embedder", DiTLinear(out_channels_, inner_dim)); + register_module("x_embedder", DiTLinear(in_channels_, inner_dim)); context_embedder_->to(options_); x_embedder_->to(options_); // mm-dit block @@ -1450,6 +1451,7 @@ class FluxTransformer2DModelImpl : public torch::nn::Module { AdaLayerNormContinuous norm_out_{nullptr}; DiTLinear proj_out_{nullptr}; bool guidance_embeds_; + int64_t in_channels_; int64_t out_channels_; torch::TensorOptions options_; }; @@ -1500,6 +1502,7 @@ REGISTER_MODEL_ARGS(FluxTransformer2DModel, [&] { LOAD_ARG_OR(dtype, "dtype", "bfloat16"); LOAD_ARG_OR(mm_patch_size, "patch_size", 1); LOAD_ARG_OR(in_channels, "in_channels", 64); + LOAD_ARG_OR(out_channels, "out_channels", 64); LOAD_ARG_OR(num_layers, "num_layers", 19); LOAD_ARG_OR(num_single_layers, "num_single_layers", 38); LOAD_ARG_OR(head_dim, "attention_head_dim", 128); diff --git a/xllm/models/dit/pipeline_flux.h b/xllm/models/dit/pipeline_flux.h index 209ff5e5..dfaea129 100644 --- a/xllm/models/dit/pipeline_flux.h +++ b/xllm/models/dit/pipeline_flux.h @@ -14,224 +14,32 @@ limitations under the License. ==============================================================================*/ #pragma once -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "autoencoder_kl.h" -#include "clip_text_model.h" -#include "core/framework/dit_model_loader.h" -#include "core/framework/model/model_input_params.h" -#include "core/framework/model_context.h" -#include "core/framework/request/dit_request_state.h" -#include "core/framework/state_dict/state_dict.h" -#include "core/framework/state_dict/utils.h" #include "core/layers/pos_embedding.h" -#include "core/layers/rms_norm.h" #include "core/layers/rotary_embedding.h" #include "dit.h" -#include "flowmatch_euler_discrete_scheduler.h" -#include "framework/model_context.h" -#include "models/model_registry.h" -#include "t5_encoder.h" -namespace xllm { - -float calculate_shift(int64_t image_seq_len, - int64_t base_seq_len = 256, - int64_t max_seq_len = 4096, - float base_shift = 0.5f, - float max_shift = 1.15f) { - float m = - (max_shift - base_shift) / static_cast(max_seq_len - base_seq_len); - float b = base_shift - m * static_cast(base_seq_len); - float mu = static_cast(image_seq_len) * m + b; - return mu; -} - -std::pair retrieve_timesteps( - FlowMatchEulerDiscreteScheduler scheduler, - int64_t num_inference_steps = 0, - torch::Device device = torch::kCPU, - std::optional> sigmas = std::nullopt, - std::optional mu = std::nullopt) { - torch::Tensor scheduler_timesteps; - int64_t steps; - if (sigmas.has_value()) { - steps = sigmas->size(); - scheduler->set_timesteps( - static_cast(steps), device, *sigmas, mu, std::nullopt); - - scheduler_timesteps = scheduler->timesteps(); - } else { - steps = num_inference_steps; - scheduler->set_timesteps( - static_cast(steps), device, std::nullopt, mu, std::nullopt); - scheduler_timesteps = scheduler->timesteps(); - } - if (scheduler_timesteps.device() != device) { - scheduler_timesteps = scheduler_timesteps.to(device); - } - return {scheduler_timesteps, steps}; -} - -torch::Tensor randn_tensor(const std::vector& shape, - int64_t seed, - torch::TensorOptions& options) { - if (shape.empty()) { - LOG(FATAL) << "Shape must not be empty."; - } - at::Generator gen = at::detail::createCPUGenerator(); - gen = gen.clone(); - gen.set_current_seed(seed); - torch::Tensor latents; - latents = torch::randn(shape, gen, options.device(torch::kCPU)); - latents = latents.to(options); - return latents; -} - -inline torch::Tensor get_1d_rotary_pos_embed( - int64_t dim, - const torch::Tensor& pos, - float theta = 10000.0, - bool use_real = false, - float linear_factor = 1.0, - float ntk_factor = 1.0, - bool repeat_interleave_real = true, - torch::Dtype freqs_dtype = torch::kFloat32) { - TORCH_CHECK(dim % 2 == 0, "Dimension must be even"); - - torch::Tensor pos_tensor = pos; - if (pos.dim() == 0) { - pos_tensor = torch::arange(pos.item(), pos.options()); - } - - theta = theta * ntk_factor; - - auto freqs = - 1.0 / - (torch::pow( - theta, - torch::arange( - 0, dim, 2, torch::dtype(freqs_dtype).device(pos.device())) / - dim) * - linear_factor); // [D/2] - - auto tensors = {pos_tensor, freqs}; +#include "pipeline_flux_base.h" +// pipeline_flux compatible with huggingface weights +// ref to: +// https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux/pipeline_flux.py - auto freqs_outer = torch::einsum("s,d->sd", tensors); // [S, D/2] -#if defined(USE_NPU) - freqs_outer = freqs_outer.to(torch::kFloat32); -#endif - if (use_real && repeat_interleave_real) { - auto cos_vals = torch::cos(freqs_outer); // [S, D/2] - auto sin_vals = torch::sin(freqs_outer); // [S, D/2] - - auto freqs_cos = cos_vals.transpose(-1, -2) - .repeat_interleave(2, -2) - .transpose(-1, -2) - .to(torch::kFloat32); // [S, D] - - auto freqs_sin = sin_vals.transpose(-1, -2) - .repeat_interleave(2, -2) - .transpose(-1, -2) - .to(torch::kFloat32); // [S, D] - return torch::cat({freqs_cos.unsqueeze(0), freqs_sin.unsqueeze(0)}, - 0); // [2, S, D] - } -} - -class FluxPosEmbedImpl : public torch::nn::Module { - public: - FluxPosEmbedImpl(int64_t theta, std::vector axes_dim) { - theta_ = theta; - axes_dim_ = axes_dim; - } - - std::pair forward_cache( - const torch::Tensor& txt_ids, - const torch::Tensor& img_ids, - int64_t height = -1, - int64_t width = -1) { - auto seq_len = txt_ids.size(0); - - // recompute the cache if height or width changes - if (height != cached_image_height_ || width != cached_image_width_ || - seq_len != max_seq_len_) { - torch::Tensor ids = torch::cat({txt_ids, img_ids}, 0); - cached_image_height_ = height; - cached_image_width_ = width; - max_seq_len_ = seq_len; - auto [cos, sin] = forward(ids); - freqs_cos_cache_ = std::move(cos); - freqs_sin_cache_ = std::move(sin); - } - return {freqs_cos_cache_, freqs_sin_cache_}; - } - - std::pair forward(const torch::Tensor& ids) { - int64_t n_axes = ids.size(-1); - std::vector cos_out, sin_out; - auto pos = ids.to(torch::kFloat32); - torch::Dtype freqs_dtype = torch::kFloat64; - for (int64_t i = 0; i < n_axes; ++i) { - auto pos_slice = pos.select(-1, i); - auto result = get_1d_rotary_pos_embed(axes_dim_[i], - pos_slice, - theta_, - true, // repeat_interleave_real - 1, - 1, - true, // use_real - freqs_dtype); - auto cos = result[0]; - auto sin = result[1]; - cos_out.push_back(cos); - sin_out.push_back(sin); - } - - auto freqs_cos = torch::cat(cos_out, -1); - auto freqs_sin = torch::cat(sin_out, -1); - return {freqs_cos, freqs_sin}; - } - - private: - int64_t theta_; - std::vector axes_dim_; - torch::Tensor freqs_cos_cache_; - torch::Tensor freqs_sin_cache_; - int64_t max_seq_len_ = -1; - int64_t cached_image_height_ = -1; - int64_t cached_image_width_ = -1; -}; -TORCH_MODULE(FluxPosEmbed); +namespace xllm { -class FluxPipelineImpl : public torch::nn::Module { +class FluxPipelineImpl : public FluxPipelineBaseImpl { public: - explicit FluxPipelineImpl(const DiTModelContext& context) - : options_(context.get_tensor_options()) { + FluxPipelineImpl(const DiTModelContext& context) { const auto& model_args = context.get_model_args("vae"); + options_ = context.get_tensor_options(); vae_scale_factor_ = 1 << (model_args.block_out_channels().size() - 1); - execution_device_ = options_.device(); - execution_dtype_ = options_.dtype().toScalarType(); + device_ = options_.device(); + dtype_ = options_.dtype().toScalarType(); vae_shift_factor_ = model_args.shift_factor(); vae_scaling_factor_ = model_args.scale_factor(); default_sample_size_ = 128; tokenizer_max_length_ = 77; // TODO: get from config file LOG(INFO) << "Initializing Flux pipeline..."; - vae_image_processor_ = VAEImageProcessor(context.get_model_context("vae")); + vae_image_processor_ = VAEImageProcessor( + context.get_model_context("vae"), true, true, false, false, false); vae_ = VAE(context.get_model_context("vae")); LOG(INFO) << "VAE initialized."; pos_embed_ = register_module( @@ -323,7 +131,6 @@ class FluxPipelineImpl : public torch::nn::Module { void load_model(std::unique_ptr loader) { LOG(INFO) << "FluxPipeline loading model from" << loader->model_root_path(); - // transformer_.to(options_); std::string model_path = loader->model_root_path(); auto transformer_loader = loader->take_component_loader("transformer"); auto vae_loader = loader->take_component_loader("vae"); @@ -334,13 +141,13 @@ class FluxPipelineImpl : public torch::nn::Module { LOG(INFO) << "Flux model components loaded, start to load weights to sub models"; transformer_->load_model(std::move(transformer_loader)); - transformer_->to(execution_device_); + transformer_->to(device_); vae_->load_model(std::move(vae_loader)); - vae_->to(execution_device_); + vae_->to(device_); t5_->load_model(std::move(t5_loader)); - t5_->to(execution_device_); + t5_->to(device_); clip_text_model_->load_model(std::move(clip_loader)); - clip_text_model_->to(execution_device_); + clip_text_model_->to(device_); tokenizer_ = tokenizer_loader->tokenizer(); tokenizer_2_ = tokenizer_2_loader->tokenizer(); } @@ -358,174 +165,21 @@ class FluxPipelineImpl : public torch::nn::Module { std::vector shape = { batch_size, num_channels_latents, adjusted_height, adjusted_width}; if (latents.has_value()) { - torch::Tensor latent_image_ids = _prepare_latent_image_ids( + torch::Tensor latent_image_ids = prepare_latent_image_ids( batch_size, adjusted_height / 2, adjusted_width / 2); return {latents.value(), latent_image_ids}; } torch::Tensor latents_tensor = randn_tensor(shape, seed, options_); - torch::Tensor packed_latents = _pack_latents(latents_tensor, - batch_size, - num_channels_latents, - adjusted_height, - adjusted_width); - torch::Tensor latent_image_ids = _prepare_latent_image_ids( + torch::Tensor packed_latents = pack_latents(latents_tensor, + batch_size, + num_channels_latents, + adjusted_height, + adjusted_width); + torch::Tensor latent_image_ids = prepare_latent_image_ids( batch_size, adjusted_height / 2, adjusted_width / 2); return {packed_latents, latent_image_ids}; } - torch::Tensor _prepare_latent_image_ids(int64_t batch_size, - int64_t height, - int64_t width) { - torch::Tensor latent_image_ids = torch::zeros({height, width, 3}, options_); - torch::Tensor height_range = torch::arange(height, options_).unsqueeze(1); - latent_image_ids.select(2, 1) += height_range; - torch::Tensor width_range = torch::arange(width, options_).unsqueeze(0); - latent_image_ids.select(2, 2) += width_range; - latent_image_ids = latent_image_ids.view({height * width, 3}); - return latent_image_ids; - } - - torch::Tensor _pack_latents(const torch::Tensor& latents, - int64_t batch_size, - int64_t num_channels_latents, - int64_t height, - int64_t width) { - torch::Tensor packed = latents.view( - {batch_size, num_channels_latents, height / 2, 2, width / 2, 2}); - packed = packed.permute({0, 2, 4, 1, 3, 5}); - packed = packed.reshape( - {batch_size, (height / 2) * (width / 2), num_channels_latents * 4}); - - return packed; - } - - torch::Tensor _unpack_latents(const torch::Tensor& latents, - int64_t height, - int64_t width, - int64_t vae_scale_factor) { - int64_t batch_size = latents.size(0); - int64_t num_patches = latents.size(1); - int64_t channels = latents.size(2); - int64_t adjusted_height = 2 * (height / (vae_scale_factor * 2)); - int64_t adjusted_width = 2 * (width / (vae_scale_factor * 2)); - torch::Tensor unpacked = latents.view({batch_size, - adjusted_height / 2, - adjusted_width / 2, - channels / 4, - 2, - 2}); - unpacked = unpacked.permute({0, 3, 1, 4, 2, 5}); - unpacked = unpacked.reshape( - {batch_size, channels / (2 * 2), adjusted_height, adjusted_width}); - - return unpacked; - } - - torch::Tensor _get_clip_prompt_embeds(std::vector& prompt, - int64_t num_images_per_prompt = 1) { - std::vector processed_prompt = prompt; - int64_t batch_size = processed_prompt.size(); - TORCH_CHECK(batch_size > 0, "Prompt list cannot be empty"); - - std::vector> text_input_ids; - text_input_ids.reserve(batch_size); - CHECK(tokenizer_->batch_encode(processed_prompt, &text_input_ids)); - for (auto& ids : text_input_ids) { - LOG(INFO) << "CLIP Original IDs size: " << ids; - ids.resize(tokenizer_max_length_, 49407); - ids.back() = 49407; - } - std::vector text_input_ids_flat; - text_input_ids_flat.reserve(batch_size * tokenizer_max_length_); - for (const auto& ids : text_input_ids) { - text_input_ids_flat.insert( - text_input_ids_flat.end(), ids.begin(), ids.end()); - } - auto input_ids = - torch::tensor(text_input_ids_flat, torch::dtype(torch::kLong)) - .view({batch_size, tokenizer_max_length_}) - .to(execution_device_); - auto encoder_output = clip_text_model_->forward(input_ids); - torch::Tensor prompt_embeds = encoder_output; - prompt_embeds = prompt_embeds.to(execution_device_).to(execution_dtype_); - prompt_embeds = prompt_embeds.repeat({1, num_images_per_prompt}); - prompt_embeds = - prompt_embeds.view({batch_size * num_images_per_prompt, -1}); - return prompt_embeds; - } - - torch::Tensor _get_t5_prompt_embeds(std::vector& prompt, - int64_t num_images_per_prompt = 1, - int64_t max_sequence_length = 512) { - std::vector processed_prompt = prompt; - int64_t batch_size = processed_prompt.size(); - TORCH_CHECK(batch_size > 0, "Prompt list cannot be empty"); - - std::vector> text_input_ids; - text_input_ids.reserve(batch_size); - CHECK(tokenizer_2_->batch_encode(processed_prompt, &text_input_ids)); - for (auto& ids : text_input_ids) { - LOG(INFO) << "T5 Original IDs size: " << ids; - ids.resize(max_sequence_length, 0); - } - - std::vector text_input_ids_flat; - text_input_ids_flat.reserve(batch_size * max_sequence_length); - for (const auto& ids : text_input_ids) { - text_input_ids_flat.insert( - text_input_ids_flat.end(), ids.begin(), ids.end()); - } - auto input_ids = - torch::tensor(text_input_ids_flat, torch::dtype(torch::kLong)) - .view({batch_size, max_sequence_length}) - .to(execution_device_); - torch::Tensor prompt_embeds = t5_->forward(input_ids); - prompt_embeds = prompt_embeds.to(execution_dtype_).to(execution_device_); - int64_t seq_len = prompt_embeds.size(1); - prompt_embeds = prompt_embeds.repeat({1, num_images_per_prompt, 1}); - prompt_embeds = - prompt_embeds.view({batch_size * num_images_per_prompt, seq_len, -1}); - return prompt_embeds; - } - - std::tuple encode_prompt( - std::optional> prompt, - std::optional> prompt_2, - std::optional prompt_embeds, - std::optional pooled_prompt_embeds, - int64_t num_images_per_prompt = 1, - int64_t max_sequence_length = 512) { - std::vector prompt_list; - if (prompt.has_value()) { - prompt_list = prompt.value(); - } - if (prompt_list.empty()) { - prompt_list = {""}; - } - if (!prompt_embeds.has_value()) { - std::vector prompt_2_list; - if (prompt_2.has_value()) { - prompt_2_list = prompt_2.value(); - } - if (prompt_2_list.empty()) { - prompt_2_list = prompt_list; - } - pooled_prompt_embeds = - _get_clip_prompt_embeds(prompt_list, num_images_per_prompt); - prompt_embeds = _get_t5_prompt_embeds( - prompt_2_list, num_images_per_prompt, max_sequence_length); - } - torch::Tensor text_ids = - torch::zeros({prompt_embeds.value().size(1), 3}, - torch::device(execution_device_).dtype(execution_dtype_)); - - return std::make_tuple(prompt_embeds.value(), - pooled_prompt_embeds.has_value() - ? pooled_prompt_embeds.value() - : torch::Tensor(), - text_ids); - } - std::vector forward_( std::optional> prompt = std::nullopt, std::optional> prompt_2 = std::nullopt, @@ -609,7 +263,7 @@ class FluxPipelineImpl : public torch::nn::Module { scheduler_->base_shift(), scheduler_->max_shift()); auto [timesteps, num_inference_steps_actual] = retrieve_timesteps( - scheduler_, num_inference_steps, execution_device_, new_sigmas, mu); + scheduler_, num_inference_steps, device_, new_sigmas, mu); int64_t num_warmup_steps = std::max(static_cast(timesteps.numel()) - num_inference_steps_actual * scheduler_->order(), @@ -618,7 +272,7 @@ class FluxPipelineImpl : public torch::nn::Module { torch::Tensor guidance; if (transformer_->guidance_embeds()) { torch::TensorOptions options = - torch::dtype(torch::kFloat32).device(execution_device_); + torch::dtype(torch::kFloat32).device(device_); guidance = torch::full(at::IntArrayRef({1}), guidance_scale, options); guidance = guidance.expand({prepared_latents.size(0)}); @@ -672,11 +326,11 @@ class FluxPipelineImpl : public torch::nn::Module { } torch::Tensor image; // Unpack latents - torch::Tensor unpacked_latents = _unpack_latents( + torch::Tensor unpacked_latents = unpack_latents( prepared_latents, actual_height, actual_width, vae_scale_factor_); unpacked_latents = (unpacked_latents / vae_scaling_factor_) + vae_shift_factor_; - unpacked_latents = unpacked_latents.to(execution_dtype_); + unpacked_latents = unpacked_latents.to(dtype_); image = vae_->decode(unpacked_latents); image = vae_image_processor_->postprocess(image, "pil"); return std::vector{{image}}; @@ -687,19 +341,10 @@ class FluxPipelineImpl : public torch::nn::Module { VAE vae_{nullptr}; VAEImageProcessor vae_image_processor_{nullptr}; FluxDiTModel transformer_{nullptr}; - T5EncoderModel t5_{nullptr}; - CLIPTextModel clip_text_model_{nullptr}; - int vae_scale_factor_; float vae_scaling_factor_; float vae_shift_factor_; - int tokenizer_max_length_; int default_sample_size_; - torch::Device execution_device_ = torch::kCPU; - torch::ScalarType execution_dtype_ = torch::kFloat32; - torch::TensorOptions options_; FluxPosEmbed pos_embed_{nullptr}; - std::unique_ptr tokenizer_; - std::unique_ptr tokenizer_2_; }; TORCH_MODULE(FluxPipeline); diff --git a/xllm/models/dit/pipeline_flux_base.h b/xllm/models/dit/pipeline_flux_base.h new file mode 100644 index 00000000..b9f8442f --- /dev/null +++ b/xllm/models/dit/pipeline_flux_base.h @@ -0,0 +1,351 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once +#include +#include + +#include +#include +#include + +#include "autoencoder_kl.h" +#include "clip_text_model.h" +#include "core/framework/dit_model_loader.h" +#include "core/framework/model_context.h" +#include "core/framework/request/dit_request_state.h" +#include "core/framework/state_dict/state_dict.h" +#include "core/framework/state_dict/utils.h" +#include "core/layers/pos_embedding.h" +#include "core/layers/rotary_embedding.h" +#include "flowmatch_euler_discrete_scheduler.h" +#include "models/model_registry.h" +#include "t5_encoder.h" + +namespace xllm { + +float calculate_shift(int64_t image_seq_len, + int64_t base_seq_len = 256, + int64_t max_seq_len = 4096, + float base_shift = 0.5f, + float max_shift = 1.15f) { + float m = + (max_shift - base_shift) / static_cast(max_seq_len - base_seq_len); + float b = base_shift - m * static_cast(base_seq_len); + float mu = static_cast(image_seq_len) * m + b; + return mu; +} + +std::pair retrieve_timesteps( + FlowMatchEulerDiscreteScheduler scheduler, + int64_t num_inference_steps = 0, + torch::Device device = torch::kCPU, + std::optional> sigmas = std::nullopt, + std::optional mu = std::nullopt) { + torch::Tensor scheduler_timesteps; + int64_t steps; + if (sigmas.has_value()) { + steps = sigmas->size(); + scheduler->set_timesteps( + static_cast(steps), device, *sigmas, mu, std::nullopt); + + scheduler_timesteps = scheduler->timesteps(); + } else { + steps = num_inference_steps; + scheduler->set_timesteps( + static_cast(steps), device, std::nullopt, mu, std::nullopt); + scheduler_timesteps = scheduler->timesteps(); + } + if (scheduler_timesteps.device() != device) { + scheduler_timesteps = scheduler_timesteps.to(device); + } + return {scheduler_timesteps, steps}; +} + +torch::Tensor get_1d_rotary_pos_embed( + int64_t dim, + const torch::Tensor& pos, + float theta = 10000.0, + bool use_real = false, + float linear_factor = 1.0, + float ntk_factor = 1.0, + bool repeat_interleave_real = true, + torch::Dtype freqs_dtype = torch::kFloat32) { + TORCH_CHECK(dim % 2 == 0, "Dimension must be even"); + + torch::Tensor pos_tensor = pos; + if (pos.dim() == 0) { + pos_tensor = torch::arange(pos.item(), pos.options()); + } + + theta = theta * ntk_factor; + + auto freqs = + 1.0 / + (torch::pow( + theta, + torch::arange( + 0, dim, 2, torch::dtype(freqs_dtype).device(pos.device())) / + dim) * + linear_factor); // [D/2] + + auto tensors = {pos_tensor, freqs}; + + auto freqs_outer = torch::einsum("s,d->sd", tensors); // [S, D/2] +#if defined(USE_NPU) + freqs_outer = freqs_outer.to(torch::kFloat32); +#endif + if (use_real && repeat_interleave_real) { + auto cos_vals = torch::cos(freqs_outer); // [S, D/2] + auto sin_vals = torch::sin(freqs_outer); // [S, D/2] + + auto freqs_cos = cos_vals.transpose(-1, -2) + .repeat_interleave(2, -2) + .transpose(-1, -2) + .to(torch::kFloat32); // [S, D] + + auto freqs_sin = sin_vals.transpose(-1, -2) + .repeat_interleave(2, -2) + .transpose(-1, -2) + .to(torch::kFloat32); // [S, D] + return torch::cat({freqs_cos.unsqueeze(0), freqs_sin.unsqueeze(0)}, + 0); // [2, S, D] + } +} + +class FluxPosEmbedImpl : public torch::nn::Module { + public: + FluxPosEmbedImpl(int64_t theta, std::vector axes_dim) { + theta_ = theta; + axes_dim_ = axes_dim; + } + + std::pair forward_cache( + const torch::Tensor& txt_ids, + const torch::Tensor& img_ids, + int64_t height = -1, + int64_t width = -1) { + auto seq_len = txt_ids.size(0); + + // recompute the cache if height or width changes + if (height != cached_image_height_ || width != cached_image_width_ || + seq_len != max_seq_len_) { + torch::Tensor ids = torch::cat({txt_ids, img_ids}, 0); + cached_image_height_ = height; + cached_image_width_ = width; + max_seq_len_ = seq_len; + auto [cos, sin] = forward(ids); + freqs_cos_cache_ = std::move(cos); + freqs_sin_cache_ = std::move(sin); + } + return {freqs_cos_cache_, freqs_sin_cache_}; + } + + std::pair forward(const torch::Tensor& ids) { + int64_t n_axes = ids.size(-1); + std::vector cos_out, sin_out; + auto pos = ids.to(torch::kFloat32); + torch::Dtype freqs_dtype = torch::kFloat64; + for (int64_t i = 0; i < n_axes; ++i) { + auto pos_slice = pos.select(-1, i); + auto result = get_1d_rotary_pos_embed(axes_dim_[i], + pos_slice, + theta_, + true, // repeat_interleave_real + 1, + 1, + true, // use_real + freqs_dtype); + auto cos = result[0]; + auto sin = result[1]; + cos_out.push_back(cos); + sin_out.push_back(sin); + } + + auto freqs_cos = torch::cat(cos_out, -1); + auto freqs_sin = torch::cat(sin_out, -1); + return {freqs_cos, freqs_sin}; + } + + private: + int64_t theta_; + std::vector axes_dim_; + torch::Tensor freqs_cos_cache_; + torch::Tensor freqs_sin_cache_; + int64_t max_seq_len_ = -1; + int64_t cached_image_height_ = -1; + int64_t cached_image_width_ = -1; +}; +TORCH_MODULE(FluxPosEmbed); + +class FluxPipelineBaseImpl : public torch::nn::Module { + protected: + torch::Tensor get_t5_prompt_embeds(std::vector& prompt, + int64_t num_images_per_prompt = 1, + int64_t max_sequence_length = 512) { + int64_t batch_size = prompt.size(); + std::vector> text_input_ids; + text_input_ids.reserve(batch_size); + CHECK(tokenizer_2_->batch_encode(prompt, &text_input_ids)); + for (auto& ids : text_input_ids) { + LOG(INFO) << "T5 Original IDs size: " << ids; + ids.resize(max_sequence_length, 0); + } + + std::vector text_input_ids_flat; + text_input_ids_flat.reserve(batch_size * max_sequence_length); + for (const auto& ids : text_input_ids) { + text_input_ids_flat.insert( + text_input_ids_flat.end(), ids.begin(), ids.end()); + } + auto input_ids = + torch::tensor(text_input_ids_flat, torch::dtype(torch::kLong)) + .view({batch_size, max_sequence_length}) + .to(device_); + torch::Tensor prompt_embeds = t5_->forward(input_ids); + prompt_embeds = prompt_embeds.to(device_).to(dtype_); + int64_t seq_len = prompt_embeds.size(1); + prompt_embeds = prompt_embeds.repeat({1, num_images_per_prompt, 1}); + prompt_embeds = + prompt_embeds.view({batch_size * num_images_per_prompt, seq_len, -1}); + return prompt_embeds; + } + + torch::Tensor get_clip_prompt_embeds(std::vector& prompt, + int64_t num_images_per_prompt = 1) { + int64_t batch_size = prompt.size(); + std::vector> text_input_ids; + text_input_ids.reserve(batch_size); + CHECK(tokenizer_->batch_encode(prompt, &text_input_ids)); + for (auto& ids : text_input_ids) { + LOG(INFO) << "CLIP Original IDs size: " << ids; + ids.resize(tokenizer_max_length_, 49407); + ids.back() = 49407; + } + + std::vector text_input_ids_flat; + text_input_ids_flat.reserve(batch_size * tokenizer_max_length_); + for (const auto& ids : text_input_ids) { + text_input_ids_flat.insert( + text_input_ids_flat.end(), ids.begin(), ids.end()); + } + auto input_ids = + torch::tensor(text_input_ids_flat, torch::dtype(torch::kLong)) + .view({batch_size, tokenizer_max_length_}) + .to(device_); + auto encoder_output = clip_text_model_->forward(input_ids); + torch::Tensor prompt_embeds = encoder_output; + prompt_embeds = prompt_embeds.to(device_).to(dtype_); + prompt_embeds = prompt_embeds.repeat({1, num_images_per_prompt}); + prompt_embeds = + prompt_embeds.view({batch_size * num_images_per_prompt, -1}); + return prompt_embeds; + } + + std::tuple encode_prompt( + std::optional> prompt, + std::optional> prompt_2, + std::optional prompt_embeds, + std::optional pooled_prompt_embeds, + int64_t num_images_per_prompt = 1, + int64_t max_sequence_length = 512) { + std::vector prompt_list; + if (prompt.has_value()) { + prompt_list = prompt.value(); + } + if (prompt_list.empty()) { + prompt_list = {""}; + } + if (!prompt_embeds.has_value()) { + std::vector prompt_2_list; + if (prompt_2.has_value()) { + prompt_2_list = prompt_2.value(); + } + if (prompt_2_list.empty()) { + prompt_2_list = prompt_list; + } + pooled_prompt_embeds = + get_clip_prompt_embeds(prompt_list, num_images_per_prompt); + prompt_embeds = get_t5_prompt_embeds( + prompt_2_list, num_images_per_prompt, max_sequence_length); + } + torch::Tensor text_ids = torch::zeros({prompt_embeds.value().size(1), 3}, + torch::device(device_).dtype(dtype_)); + + return std::make_tuple(prompt_embeds.value(), + pooled_prompt_embeds.has_value() + ? pooled_prompt_embeds.value() + : torch::Tensor(), + text_ids); + } + + torch::Tensor prepare_latent_image_ids(int64_t batch_size, + int64_t height, + int64_t width) { + torch::Tensor latent_image_ids = torch::zeros({height, width, 3}, options_); + torch::Tensor height_range = torch::arange(height, options_).unsqueeze(1); + latent_image_ids.select(2, 1) += height_range; + torch::Tensor width_range = torch::arange(width, options_).unsqueeze(0); + latent_image_ids.select(2, 2) += width_range; + latent_image_ids = latent_image_ids.view({height * width, 3}); + return latent_image_ids; + } + + torch::Tensor pack_latents(const torch::Tensor& latents, + int64_t batch_size, + int64_t num_channels_latents, + int64_t height, + int64_t width) { + torch::Tensor latents_packed = latents.view( + {batch_size, num_channels_latents, height / 2, 2, width / 2, 2}); + latents_packed = latents_packed.permute({0, 2, 4, 1, 3, 5}); + latents_packed = latents_packed.reshape( + {batch_size, (height / 2) * (width / 2), num_channels_latents * 4}); + + return latents_packed; + } + + torch::Tensor unpack_latents(const torch::Tensor& latents, + int64_t height, + int64_t width, + int64_t vae_scale_factor) { + int64_t batch_size = latents.size(0); + int64_t num_patches = latents.size(1); + int64_t channels = latents.size(2); + height = 2 * (height / (vae_scale_factor_ * 2)); + width = 2 * (width / (vae_scale_factor_ * 2)); + + torch::Tensor latents_unpacked = + latents.view({batch_size, height / 2, width / 2, channels / 4, 2, 2}); + latents_unpacked = latents_unpacked.permute({0, 3, 1, 4, 2, 5}); + latents_unpacked = latents_unpacked.reshape( + {batch_size, channels / (2 * 2), height, width}); + + return latents_unpacked; + } + + protected: + T5EncoderModel t5_{nullptr}; + CLIPTextModel clip_text_model_{nullptr}; + torch::Device device_ = torch::kCPU; + torch::ScalarType dtype_; + std::unique_ptr tokenizer_; + std::unique_ptr tokenizer_2_; + torch::TensorOptions options_; + int tokenizer_max_length_; + int vae_scale_factor_; +}; + +} // namespace xllm diff --git a/xllm/models/dit/pipeline_flux_fill.h b/xllm/models/dit/pipeline_flux_fill.h new file mode 100644 index 00000000..73e1579f --- /dev/null +++ b/xllm/models/dit/pipeline_flux_fill.h @@ -0,0 +1,436 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once +#include "core/layers/pos_embedding.h" +#include "core/layers/rotary_embedding.h" +#include "dit.h" +#include "pipeline_flux_base.h" +// pipeline_flux_fill compatible with huggingface weights +// ref to: +// https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux/pipeline_flux_fill.py + +namespace xllm { + +class FluxFillPipelineImpl : public FluxPipelineBaseImpl { + public: + FluxFillPipelineImpl(const DiTModelContext& context) { + auto model_args = context.get_model_args("vae"); + options_ = context.get_tensor_options(); + device_ = options_.device(); + dtype_ = options_.dtype().toScalarType(); + vae_scale_factor_ = 1 << (model_args.block_out_channels().size() - 1); + vae_shift_factor_ = model_args.shift_factor(); + vae_scaling_factor_ = model_args.scale_factor(); + latent_channels_ = model_args.latent_channels(); + + default_sample_size_ = 128; + tokenizer_max_length_ = 77; // TODO: get from config file + LOG(INFO) << "Initializing FluxFill pipeline..."; + image_processor_ = VAEImageProcessor( + context.get_model_context("vae"), true, true, false, false, false); + mask_processor_ = VAEImageProcessor( + context.get_model_context("vae"), true, false, true, false, true); + vae_ = VAE(context.get_model_context("vae")); + LOG(INFO) << "VAE initialized."; + pos_embed_ = register_module( + "pos_embed", + FluxPosEmbed(10000, + context.get_model_args("transformer").axes_dims_rope())); + transformer_ = FluxDiTModel(context.get_model_context("transformer")); + LOG(INFO) << "DiT transformer initialized."; + t5_ = T5EncoderModel(context.get_model_context("text_encoder_2")); + LOG(INFO) << "T5 initialized."; + clip_text_model_ = CLIPTextModel(context.get_model_context("text_encoder")); + LOG(INFO) << "CLIP text model initialized."; + scheduler_ = + FlowMatchEulerDiscreteScheduler(context.get_model_context("scheduler")); + LOG(INFO) << "FluxFill pipeline initialized."; + register_module("vae", vae_); + LOG(INFO) << "VAE registered."; + register_module("vae_image_processor", image_processor_); + LOG(INFO) << "VAE image processor registered."; + register_module("mask_processor", mask_processor_); + LOG(INFO) << "mask processor registered."; + register_module("transformer", transformer_); + LOG(INFO) << "DiT transformer registered."; + register_module("t5", t5_); + LOG(INFO) << "T5 registered."; + register_module("scheduler", scheduler_); + LOG(INFO) << "Scheduler registered."; + register_module("clip_text_model", clip_text_model_); + LOG(INFO) << "CLIP text model registered."; + } + + DiTForwardOutput forward(const DiTForwardInput& input) { + const auto& generation_params = input.generation_params; + int64_t height = generation_params.height; + int64_t width = generation_params.width; + auto seed = generation_params.seed > 0 ? generation_params.seed : 42; + auto prompts = std::make_optional(input.prompts); + auto prompts_2 = input.prompts_2.empty() + ? std::nullopt + : std::make_optional(input.prompts_2); + + auto image = input.images.defined() ? std::make_optional(input.images) + : std::nullopt; + auto mask_image = input.mask_images.defined() + ? std::make_optional(input.mask_images) + : std::nullopt; + auto masked_image_latents = + input.masked_image_latents.defined() + ? std::make_optional(input.masked_image_latents) + : std::nullopt; + + auto latents = input.latents.defined() ? std::make_optional(input.latents) + : std::nullopt; + auto prompt_embeds = input.prompt_embeds.defined() + ? std::make_optional(input.prompt_embeds) + : std::nullopt; + auto pooled_prompt_embeds = + input.pooled_prompt_embeds.defined() + ? std::make_optional(input.pooled_prompt_embeds) + : std::nullopt; + + std::vector output = + forward_(prompts, + prompts_2, + image, + mask_image, + masked_image_latents, + height, + width, + generation_params.strength, + generation_params.num_inference_steps, + generation_params.guidance_scale, + generation_params.num_images_per_prompt, + seed, + latents, + prompt_embeds, + pooled_prompt_embeds, + generation_params.max_sequence_length); + + DiTForwardOutput out; + out.tensors = torch::chunk(output[0], output[0].size(0), 0); + LOG(INFO) << "Output tensor chunks size: " << out.tensors.size(); + return out; + } + + void load_model(std::unique_ptr loader) { + LOG(INFO) << "FluxFillPipeline loading model from" + << loader->model_root_path(); + std::string model_path = loader->model_root_path(); + auto transformer_loader = loader->take_component_loader("transformer"); + auto vae_loader = loader->take_component_loader("vae"); + auto t5_loader = loader->take_component_loader("text_encoder_2"); + auto clip_loader = loader->take_component_loader("text_encoder"); + auto tokenizer_loader = loader->take_component_loader("tokenizer"); + auto tokenizer_2_loader = loader->take_component_loader("tokenizer_2"); + LOG(INFO) << "FluxFill model components loaded, start to load weights to " + "sub models"; + transformer_->load_model(std::move(transformer_loader)); + transformer_->to(device_); + vae_->load_model(std::move(vae_loader)); + vae_->to(device_); + t5_->load_model(std::move(t5_loader)); + t5_->to(device_); + clip_text_model_->load_model(std::move(clip_loader)); + clip_text_model_->to(device_); + tokenizer_ = tokenizer_loader->tokenizer(); + tokenizer_2_ = tokenizer_2_loader->tokenizer(); + } + + private: + std::pair prepare_mask_latents( + torch::Tensor mask, + torch::Tensor masked_image, + int64_t batch_size, + int64_t num_channels_latents, + int64_t num_images_per_prompt, + int64_t height, + int64_t width, + int64_t seed) { + height = 2 * (height / (vae_scale_factor_ * 2)); + width = 2 * (width / (vae_scale_factor_ * 2)); + + torch::Tensor masked_image_latents; + if (masked_image.size(1) == num_channels_latents) { + masked_image_latents = masked_image; + } else { + masked_image_latents = vae_->encode(masked_image, seed); + } + + masked_image_latents = + (masked_image_latents - vae_shift_factor_) * vae_scaling_factor_; + masked_image_latents = masked_image_latents.to(device_).to(dtype_); + + batch_size = batch_size * num_images_per_prompt; + if (mask.size(0) < batch_size) { + CHECK(batch_size % mask.size(0) == 0) + << "Masks batch size mismatch: mask cannot be duplicated to match " + "total batch."; + mask = mask.repeat({batch_size / mask.size(0), 1, 1, 1}); + } + + if (masked_image_latents.size(0) < batch_size) { + CHECK(batch_size % masked_image_latents.size(0) == 0) + << "Masked image batch size mismatch: cannot duplicate to match " + "total batch."; + masked_image_latents = masked_image_latents.repeat( + {batch_size / masked_image_latents.size(0), 1, 1, 1}); + } + + masked_image_latents = pack_latents( + masked_image_latents, batch_size, num_channels_latents, height, width); + + mask = mask.select(1, 0); + mask = mask.view( + {batch_size, height, vae_scale_factor_, width, vae_scale_factor_}); + mask = mask.permute({0, 2, 4, 1, 3}); + mask = mask.reshape( + {batch_size, vae_scale_factor_ * vae_scale_factor_, height, width}); + mask = pack_latents( + mask, batch_size, vae_scale_factor_ * vae_scale_factor_, height, width); + mask = mask.to(device_).to(dtype_); + + return {mask, masked_image_latents}; + } + + torch::Tensor encode_vae_image(const torch::Tensor& image, int64_t seed) { + torch::Tensor latents = vae_->encode(image, seed); + latents = (latents - vae_shift_factor_) * vae_scaling_factor_; + return latents; + } + + std::pair get_timesteps(int64_t num_inference_steps, + float strength) { + int64_t init_timestep = + std::min(static_cast(num_inference_steps * strength), + num_inference_steps); + + int64_t t_start = std::max(num_inference_steps - init_timestep, int64_t(0)); + int64_t start_idx = t_start * scheduler_->order(); + auto timesteps = + scheduler_->timesteps().slice(0, start_idx).to(device_).to(dtype_); + scheduler_->set_begin_index(start_idx); + return {timesteps, num_inference_steps - t_start}; + } + + std::pair prepare_latents( + torch::Tensor image, + torch::Tensor timesteps, + int64_t batch_size, + int64_t num_channels_latents, + int64_t height, + int64_t width, + int64_t seed, + std::optional latents = std::nullopt) { + height = 2 * (height / (vae_scale_factor_ * 2)); + width = 2 * (width / (vae_scale_factor_ * 2)); + + std::vector shape = { + batch_size, num_channels_latents, height, width}; + torch::Tensor latent_image_ids = + prepare_latent_image_ids(batch_size, height / 2, width / 2); + if (latents.has_value()) { + return {latents.value().to(device_).to(dtype_), latent_image_ids}; + } + + torch::Tensor image_latents; + if (image.size(1) != latent_channels_) { + image_latents = encode_vae_image(image, seed); + } else { + image_latents = image; + } + int64_t additional_image_per_prompt; + if (batch_size > image_latents.size(0) && + batch_size % image_latents.size(0) == 0) { + additional_image_per_prompt = batch_size / image_latents.size(0); + image_latents = + image_latents.repeat({additional_image_per_prompt, 1, 1, 1}); + } else if (batch_size > image_latents.size(0) && + batch_size % image_latents.size(0) != 0) { + LOG(FATAL) << "Cannot match batch_size with input images."; + } else { + image_latents = torch::cat({image_latents}, 0); + } + auto noise = randn_tensor(shape, seed, options_); + latents = scheduler_->scale_noise(image_latents, timesteps, noise); + latents = pack_latents( + latents.value(), batch_size, num_channels_latents, height, width); + return {latents.value(), latent_image_ids}; + } + + std::vector forward_( + std::optional> prompt = std::nullopt, + std::optional> prompt_2 = std::nullopt, + std::optional image = std::nullopt, + std::optional mask_image = std::nullopt, + std::optional masked_image_latents = std::nullopt, + int64_t height = 512, + int64_t width = 512, + float strength = 1.0f, + int64_t num_inference_steps = 50, + float guidance_scale = 30.0f, + int64_t num_images_per_prompt = 1, + int64_t seed = 42, + std::optional latents = std::nullopt, + std::optional prompt_embeds = std::nullopt, + std::optional pooled_prompt_embeds = std::nullopt, + int64_t max_sequence_length = 512) { + torch::NoGradGuard no_grad; + torch::Tensor init_image = + image_processor_->preprocess(image.value(), height, width); + + int64_t batch_size; + if (prompt.has_value()) { + batch_size = prompt.value().size(); + } else { + batch_size = prompt_embeds.value().size(0); + } + + torch::Tensor text_ids; + std::tie(prompt_embeds, pooled_prompt_embeds, text_ids) = + encode_prompt(prompt, + prompt_2, + prompt_embeds, + pooled_prompt_embeds, + num_images_per_prompt, + max_sequence_length); + + std::vector sigmas = [&](int64_t steps) { + std::vector result(steps); + for (int64_t i = 0; i < steps; ++i) + result[i] = 1.0f - static_cast(i) / steps; + return result; + }(num_inference_steps); + + int64_t image_seq_len = + (height / vae_scale_factor_ / 2) * (width / vae_scale_factor_ / 2); + float mu = calculate_shift(image_seq_len, + scheduler_->base_image_seq_len(), + scheduler_->max_image_seq_len(), + scheduler_->base_shift(), + scheduler_->max_shift()); + + retrieve_timesteps(scheduler_, num_inference_steps, device_, sigmas, mu); + torch::Tensor timesteps; + std::tie(timesteps, num_inference_steps) = + get_timesteps(num_inference_steps, strength); + CHECK(num_inference_steps >= 1); + + torch::Tensor latent_timestep = + timesteps.index({torch::indexing::Slice(0, 1)}) + .repeat({batch_size * num_images_per_prompt}); + + int64_t num_channels_latents = latent_channels_; + torch::Tensor latent_image_ids; + std::tie(latents, latent_image_ids) = + prepare_latents(init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + seed, + latents); + + if (masked_image_latents.has_value()) { + masked_image_latents = + masked_image_latents.value().to(device_).to(dtype_); + } else { + mask_image = + mask_processor_->preprocess(mask_image.value(), height, width); + torch::Tensor masked_image = init_image * (1 - mask_image.value()); + + height = init_image.size(-2); + width = init_image.size(-1); + + torch::Tensor mask; + std::tie(mask, masked_image_latents) = + prepare_mask_latents(mask_image.value(), + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + seed); + masked_image_latents = + torch::cat({masked_image_latents.value(), mask}, -1); + } + + torch::Tensor guidance; + if (transformer_->guidance_embeds()) { + guidance = torch::full(at::IntArrayRef({1}), guidance_scale, options_); + guidance = guidance.expand({latents.value().size(0)}); + } + + auto [rot_emb1, rot_emb2] = + pos_embed_->forward_cache(text_ids, + latent_image_ids, + height / (vae_scale_factor_ * 2), + width / (vae_scale_factor_ * 2)); + + torch::Tensor image_rotary_emb = + torch::stack({rot_emb1, rot_emb2}, 0).to(device_); + + for (int64_t i = 0; i < timesteps.size(0); ++i) { + torch::Tensor t = timesteps[i]; + torch::Tensor timestep = t.expand({latents->size(0)}).to(device_); + + int64_t step_id = i + 1; + torch::Tensor input_latents = + torch::cat({latents.value(), masked_image_latents.value()}, 2); + + torch::Tensor noise_pred = + transformer_->forward(input_latents, + prompt_embeds.value(), + pooled_prompt_embeds.value(), + timestep / 1000, + image_rotary_emb, + guidance, + step_id); + auto prev_latents = scheduler_->step(noise_pred, t, latents.value()); + latents = prev_latents.detach().to(device_); + } + + torch::Tensor output_image; + torch::Tensor unpacked_latents = + unpack_latents(latents.value(), height, width, vae_scale_factor_); + unpacked_latents = + (unpacked_latents / vae_scaling_factor_) + vae_shift_factor_; + + output_image = vae_->decode(unpacked_latents); + output_image = image_processor_->postprocess(output_image, "pil"); + return std::vector{{output_image}}; + } + + private: + FlowMatchEulerDiscreteScheduler scheduler_{nullptr}; + VAE vae_{nullptr}; + VAEImageProcessor image_processor_{nullptr}; + VAEImageProcessor mask_processor_{nullptr}; + FluxDiTModel transformer_{nullptr}; + float vae_scaling_factor_; + float vae_shift_factor_; + int default_sample_size_; + int64_t latent_channels_; + FluxPosEmbed pos_embed_{nullptr}; +}; +TORCH_MODULE(FluxFillPipeline); + +REGISTER_DIT_MODEL(fluxfill, FluxFillPipeline); +} // namespace xllm diff --git a/xllm/models/models.h b/xllm/models/models.h index d000afd2..5c77ce86 100644 --- a/xllm/models/models.h +++ b/xllm/models/models.h @@ -16,20 +16,21 @@ limitations under the License. #pragma once #if defined(USE_NPU) -#include "dit/pipeline_flux.h" // IWYU pragma: keep -#include "llm/deepseek_v2.h" // IWYU pragma: keep -#include "llm/deepseek_v2_mtp.h" // IWYU pragma: keep -#include "llm/deepseek_v3.h" // IWYU pragma: keep -#include "llm/glm4_moe.h" // IWYU pragma: keep -#include "llm/glm4_moe_mtp.h" // IWYU pragma: keep -#include "llm/kimi_k2.h" // IWYU pragma: keep -#include "llm/llama.h" // IWYU pragma: keep -#include "llm/llama3.h" // IWYU pragma: keep -#include "llm/llm_model_base.h" // IWYU pragma: keep -#include "llm/qwen2.h" // IWYU pragma: keep -#include "llm/qwen3_embedding.h" // IWYU pragma: keep -#include "vlm/minicpmv.h" // IWYU pragma: keep -#include "vlm/qwen2_5_vl.h" // IWYU pragma: keep +#include "dit/pipeline_flux.h" // IWYU pragma: keep +#include "dit/pipeline_flux_fill.h" // IWYU pragma: keep +#include "llm/deepseek_v2.h" // IWYU pragma: keep +#include "llm/deepseek_v2_mtp.h" // IWYU pragma: keep +#include "llm/deepseek_v3.h" // IWYU pragma: keep +#include "llm/glm4_moe.h" // IWYU pragma: keep +#include "llm/glm4_moe_mtp.h" // IWYU pragma: keep +#include "llm/kimi_k2.h" // IWYU pragma: keep +#include "llm/llama.h" // IWYU pragma: keep +#include "llm/llama3.h" // IWYU pragma: keep +#include "llm/llm_model_base.h" // IWYU pragma: keep +#include "llm/qwen2.h" // IWYU pragma: keep +#include "llm/qwen3_embedding.h" // IWYU pragma: keep +#include "vlm/minicpmv.h" // IWYU pragma: keep +#include "vlm/qwen2_5_vl.h" // IWYU pragma: keep #endif #include "llm/llm_model_base.h" // IWYU pragma: keep diff --git a/xllm/proto/image_generation.proto b/xllm/proto/image_generation.proto index 062d691e..f16b9f17 100644 --- a/xllm/proto/image_generation.proto +++ b/xllm/proto/image_generation.proto @@ -34,6 +34,15 @@ message Input { // initial latent: [batch_size][channels][height/8][width/8] optional Tensor latent = 9; + + // Input type: "base64" + optional string image = 10; + + // Input type: "base64" + optional string mask_image = 11; + + // An image batch of mask images generated by the VAE + optional Tensor masked_image_latent = 12; } // Generation parameters container @@ -59,11 +68,14 @@ message Parameters { // Maximum sequence length for prompt processing optional int32 max_sequence_length = 7; + // The extent to which the reference image is altered, between 0 and 1 + optional float strength = 8; + // Array of sigma values for noise scheduling - // repeated float sigmas = 8; + // repeated float sigmas = 9; // Output type, either "base64" or "url" - // optional string output_type = 9; + // optional string output_type = 10; } // Request structure for image generation tasks using FLUX models