-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Fix: Remove hardcoded CUDA autocast in Kandinsky 5 to fix import warning #12814
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix: Remove hardcoded CUDA autocast in Kandinsky 5 to fix import warning #12814
Conversation
|
Looks good to me! |
|
Thanks for the quick fix! I didn't have time to submit a PR myself, so I really appreciate you jumping on this. 🙏 |
|
@yiyixuxu @sayakpaul |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! Could you also provide your testing script?
The verification script is already provided in the PR description above. from diffusers.models.transformers import transformer_kandinsky
print("Import successful.")Should print a UserWarning on main, but not on this branch. |
yiyixuxu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
|
@bot /style |
|
Style bot fixed some files and pushed the changes. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| @torch.autocast(device_type="cuda", dtype=torch.float32) | ||
| def forward(self, time): | ||
| args = torch.outer(time, self.freqs.to(device=time.device)) | ||
| time = time.to(dtype=torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| time = time.to(dtype=torch.float32) | |
| origintal_dtype = time.dtype | |
| time = time.to(dtype=torch.float32) |
| freqs = self.freqs.to(device=time.device, dtype=torch.float32) | ||
| args = torch.outer(time, freqs) | ||
| time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | ||
| time_embed = time_embed.to(dtype=self.in_layer.weight.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| time_embed = time_embed.to(dtype=self.in_layer.weight.dtype) | |
| time_embed = time_embed.to(dtype=original_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason I cast to self.in_layer.weight.dtype instead of original_dtype is to prevent runtime crashes on backends like XPU as mentioned by @vladmandic here.
If users load the pipeline in float16, and we pass time_embed as float32, that will raise an error, won't it?
I might be wrong, correct me if so.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi! I also tried to apply your suggested change and ran the exact same code attached in the PR description above on an L40S (which supports bf16), and faced below error:
Traceback (most recent call last):
File "/teamspace/studios/this_studio/diffusers/verify_fix.py", line 25, in <module>
output = pipe(
^^^^^
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/teamspace/studios/this_studio/diffusers/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky_t2i.py", line 731, in __call__
pred_velocity = self.transformer(
^^^^^^^^^^^^^^^^^
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/transformers/transformer_kandinsky.py", line 647, in forward
text_embed = text_transformer_block(text_embed, time_embed, text_rope)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/transformers/transformer_kandinsky.py", line 461, in forward
self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/transformers/transformer_kandinsky.py", line 280, in forward
return self.out_layer(self.activation(x))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 125, in forward
return F.linear(input, self.weight, self.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: mat1 and mat2 must have the same dtype, but got Float and BFloat16
Thus, I think casting to self.in_layer.weight.dtype is a safer option.
Please let me know your thoughts.
|
|
||
| @torch.autocast(device_type="cuda", dtype=torch.float32) | ||
| def forward(self, x): | ||
| x = x.to(dtype=self.out_layer.weight.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
umm actually this did not look correct to me - we want to upcast it to float32, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, if we force x to float32 here, we might hit the same mismatch crash if the out_layer weights are float16/bfloat16.
|
Thanks for the detailed analysis and script @hlky! You are right. I misunderstood the original author's intent. I assumed they only wanted to protect specific operations (like sin/cos) from overflow. I'll update the PR with your suggested fix. Thanks again! |
- Removed @torch.autocast decorator from Kandinsky classes. - Implemented manual F.linear casting to ensure numerical parity with FP32. - Verified bit-exact output matches main branch. Co-authored-by: hlky <hlky@hlky.ac>
|
@yiyixuxu Hi! Just a gentle bump on this. |
liutyi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tested on Intel platform with sdnext. That works now.
|
bump to review (and merge)? |
|
@vladmandic Agree, this should not have taken so long to be reviewed/merged. @adi776borate has asked multiple times. Just to reiterate, numerically these changes are now exactly the same as with the autocast context. |
yes, i can see the difference - so its the question of doing what feels right vs bad original implementation. too many times i've seen hard coded cuda devices, fp32 where its not needed, etc. - imo, there is no visual difference and bf16 clearly uses less vram |
|
cc @leffff here the fix here is incredibly hacky and we strongly prefer not to go with it unless there is a strong reason so |
|
I believe we are okay! |
yiyixuxu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you change to _keep_in_fp32_modules here instead of the manual F.linear approach.
While I understand the concern about VRAM, this is a single linear layer - the memory difference between storing it in FP32 vs BF16/FP16 is negligible in practice. Using _keep_in_fp32_modules is cleaner, more maintainable, and aligns with our standard. Let's keep it simple
|
It is negligible, though for transparency the difference in weights across Using Differences:
Reproduction
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def get_freqs(dim, max_period=10000.0):
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=dim, dtype=torch.float32)
/ dim
)
return freqs
class Kandinsky5TimeEmbeddings(nn.Module):
def __init__(self, model_dim, time_dim, max_period=10000.0):
super().__init__()
assert model_dim % 2 == 0
self.model_dim = model_dim
self.max_period = max_period
self.freqs = get_freqs(self.model_dim // 2, self.max_period)
self.in_layer = nn.Linear(model_dim, time_dim, bias=True)
self.activation = nn.SiLU()
self.out_layer = nn.Linear(time_dim, time_dim, bias=True)
@torch.autocast(device_type="cuda", dtype=torch.float32)
def forward(self, time):
args = torch.outer(time, self.freqs.to(device=time.device))
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
return time_embed
class Kandinsky5TimeEmbeddingsBF16(nn.Module):
def __init__(self, model_dim, time_dim, max_period=10000.0):
super().__init__()
assert model_dim % 2 == 0
self.model_dim = model_dim
self.max_period = max_period
self.freqs = get_freqs(self.model_dim // 2, self.max_period)
self.in_layer = nn.Linear(model_dim, time_dim, bias=True)
self.activation = nn.SiLU()
self.out_layer = nn.Linear(time_dim, time_dim, bias=True)
def forward(self, time):
time = time.to(dtype=torch.float32)
freqs = self.freqs.to(device=time.device, dtype=torch.float32)
args = torch.outer(time, freqs)
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
time_embed = time_embed.to(dtype=self.in_layer.weight.dtype)
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
return time_embed
class Kandinsky5TimeEmbeddingsNoAutocast(nn.Module):
def __init__(self, model_dim, time_dim, max_period=10000.0):
super().__init__()
assert model_dim % 2 == 0
self.model_dim = model_dim
self.max_period = max_period
self.freqs = get_freqs(self.model_dim // 2, self.max_period)
self.in_layer = nn.Linear(model_dim, time_dim, bias=True)
self.activation = nn.SiLU()
self.out_layer = nn.Linear(time_dim, time_dim, bias=True)
def forward(self, time):
args = torch.outer(time.to(torch.float32), self.freqs.to(device=time.device))
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
time_embed = F.linear(
self.activation(
F.linear(
time_embed,
self.in_layer.weight.to(torch.float32),
self.in_layer.bias.to(torch.float32),
)
),
self.out_layer.weight.to(torch.float32),
self.out_layer.bias.to(torch.float32),
)
return time_embed
class Kandinsky5TimeEmbeddingsKeepModules(nn.Module):
def __init__(self, model_dim, time_dim, max_period=10000.0):
super().__init__()
assert model_dim % 2 == 0
self.model_dim = model_dim
self.max_period = max_period
self.freqs = get_freqs(self.model_dim // 2, self.max_period)
self.in_layer = nn.Linear(model_dim, time_dim, bias=True)
self.activation = nn.SiLU()
self.out_layer = nn.Linear(time_dim, time_dim, bias=True)
def forward(self, time):
args = torch.outer(time.to(torch.float32), self.freqs.to(device=time.device))
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
return time_embed
torch.manual_seed(0)
with_autocast = (
Kandinsky5TimeEmbeddings(model_dim=2560, time_dim=512)
.to("cuda", torch.bfloat16)
.eval()
)
torch.manual_seed(0)
pr = (
Kandinsky5TimeEmbeddingsBF16(model_dim=2560, time_dim=512)
.to("cuda", torch.bfloat16)
.eval()
)
torch.manual_seed(0)
no_autocast = (
Kandinsky5TimeEmbeddingsNoAutocast(model_dim=2560, time_dim=512)
.to("cuda", torch.bfloat16)
.eval()
)
torch.manual_seed(0)
keep_modules = (
Kandinsky5TimeEmbeddingsKeepModules(model_dim=2560, time_dim=512).to("cuda").eval()
)
torch.manual_seed(0)
keep_modules_cast_from_bf16 = (
Kandinsky5TimeEmbeddingsKeepModules(model_dim=2560, time_dim=512)
.to("cuda", torch.bfloat16)
.to(torch.float32)
.eval()
)
with torch.no_grad():
time = torch.tensor([952.0]).to("cuda", torch.bfloat16)
with_out = with_autocast(time.clone())
bf16_out = pr(time.clone())
no_out = no_autocast(time.clone())
keep_out = keep_modules(time.clone())
keep_bf16_out = keep_modules_cast_from_bf16(time.clone())
diff_autocast_bf16 = (with_out - bf16_out).abs().max().item()
diff_autocast_keep = (with_out - keep_out).abs().max().item()
diff_autocast_no = (with_out - no_out).abs().max().item()
diff_autocast_keep_bf16 = (with_out - keep_bf16_out).abs().max().item()
print(f"{diff_autocast_bf16=}")
print(f"{diff_autocast_keep=}")
print(f"{diff_autocast_no=}")
print(f"{diff_autocast_keep_bf16=}")_keep_in_fp32_modules patch
diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py
index c841cc522..57b28991d 100644
--- a/src/diffusers/models/transformers/transformer_kandinsky.py
+++ b/src/diffusers/models/transformers/transformer_kandinsky.py
@@ -168,17 +168,7 @@ class Kandinsky5TimeEmbeddings(nn.Module):
def forward(self, time):
args = torch.outer(time.to(torch.float32), self.freqs.to(device=time.device))
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
- time_embed = F.linear(
- self.activation(
- F.linear(
- time_embed,
- self.in_layer.weight.to(torch.float32),
- self.in_layer.bias.to(torch.float32),
- )
- ),
- self.out_layer.weight.to(torch.float32),
- self.out_layer.bias.to(torch.float32),
- )
+ time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
return time_embed
@@ -279,11 +269,7 @@ class Kandinsky5Modulation(nn.Module):
self.out_layer.bias.data.zero_()
def forward(self, x):
- return F.linear(
- self.activation(x.to(torch.float32)),
- self.out_layer.weight.to(torch.float32),
- self.out_layer.bias.to(torch.float32),
- )
+ return self.out_layer(self.activation(x))
class Kandinsky5AttnProcessor:
@@ -537,6 +523,7 @@ class Kandinsky5Transformer3DModel(
"Kandinsky5TransformerEncoderBlock",
"Kandinsky5TransformerDecoderBlock",
]
+ _keep_in_fp32_modules = ["time_embeddings", "modulation", "visual_modulation", "text_modulation"]
_supports_gradient_checkpointing = True
@register_to_configNote on the patch: |
|
@adi776borate please provide some examples of video generation. Because changes in DiT will also affect t2v and i2v generation. |
|
Hi! I applied the I also tested both image and video generation. The images remain bit-exact. The videos are visually similar, but their are some minor differences and hashes also differ. Can anyone point out the reason? Before (Main branch):
Video output_before-fix_lite_seed42.mp4After
Video output_after-fix_lite_seed42.mp4Ready for review! |
import torch
from diffusers import Kandinsky5T2VPipeline
from diffusers.utils import export_to_video
# Load the pipeline
model_id = "kandinskylab/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers"
device = "cuda" if torch.cuda.is_available() else "cpu"
seed = 42
generator = torch.Generator(device=device).manual_seed(seed)
pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
# Generate video
prompt = "A cat and a dog baking a cake together in a kitchen."
negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=512,
width=768,
num_frames=121, # ~5 seconds at 24fps
num_inference_steps=50,
guidance_scale=5.0,
generator=generator,
).frames[0]
export_to_video(output, "output_before-fix_lite_seed42.mp4", fps=24, quality=9)Note: I verified T2V as shown above. I was unable to verify I2V locally due to VRAM constraints. |
|
@adi776borate Thanks! Great job! Thank you for your effort. everything looks good. lets merge |
It's as I described here, |
|
thanks a lot for working on this! |





What does this PR do?
Fixes #12809
This PR fixes it by:
Verification
I verified that the results remain stable before and after this change by generating images with a fixed seed (
generator=torch.manual_seed(42)).The results are almost the same with some minor differences.
Reproduction Script
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@yiyixuxu @leffff
Anyone in the community is free to review the PR once the tests have passed.