Skip to content

Commit bbfae59

Browse files
committed
Enhance Cosmos 2.5 model with distributed attention support
- Integrated DistributedAttention and LocalAttention for flexible backend support in Cosmos25SelfAttention and Cosmos25CrossAttention classes. - Updated attention computation to handle both distributed and local scenarios. - Refactored attention backend initialization to check for distributed environment. - Cleaned up unused comments and improved code readability in the test suite for Cosmos 2.5. This update improves the model's adaptability to different hardware configurations while maintaining performance.
1 parent 3dfef6f commit bbfae59

File tree

3 files changed

+205
-121
lines changed

3 files changed

+205
-121
lines changed

fastvideo/configs/models/dits/cosmos2_5.py

Lines changed: 100 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -11,74 +11,110 @@ def is_transformer_blocks(n: str, m) -> bool:
1111
@dataclass
1212
class Cosmos25ArchConfig(DiTArchConfig):
1313
"""Configuration for Cosmos 2.5 architecture (MiniTrainDIT)."""
14-
14+
1515
_fsdp_shard_conditions: list = field(
1616
default_factory=lambda: [is_transformer_blocks])
1717

1818
param_names_mapping: dict = field(
1919
default_factory=lambda: {
2020
# Remove "net." prefix and map official structure to FastVideo
2121
# Patch embedding: net.x_embedder.proj.1.weight -> patch_embed.proj.weight
22-
r"^net\.x_embedder\.proj\.1\.(.*)$": r"patch_embed.proj.\1",
23-
22+
r"^net\.x_embedder\.proj\.1\.(.*)$":
23+
r"patch_embed.proj.\1",
24+
2425
# Time embedding: net.t_embedder.1.linear_1.weight -> time_embed.t_embedder.linear_1.weight
25-
r"^net\.t_embedder\.1\.linear_1\.(.*)$": r"time_embed.t_embedder.linear_1.\1",
26-
r"^net\.t_embedder\.1\.linear_2\.(.*)$": r"time_embed.t_embedder.linear_2.\1",
26+
r"^net\.t_embedder\.1\.linear_1\.(.*)$":
27+
r"time_embed.t_embedder.linear_1.\1",
28+
r"^net\.t_embedder\.1\.linear_2\.(.*)$":
29+
r"time_embed.t_embedder.linear_2.\1",
2730
# Time embedding norm: net.t_embedding_norm.weight -> time_embed.norm.weight
2831
# Note: This also handles _extra_state if present
29-
r"^net\.t_embedding_norm\.(.*)$": r"time_embed.norm.\1",
30-
32+
r"^net\.t_embedding_norm\.(.*)$":
33+
r"time_embed.norm.\1",
34+
3135
# Cross-attention projection (optional): net.crossattn_proj.0.weight -> crossattn_proj.0.weight
32-
r"^net\.crossattn_proj\.0\.weight$": r"crossattn_proj.0.weight",
33-
r"^net\.crossattn_proj\.0\.bias$": r"crossattn_proj.0.bias",
34-
36+
r"^net\.crossattn_proj\.0\.weight$":
37+
r"crossattn_proj.0.weight",
38+
r"^net\.crossattn_proj\.0\.bias$":
39+
r"crossattn_proj.0.bias",
40+
3541
# Transformer blocks: net.blocks.N -> transformer_blocks.N
3642
# Self-attention (self_attn -> attn1)
37-
r"^net\.blocks\.(\d+)\.self_attn\.q_proj\.(.*)$": r"transformer_blocks.\1.attn1.to_q.\2",
38-
r"^net\.blocks\.(\d+)\.self_attn\.k_proj\.(.*)$": r"transformer_blocks.\1.attn1.to_k.\2",
39-
r"^net\.blocks\.(\d+)\.self_attn\.v_proj\.(.*)$": r"transformer_blocks.\1.attn1.to_v.\2",
40-
r"^net\.blocks\.(\d+)\.self_attn\.output_proj\.(.*)$": r"transformer_blocks.\1.attn1.to_out.\2",
41-
r"^net\.blocks\.(\d+)\.self_attn\.q_norm\.weight$": r"transformer_blocks.\1.attn1.norm_q.weight",
42-
r"^net\.blocks\.(\d+)\.self_attn\.k_norm\.weight$": r"transformer_blocks.\1.attn1.norm_k.weight",
43+
r"^net\.blocks\.(\d+)\.self_attn\.q_proj\.(.*)$":
44+
r"transformer_blocks.\1.attn1.to_q.\2",
45+
r"^net\.blocks\.(\d+)\.self_attn\.k_proj\.(.*)$":
46+
r"transformer_blocks.\1.attn1.to_k.\2",
47+
r"^net\.blocks\.(\d+)\.self_attn\.v_proj\.(.*)$":
48+
r"transformer_blocks.\1.attn1.to_v.\2",
49+
r"^net\.blocks\.(\d+)\.self_attn\.output_proj\.(.*)$":
50+
r"transformer_blocks.\1.attn1.to_out.\2",
51+
r"^net\.blocks\.(\d+)\.self_attn\.q_norm\.weight$":
52+
r"transformer_blocks.\1.attn1.norm_q.weight",
53+
r"^net\.blocks\.(\d+)\.self_attn\.k_norm\.weight$":
54+
r"transformer_blocks.\1.attn1.norm_k.weight",
4355
# RMSNorm _extra_state keys (internal PyTorch state, will be recomputed automatically)
44-
r"^net\.blocks\.(\d+)\.self_attn\.q_norm\._extra_state$": r"transformer_blocks.\1.attn1.norm_q._extra_state",
45-
r"^net\.blocks\.(\d+)\.self_attn\.k_norm\._extra_state$": r"transformer_blocks.\1.attn1.norm_k._extra_state",
46-
56+
r"^net\.blocks\.(\d+)\.self_attn\.q_norm\._extra_state$":
57+
r"transformer_blocks.\1.attn1.norm_q._extra_state",
58+
r"^net\.blocks\.(\d+)\.self_attn\.k_norm\._extra_state$":
59+
r"transformer_blocks.\1.attn1.norm_k._extra_state",
60+
4761
# Cross-attention (cross_attn -> attn2)
48-
r"^net\.blocks\.(\d+)\.cross_attn\.q_proj\.(.*)$": r"transformer_blocks.\1.attn2.to_q.\2",
49-
r"^net\.blocks\.(\d+)\.cross_attn\.k_proj\.(.*)$": r"transformer_blocks.\1.attn2.to_k.\2",
50-
r"^net\.blocks\.(\d+)\.cross_attn\.v_proj\.(.*)$": r"transformer_blocks.\1.attn2.to_v.\2",
51-
r"^net\.blocks\.(\d+)\.cross_attn\.output_proj\.(.*)$": r"transformer_blocks.\1.attn2.to_out.\2",
52-
r"^net\.blocks\.(\d+)\.cross_attn\.q_norm\.weight$": r"transformer_blocks.\1.attn2.norm_q.weight",
53-
r"^net\.blocks\.(\d+)\.cross_attn\.k_norm\.weight$": r"transformer_blocks.\1.attn2.norm_k.weight",
62+
r"^net\.blocks\.(\d+)\.cross_attn\.q_proj\.(.*)$":
63+
r"transformer_blocks.\1.attn2.to_q.\2",
64+
r"^net\.blocks\.(\d+)\.cross_attn\.k_proj\.(.*)$":
65+
r"transformer_blocks.\1.attn2.to_k.\2",
66+
r"^net\.blocks\.(\d+)\.cross_attn\.v_proj\.(.*)$":
67+
r"transformer_blocks.\1.attn2.to_v.\2",
68+
r"^net\.blocks\.(\d+)\.cross_attn\.output_proj\.(.*)$":
69+
r"transformer_blocks.\1.attn2.to_out.\2",
70+
r"^net\.blocks\.(\d+)\.cross_attn\.q_norm\.weight$":
71+
r"transformer_blocks.\1.attn2.norm_q.weight",
72+
r"^net\.blocks\.(\d+)\.cross_attn\.k_norm\.weight$":
73+
r"transformer_blocks.\1.attn2.norm_k.weight",
5474
# RMSNorm _extra_state keys for cross-attention
55-
r"^net\.blocks\.(\d+)\.cross_attn\.q_norm\._extra_state$": r"transformer_blocks.\1.attn2.norm_q._extra_state",
56-
r"^net\.blocks\.(\d+)\.cross_attn\.k_norm\._extra_state$": r"transformer_blocks.\1.attn2.norm_k._extra_state",
57-
75+
r"^net\.blocks\.(\d+)\.cross_attn\.q_norm\._extra_state$":
76+
r"transformer_blocks.\1.attn2.norm_q._extra_state",
77+
r"^net\.blocks\.(\d+)\.cross_attn\.k_norm\._extra_state$":
78+
r"transformer_blocks.\1.attn2.norm_k._extra_state",
79+
5880
# MLP: net.blocks.N.mlp.layer1 -> transformer_blocks.N.mlp.fc_in
59-
r"^net\.blocks\.(\d+)\.mlp\.layer1\.(.*)$": r"transformer_blocks.\1.mlp.fc_in.\2",
60-
r"^net\.blocks\.(\d+)\.mlp\.layer2\.(.*)$": r"transformer_blocks.\1.mlp.fc_out.\2",
61-
81+
r"^net\.blocks\.(\d+)\.mlp\.layer1\.(.*)$":
82+
r"transformer_blocks.\1.mlp.fc_in.\2",
83+
r"^net\.blocks\.(\d+)\.mlp\.layer2\.(.*)$":
84+
r"transformer_blocks.\1.mlp.fc_out.\2",
85+
6286
# AdaLN-LoRA modulations: net.blocks.N.adaln_modulation_* -> transformer_blocks.N.adaln_modulation_*
6387
# These are now at the block level, not inside norm layers
64-
r"^net\.blocks\.(\d+)\.adaln_modulation_self_attn\.1\.(.*)$": r"transformer_blocks.\1.adaln_modulation_self_attn.1.\2",
65-
r"^net\.blocks\.(\d+)\.adaln_modulation_self_attn\.2\.(.*)$": r"transformer_blocks.\1.adaln_modulation_self_attn.2.\2",
66-
r"^net\.blocks\.(\d+)\.adaln_modulation_cross_attn\.1\.(.*)$": r"transformer_blocks.\1.adaln_modulation_cross_attn.1.\2",
67-
r"^net\.blocks\.(\d+)\.adaln_modulation_cross_attn\.2\.(.*)$": r"transformer_blocks.\1.adaln_modulation_cross_attn.2.\2",
68-
r"^net\.blocks\.(\d+)\.adaln_modulation_mlp\.1\.(.*)$": r"transformer_blocks.\1.adaln_modulation_mlp.1.\2",
69-
r"^net\.blocks\.(\d+)\.adaln_modulation_mlp\.2\.(.*)$": r"transformer_blocks.\1.adaln_modulation_mlp.2.\2",
70-
88+
r"^net\.blocks\.(\d+)\.adaln_modulation_self_attn\.1\.(.*)$":
89+
r"transformer_blocks.\1.adaln_modulation_self_attn.1.\2",
90+
r"^net\.blocks\.(\d+)\.adaln_modulation_self_attn\.2\.(.*)$":
91+
r"transformer_blocks.\1.adaln_modulation_self_attn.2.\2",
92+
r"^net\.blocks\.(\d+)\.adaln_modulation_cross_attn\.1\.(.*)$":
93+
r"transformer_blocks.\1.adaln_modulation_cross_attn.1.\2",
94+
r"^net\.blocks\.(\d+)\.adaln_modulation_cross_attn\.2\.(.*)$":
95+
r"transformer_blocks.\1.adaln_modulation_cross_attn.2.\2",
96+
r"^net\.blocks\.(\d+)\.adaln_modulation_mlp\.1\.(.*)$":
97+
r"transformer_blocks.\1.adaln_modulation_mlp.1.\2",
98+
r"^net\.blocks\.(\d+)\.adaln_modulation_mlp\.2\.(.*)$":
99+
r"transformer_blocks.\1.adaln_modulation_mlp.2.\2",
100+
71101
# Layer norms: net.blocks.N.layer_norm_* -> transformer_blocks.N.norm*.norm
72-
r"^net\.blocks\.(\d+)\.layer_norm_self_attn\._extra_state$": r"transformer_blocks.\1.norm1.norm._extra_state",
73-
r"^net\.blocks\.(\d+)\.layer_norm_cross_attn\._extra_state$": r"transformer_blocks.\1.norm2.norm._extra_state",
74-
r"^net\.blocks\.(\d+)\.layer_norm_mlp\._extra_state$": r"transformer_blocks.\1.norm3.norm._extra_state",
75-
102+
r"^net\.blocks\.(\d+)\.layer_norm_self_attn\._extra_state$":
103+
r"transformer_blocks.\1.norm1.norm._extra_state",
104+
r"^net\.blocks\.(\d+)\.layer_norm_cross_attn\._extra_state$":
105+
r"transformer_blocks.\1.norm2.norm._extra_state",
106+
r"^net\.blocks\.(\d+)\.layer_norm_mlp\._extra_state$":
107+
r"transformer_blocks.\1.norm3.norm._extra_state",
108+
76109
# Final layer: net.final_layer.linear -> final_layer.proj_out
77-
r"^net\.final_layer\.linear\.(.*)$": r"final_layer.proj_out.\1",
110+
r"^net\.final_layer\.linear\.(.*)$":
111+
r"final_layer.proj_out.\1",
78112
# Final layer AdaLN-LoRA: net.final_layer.adaln_modulation -> final_layer.linear_*
79-
r"^net\.final_layer\.adaln_modulation\.1\.(.*)$": r"final_layer.linear_1.\1",
80-
r"^net\.final_layer\.adaln_modulation\.2\.(.*)$": r"final_layer.linear_2.\1",
81-
113+
r"^net\.final_layer\.adaln_modulation\.1\.(.*)$":
114+
r"final_layer.linear_1.\1",
115+
r"^net\.final_layer\.adaln_modulation\.2\.(.*)$":
116+
r"final_layer.linear_2.\1",
117+
82118
# Note: The following keys from official checkpoint are NOT mapped and can be safely ignored:
83119
# - net.pos_embedder.* (seq, dim_spatial_range, dim_temporal_range) - These are computed dynamically
84120
# in FastVideo's Cosmos25RotaryPosEmbed forward() method, so they don't need to be loaded.
@@ -87,15 +123,24 @@ class Cosmos25ArchConfig(DiTArchConfig):
87123

88124
lora_param_names_mapping: dict = field(
89125
default_factory=lambda: {
90-
r"^transformer_blocks\.(\d+)\.attn1\.to_q\.(.*)$": r"transformer_blocks.\1.attn1.to_q.\2",
91-
r"^transformer_blocks\.(\d+)\.attn1\.to_k\.(.*)$": r"transformer_blocks.\1.attn1.to_k.\2",
92-
r"^transformer_blocks\.(\d+)\.attn1\.to_v\.(.*)$": r"transformer_blocks.\1.attn1.to_v.\2",
93-
r"^transformer_blocks\.(\d+)\.attn1\.to_out\.(.*)$": r"transformer_blocks.\1.attn1.to_out.\2",
94-
r"^transformer_blocks\.(\d+)\.attn2\.to_q\.(.*)$": r"transformer_blocks.\1.attn2.to_q.\2",
95-
r"^transformer_blocks\.(\d+)\.attn2\.to_k\.(.*)$": r"transformer_blocks.\1.attn2.to_k.\2",
96-
r"^transformer_blocks\.(\d+)\.attn2\.to_v\.(.*)$": r"transformer_blocks.\1.attn2.to_v.\2",
97-
r"^transformer_blocks\.(\d+)\.attn2\.to_out\.(.*)$": r"transformer_blocks.\1.attn2.to_out.\2",
98-
r"^transformer_blocks\.(\d+)\.mlp\.(.*)$": r"transformer_blocks.\1.mlp.\2",
126+
r"^transformer_blocks\.(\d+)\.attn1\.to_q\.(.*)$":
127+
r"transformer_blocks.\1.attn1.to_q.\2",
128+
r"^transformer_blocks\.(\d+)\.attn1\.to_k\.(.*)$":
129+
r"transformer_blocks.\1.attn1.to_k.\2",
130+
r"^transformer_blocks\.(\d+)\.attn1\.to_v\.(.*)$":
131+
r"transformer_blocks.\1.attn1.to_v.\2",
132+
r"^transformer_blocks\.(\d+)\.attn1\.to_out\.(.*)$":
133+
r"transformer_blocks.\1.attn1.to_out.\2",
134+
r"^transformer_blocks\.(\d+)\.attn2\.to_q\.(.*)$":
135+
r"transformer_blocks.\1.attn2.to_q.\2",
136+
r"^transformer_blocks\.(\d+)\.attn2\.to_k\.(.*)$":
137+
r"transformer_blocks.\1.attn2.to_k.\2",
138+
r"^transformer_blocks\.(\d+)\.attn2\.to_v\.(.*)$":
139+
r"transformer_blocks.\1.attn2.to_v.\2",
140+
r"^transformer_blocks\.(\d+)\.attn2\.to_out\.(.*)$":
141+
r"transformer_blocks.\1.attn2.to_out.\2",
142+
r"^transformer_blocks\.(\d+)\.mlp\.(.*)$":
143+
r"transformer_blocks.\1.mlp.\2",
99144
})
100145

101146
# Cosmos 2.5 specific config parameters
@@ -134,4 +179,3 @@ class Cosmos25VideoConfig(DiTConfig):
134179
"""Configuration for Cosmos 2.5 video generation model."""
135180
arch_config: DiTArchConfig = field(default_factory=Cosmos25ArchConfig)
136181
prefix: str = "Cosmos25"
137-

0 commit comments

Comments
 (0)