|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +from dataclasses import dataclass, field |
| 3 | + |
| 4 | +from fastvideo.configs.models.dits.base import DiTArchConfig, DiTConfig |
| 5 | + |
| 6 | + |
| 7 | +def is_transformer_blocks(n: str, m) -> bool: |
| 8 | + return "transformer_blocks" in n and str.isdigit(n.split(".")[-1]) |
| 9 | + |
| 10 | + |
| 11 | +@dataclass |
| 12 | +class Cosmos25ArchConfig(DiTArchConfig): |
| 13 | + """Configuration for Cosmos 2.5 architecture (MiniTrainDIT).""" |
| 14 | + |
| 15 | + _fsdp_shard_conditions: list = field( |
| 16 | + default_factory=lambda: [is_transformer_blocks]) |
| 17 | + |
| 18 | + param_names_mapping: dict = field( |
| 19 | + default_factory=lambda: { |
| 20 | + # Remove "net." prefix and map official structure to FastVideo |
| 21 | + # Patch embedding: net.x_embedder.proj.1.weight -> patch_embed.proj.weight |
| 22 | + r"^net\.x_embedder\.proj\.1\.(.*)$": |
| 23 | + r"patch_embed.proj.\1", |
| 24 | + |
| 25 | + # Time embedding: net.t_embedder.1.linear_1.weight -> time_embed.t_embedder.linear_1.weight |
| 26 | + r"^net\.t_embedder\.1\.linear_1\.(.*)$": |
| 27 | + r"time_embed.t_embedder.linear_1.\1", |
| 28 | + r"^net\.t_embedder\.1\.linear_2\.(.*)$": |
| 29 | + r"time_embed.t_embedder.linear_2.\1", |
| 30 | + # Time embedding norm: net.t_embedding_norm.weight -> time_embed.norm.weight |
| 31 | + # Note: This also handles _extra_state if present |
| 32 | + r"^net\.t_embedding_norm\.(.*)$": |
| 33 | + r"time_embed.norm.\1", |
| 34 | + |
| 35 | + # Cross-attention projection (optional): net.crossattn_proj.0.weight -> crossattn_proj.0.weight |
| 36 | + r"^net\.crossattn_proj\.0\.weight$": |
| 37 | + r"crossattn_proj.0.weight", |
| 38 | + r"^net\.crossattn_proj\.0\.bias$": |
| 39 | + r"crossattn_proj.0.bias", |
| 40 | + |
| 41 | + # Transformer blocks: net.blocks.N -> transformer_blocks.N |
| 42 | + # Self-attention (self_attn -> attn1) |
| 43 | + r"^net\.blocks\.(\d+)\.self_attn\.q_proj\.(.*)$": |
| 44 | + r"transformer_blocks.\1.attn1.to_q.\2", |
| 45 | + r"^net\.blocks\.(\d+)\.self_attn\.k_proj\.(.*)$": |
| 46 | + r"transformer_blocks.\1.attn1.to_k.\2", |
| 47 | + r"^net\.blocks\.(\d+)\.self_attn\.v_proj\.(.*)$": |
| 48 | + r"transformer_blocks.\1.attn1.to_v.\2", |
| 49 | + r"^net\.blocks\.(\d+)\.self_attn\.output_proj\.(.*)$": |
| 50 | + r"transformer_blocks.\1.attn1.to_out.\2", |
| 51 | + r"^net\.blocks\.(\d+)\.self_attn\.q_norm\.weight$": |
| 52 | + r"transformer_blocks.\1.attn1.norm_q.weight", |
| 53 | + r"^net\.blocks\.(\d+)\.self_attn\.k_norm\.weight$": |
| 54 | + r"transformer_blocks.\1.attn1.norm_k.weight", |
| 55 | + # RMSNorm _extra_state keys (internal PyTorch state, will be recomputed automatically) |
| 56 | + r"^net\.blocks\.(\d+)\.self_attn\.q_norm\._extra_state$": |
| 57 | + r"transformer_blocks.\1.attn1.norm_q._extra_state", |
| 58 | + r"^net\.blocks\.(\d+)\.self_attn\.k_norm\._extra_state$": |
| 59 | + r"transformer_blocks.\1.attn1.norm_k._extra_state", |
| 60 | + |
| 61 | + # Cross-attention (cross_attn -> attn2) |
| 62 | + r"^net\.blocks\.(\d+)\.cross_attn\.q_proj\.(.*)$": |
| 63 | + r"transformer_blocks.\1.attn2.to_q.\2", |
| 64 | + r"^net\.blocks\.(\d+)\.cross_attn\.k_proj\.(.*)$": |
| 65 | + r"transformer_blocks.\1.attn2.to_k.\2", |
| 66 | + r"^net\.blocks\.(\d+)\.cross_attn\.v_proj\.(.*)$": |
| 67 | + r"transformer_blocks.\1.attn2.to_v.\2", |
| 68 | + r"^net\.blocks\.(\d+)\.cross_attn\.output_proj\.(.*)$": |
| 69 | + r"transformer_blocks.\1.attn2.to_out.\2", |
| 70 | + r"^net\.blocks\.(\d+)\.cross_attn\.q_norm\.weight$": |
| 71 | + r"transformer_blocks.\1.attn2.norm_q.weight", |
| 72 | + r"^net\.blocks\.(\d+)\.cross_attn\.k_norm\.weight$": |
| 73 | + r"transformer_blocks.\1.attn2.norm_k.weight", |
| 74 | + # RMSNorm _extra_state keys for cross-attention |
| 75 | + r"^net\.blocks\.(\d+)\.cross_attn\.q_norm\._extra_state$": |
| 76 | + r"transformer_blocks.\1.attn2.norm_q._extra_state", |
| 77 | + r"^net\.blocks\.(\d+)\.cross_attn\.k_norm\._extra_state$": |
| 78 | + r"transformer_blocks.\1.attn2.norm_k._extra_state", |
| 79 | + |
| 80 | + # MLP: net.blocks.N.mlp.layer1 -> transformer_blocks.N.mlp.fc_in |
| 81 | + r"^net\.blocks\.(\d+)\.mlp\.layer1\.(.*)$": |
| 82 | + r"transformer_blocks.\1.mlp.fc_in.\2", |
| 83 | + r"^net\.blocks\.(\d+)\.mlp\.layer2\.(.*)$": |
| 84 | + r"transformer_blocks.\1.mlp.fc_out.\2", |
| 85 | + |
| 86 | + # AdaLN-LoRA modulations: net.blocks.N.adaln_modulation_* -> transformer_blocks.N.adaln_modulation_* |
| 87 | + # These are now at the block level, not inside norm layers |
| 88 | + r"^net\.blocks\.(\d+)\.adaln_modulation_self_attn\.1\.(.*)$": |
| 89 | + r"transformer_blocks.\1.adaln_modulation_self_attn.1.\2", |
| 90 | + r"^net\.blocks\.(\d+)\.adaln_modulation_self_attn\.2\.(.*)$": |
| 91 | + r"transformer_blocks.\1.adaln_modulation_self_attn.2.\2", |
| 92 | + r"^net\.blocks\.(\d+)\.adaln_modulation_cross_attn\.1\.(.*)$": |
| 93 | + r"transformer_blocks.\1.adaln_modulation_cross_attn.1.\2", |
| 94 | + r"^net\.blocks\.(\d+)\.adaln_modulation_cross_attn\.2\.(.*)$": |
| 95 | + r"transformer_blocks.\1.adaln_modulation_cross_attn.2.\2", |
| 96 | + r"^net\.blocks\.(\d+)\.adaln_modulation_mlp\.1\.(.*)$": |
| 97 | + r"transformer_blocks.\1.adaln_modulation_mlp.1.\2", |
| 98 | + r"^net\.blocks\.(\d+)\.adaln_modulation_mlp\.2\.(.*)$": |
| 99 | + r"transformer_blocks.\1.adaln_modulation_mlp.2.\2", |
| 100 | + |
| 101 | + # Layer norms: net.blocks.N.layer_norm_* -> transformer_blocks.N.norm*.norm |
| 102 | + r"^net\.blocks\.(\d+)\.layer_norm_self_attn\._extra_state$": |
| 103 | + r"transformer_blocks.\1.norm1.norm._extra_state", |
| 104 | + r"^net\.blocks\.(\d+)\.layer_norm_cross_attn\._extra_state$": |
| 105 | + r"transformer_blocks.\1.norm2.norm._extra_state", |
| 106 | + r"^net\.blocks\.(\d+)\.layer_norm_mlp\._extra_state$": |
| 107 | + r"transformer_blocks.\1.norm3.norm._extra_state", |
| 108 | + |
| 109 | + # Final layer: net.final_layer.linear -> final_layer.proj_out |
| 110 | + r"^net\.final_layer\.linear\.(.*)$": |
| 111 | + r"final_layer.proj_out.\1", |
| 112 | + # Final layer AdaLN-LoRA: net.final_layer.adaln_modulation -> final_layer.linear_* |
| 113 | + r"^net\.final_layer\.adaln_modulation\.1\.(.*)$": |
| 114 | + r"final_layer.linear_1.\1", |
| 115 | + r"^net\.final_layer\.adaln_modulation\.2\.(.*)$": |
| 116 | + r"final_layer.linear_2.\1", |
| 117 | + |
| 118 | + # Note: The following keys from official checkpoint are NOT mapped and can be safely ignored: |
| 119 | + # - net.pos_embedder.* (seq, dim_spatial_range, dim_temporal_range) - These are computed dynamically |
| 120 | + # in FastVideo's Cosmos25RotaryPosEmbed forward() method, so they don't need to be loaded. |
| 121 | + # - net.accum_* keys (training metadata) - These are skipped during checkpoint loading. |
| 122 | + }) |
| 123 | + |
| 124 | + lora_param_names_mapping: dict = field( |
| 125 | + default_factory=lambda: { |
| 126 | + r"^transformer_blocks\.(\d+)\.attn1\.to_q\.(.*)$": |
| 127 | + r"transformer_blocks.\1.attn1.to_q.\2", |
| 128 | + r"^transformer_blocks\.(\d+)\.attn1\.to_k\.(.*)$": |
| 129 | + r"transformer_blocks.\1.attn1.to_k.\2", |
| 130 | + r"^transformer_blocks\.(\d+)\.attn1\.to_v\.(.*)$": |
| 131 | + r"transformer_blocks.\1.attn1.to_v.\2", |
| 132 | + r"^transformer_blocks\.(\d+)\.attn1\.to_out\.(.*)$": |
| 133 | + r"transformer_blocks.\1.attn1.to_out.\2", |
| 134 | + r"^transformer_blocks\.(\d+)\.attn2\.to_q\.(.*)$": |
| 135 | + r"transformer_blocks.\1.attn2.to_q.\2", |
| 136 | + r"^transformer_blocks\.(\d+)\.attn2\.to_k\.(.*)$": |
| 137 | + r"transformer_blocks.\1.attn2.to_k.\2", |
| 138 | + r"^transformer_blocks\.(\d+)\.attn2\.to_v\.(.*)$": |
| 139 | + r"transformer_blocks.\1.attn2.to_v.\2", |
| 140 | + r"^transformer_blocks\.(\d+)\.attn2\.to_out\.(.*)$": |
| 141 | + r"transformer_blocks.\1.attn2.to_out.\2", |
| 142 | + r"^transformer_blocks\.(\d+)\.mlp\.(.*)$": |
| 143 | + r"transformer_blocks.\1.mlp.\2", |
| 144 | + }) |
| 145 | + |
| 146 | + # Cosmos 2.5 specific config parameters |
| 147 | + in_channels: int = 16 |
| 148 | + out_channels: int = 16 |
| 149 | + num_attention_heads: int = 16 |
| 150 | + attention_head_dim: int = 128 # 2048 / 16 |
| 151 | + num_layers: int = 28 |
| 152 | + mlp_ratio: float = 4.0 |
| 153 | + text_embed_dim: int = 1024 |
| 154 | + adaln_lora_dim: int = 256 |
| 155 | + use_adaln_lora: bool = True |
| 156 | + max_size: tuple[int, int, int] = (128, 240, 240) |
| 157 | + patch_size: tuple[int, int, int] = (1, 2, 2) |
| 158 | + rope_scale: tuple[float, float, float] = (1.0, 3.0, 3.0) # T, H, W scaling |
| 159 | + concat_padding_mask: bool = True |
| 160 | + extra_pos_embed_type: str | None = None # "learnable" or None |
| 161 | + # Note: Official checkpoint has use_crossattn_projection=True with 100K-dim input from Qwen 7B. |
| 162 | + # When enabled, must provide 100,352-dim embeddings to match the projection layer in checkpoint. |
| 163 | + use_crossattn_projection: bool = False |
| 164 | + crossattn_proj_in_channels: int = 100352 # Qwen 7B embedding dimension |
| 165 | + rope_enable_fps_modulation: bool = True |
| 166 | + qk_norm: str = "rms_norm" |
| 167 | + eps: float = 1e-6 |
| 168 | + exclude_lora_layers: list[str] = field(default_factory=lambda: ["embedder"]) |
| 169 | + |
| 170 | + def __post_init__(self): |
| 171 | + super().__post_init__() |
| 172 | + self.out_channels = self.out_channels or self.in_channels |
| 173 | + self.hidden_size = self.num_attention_heads * self.attention_head_dim |
| 174 | + self.num_channels_latents = self.in_channels |
| 175 | + |
| 176 | + |
| 177 | +@dataclass |
| 178 | +class Cosmos25VideoConfig(DiTConfig): |
| 179 | + """Configuration for Cosmos 2.5 video generation model.""" |
| 180 | + arch_config: DiTArchConfig = field(default_factory=Cosmos25ArchConfig) |
| 181 | + prefix: str = "Cosmos25" |
0 commit comments