Skip to content

Commit e04a192

Browse files
[feat]: add COSMOS 2.5 DiT implementation (#897)
Co-authored-by: KyleS1016 <kyle.s@gmicloud.ai>
1 parent c9ca6d1 commit e04a192

File tree

4 files changed

+1742
-1
lines changed

4 files changed

+1742
-1
lines changed
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from fastvideo.configs.models.dits.cosmos import CosmosVideoConfig
2+
from fastvideo.configs.models.dits.cosmos2_5 import Cosmos25VideoConfig
23
from fastvideo.configs.models.dits.hunyuanvideo import HunyuanVideoConfig
34
from fastvideo.configs.models.dits.stepvideo import StepVideoConfig
45
from fastvideo.configs.models.dits.wanvideo import WanVideoConfig
56

67
__all__ = [
78
"HunyuanVideoConfig", "WanVideoConfig", "StepVideoConfig",
8-
"CosmosVideoConfig"
9+
"CosmosVideoConfig", "Cosmos25VideoConfig"
910
]
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from dataclasses import dataclass, field
3+
4+
from fastvideo.configs.models.dits.base import DiTArchConfig, DiTConfig
5+
6+
7+
def is_transformer_blocks(n: str, m) -> bool:
8+
return "transformer_blocks" in n and str.isdigit(n.split(".")[-1])
9+
10+
11+
@dataclass
12+
class Cosmos25ArchConfig(DiTArchConfig):
13+
"""Configuration for Cosmos 2.5 architecture (MiniTrainDIT)."""
14+
15+
_fsdp_shard_conditions: list = field(
16+
default_factory=lambda: [is_transformer_blocks])
17+
18+
param_names_mapping: dict = field(
19+
default_factory=lambda: {
20+
# Remove "net." prefix and map official structure to FastVideo
21+
# Patch embedding: net.x_embedder.proj.1.weight -> patch_embed.proj.weight
22+
r"^net\.x_embedder\.proj\.1\.(.*)$":
23+
r"patch_embed.proj.\1",
24+
25+
# Time embedding: net.t_embedder.1.linear_1.weight -> time_embed.t_embedder.linear_1.weight
26+
r"^net\.t_embedder\.1\.linear_1\.(.*)$":
27+
r"time_embed.t_embedder.linear_1.\1",
28+
r"^net\.t_embedder\.1\.linear_2\.(.*)$":
29+
r"time_embed.t_embedder.linear_2.\1",
30+
# Time embedding norm: net.t_embedding_norm.weight -> time_embed.norm.weight
31+
# Note: This also handles _extra_state if present
32+
r"^net\.t_embedding_norm\.(.*)$":
33+
r"time_embed.norm.\1",
34+
35+
# Cross-attention projection (optional): net.crossattn_proj.0.weight -> crossattn_proj.0.weight
36+
r"^net\.crossattn_proj\.0\.weight$":
37+
r"crossattn_proj.0.weight",
38+
r"^net\.crossattn_proj\.0\.bias$":
39+
r"crossattn_proj.0.bias",
40+
41+
# Transformer blocks: net.blocks.N -> transformer_blocks.N
42+
# Self-attention (self_attn -> attn1)
43+
r"^net\.blocks\.(\d+)\.self_attn\.q_proj\.(.*)$":
44+
r"transformer_blocks.\1.attn1.to_q.\2",
45+
r"^net\.blocks\.(\d+)\.self_attn\.k_proj\.(.*)$":
46+
r"transformer_blocks.\1.attn1.to_k.\2",
47+
r"^net\.blocks\.(\d+)\.self_attn\.v_proj\.(.*)$":
48+
r"transformer_blocks.\1.attn1.to_v.\2",
49+
r"^net\.blocks\.(\d+)\.self_attn\.output_proj\.(.*)$":
50+
r"transformer_blocks.\1.attn1.to_out.\2",
51+
r"^net\.blocks\.(\d+)\.self_attn\.q_norm\.weight$":
52+
r"transformer_blocks.\1.attn1.norm_q.weight",
53+
r"^net\.blocks\.(\d+)\.self_attn\.k_norm\.weight$":
54+
r"transformer_blocks.\1.attn1.norm_k.weight",
55+
# RMSNorm _extra_state keys (internal PyTorch state, will be recomputed automatically)
56+
r"^net\.blocks\.(\d+)\.self_attn\.q_norm\._extra_state$":
57+
r"transformer_blocks.\1.attn1.norm_q._extra_state",
58+
r"^net\.blocks\.(\d+)\.self_attn\.k_norm\._extra_state$":
59+
r"transformer_blocks.\1.attn1.norm_k._extra_state",
60+
61+
# Cross-attention (cross_attn -> attn2)
62+
r"^net\.blocks\.(\d+)\.cross_attn\.q_proj\.(.*)$":
63+
r"transformer_blocks.\1.attn2.to_q.\2",
64+
r"^net\.blocks\.(\d+)\.cross_attn\.k_proj\.(.*)$":
65+
r"transformer_blocks.\1.attn2.to_k.\2",
66+
r"^net\.blocks\.(\d+)\.cross_attn\.v_proj\.(.*)$":
67+
r"transformer_blocks.\1.attn2.to_v.\2",
68+
r"^net\.blocks\.(\d+)\.cross_attn\.output_proj\.(.*)$":
69+
r"transformer_blocks.\1.attn2.to_out.\2",
70+
r"^net\.blocks\.(\d+)\.cross_attn\.q_norm\.weight$":
71+
r"transformer_blocks.\1.attn2.norm_q.weight",
72+
r"^net\.blocks\.(\d+)\.cross_attn\.k_norm\.weight$":
73+
r"transformer_blocks.\1.attn2.norm_k.weight",
74+
# RMSNorm _extra_state keys for cross-attention
75+
r"^net\.blocks\.(\d+)\.cross_attn\.q_norm\._extra_state$":
76+
r"transformer_blocks.\1.attn2.norm_q._extra_state",
77+
r"^net\.blocks\.(\d+)\.cross_attn\.k_norm\._extra_state$":
78+
r"transformer_blocks.\1.attn2.norm_k._extra_state",
79+
80+
# MLP: net.blocks.N.mlp.layer1 -> transformer_blocks.N.mlp.fc_in
81+
r"^net\.blocks\.(\d+)\.mlp\.layer1\.(.*)$":
82+
r"transformer_blocks.\1.mlp.fc_in.\2",
83+
r"^net\.blocks\.(\d+)\.mlp\.layer2\.(.*)$":
84+
r"transformer_blocks.\1.mlp.fc_out.\2",
85+
86+
# AdaLN-LoRA modulations: net.blocks.N.adaln_modulation_* -> transformer_blocks.N.adaln_modulation_*
87+
# These are now at the block level, not inside norm layers
88+
r"^net\.blocks\.(\d+)\.adaln_modulation_self_attn\.1\.(.*)$":
89+
r"transformer_blocks.\1.adaln_modulation_self_attn.1.\2",
90+
r"^net\.blocks\.(\d+)\.adaln_modulation_self_attn\.2\.(.*)$":
91+
r"transformer_blocks.\1.adaln_modulation_self_attn.2.\2",
92+
r"^net\.blocks\.(\d+)\.adaln_modulation_cross_attn\.1\.(.*)$":
93+
r"transformer_blocks.\1.adaln_modulation_cross_attn.1.\2",
94+
r"^net\.blocks\.(\d+)\.adaln_modulation_cross_attn\.2\.(.*)$":
95+
r"transformer_blocks.\1.adaln_modulation_cross_attn.2.\2",
96+
r"^net\.blocks\.(\d+)\.adaln_modulation_mlp\.1\.(.*)$":
97+
r"transformer_blocks.\1.adaln_modulation_mlp.1.\2",
98+
r"^net\.blocks\.(\d+)\.adaln_modulation_mlp\.2\.(.*)$":
99+
r"transformer_blocks.\1.adaln_modulation_mlp.2.\2",
100+
101+
# Layer norms: net.blocks.N.layer_norm_* -> transformer_blocks.N.norm*.norm
102+
r"^net\.blocks\.(\d+)\.layer_norm_self_attn\._extra_state$":
103+
r"transformer_blocks.\1.norm1.norm._extra_state",
104+
r"^net\.blocks\.(\d+)\.layer_norm_cross_attn\._extra_state$":
105+
r"transformer_blocks.\1.norm2.norm._extra_state",
106+
r"^net\.blocks\.(\d+)\.layer_norm_mlp\._extra_state$":
107+
r"transformer_blocks.\1.norm3.norm._extra_state",
108+
109+
# Final layer: net.final_layer.linear -> final_layer.proj_out
110+
r"^net\.final_layer\.linear\.(.*)$":
111+
r"final_layer.proj_out.\1",
112+
# Final layer AdaLN-LoRA: net.final_layer.adaln_modulation -> final_layer.linear_*
113+
r"^net\.final_layer\.adaln_modulation\.1\.(.*)$":
114+
r"final_layer.linear_1.\1",
115+
r"^net\.final_layer\.adaln_modulation\.2\.(.*)$":
116+
r"final_layer.linear_2.\1",
117+
118+
# Note: The following keys from official checkpoint are NOT mapped and can be safely ignored:
119+
# - net.pos_embedder.* (seq, dim_spatial_range, dim_temporal_range) - These are computed dynamically
120+
# in FastVideo's Cosmos25RotaryPosEmbed forward() method, so they don't need to be loaded.
121+
# - net.accum_* keys (training metadata) - These are skipped during checkpoint loading.
122+
})
123+
124+
lora_param_names_mapping: dict = field(
125+
default_factory=lambda: {
126+
r"^transformer_blocks\.(\d+)\.attn1\.to_q\.(.*)$":
127+
r"transformer_blocks.\1.attn1.to_q.\2",
128+
r"^transformer_blocks\.(\d+)\.attn1\.to_k\.(.*)$":
129+
r"transformer_blocks.\1.attn1.to_k.\2",
130+
r"^transformer_blocks\.(\d+)\.attn1\.to_v\.(.*)$":
131+
r"transformer_blocks.\1.attn1.to_v.\2",
132+
r"^transformer_blocks\.(\d+)\.attn1\.to_out\.(.*)$":
133+
r"transformer_blocks.\1.attn1.to_out.\2",
134+
r"^transformer_blocks\.(\d+)\.attn2\.to_q\.(.*)$":
135+
r"transformer_blocks.\1.attn2.to_q.\2",
136+
r"^transformer_blocks\.(\d+)\.attn2\.to_k\.(.*)$":
137+
r"transformer_blocks.\1.attn2.to_k.\2",
138+
r"^transformer_blocks\.(\d+)\.attn2\.to_v\.(.*)$":
139+
r"transformer_blocks.\1.attn2.to_v.\2",
140+
r"^transformer_blocks\.(\d+)\.attn2\.to_out\.(.*)$":
141+
r"transformer_blocks.\1.attn2.to_out.\2",
142+
r"^transformer_blocks\.(\d+)\.mlp\.(.*)$":
143+
r"transformer_blocks.\1.mlp.\2",
144+
})
145+
146+
# Cosmos 2.5 specific config parameters
147+
in_channels: int = 16
148+
out_channels: int = 16
149+
num_attention_heads: int = 16
150+
attention_head_dim: int = 128 # 2048 / 16
151+
num_layers: int = 28
152+
mlp_ratio: float = 4.0
153+
text_embed_dim: int = 1024
154+
adaln_lora_dim: int = 256
155+
use_adaln_lora: bool = True
156+
max_size: tuple[int, int, int] = (128, 240, 240)
157+
patch_size: tuple[int, int, int] = (1, 2, 2)
158+
rope_scale: tuple[float, float, float] = (1.0, 3.0, 3.0) # T, H, W scaling
159+
concat_padding_mask: bool = True
160+
extra_pos_embed_type: str | None = None # "learnable" or None
161+
# Note: Official checkpoint has use_crossattn_projection=True with 100K-dim input from Qwen 7B.
162+
# When enabled, must provide 100,352-dim embeddings to match the projection layer in checkpoint.
163+
use_crossattn_projection: bool = False
164+
crossattn_proj_in_channels: int = 100352 # Qwen 7B embedding dimension
165+
rope_enable_fps_modulation: bool = True
166+
qk_norm: str = "rms_norm"
167+
eps: float = 1e-6
168+
exclude_lora_layers: list[str] = field(default_factory=lambda: ["embedder"])
169+
170+
def __post_init__(self):
171+
super().__post_init__()
172+
self.out_channels = self.out_channels or self.in_channels
173+
self.hidden_size = self.num_attention_heads * self.attention_head_dim
174+
self.num_channels_latents = self.in_channels
175+
176+
177+
@dataclass
178+
class Cosmos25VideoConfig(DiTConfig):
179+
"""Configuration for Cosmos 2.5 video generation model."""
180+
arch_config: DiTArchConfig = field(default_factory=Cosmos25ArchConfig)
181+
prefix: str = "Cosmos25"

0 commit comments

Comments
 (0)