From 951dfefe62db79a3d92457d01ff17850b1afec74 Mon Sep 17 00:00:00 2001 From: akleine Date: Thu, 1 Jan 2026 11:54:52 +0100 Subject: [PATCH 1/5] feat: add U-Net specials of SDXS --- model.cpp | 7 +++++++ model.h | 3 ++- stable-diffusion.cpp | 1 + unet.hpp | 5 ++++- 4 files changed, 14 insertions(+), 2 deletions(-) diff --git a/model.cpp b/model.cpp index dcb77e816..f40699557 100644 --- a/model.cpp +++ b/model.cpp @@ -1038,6 +1038,7 @@ SDVersion ModelLoader::get_sd_version() { int64_t patch_embedding_channels = 0; bool has_img_emb = false; bool has_middle_block_1 = false; + bool has_output_block_71 = false; for (auto& [name, tensor_storage] : tensor_storage_map) { if (!(is_xl)) { @@ -1094,6 +1095,9 @@ SDVersion ModelLoader::get_sd_version() { tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) { has_middle_block_1 = true; } + if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1") != std::string::npos) { + has_output_block_71 = true; + } if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" || tensor_storage.name == "cond_stage_model.model.token_embedding.weight" || tensor_storage.name == "text_model.embeddings.token_embedding.weight" || @@ -1155,6 +1159,9 @@ SDVersion ModelLoader::get_sd_version() { return VERSION_SD1_PIX2PIX; } if (!has_middle_block_1) { + if (!has_output_block_71) { + return VERSION_SDXS; + } return VERSION_SD1_TINY_UNET; } return VERSION_SD1; diff --git a/model.h b/model.h index b9e50ad63..e52766cc0 100644 --- a/model.h +++ b/model.h @@ -28,6 +28,7 @@ enum SDVersion { VERSION_SD2, VERSION_SD2_INPAINT, VERSION_SD2_TINY_UNET, + VERSION_SDXS, VERSION_SDXL, VERSION_SDXL_INPAINT, VERSION_SDXL_PIX2PIX, @@ -50,7 +51,7 @@ enum SDVersion { }; static inline bool sd_version_is_sd1(SDVersion version) { - if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX || version == VERSION_SD1_TINY_UNET) { + if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX || version == VERSION_SD1_TINY_UNET || version == VERSION_SDXS) { return true; } return false; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 1c8c55ba8..f9976b43c 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -31,6 +31,7 @@ const char* model_version_to_str[] = { "SD 2.x", "SD 2.x Inpaint", "SD 2.x Tiny UNet", + "SDXS", "SDXL", "SDXL Inpaint", "SDXL Instruct-Pix2Pix", diff --git a/unet.hpp b/unet.hpp index ec7578e4b..563e09916 100644 --- a/unet.hpp +++ b/unet.hpp @@ -215,10 +215,13 @@ class UnetModelBlock : public GGMLBlock { } else if (sd_version_is_unet_edit(version)) { in_channels = 8; } - if (version == VERSION_SD1_TINY_UNET || version == VERSION_SD2_TINY_UNET) { + if (version == VERSION_SD1_TINY_UNET || version == VERSION_SD2_TINY_UNET || version == VERSION_SDXS) { num_res_blocks = 1; channel_mult = {1, 2, 4}; tiny_unet = true; + if (version == VERSION_SDXS) { + attention_resolutions = {4, 2}; // here just like SDXL + } } // dims is always 2 From b59f908e6f7bedd45f94e980e046d50e23068704 Mon Sep 17 00:00:00 2001 From: akleine Date: Tue, 6 Jan 2026 10:48:41 +0100 Subject: [PATCH 2/5] docs: update distilled_sd.md for SDXS-512 --- docs/distilled_sd.md | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/docs/distilled_sd.md b/docs/distilled_sd.md index 478305f27..b47de664d 100644 --- a/docs/distilled_sd.md +++ b/docs/distilled_sd.md @@ -83,7 +83,7 @@ python convert_diffusers_to_original_stable_diffusion.py \ The file segmind_tiny-sd.ckpt will be generated and is now ready for use with sd.cpp. You can follow a similar process for the other models mentioned above. -### Another available .ckpt file: +##### Another available .ckpt file: * https://huggingface.co/ClashSAN/small-sd/resolve/main/tinySDdistilled.ckpt @@ -97,3 +97,26 @@ for key, value in ckpt['state_dict'].items(): ckpt['state_dict'][key] = value.contiguous() torch.save(ckpt, "tinySDdistilled_fixed.ckpt") ``` + + +### SDXS-512 + +Another very tiny and **incredibly fast** model is SDXS. The authors refer to it as *"Real-Time One-Step Latent Diffusion Models with Image Conditions"*. For details read the paper: https://arxiv.org/pdf/2403.16627 . Once again the authors removed some more blocks of U-Net part and unlike other SD1 models they use an adjusted _AutoencoderTiny_ instead of default _AutoencoderKL_ for the VAE part. + +##### First download the diffusers models from Hugging Face using Python: + +```python +from diffusers import StableDiffusionPipeline +pipe = StableDiffusionPipeline.from_pretrained("IDKiro/sdxs-512-dreamshaper") +pipe.save_pretrained(save_directory="sdxs") +``` + +##### Second run the model as follows: + +```python +~/stable-diffusion.cpp/build/bin/sd-cli -m sdxs -p "portrait of a lovely cat" \ + --cfg-scale 1 --steps 1 \ + --taesd sdxs/vae/diffusion_pytorch_model.safetensors +``` + +All options: ``` --cfg-scale 1 ``` , ``` --steps 1 ``` and ``` --taesd sdxs/vae/diffusion_pytorch_model.safetensors``` are mandatory here. From 25f55cd7d9928b98af06e9d1e27263fb1baf1368 Mon Sep 17 00:00:00 2001 From: akleine Date: Tue, 6 Jan 2026 17:15:47 +0100 Subject: [PATCH 3/5] feat: for SDXS use AutoencoderTiny as the primary VAE --- stable-diffusion.cpp | 107 ++++++++++++++++++++++++++++++------------- tae.hpp | 9 ++++ 2 files changed, 85 insertions(+), 31 deletions(-) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index f9976b43c..b8e587c7d 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -115,7 +115,8 @@ class StableDiffusionGGML { std::shared_ptr clip_vision; // for svd or wan2.1 i2v std::shared_ptr diffusion_model; std::shared_ptr high_noise_diffusion_model; - std::shared_ptr first_stage_model; + std::shared_ptr first_stage_model = nullptr; + std::shared_ptr first_stage_model_tiny = nullptr; std::shared_ptr tae_first_stage; std::shared_ptr control_net; std::shared_ptr pmid_model; @@ -606,28 +607,42 @@ class StableDiffusionGGML { first_stage_model = std::make_shared(vae_backend, offload_params_to_cpu); } else { - first_stage_model = std::make_shared(vae_backend, - offload_params_to_cpu, - tensor_storage_map, - "first_stage_model", - vae_decode_only, - false, - version); - if (sd_ctx_params->vae_conv_direct) { - LOG_INFO("Using Conv2d direct in the vae model"); - first_stage_model->set_conv2d_direct_enabled(true); - } - if (version == VERSION_SDXL && - (strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale)) { - float vae_conv_2d_scale = 1.f / 32.f; - LOG_WARN( - "No VAE specified with --vae or --force-sdxl-vae-conv-scale flag set, " - "using Conv2D scale %.3f", - vae_conv_2d_scale); - first_stage_model->set_conv2d_scale(vae_conv_2d_scale); + if (version == VERSION_SDXS) { + first_stage_model_tiny = std::make_shared(vae_backend, + offload_params_to_cpu, + tensor_storage_map, + "decoder.layers", + vae_decode_only, + version); + first_stage_model_tiny->alloc_params_buffer(); + first_stage_model_tiny->get_param_tensors(tensors,"first_stage_model"); + if (sd_ctx_params->vae_conv_direct) { + first_stage_model_tiny->set_conv2d_direct_enabled(true); + } + } else { + first_stage_model = std::make_shared(vae_backend, + offload_params_to_cpu, + tensor_storage_map, + "first_stage_model", + vae_decode_only, + false, + version); + if (sd_ctx_params->vae_conv_direct) { + LOG_INFO("Using Conv2d direct in the vae model"); + first_stage_model->set_conv2d_direct_enabled(true); + } + if (version == VERSION_SDXL && + (strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale)) { + float vae_conv_2d_scale = 1.f / 32.f; + LOG_WARN( + "No VAE specified with --vae or --force-sdxl-vae-conv-scale flag set, " + "using Conv2D scale %.3f", + vae_conv_2d_scale); + first_stage_model->set_conv2d_scale(vae_conv_2d_scale); + } + first_stage_model->alloc_params_buffer(); + first_stage_model->get_param_tensors(tensors, "first_stage_model"); } - first_stage_model->alloc_params_buffer(); - first_stage_model->get_param_tensors(tensors, "first_stage_model"); } } @@ -723,6 +738,9 @@ class StableDiffusionGGML { if (first_stage_model) { first_stage_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y); } + if (first_stage_model_tiny) { + first_stage_model_tiny->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y); + } if (tae_first_stage) { tae_first_stage->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y); } @@ -784,7 +802,11 @@ class StableDiffusionGGML { } size_t vae_params_mem_size = 0; if (!use_tiny_autoencoder || sd_ctx_params->tae_preview_only) { - vae_params_mem_size = first_stage_model->get_params_buffer_size(); + if (first_stage_model_tiny != nullptr) { + vae_params_mem_size = first_stage_model_tiny->get_params_buffer_size(); + } else { + vae_params_mem_size = first_stage_model->get_params_buffer_size(); + } } if (use_tiny_autoencoder) { if (!tae_first_stage->load_from_file(taesd_path, n_threads)) { @@ -2518,9 +2540,17 @@ class StableDiffusionGGML { }; sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, on_tiling); } else { - first_stage_model->compute(n_threads, x, false, &result, work_ctx); + if (version == VERSION_SDXS) { + first_stage_model_tiny->compute(n_threads, x, false, &result, work_ctx); + } else { + first_stage_model->compute(n_threads, x, false, &result, work_ctx); + } + } + if (version == VERSION_SDXS) { + first_stage_model_tiny->free_compute_buffer(); + } else { + first_stage_model->free_compute_buffer(); } - first_stage_model->free_compute_buffer(); } else { if (vae_tiling_params.enabled && !encode_video) { // split latent in 32x32 tiles and compute in several steps @@ -2574,6 +2604,7 @@ class StableDiffusionGGML { sd_version_is_qwen_image(version) || sd_version_is_wan(version) || sd_version_is_flux2(version) || + version == VERSION_SDXS || version == VERSION_CHROMA_RADIANCE) { latent = vae_output; } else if (version == VERSION_SD1_PIX2PIX) { @@ -2632,7 +2663,9 @@ class StableDiffusionGGML { if (sd_version_is_qwen_image(version)) { x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]); } - process_latent_out(x); + if (first_stage_model_tiny == nullptr) { + process_latent_out(x); + } // x = load_tensor_from_file(work_ctx, "wan_vae_z.bin"); if (vae_tiling_params.enabled && !decode_video) { float tile_overlap; @@ -2643,14 +2676,22 @@ class StableDiffusionGGML { // split latent in 32x32 tiles and compute in several steps auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { - first_stage_model->compute(n_threads, in, true, &out, nullptr); + first_stage_model_tiny != nullptr ? first_stage_model_tiny->compute(n_threads, in, true, &out, nullptr) : first_stage_model->compute(n_threads, in, true, &out, nullptr); }; sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, on_tiling); } else { - first_stage_model->compute(n_threads, x, true, &result, work_ctx); + if (first_stage_model_tiny != nullptr) { + first_stage_model_tiny->compute(n_threads, x, true, &result, work_ctx); + } else { + first_stage_model->compute(n_threads, x, true, &result, work_ctx); + } + } + if (first_stage_model_tiny != nullptr) { + first_stage_model_tiny->free_compute_buffer(); + } else { + first_stage_model->free_compute_buffer(); + process_vae_output_tensor(result); } - first_stage_model->free_compute_buffer(); - process_vae_output_tensor(result); } else { if (vae_tiling_params.enabled && !decode_video) { // split latent in 64x64 tiles and compute in several steps @@ -3412,7 +3453,11 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, int64_t t4 = ggml_time_ms(); LOG_INFO("decode_first_stage completed, taking %.2fs", (t4 - t3) * 1.0f / 1000); if (sd_ctx->sd->free_params_immediately && !sd_ctx->sd->use_tiny_autoencoder) { - sd_ctx->sd->first_stage_model->free_params_buffer(); + if (sd_ctx->sd->first_stage_model_tiny != nullptr) { + sd_ctx->sd->first_stage_model_tiny->free_params_buffer(); + } else { + sd_ctx->sd->first_stage_model->free_params_buffer(); + } } sd_ctx->sd->lora_stat(); diff --git a/tae.hpp b/tae.hpp index 5da76e692..dc4dbcad8 100644 --- a/tae.hpp +++ b/tae.hpp @@ -506,6 +506,7 @@ struct TinyAutoEncoder : public GGMLRunner { struct ggml_context* output_ctx = nullptr) = 0; virtual bool load_from_file(const std::string& file_path, int n_threads) = 0; + virtual void get_param_tensors(std::map& tensors, const std::string prefix) = 0; }; struct TinyImageAutoEncoder : public TinyAutoEncoder { @@ -555,6 +556,10 @@ struct TinyImageAutoEncoder : public TinyAutoEncoder { return success; } + void get_param_tensors(std::map& tensors, const std::string prefix) { + taesd.get_param_tensors(tensors,prefix); + } + struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) { struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); z = to_backend(z); @@ -624,6 +629,10 @@ struct TinyVideoAutoEncoder : public TinyAutoEncoder { return success; } + void get_param_tensors(std::map& tensors, const std::string prefix) { + taehv.get_param_tensors(tensors,prefix); + } + struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) { struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); z = to_backend(z); From 6669e227f5661c3c9d2336eecd7451d485f865ed Mon Sep 17 00:00:00 2001 From: akleine Date: Wed, 7 Jan 2026 09:20:41 +0100 Subject: [PATCH 4/5] docs: update distilled_sd.md for SDXS-512 --- docs/distilled_sd.md | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/docs/distilled_sd.md b/docs/distilled_sd.md index b47de664d..232c02288 100644 --- a/docs/distilled_sd.md +++ b/docs/distilled_sd.md @@ -101,22 +101,27 @@ torch.save(ckpt, "tinySDdistilled_fixed.ckpt") ### SDXS-512 -Another very tiny and **incredibly fast** model is SDXS. The authors refer to it as *"Real-Time One-Step Latent Diffusion Models with Image Conditions"*. For details read the paper: https://arxiv.org/pdf/2403.16627 . Once again the authors removed some more blocks of U-Net part and unlike other SD1 models they use an adjusted _AutoencoderTiny_ instead of default _AutoencoderKL_ for the VAE part. +Another very tiny and **incredibly fast** model is SDXS by IDKiro et al. The authors refer to it as *"Real-Time One-Step Latent Diffusion Models with Image Conditions"*. For details read the paper: https://arxiv.org/pdf/2403.16627 . Once again the authors removed some more blocks of U-Net part and unlike other SD1 models they use an adjusted _AutoEncoderTiny_ instead of default _AutoEncoderKL_ for the VAE part. -##### First download the diffusers models from Hugging Face using Python: +##### 1. Download the diffusers model from Hugging Face using Python: ```python from diffusers import StableDiffusionPipeline pipe = StableDiffusionPipeline.from_pretrained("IDKiro/sdxs-512-dreamshaper") pipe.save_pretrained(save_directory="sdxs") ``` +##### 2. Create a safetensors file -##### Second run the model as follows: +```bash +python convert_diffusers_to_original_stable_diffusion.py \ + --model_path sdxs --checkpoint_path sdxs.safetensors --half --use_safetensors +``` -```python -~/stable-diffusion.cpp/build/bin/sd-cli -m sdxs -p "portrait of a lovely cat" \ - --cfg-scale 1 --steps 1 \ - --taesd sdxs/vae/diffusion_pytorch_model.safetensors +##### 3. Run the model as follows: + +```bash +~/stable-diffusion.cpp/build/bin/sd-cli -m sdxs.safetensors -p "portrait of a lovely cat" \ + --cfg-scale 1 --steps 1 ``` -All options: ``` --cfg-scale 1 ``` , ``` --steps 1 ``` and ``` --taesd sdxs/vae/diffusion_pytorch_model.safetensors``` are mandatory here. +Both options: ``` --cfg-scale 1 ``` and ``` --steps 1 ``` are mandatory here. From 995e4d071fbf9dd8c06fbdbe1db59df09fcb8c74 Mon Sep 17 00:00:00 2001 From: akleine Date: Thu, 8 Jan 2026 12:49:19 +0100 Subject: [PATCH 5/5] fix: SDXS code cleaning after review by stduhpf --- stable-diffusion.cpp | 123 +++++++++++++++---------------------------- 1 file changed, 41 insertions(+), 82 deletions(-) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index b8e587c7d..6e995adee 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -116,8 +116,7 @@ class StableDiffusionGGML { std::shared_ptr diffusion_model; std::shared_ptr high_noise_diffusion_model; std::shared_ptr first_stage_model = nullptr; - std::shared_ptr first_stage_model_tiny = nullptr; - std::shared_ptr tae_first_stage; + std::shared_ptr tae_first_stage = nullptr; std::shared_ptr control_net; std::shared_ptr pmid_model; std::shared_ptr pmid_lora; @@ -593,7 +592,7 @@ class StableDiffusionGGML { vae_backend = backend; } - if (!use_tiny_autoencoder || sd_ctx_params->tae_preview_only) { + if (!(use_tiny_autoencoder || version == VERSION_SDXS) || sd_ctx_params->tae_preview_only) { if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) { first_stage_model = std::make_shared(vae_backend, offload_params_to_cpu, @@ -607,46 +606,31 @@ class StableDiffusionGGML { first_stage_model = std::make_shared(vae_backend, offload_params_to_cpu); } else { - if (version == VERSION_SDXS) { - first_stage_model_tiny = std::make_shared(vae_backend, - offload_params_to_cpu, - tensor_storage_map, - "decoder.layers", - vae_decode_only, - version); - first_stage_model_tiny->alloc_params_buffer(); - first_stage_model_tiny->get_param_tensors(tensors,"first_stage_model"); - if (sd_ctx_params->vae_conv_direct) { - first_stage_model_tiny->set_conv2d_direct_enabled(true); - } - } else { - first_stage_model = std::make_shared(vae_backend, - offload_params_to_cpu, - tensor_storage_map, - "first_stage_model", - vae_decode_only, - false, - version); - if (sd_ctx_params->vae_conv_direct) { - LOG_INFO("Using Conv2d direct in the vae model"); - first_stage_model->set_conv2d_direct_enabled(true); - } - if (version == VERSION_SDXL && - (strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale)) { - float vae_conv_2d_scale = 1.f / 32.f; - LOG_WARN( - "No VAE specified with --vae or --force-sdxl-vae-conv-scale flag set, " - "using Conv2D scale %.3f", - vae_conv_2d_scale); - first_stage_model->set_conv2d_scale(vae_conv_2d_scale); - } - first_stage_model->alloc_params_buffer(); - first_stage_model->get_param_tensors(tensors, "first_stage_model"); + first_stage_model = std::make_shared(vae_backend, + offload_params_to_cpu, + tensor_storage_map, + "first_stage_model", + vae_decode_only, + false, + version); + if (sd_ctx_params->vae_conv_direct) { + LOG_INFO("Using Conv2d direct in the vae model"); + first_stage_model->set_conv2d_direct_enabled(true); + } + if (version == VERSION_SDXL && + (strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale)) { + float vae_conv_2d_scale = 1.f / 32.f; + LOG_WARN( + "No VAE specified with --vae or --force-sdxl-vae-conv-scale flag set, " + "using Conv2D scale %.3f", + vae_conv_2d_scale); + first_stage_model->set_conv2d_scale(vae_conv_2d_scale); } + first_stage_model->alloc_params_buffer(); + first_stage_model->get_param_tensors(tensors, "first_stage_model"); } } - - if (use_tiny_autoencoder) { + if (use_tiny_autoencoder || version == VERSION_SDXS) { if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) { tae_first_stage = std::make_shared(vae_backend, offload_params_to_cpu, @@ -661,6 +645,10 @@ class StableDiffusionGGML { "decoder.layers", vae_decode_only, version); + if (version == VERSION_SDXS) { + tae_first_stage->alloc_params_buffer(); + tae_first_stage->get_param_tensors(tensors,"first_stage_model"); + } } if (sd_ctx_params->vae_conv_direct) { LOG_INFO("Using Conv2d direct in the tae model"); @@ -738,9 +726,6 @@ class StableDiffusionGGML { if (first_stage_model) { first_stage_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y); } - if (first_stage_model_tiny) { - first_stage_model_tiny->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y); - } if (tae_first_stage) { tae_first_stage->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y); } @@ -801,17 +786,14 @@ class StableDiffusionGGML { unet_params_mem_size += high_noise_diffusion_model->get_params_buffer_size(); } size_t vae_params_mem_size = 0; - if (!use_tiny_autoencoder || sd_ctx_params->tae_preview_only) { - if (first_stage_model_tiny != nullptr) { - vae_params_mem_size = first_stage_model_tiny->get_params_buffer_size(); - } else { - vae_params_mem_size = first_stage_model->get_params_buffer_size(); - } + if (!(use_tiny_autoencoder || version == VERSION_SDXS) || sd_ctx_params->tae_preview_only) { + vae_params_mem_size = first_stage_model->get_params_buffer_size(); } - if (use_tiny_autoencoder) { - if (!tae_first_stage->load_from_file(taesd_path, n_threads)) { + if (use_tiny_autoencoder || version == VERSION_SDXS) { + if (use_tiny_autoencoder && !tae_first_stage->load_from_file(taesd_path, n_threads)) { return false; } + use_tiny_autoencoder = true; // now the processing is identical for VERSION_SDXS vae_params_mem_size = tae_first_stage->get_params_buffer_size(); } size_t control_net_params_mem_size = 0; @@ -2540,17 +2522,9 @@ class StableDiffusionGGML { }; sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, on_tiling); } else { - if (version == VERSION_SDXS) { - first_stage_model_tiny->compute(n_threads, x, false, &result, work_ctx); - } else { - first_stage_model->compute(n_threads, x, false, &result, work_ctx); - } - } - if (version == VERSION_SDXS) { - first_stage_model_tiny->free_compute_buffer(); - } else { - first_stage_model->free_compute_buffer(); + first_stage_model->compute(n_threads, x, false, &result, work_ctx); } + first_stage_model->free_compute_buffer(); } else { if (vae_tiling_params.enabled && !encode_video) { // split latent in 32x32 tiles and compute in several steps @@ -2604,7 +2578,6 @@ class StableDiffusionGGML { sd_version_is_qwen_image(version) || sd_version_is_wan(version) || sd_version_is_flux2(version) || - version == VERSION_SDXS || version == VERSION_CHROMA_RADIANCE) { latent = vae_output; } else if (version == VERSION_SD1_PIX2PIX) { @@ -2663,9 +2636,7 @@ class StableDiffusionGGML { if (sd_version_is_qwen_image(version)) { x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]); } - if (first_stage_model_tiny == nullptr) { - process_latent_out(x); - } + process_latent_out(x); // x = load_tensor_from_file(work_ctx, "wan_vae_z.bin"); if (vae_tiling_params.enabled && !decode_video) { float tile_overlap; @@ -2676,22 +2647,14 @@ class StableDiffusionGGML { // split latent in 32x32 tiles and compute in several steps auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { - first_stage_model_tiny != nullptr ? first_stage_model_tiny->compute(n_threads, in, true, &out, nullptr) : first_stage_model->compute(n_threads, in, true, &out, nullptr); + first_stage_model->compute(n_threads, in, true, &out, nullptr); }; sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, on_tiling); } else { - if (first_stage_model_tiny != nullptr) { - first_stage_model_tiny->compute(n_threads, x, true, &result, work_ctx); - } else { - first_stage_model->compute(n_threads, x, true, &result, work_ctx); - } - } - if (first_stage_model_tiny != nullptr) { - first_stage_model_tiny->free_compute_buffer(); - } else { - first_stage_model->free_compute_buffer(); - process_vae_output_tensor(result); + first_stage_model->compute(n_threads, x, true, &result, work_ctx); } + first_stage_model->free_compute_buffer(); + process_vae_output_tensor(result); } else { if (vae_tiling_params.enabled && !decode_video) { // split latent in 64x64 tiles and compute in several steps @@ -3453,11 +3416,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, int64_t t4 = ggml_time_ms(); LOG_INFO("decode_first_stage completed, taking %.2fs", (t4 - t3) * 1.0f / 1000); if (sd_ctx->sd->free_params_immediately && !sd_ctx->sd->use_tiny_autoencoder) { - if (sd_ctx->sd->first_stage_model_tiny != nullptr) { - sd_ctx->sd->first_stage_model_tiny->free_params_buffer(); - } else { - sd_ctx->sd->first_stage_model->free_params_buffer(); - } + sd_ctx->sd->first_stage_model->free_params_buffer(); } sd_ctx->sd->lora_stat();