Skip to content

Conversation

@adi776borate
Copy link
Contributor

@adi776borate adi776borate commented Dec 9, 2025

What does this PR do?

Fixes #12809
This PR fixes it by:

  1. Removing the @torch.autocast decorator (Fixes the import warning).
  2. Explicitly casting inputs to float32 inside the forward method (Preserves the required numerical stability).
  3. Casting the result back to weight.dtype before passing it to the Linear layers (Fixes the dtype mismatch crash).

Verification

I verified that the results remain stable before and after this change by generating images with a fixed seed (generator=torch.manual_seed(42)).

The results are almost the same with some minor differences.

Before Fix After Fix
kandinsky_before_fix kandinsky_after_fix
Reproduction Script
import torch
from diffusers import Kandinsky5T2IPipeline

model_id = "kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers"
device = "cuda" if torch.cuda.is_available() else "cpu"

dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
pipe = Kandinsky5T2IPipeline.from_pretrained(model_id, torch_dtype=dtype)
pipe.to(device)

seed = 42
generator = torch.Generator(device=device).manual_seed(seed)

print("Generating image...")
output = pipe(
    prompt="A cat and a dog baking a cake together in a kitchen.",
    negative_prompt="",
    num_inference_steps=25, # Reduced for faster verification
    guidance_scale=3.5,
    height=1024,
    width=1024,
    generator=generator, 
)

image = output.image[0]
image.save("kandinsky_after_fix.png")

Before submitting

Who can review?

@yiyixuxu @leffff
Anyone in the community is free to review the PR once the tests have passed.

@leffff
Copy link
Contributor

leffff commented Dec 9, 2025

Looks good to me!

@knd0331
Copy link

knd0331 commented Dec 10, 2025

Thanks for the quick fix! I didn't have time to submit a PR myself, so I really appreciate you jumping on this. 🙏
@adi776borate

@adi776borate
Copy link
Contributor Author

@yiyixuxu @sayakpaul
A gentle ping to review

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Could you also provide your testing script?

@adi776borate
Copy link
Contributor Author

Thank you! Could you also provide your testing script?

The verification script is already provided in the PR description above.
If you want to test minimally, we can just do:

from diffusers.models.transformers import transformer_kandinsky
print("Import successful.")

Should print a UserWarning on main, but not on this branch.

@sayakpaul sayakpaul requested a review from yiyixuxu December 11, 2025 12:00
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@yiyixuxu
Copy link
Collaborator

@bot /style

@github-actions
Copy link
Contributor

github-actions bot commented Dec 11, 2025

Style bot fixed some files and pushed the changes.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@torch.autocast(device_type="cuda", dtype=torch.float32)
def forward(self, time):
args = torch.outer(time, self.freqs.to(device=time.device))
time = time.to(dtype=torch.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
time = time.to(dtype=torch.float32)
origintal_dtype = time.dtype
time = time.to(dtype=torch.float32)

freqs = self.freqs.to(device=time.device, dtype=torch.float32)
args = torch.outer(time, freqs)
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
time_embed = time_embed.to(dtype=self.in_layer.weight.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
time_embed = time_embed.to(dtype=self.in_layer.weight.dtype)
time_embed = time_embed.to(dtype=original_dtype)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I cast to self.in_layer.weight.dtype instead of original_dtype is to prevent runtime crashes on backends like XPU as mentioned by @vladmandic here.
If users load the pipeline in float16, and we pass time_embed as float32, that will raise an error, won't it?
I might be wrong, correct me if so.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! I also tried to apply your suggested change and ran the exact same code attached in the PR description above on an L40S (which supports bf16), and faced below error:

Traceback (most recent call last):
  File "/teamspace/studios/this_studio/diffusers/verify_fix.py", line 25, in <module>
    output = pipe(
             ^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/teamspace/studios/this_studio/diffusers/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky_t2i.py", line 731, in __call__
    pred_velocity = self.transformer(
                    ^^^^^^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/transformers/transformer_kandinsky.py", line 647, in forward
    text_embed = text_transformer_block(text_embed, time_embed, text_rope)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/transformers/transformer_kandinsky.py", line 461, in forward
    self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1)
                                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/transformers/transformer_kandinsky.py", line 280, in forward
    return self.out_layer(self.activation(x))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 125, in forward
    return F.linear(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: mat1 and mat2 must have the same dtype, but got Float and BFloat16

Thus, I think casting to self.in_layer.weight.dtype is a safer option.
Please let me know your thoughts.


@torch.autocast(device_type="cuda", dtype=torch.float32)
def forward(self, x):
x = x.to(dtype=self.out_layer.weight.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

umm actually this did not look correct to me - we want to upcast it to float32, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, if we force x to float32 here, we might hit the same mismatch crash if the out_layer weights are float16/bfloat16.

@hlky
Copy link
Contributor

hlky commented Dec 13, 2025

This is incorrect.

Minimal reproduction

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


def get_freqs(dim, max_period=10000.0):
    freqs = torch.exp(
        -math.log(max_period)
        * torch.arange(start=0, end=dim, dtype=torch.float32)
        / dim
    )
    return freqs


class Kandinsky5TimeEmbeddings(nn.Module):
    def __init__(self, model_dim, time_dim, max_period=10000.0):
        super().__init__()
        assert model_dim % 2 == 0
        self.model_dim = model_dim
        self.max_period = max_period
        self.freqs = get_freqs(self.model_dim // 2, self.max_period)
        self.in_layer = nn.Linear(model_dim, time_dim, bias=True)
        self.activation = nn.SiLU()
        self.out_layer = nn.Linear(time_dim, time_dim, bias=True)

    @torch.autocast(device_type="cuda", dtype=torch.float32)
    def forward(self, time):
        args = torch.outer(time, self.freqs.to(device=time.device))
        time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
        return time_embed


class Kandinsky5TimeEmbeddingsPR(nn.Module):
    def __init__(self, model_dim, time_dim, max_period=10000.0):
        super().__init__()
        assert model_dim % 2 == 0
        self.model_dim = model_dim
        self.max_period = max_period
        self.freqs = get_freqs(self.model_dim // 2, self.max_period)
        self.in_layer = nn.Linear(model_dim, time_dim, bias=True)
        self.activation = nn.SiLU()
        self.out_layer = nn.Linear(time_dim, time_dim, bias=True)

    def forward(self, time):
        time = time.to(dtype=torch.float32)
        freqs = self.freqs.to(device=time.device, dtype=torch.float32)
        args = torch.outer(time, freqs)
        time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        time_embed = time_embed.to(dtype=self.in_layer.weight.dtype)
        time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
        return time_embed


class Kandinsky5TimeEmbeddingsNoAutocast(nn.Module):
    def __init__(self, model_dim, time_dim, max_period=10000.0):
        super().__init__()
        assert model_dim % 2 == 0
        self.model_dim = model_dim
        self.max_period = max_period
        self.freqs = get_freqs(self.model_dim // 2, self.max_period)
        self.in_layer = nn.Linear(model_dim, time_dim, bias=True)
        self.activation = nn.SiLU()
        self.out_layer = nn.Linear(time_dim, time_dim, bias=True)

    def forward(self, time):
        args = torch.outer(time.to(torch.float32), self.freqs.to(device=time.device))
        time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        time_embed = F.linear(
            self.activation(
                F.linear(
                    time_embed,
                    self.in_layer.weight.to(torch.float32),
                    self.in_layer.bias.to(torch.float32),
                )
            ),
            self.out_layer.weight.to(torch.float32),
            self.out_layer.bias.to(torch.float32),
        )
        return time_embed


torch.manual_seed(0)
with_autocast = (
    Kandinsky5TimeEmbeddings(model_dim=2560, time_dim=512)
    .to("cuda", torch.bfloat16)
    .eval()
)
torch.manual_seed(0)
pr = (
    Kandinsky5TimeEmbeddingsPR(model_dim=2560, time_dim=512)
    .to("cuda", torch.bfloat16)
    .eval()
)
torch.manual_seed(0)
no_autocast = (
    Kandinsky5TimeEmbeddingsNoAutocast(model_dim=2560, time_dim=512)
    .to("cuda", torch.bfloat16)
    .eval()
)


with torch.no_grad():
    time = torch.tensor([952.0]).to("cuda", torch.bfloat16)
    with_out = with_autocast(time.clone())
    pr_out = pr(time.clone())
    no_out = no_autocast(time.clone())

print(f"{with_out.dtype=}, {pr_out.dtype=}, {no_out.dtype=}")
try:
    print(f"{torch.allclose(with_out, pr_out)=}")
except RuntimeError as e:
    print(f"{e}, casting")
    print(f"{torch.allclose(with_out.to(pr_out.dtype), pr_out)=}")

print(f"{torch.allclose(with_out, no_out)=}")

with_out.dtype=torch.float32, pr_out.dtype=torch.bfloat16, no_out.dtype=torch.float32
Float did not match BFloat16, casting
torch.allclose(with_out.to(pr_out.dtype), pr_out)=False
torch.allclose(with_out, no_out)=True

As we see from the minimal reproduction of Kandinsky5TimeEmbeddings, the output from this PR does not match the output from main.

@torch.autocast(device_type="cuda", dtype=torch.float32) means everything is cast to float32, the Linear layers and activation also run in float32 and the output from forward is float32.

In this PR the Linear layers and activation are running in bfloat16, which results in different output from the module and in turn different output image.

Kandinsky5TimeEmbeddings should be:

class Kandinsky5TimeEmbeddingsNoAutocast(nn.Module):
    def __init__(self, model_dim, time_dim, max_period=10000.0):
        super().__init__()
        assert model_dim % 2 == 0
        self.model_dim = model_dim
        self.max_period = max_period
        self.freqs = get_freqs(self.model_dim // 2, self.max_period)
        self.in_layer = nn.Linear(model_dim, time_dim, bias=True)
        self.activation = nn.SiLU()
        self.out_layer = nn.Linear(time_dim, time_dim, bias=True)

    def forward(self, time):
        args = torch.outer(time.to(torch.float32), self.freqs.to(device=time.device))
        time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        time_embed = F.linear(
            self.activation(
                F.linear(
                    time_embed,
                    self.in_layer.weight.to(torch.float32),
                    self.in_layer.bias.to(torch.float32),
                )
            ),
            self.out_layer.weight.to(torch.float32),
            self.out_layer.bias.to(torch.float32),
        )
        return time_embed

and Kandinsky5Modulation:

class Kandinsky5Modulation(nn.Module):
    def __init__(self, time_dim, model_dim, num_params):
        super().__init__()
        self.activation = nn.SiLU()
        self.out_layer = nn.Linear(time_dim, num_params * model_dim)
        self.out_layer.weight.data.zero_()
        self.out_layer.bias.data.zero_()

    def forward(self, x):
        return F.linear(
            self.activation(x.to(torch.float32)),
            self.out_layer.weight.to(torch.float32),
            self.out_layer.bias.to(torch.float32),
        )

With those changes, image output matches exactly:

Main Fix
kandinsky_before_fix kandinsky_after_fix
PS C:\Users\user\Downloads> certutil -hashfile kandinsky_after_fix.png SHA256
SHA256 hash of kandinsky_after_fix.png:
3fb7319edc17983593d2a1abc0b5ffed418700f5f7f70d450aefd1e225b52143
CertUtil: -hashfile command completed successfully.
PS C:\Users\user\Downloads> certutil -hashfile kandinsky_before_fix.png SHA256
SHA256 hash of kandinsky_before_fix.png:
3fb7319edc17983593d2a1abc0b5ffed418700f5f7f70d450aefd1e225b52143
CertUtil: -hashfile command completed successfully.

Perhaps the changes could be slightly simplified by making use of _keep_in_fp32_modules so we wouldn't need to cast the weights, but we would still need to cast everything else.

@adi776borate
Copy link
Contributor Author

Thanks for the detailed analysis and script @hlky! You are right.

I misunderstood the original author's intent. I assumed they only wanted to protect specific operations (like sin/cos) from overflow.
Regarding _keep_in_fp32_modules: I prefer the manual F.linear approach because it allows us to keep the weights stored in bfloat16/float16 (saving VRAM) and only cast them on-the-fly, whereas _keep_in_fp32_modules would force them to be stored in FP32 permanently.

I'll update the PR with your suggested fix. Thanks again!

- Removed @torch.autocast decorator from Kandinsky classes.
- Implemented manual F.linear casting to ensure numerical parity with FP32.
- Verified bit-exact output matches main branch.

Co-authored-by: hlky <hlky@hlky.ac>
@adi776borate
Copy link
Contributor Author

@yiyixuxu Hi! Just a gentle bump on this.
Let me know if there are any other changes needed! Thanks.

Copy link

@liutyi liutyi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested on Intel platform with sdnext. That works now.

@vladmandic
Copy link
Contributor

bump to review (and merge)?
btw, having a hardcoded pipeline for a month basically results in model being forgotten by general public - most of ppl give up and move on.

@hlky
Copy link
Contributor

hlky commented Jan 6, 2026

@vladmandic Agree, this should not have taken so long to be reviewed/merged. @adi776borate has asked multiple times. Just to reiterate, numerically these changes are now exactly the same as with the autocast context.

@vladmandic
Copy link
Contributor

@vladmandic Agree, this should not have taken so long to be reviewed/merged. @adi776borate has asked multiple times. Just to reiterate, numerically these changes are now exactly the same as with the autocast context.

yes, i can see the difference - so its the question of doing what feels right vs bad original implementation. too many times i've seen hard coded cuda devices, fp32 where its not needed, etc. - imo, there is no visual difference and bf16 clearly uses less vram

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jan 7, 2026

cc @leffff here
are we ok removing the autocast and just go with bf16 instead? it seems like there is no quality difference

the fix here is incredibly hacky and we strongly prefer not to go with it unless there is a strong reason so

@leffff
Copy link
Contributor

leffff commented Jan 7, 2026

I believe we are okay!
I am thankful for such effort put into this fix.
The output matches exactly, so lets remove the autocast.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you change to _keep_in_fp32_modules here instead of the manual F.linear approach.

While I understand the concern about VRAM, this is a single linear layer - the memory difference between storing it in FP32 vs BF16/FP16 is negligible in practice. Using _keep_in_fp32_modules is cleaner, more maintainable, and aligns with our standard. Let's keep it simple

@hlky
Copy link
Contributor

hlky commented Jan 8, 2026

It is negligible, though for transparency the difference in weights across Kandinsky5Modulation and Kandinsky5TimeEmbeddings will be around ~45MB compared to storing in bfloat16.

Using _keep_in_fp32_modules will work, it does produce the exact same results as with the autocast context, but it should be noted that this is because the distributed weights are in bfloat16 meaning precision has already been lost. There would be a difference between autocast context and _keep_in_fp32_modules if the weights were distributed in float32, see the reproduction for further details, just something to keep in mind because sometimes checkpoints are distributed in float32.

Differences:

Autocast -> Original PR Autocast -> float32 weights Autocast -> bfloat16 weights (upcast by _keep_in_fp32_modules) Autocast -> Manual cast
0.00146 0.00103 0.0 0.0
Reproduction

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


def get_freqs(dim, max_period=10000.0):
    freqs = torch.exp(
        -math.log(max_period)
        * torch.arange(start=0, end=dim, dtype=torch.float32)
        / dim
    )
    return freqs


class Kandinsky5TimeEmbeddings(nn.Module):
    def __init__(self, model_dim, time_dim, max_period=10000.0):
        super().__init__()
        assert model_dim % 2 == 0
        self.model_dim = model_dim
        self.max_period = max_period
        self.freqs = get_freqs(self.model_dim // 2, self.max_period)
        self.in_layer = nn.Linear(model_dim, time_dim, bias=True)
        self.activation = nn.SiLU()
        self.out_layer = nn.Linear(time_dim, time_dim, bias=True)

    @torch.autocast(device_type="cuda", dtype=torch.float32)
    def forward(self, time):
        args = torch.outer(time, self.freqs.to(device=time.device))
        time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
        return time_embed


class Kandinsky5TimeEmbeddingsBF16(nn.Module):
    def __init__(self, model_dim, time_dim, max_period=10000.0):
        super().__init__()
        assert model_dim % 2 == 0
        self.model_dim = model_dim
        self.max_period = max_period
        self.freqs = get_freqs(self.model_dim // 2, self.max_period)
        self.in_layer = nn.Linear(model_dim, time_dim, bias=True)
        self.activation = nn.SiLU()
        self.out_layer = nn.Linear(time_dim, time_dim, bias=True)

    def forward(self, time):
        time = time.to(dtype=torch.float32)
        freqs = self.freqs.to(device=time.device, dtype=torch.float32)
        args = torch.outer(time, freqs)
        time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        time_embed = time_embed.to(dtype=self.in_layer.weight.dtype)
        time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
        return time_embed


class Kandinsky5TimeEmbeddingsNoAutocast(nn.Module):
    def __init__(self, model_dim, time_dim, max_period=10000.0):
        super().__init__()
        assert model_dim % 2 == 0
        self.model_dim = model_dim
        self.max_period = max_period
        self.freqs = get_freqs(self.model_dim // 2, self.max_period)
        self.in_layer = nn.Linear(model_dim, time_dim, bias=True)
        self.activation = nn.SiLU()
        self.out_layer = nn.Linear(time_dim, time_dim, bias=True)

    def forward(self, time):
        args = torch.outer(time.to(torch.float32), self.freqs.to(device=time.device))
        time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        time_embed = F.linear(
            self.activation(
                F.linear(
                    time_embed,
                    self.in_layer.weight.to(torch.float32),
                    self.in_layer.bias.to(torch.float32),
                )
            ),
            self.out_layer.weight.to(torch.float32),
            self.out_layer.bias.to(torch.float32),
        )
        return time_embed


class Kandinsky5TimeEmbeddingsKeepModules(nn.Module):
    def __init__(self, model_dim, time_dim, max_period=10000.0):
        super().__init__()
        assert model_dim % 2 == 0
        self.model_dim = model_dim
        self.max_period = max_period
        self.freqs = get_freqs(self.model_dim // 2, self.max_period)
        self.in_layer = nn.Linear(model_dim, time_dim, bias=True)
        self.activation = nn.SiLU()
        self.out_layer = nn.Linear(time_dim, time_dim, bias=True)

    def forward(self, time):
        args = torch.outer(time.to(torch.float32), self.freqs.to(device=time.device))
        time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
        return time_embed


torch.manual_seed(0)
with_autocast = (
    Kandinsky5TimeEmbeddings(model_dim=2560, time_dim=512)
    .to("cuda", torch.bfloat16)
    .eval()
)
torch.manual_seed(0)
pr = (
    Kandinsky5TimeEmbeddingsBF16(model_dim=2560, time_dim=512)
    .to("cuda", torch.bfloat16)
    .eval()
)
torch.manual_seed(0)
no_autocast = (
    Kandinsky5TimeEmbeddingsNoAutocast(model_dim=2560, time_dim=512)
    .to("cuda", torch.bfloat16)
    .eval()
)
torch.manual_seed(0)
keep_modules = (
    Kandinsky5TimeEmbeddingsKeepModules(model_dim=2560, time_dim=512).to("cuda").eval()
)
torch.manual_seed(0)
keep_modules_cast_from_bf16 = (
    Kandinsky5TimeEmbeddingsKeepModules(model_dim=2560, time_dim=512)
    .to("cuda", torch.bfloat16)
    .to(torch.float32)
    .eval()
)


with torch.no_grad():
    time = torch.tensor([952.0]).to("cuda", torch.bfloat16)
    with_out = with_autocast(time.clone())
    bf16_out = pr(time.clone())
    no_out = no_autocast(time.clone())
    keep_out = keep_modules(time.clone())
    keep_bf16_out = keep_modules_cast_from_bf16(time.clone())

diff_autocast_bf16 = (with_out - bf16_out).abs().max().item()
diff_autocast_keep = (with_out - keep_out).abs().max().item()
diff_autocast_no = (with_out - no_out).abs().max().item()
diff_autocast_keep_bf16 = (with_out - keep_bf16_out).abs().max().item()
print(f"{diff_autocast_bf16=}")
print(f"{diff_autocast_keep=}")
print(f"{diff_autocast_no=}")
print(f"{diff_autocast_keep_bf16=}")

_keep_in_fp32_modules patch

diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py
index c841cc522..57b28991d 100644
--- a/src/diffusers/models/transformers/transformer_kandinsky.py
+++ b/src/diffusers/models/transformers/transformer_kandinsky.py
@@ -168,17 +168,7 @@ class Kandinsky5TimeEmbeddings(nn.Module):
     def forward(self, time):
         args = torch.outer(time.to(torch.float32), self.freqs.to(device=time.device))
         time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
-        time_embed = F.linear(
-            self.activation(
-                F.linear(
-                    time_embed,
-                    self.in_layer.weight.to(torch.float32),
-                    self.in_layer.bias.to(torch.float32),
-                )
-            ),
-            self.out_layer.weight.to(torch.float32),
-            self.out_layer.bias.to(torch.float32),
-        )
+        time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
         return time_embed
 
 
@@ -279,11 +269,7 @@ class Kandinsky5Modulation(nn.Module):
         self.out_layer.bias.data.zero_()
 
     def forward(self, x):
-        return F.linear(
-            self.activation(x.to(torch.float32)),
-            self.out_layer.weight.to(torch.float32),
-            self.out_layer.bias.to(torch.float32),
-        )
+        return self.out_layer(self.activation(x))
 
 
 class Kandinsky5AttnProcessor:
@@ -537,6 +523,7 @@ class Kandinsky5Transformer3DModel(
         "Kandinsky5TransformerEncoderBlock",
         "Kandinsky5TransformerDecoderBlock",
     ]
+    _keep_in_fp32_modules = ["time_embeddings", "modulation", "visual_modulation", "text_modulation"]
     _supports_gradient_checkpointing = True
 
     @register_to_config

Note on the patch: ["time_embeddings", "modulation"] would also work, _keep_in_fp32_modules checks these are in the parameter name, including visual_modulation and text_modulation is more explicit though.

@leffff
Copy link
Contributor

leffff commented Jan 8, 2026

@adi776borate please provide some examples of video generation. Because changes in DiT will also affect t2v and i2v generation.

@adi776borate
Copy link
Contributor Author

adi776borate commented Jan 8, 2026

@yiyixuxu @leffff

Hi! I applied the _keep_in_fp32_modules as suggested by @yiyixuxu, including the visual_modulation and text_modulation as suggested by @hlky.

I also tested both image and video generation. The images remain bit-exact. The videos are visually similar, but their are some minor differences and hashes also differ. Can anyone point out the reason?

⚡ main ~/diffusers shasum -a 256 output_before-fix_lite_seed42.mp4
a6721db6d7fdf3da9187d01ae20f8d9c9522571d47726722ede823269cbefbeb  output_before-fix_lite_seed42.mp4
⚡ main ~/diffusers shasum -a 256 output_after-fix_lite_seed42.mp4
9901bcc90945246ad2d58f12e77ff90c6cdc76591138579ec793900260e5a736  output_after-fix_lite_seed42.mp4
⚡ main ~/diffusers shasum -a 256 kandinsky_main.png              
b041b0ef949625255cd863a4a581056b376284f7ec4c8e942150485d0737f511  kandinsky_main.png
⚡ main ~/diffusers shasum -a 256 kandinsky_after_keepfp32.png
b041b0ef949625255cd863a4a581056b376284f7ec4c8e942150485d0737f511  kandinsky_after_keepfp32.png

Before (Main branch):
Image

kandinsky_main

Video

output_before-fix_lite_seed42.mp4

After _keep_in_fp32_modules:
Image

kandinsky_after_keepfp32

Video

output_after-fix_lite_seed42.mp4

Ready for review!

@adi776borate
Copy link
Contributor Author

import torch
from diffusers import Kandinsky5T2VPipeline
from diffusers.utils import export_to_video

# Load the pipeline
model_id = "kandinskylab/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers"
device = "cuda" if torch.cuda.is_available() else "cpu"

seed = 42
generator = torch.Generator(device=device).manual_seed(seed)

pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")

# Generate video
prompt = "A cat and a dog baking a cake together in a kitchen."
negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"

output = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=512,
    width=768,
    num_frames=121,  # ~5 seconds at 24fps
    num_inference_steps=50,
    guidance_scale=5.0,
    generator=generator,
).frames[0]

export_to_video(output, "output_before-fix_lite_seed42.mp4", fps=24, quality=9)

Note: I verified T2V as shown above. I was unable to verify I2V locally due to VRAM constraints.

@leffff
Copy link
Contributor

leffff commented Jan 8, 2026

@adi776borate Thanks! Great job! Thank you for your effort. everything looks good. lets merge

@hlky
Copy link
Contributor

hlky commented Jan 8, 2026

@adi776borate

Can anyone point out the reason?

It's as I described here, kandinskylab/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers this checkpoint (and the other video checkpoints I've looked at) has time_embeddings and modulation layers in float32 in the checkpoint. Before these layers were being downcast to bfloat16, then upcast by the the autocast context, the downcast causes precision to be lost. Now these layers remain in float32, precision is not lost, this causes the difference in outputs.
image

@yiyixuxu yiyixuxu merged commit 8b9f817 into huggingface:main Jan 8, 2026
10 of 11 checks passed
@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jan 8, 2026

thanks a lot for working on this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Kandinsky5TimeEmbeddings hardcodes 'cuda' in @torch.autocast decorator, causing warning on non-CUDA systems

9 participants