From da419f2a99e226488fc46f4026f558e3a59e4e15 Mon Sep 17 00:00:00 2001 From: "yueyang.hyy" Date: Tue, 24 Jun 2025 11:22:44 +0800 Subject: [PATCH] fix batch cfg && support flux ckpt without guidance_embedder --- .../models/basic/transformer_helper.py | 2 +- diffsynth_engine/models/flux/flux_dit.py | 19 +++++++++++++------ tests/test_pipelines/test_flux_image.py | 3 ++- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/diffsynth_engine/models/basic/transformer_helper.py b/diffsynth_engine/models/basic/transformer_helper.py index f5170d4..d1d9027 100644 --- a/diffsynth_engine/models/basic/transformer_helper.py +++ b/diffsynth_engine/models/basic/transformer_helper.py @@ -4,7 +4,7 @@ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): - return x * (1 + scale) + shift + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) class AdaLayerNorm(nn.Module): diff --git a/diffsynth_engine/models/flux/flux_dit.py b/diffsynth_engine/models/flux/flux_dit.py index b48045c..ff28e1a 100644 --- a/diffsynth_engine/models/flux/flux_dit.py +++ b/diffsynth_engine/models/flux/flux_dit.py @@ -255,16 +255,17 @@ def forward(self, image, text, t_emb, rope_emb, image_emb=None): image_in, gate_a = self.norm_msa_a(image, t_emb) text_in, gate_b = self.norm_msa_b(text, t_emb) image_out, text_out = self.attn(image_in, text_in, rope_emb, image_emb) - image = image + gate_a * image_out - text = text + gate_b * text_out + + image = image + gate_a.unsqueeze(1) * image_out + text = text + gate_b.unsqueeze(1) * text_out # AdaLayerNorm-Zero for Image MLP image_in, gate_a = self.norm_mlp_a(image, t_emb) - image = image + gate_a * self.ff_a(image_in) + image = image + gate_a.unsqueeze(1) * self.ff_a(image_in) # AdaLayerNorm-Zero for Text MLP text_in, gate_b = self.norm_mlp_b(text, t_emb) - text = text + gate_b * self.ff_b(text_in) + text = text + gate_b.unsqueeze(1) * self.ff_b(text_in) return image, text @@ -318,7 +319,7 @@ def forward(self, x, t_emb, rope_emb, image_emb=None): h, gate = self.norm(x, emb=t_emb) attn_output = self.attn(h, rope_emb, image_emb) mlp_output = self.mlp(h) - return x + gate * self.proj_out(torch.cat([attn_output, mlp_output], dim=2)) + return x + gate.unsqueeze(1) * self.proj_out(torch.cat([attn_output, mlp_output], dim=2)) class FluxDiT(PreTrainedModel): @@ -329,13 +330,17 @@ def __init__( in_channel: int = 64, attn_impl: Optional[str] = None, use_usp: bool = False, + guidance_embedder: bool = True, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16, ): super().__init__() self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56]) self.time_embedder = TimestepEmbeddings(256, 3072, device=device, dtype=dtype) - self.guidance_embedder = TimestepEmbeddings(256, 3072, device=device, dtype=dtype) + if guidance_embedder: + self.guidance_embedder = TimestepEmbeddings(256, 3072, device=device, dtype=dtype) + else: + self.guidance_embedder = None self.pooled_text_embedder = nn.Sequential( nn.Linear(768, 3072, device=device, dtype=dtype), nn.SiLU(), @@ -476,6 +481,7 @@ def from_state_dict( in_channel: int = 64, attn_impl: Optional[str] = None, use_usp: bool = False, + guidance_embedder: bool = True, ): with no_init_weights(): model = torch.nn.utils.skip_init( @@ -485,6 +491,7 @@ def from_state_dict( in_channel=in_channel, attn_impl=attn_impl, use_usp=use_usp, + guidance_embedder=guidance_embedder, ) model = model.requires_grad_(False) # for loading gguf model.load_state_dict(state_dict, assign=True) diff --git a/tests/test_pipelines/test_flux_image.py b/tests/test_pipelines/test_flux_image.py index 361830b..ba7da40 100644 --- a/tests/test_pipelines/test_flux_image.py +++ b/tests/test_pipelines/test_flux_image.py @@ -46,7 +46,8 @@ def test_unfused_lora(self): ) self.pipe.unload_loras() self.assertImageEqualAndSaveFailed(image, "flux/flux_lora.png", threshold=0.98) - + + # TODO: Add batch cfg test / Add diffusers format ckpt test class TestFLUXGGUF(ImageTestCase): @classmethod