@@ -11,74 +11,110 @@ def is_transformer_blocks(n: str, m) -> bool:
1111@dataclass
1212class Cosmos25ArchConfig (DiTArchConfig ):
1313 """Configuration for Cosmos 2.5 architecture (MiniTrainDIT)."""
14-
14+
1515 _fsdp_shard_conditions : list = field (
1616 default_factory = lambda : [is_transformer_blocks ])
1717
1818 param_names_mapping : dict = field (
1919 default_factory = lambda : {
2020 # Remove "net." prefix and map official structure to FastVideo
2121 # Patch embedding: net.x_embedder.proj.1.weight -> patch_embed.proj.weight
22- r"^net\.x_embedder\.proj\.1\.(.*)$" : r"patch_embed.proj.\1" ,
23-
22+ r"^net\.x_embedder\.proj\.1\.(.*)$" :
23+ r"patch_embed.proj.\1" ,
24+
2425 # Time embedding: net.t_embedder.1.linear_1.weight -> time_embed.t_embedder.linear_1.weight
25- r"^net\.t_embedder\.1\.linear_1\.(.*)$" : r"time_embed.t_embedder.linear_1.\1" ,
26- r"^net\.t_embedder\.1\.linear_2\.(.*)$" : r"time_embed.t_embedder.linear_2.\1" ,
26+ r"^net\.t_embedder\.1\.linear_1\.(.*)$" :
27+ r"time_embed.t_embedder.linear_1.\1" ,
28+ r"^net\.t_embedder\.1\.linear_2\.(.*)$" :
29+ r"time_embed.t_embedder.linear_2.\1" ,
2730 # Time embedding norm: net.t_embedding_norm.weight -> time_embed.norm.weight
2831 # Note: This also handles _extra_state if present
29- r"^net\.t_embedding_norm\.(.*)$" : r"time_embed.norm.\1" ,
30-
32+ r"^net\.t_embedding_norm\.(.*)$" :
33+ r"time_embed.norm.\1" ,
34+
3135 # Cross-attention projection (optional): net.crossattn_proj.0.weight -> crossattn_proj.0.weight
32- r"^net\.crossattn_proj\.0\.weight$" : r"crossattn_proj.0.weight" ,
33- r"^net\.crossattn_proj\.0\.bias$" : r"crossattn_proj.0.bias" ,
34-
36+ r"^net\.crossattn_proj\.0\.weight$" :
37+ r"crossattn_proj.0.weight" ,
38+ r"^net\.crossattn_proj\.0\.bias$" :
39+ r"crossattn_proj.0.bias" ,
40+
3541 # Transformer blocks: net.blocks.N -> transformer_blocks.N
3642 # Self-attention (self_attn -> attn1)
37- r"^net\.blocks\.(\d+)\.self_attn\.q_proj\.(.*)$" : r"transformer_blocks.\1.attn1.to_q.\2" ,
38- r"^net\.blocks\.(\d+)\.self_attn\.k_proj\.(.*)$" : r"transformer_blocks.\1.attn1.to_k.\2" ,
39- r"^net\.blocks\.(\d+)\.self_attn\.v_proj\.(.*)$" : r"transformer_blocks.\1.attn1.to_v.\2" ,
40- r"^net\.blocks\.(\d+)\.self_attn\.output_proj\.(.*)$" : r"transformer_blocks.\1.attn1.to_out.\2" ,
41- r"^net\.blocks\.(\d+)\.self_attn\.q_norm\.weight$" : r"transformer_blocks.\1.attn1.norm_q.weight" ,
42- r"^net\.blocks\.(\d+)\.self_attn\.k_norm\.weight$" : r"transformer_blocks.\1.attn1.norm_k.weight" ,
43+ r"^net\.blocks\.(\d+)\.self_attn\.q_proj\.(.*)$" :
44+ r"transformer_blocks.\1.attn1.to_q.\2" ,
45+ r"^net\.blocks\.(\d+)\.self_attn\.k_proj\.(.*)$" :
46+ r"transformer_blocks.\1.attn1.to_k.\2" ,
47+ r"^net\.blocks\.(\d+)\.self_attn\.v_proj\.(.*)$" :
48+ r"transformer_blocks.\1.attn1.to_v.\2" ,
49+ r"^net\.blocks\.(\d+)\.self_attn\.output_proj\.(.*)$" :
50+ r"transformer_blocks.\1.attn1.to_out.\2" ,
51+ r"^net\.blocks\.(\d+)\.self_attn\.q_norm\.weight$" :
52+ r"transformer_blocks.\1.attn1.norm_q.weight" ,
53+ r"^net\.blocks\.(\d+)\.self_attn\.k_norm\.weight$" :
54+ r"transformer_blocks.\1.attn1.norm_k.weight" ,
4355 # RMSNorm _extra_state keys (internal PyTorch state, will be recomputed automatically)
44- r"^net\.blocks\.(\d+)\.self_attn\.q_norm\._extra_state$" : r"transformer_blocks.\1.attn1.norm_q._extra_state" ,
45- r"^net\.blocks\.(\d+)\.self_attn\.k_norm\._extra_state$" : r"transformer_blocks.\1.attn1.norm_k._extra_state" ,
46-
56+ r"^net\.blocks\.(\d+)\.self_attn\.q_norm\._extra_state$" :
57+ r"transformer_blocks.\1.attn1.norm_q._extra_state" ,
58+ r"^net\.blocks\.(\d+)\.self_attn\.k_norm\._extra_state$" :
59+ r"transformer_blocks.\1.attn1.norm_k._extra_state" ,
60+
4761 # Cross-attention (cross_attn -> attn2)
48- r"^net\.blocks\.(\d+)\.cross_attn\.q_proj\.(.*)$" : r"transformer_blocks.\1.attn2.to_q.\2" ,
49- r"^net\.blocks\.(\d+)\.cross_attn\.k_proj\.(.*)$" : r"transformer_blocks.\1.attn2.to_k.\2" ,
50- r"^net\.blocks\.(\d+)\.cross_attn\.v_proj\.(.*)$" : r"transformer_blocks.\1.attn2.to_v.\2" ,
51- r"^net\.blocks\.(\d+)\.cross_attn\.output_proj\.(.*)$" : r"transformer_blocks.\1.attn2.to_out.\2" ,
52- r"^net\.blocks\.(\d+)\.cross_attn\.q_norm\.weight$" : r"transformer_blocks.\1.attn2.norm_q.weight" ,
53- r"^net\.blocks\.(\d+)\.cross_attn\.k_norm\.weight$" : r"transformer_blocks.\1.attn2.norm_k.weight" ,
62+ r"^net\.blocks\.(\d+)\.cross_attn\.q_proj\.(.*)$" :
63+ r"transformer_blocks.\1.attn2.to_q.\2" ,
64+ r"^net\.blocks\.(\d+)\.cross_attn\.k_proj\.(.*)$" :
65+ r"transformer_blocks.\1.attn2.to_k.\2" ,
66+ r"^net\.blocks\.(\d+)\.cross_attn\.v_proj\.(.*)$" :
67+ r"transformer_blocks.\1.attn2.to_v.\2" ,
68+ r"^net\.blocks\.(\d+)\.cross_attn\.output_proj\.(.*)$" :
69+ r"transformer_blocks.\1.attn2.to_out.\2" ,
70+ r"^net\.blocks\.(\d+)\.cross_attn\.q_norm\.weight$" :
71+ r"transformer_blocks.\1.attn2.norm_q.weight" ,
72+ r"^net\.blocks\.(\d+)\.cross_attn\.k_norm\.weight$" :
73+ r"transformer_blocks.\1.attn2.norm_k.weight" ,
5474 # RMSNorm _extra_state keys for cross-attention
55- r"^net\.blocks\.(\d+)\.cross_attn\.q_norm\._extra_state$" : r"transformer_blocks.\1.attn2.norm_q._extra_state" ,
56- r"^net\.blocks\.(\d+)\.cross_attn\.k_norm\._extra_state$" : r"transformer_blocks.\1.attn2.norm_k._extra_state" ,
57-
75+ r"^net\.blocks\.(\d+)\.cross_attn\.q_norm\._extra_state$" :
76+ r"transformer_blocks.\1.attn2.norm_q._extra_state" ,
77+ r"^net\.blocks\.(\d+)\.cross_attn\.k_norm\._extra_state$" :
78+ r"transformer_blocks.\1.attn2.norm_k._extra_state" ,
79+
5880 # MLP: net.blocks.N.mlp.layer1 -> transformer_blocks.N.mlp.fc_in
59- r"^net\.blocks\.(\d+)\.mlp\.layer1\.(.*)$" : r"transformer_blocks.\1.mlp.fc_in.\2" ,
60- r"^net\.blocks\.(\d+)\.mlp\.layer2\.(.*)$" : r"transformer_blocks.\1.mlp.fc_out.\2" ,
61-
81+ r"^net\.blocks\.(\d+)\.mlp\.layer1\.(.*)$" :
82+ r"transformer_blocks.\1.mlp.fc_in.\2" ,
83+ r"^net\.blocks\.(\d+)\.mlp\.layer2\.(.*)$" :
84+ r"transformer_blocks.\1.mlp.fc_out.\2" ,
85+
6286 # AdaLN-LoRA modulations: net.blocks.N.adaln_modulation_* -> transformer_blocks.N.adaln_modulation_*
6387 # These are now at the block level, not inside norm layers
64- r"^net\.blocks\.(\d+)\.adaln_modulation_self_attn\.1\.(.*)$" : r"transformer_blocks.\1.adaln_modulation_self_attn.1.\2" ,
65- r"^net\.blocks\.(\d+)\.adaln_modulation_self_attn\.2\.(.*)$" : r"transformer_blocks.\1.adaln_modulation_self_attn.2.\2" ,
66- r"^net\.blocks\.(\d+)\.adaln_modulation_cross_attn\.1\.(.*)$" : r"transformer_blocks.\1.adaln_modulation_cross_attn.1.\2" ,
67- r"^net\.blocks\.(\d+)\.adaln_modulation_cross_attn\.2\.(.*)$" : r"transformer_blocks.\1.adaln_modulation_cross_attn.2.\2" ,
68- r"^net\.blocks\.(\d+)\.adaln_modulation_mlp\.1\.(.*)$" : r"transformer_blocks.\1.adaln_modulation_mlp.1.\2" ,
69- r"^net\.blocks\.(\d+)\.adaln_modulation_mlp\.2\.(.*)$" : r"transformer_blocks.\1.adaln_modulation_mlp.2.\2" ,
70-
88+ r"^net\.blocks\.(\d+)\.adaln_modulation_self_attn\.1\.(.*)$" :
89+ r"transformer_blocks.\1.adaln_modulation_self_attn.1.\2" ,
90+ r"^net\.blocks\.(\d+)\.adaln_modulation_self_attn\.2\.(.*)$" :
91+ r"transformer_blocks.\1.adaln_modulation_self_attn.2.\2" ,
92+ r"^net\.blocks\.(\d+)\.adaln_modulation_cross_attn\.1\.(.*)$" :
93+ r"transformer_blocks.\1.adaln_modulation_cross_attn.1.\2" ,
94+ r"^net\.blocks\.(\d+)\.adaln_modulation_cross_attn\.2\.(.*)$" :
95+ r"transformer_blocks.\1.adaln_modulation_cross_attn.2.\2" ,
96+ r"^net\.blocks\.(\d+)\.adaln_modulation_mlp\.1\.(.*)$" :
97+ r"transformer_blocks.\1.adaln_modulation_mlp.1.\2" ,
98+ r"^net\.blocks\.(\d+)\.adaln_modulation_mlp\.2\.(.*)$" :
99+ r"transformer_blocks.\1.adaln_modulation_mlp.2.\2" ,
100+
71101 # Layer norms: net.blocks.N.layer_norm_* -> transformer_blocks.N.norm*.norm
72- r"^net\.blocks\.(\d+)\.layer_norm_self_attn\._extra_state$" : r"transformer_blocks.\1.norm1.norm._extra_state" ,
73- r"^net\.blocks\.(\d+)\.layer_norm_cross_attn\._extra_state$" : r"transformer_blocks.\1.norm2.norm._extra_state" ,
74- r"^net\.blocks\.(\d+)\.layer_norm_mlp\._extra_state$" : r"transformer_blocks.\1.norm3.norm._extra_state" ,
75-
102+ r"^net\.blocks\.(\d+)\.layer_norm_self_attn\._extra_state$" :
103+ r"transformer_blocks.\1.norm1.norm._extra_state" ,
104+ r"^net\.blocks\.(\d+)\.layer_norm_cross_attn\._extra_state$" :
105+ r"transformer_blocks.\1.norm2.norm._extra_state" ,
106+ r"^net\.blocks\.(\d+)\.layer_norm_mlp\._extra_state$" :
107+ r"transformer_blocks.\1.norm3.norm._extra_state" ,
108+
76109 # Final layer: net.final_layer.linear -> final_layer.proj_out
77- r"^net\.final_layer\.linear\.(.*)$" : r"final_layer.proj_out.\1" ,
110+ r"^net\.final_layer\.linear\.(.*)$" :
111+ r"final_layer.proj_out.\1" ,
78112 # Final layer AdaLN-LoRA: net.final_layer.adaln_modulation -> final_layer.linear_*
79- r"^net\.final_layer\.adaln_modulation\.1\.(.*)$" : r"final_layer.linear_1.\1" ,
80- r"^net\.final_layer\.adaln_modulation\.2\.(.*)$" : r"final_layer.linear_2.\1" ,
81-
113+ r"^net\.final_layer\.adaln_modulation\.1\.(.*)$" :
114+ r"final_layer.linear_1.\1" ,
115+ r"^net\.final_layer\.adaln_modulation\.2\.(.*)$" :
116+ r"final_layer.linear_2.\1" ,
117+
82118 # Note: The following keys from official checkpoint are NOT mapped and can be safely ignored:
83119 # - net.pos_embedder.* (seq, dim_spatial_range, dim_temporal_range) - These are computed dynamically
84120 # in FastVideo's Cosmos25RotaryPosEmbed forward() method, so they don't need to be loaded.
@@ -87,15 +123,24 @@ class Cosmos25ArchConfig(DiTArchConfig):
87123
88124 lora_param_names_mapping : dict = field (
89125 default_factory = lambda : {
90- r"^transformer_blocks\.(\d+)\.attn1\.to_q\.(.*)$" : r"transformer_blocks.\1.attn1.to_q.\2" ,
91- r"^transformer_blocks\.(\d+)\.attn1\.to_k\.(.*)$" : r"transformer_blocks.\1.attn1.to_k.\2" ,
92- r"^transformer_blocks\.(\d+)\.attn1\.to_v\.(.*)$" : r"transformer_blocks.\1.attn1.to_v.\2" ,
93- r"^transformer_blocks\.(\d+)\.attn1\.to_out\.(.*)$" : r"transformer_blocks.\1.attn1.to_out.\2" ,
94- r"^transformer_blocks\.(\d+)\.attn2\.to_q\.(.*)$" : r"transformer_blocks.\1.attn2.to_q.\2" ,
95- r"^transformer_blocks\.(\d+)\.attn2\.to_k\.(.*)$" : r"transformer_blocks.\1.attn2.to_k.\2" ,
96- r"^transformer_blocks\.(\d+)\.attn2\.to_v\.(.*)$" : r"transformer_blocks.\1.attn2.to_v.\2" ,
97- r"^transformer_blocks\.(\d+)\.attn2\.to_out\.(.*)$" : r"transformer_blocks.\1.attn2.to_out.\2" ,
98- r"^transformer_blocks\.(\d+)\.mlp\.(.*)$" : r"transformer_blocks.\1.mlp.\2" ,
126+ r"^transformer_blocks\.(\d+)\.attn1\.to_q\.(.*)$" :
127+ r"transformer_blocks.\1.attn1.to_q.\2" ,
128+ r"^transformer_blocks\.(\d+)\.attn1\.to_k\.(.*)$" :
129+ r"transformer_blocks.\1.attn1.to_k.\2" ,
130+ r"^transformer_blocks\.(\d+)\.attn1\.to_v\.(.*)$" :
131+ r"transformer_blocks.\1.attn1.to_v.\2" ,
132+ r"^transformer_blocks\.(\d+)\.attn1\.to_out\.(.*)$" :
133+ r"transformer_blocks.\1.attn1.to_out.\2" ,
134+ r"^transformer_blocks\.(\d+)\.attn2\.to_q\.(.*)$" :
135+ r"transformer_blocks.\1.attn2.to_q.\2" ,
136+ r"^transformer_blocks\.(\d+)\.attn2\.to_k\.(.*)$" :
137+ r"transformer_blocks.\1.attn2.to_k.\2" ,
138+ r"^transformer_blocks\.(\d+)\.attn2\.to_v\.(.*)$" :
139+ r"transformer_blocks.\1.attn2.to_v.\2" ,
140+ r"^transformer_blocks\.(\d+)\.attn2\.to_out\.(.*)$" :
141+ r"transformer_blocks.\1.attn2.to_out.\2" ,
142+ r"^transformer_blocks\.(\d+)\.mlp\.(.*)$" :
143+ r"transformer_blocks.\1.mlp.\2" ,
99144 })
100145
101146 # Cosmos 2.5 specific config parameters
@@ -134,4 +179,3 @@ class Cosmos25VideoConfig(DiTConfig):
134179 """Configuration for Cosmos 2.5 video generation model."""
135180 arch_config : DiTArchConfig = field (default_factory = Cosmos25ArchConfig )
136181 prefix : str = "Cosmos25"
137-
0 commit comments