Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion fastvideo/configs/models/dits/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from fastvideo.configs.models.dits.hunyuanvideo import HunyuanVideoConfig
from fastvideo.configs.models.dits.stepvideo import StepVideoConfig
from fastvideo.configs.models.dits.wanvideo import WanVideoConfig
print("WOW")
from fastvideo.configs.models.dits.cosmos import CosmosVideoConfig

__all__ = ["HunyuanVideoConfig", "WanVideoConfig", "StepVideoConfig"]
__all__ = ["HunyuanVideoConfig", "WanVideoConfig", "StepVideoConfig", "CosmosVideoConfig"]
105 changes: 105 additions & 0 deletions fastvideo/configs/models/dits/cosmos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field

from fastvideo.configs.models.dits.base import DiTArchConfig, DiTConfig


def is_transformer_blocks(n: str, m) -> bool:
return "transformer_blocks" in n and str.isdigit(n.split(".")[-1])


@dataclass
class CosmosArchConfig(DiTArchConfig):
_fsdp_shard_conditions: list = field(
default_factory=lambda: [is_transformer_blocks])

param_names_mapping: dict = field(
default_factory=lambda: {
r"^patch_embed\.(.*)$": r"patch_embed.\1",
r"^time_embed\.time_proj\.(.*)$": r"time_embed.time_proj.\1",
r"^time_embed\.t_embedder\.(.*)$": r"time_embed.t_embedder.\1",
r"^time_embed\.norm\.(.*)$": r"time_embed.norm.\1",
r"^transformer_blocks\.(\d+)\.attn1\.to_q\.(.*)$":
r"transformer_blocks.\1.attn1.to_q.\2",
r"^transformer_blocks\.(\d+)\.attn1\.to_k\.(.*)$":
r"transformer_blocks.\1.attn1.to_k.\2",
r"^transformer_blocks\.(\d+)\.attn1\.to_v\.(.*)$":
r"transformer_blocks.\1.attn1.to_v.\2",
r"^transformer_blocks\.(\d+)\.attn1\.to_out\.0\.(.*)$":
r"transformer_blocks.\1.attn1.to_out.\2",
r"^transformer_blocks\.(\d+)\.attn1\.norm_q\.(.*)$":
r"transformer_blocks.\1.attn1.norm_q.\2",
r"^transformer_blocks\.(\d+)\.attn1\.norm_k\.(.*)$":
r"transformer_blocks.\1.attn1.norm_k.\2",
r"^transformer_blocks\.(\d+)\.attn2\.to_q\.(.*)$":
r"transformer_blocks.\1.attn2.to_q.\2",
r"^transformer_blocks\.(\d+)\.attn2\.to_k\.(.*)$":
r"transformer_blocks.\1.attn2.to_k.\2",
r"^transformer_blocks\.(\d+)\.attn2\.to_v\.(.*)$":
r"transformer_blocks.\1.attn2.to_v.\2",
r"^transformer_blocks\.(\d+)\.attn2\.to_out\.0\.(.*)$":
r"transformer_blocks.\1.attn2.to_out.\2",
r"^transformer_blocks\.(\d+)\.attn2\.norm_q\.(.*)$":
r"transformer_blocks.\1.attn2.norm_q.\2",
r"^transformer_blocks\.(\d+)\.attn2\.norm_k\.(.*)$":
r"transformer_blocks.\1.attn2.norm_k.\2",
r"^transformer_blocks\.(\d+)\.ff\.net\.0\.proj\.(.*)$":
r"transformer_blocks.\1.ff.fc_in.\2",
r"^transformer_blocks\.(\d+)\.ff\.net\.2\.(.*)$":
r"transformer_blocks.\1.ff.fc_out.\2",
r"^norm_out\.(.*)$": r"norm_out.\1",
r"^proj_out\.(.*)$": r"proj_out.\1",
})

lora_param_names_mapping: dict = field(
default_factory=lambda: {
r"^transformer_blocks\.(\d+)\.attn1\.to_q\.(.*)$":
r"transformer_blocks.\1.attn1.to_q.\2",
r"^transformer_blocks\.(\d+)\.attn1\.to_k\.(.*)$":
r"transformer_blocks.\1.attn1.to_k.\2",
r"^transformer_blocks\.(\d+)\.attn1\.to_v\.(.*)$":
r"transformer_blocks.\1.attn1.to_v.\2",
r"^transformer_blocks\.(\d+)\.attn1\.to_out\.(.*)$":
r"transformer_blocks.\1.attn1.to_out.\2",
r"^transformer_blocks\.(\d+)\.attn2\.to_q\.(.*)$":
r"transformer_blocks.\1.attn2.to_q.\2",
r"^transformer_blocks\.(\d+)\.attn2\.to_k\.(.*)$":
r"transformer_blocks.\1.attn2.to_k.\2",
r"^transformer_blocks\.(\d+)\.attn2\.to_v\.(.*)$":
r"transformer_blocks.\1.attn2.to_v.\2",
r"^transformer_blocks\.(\d+)\.attn2\.to_out\.(.*)$":
r"transformer_blocks.\1.attn2.to_out.\2",
r"^transformer_blocks\.(\d+)\.ff\.(.*)$":
r"transformer_blocks.\1.ff.\2",
})

# Cosmos-specific config parameters based on transformer_cosmos.py
in_channels: int = 16
out_channels: int = 16
num_attention_heads: int = 16
attention_head_dim: int = 128
num_layers: int = 28
mlp_ratio: float = 4.0
text_embed_dim: int = 1024
adaln_lora_dim: int = 256
max_size: tuple[int, int, int] = (128, 240, 240)
patch_size: tuple[int, int, int] = (1, 2, 2)
rope_scale: tuple[float, float, float] = (1.0, 3.0, 3.0)
concat_padding_mask: bool = True
extra_pos_embed_type: str | None = None
qk_norm: str = "rms_norm"
eps: float = 1e-6
exclude_lora_layers: list[str] = field(default_factory=lambda: ["embedder"])


def __post_init__(self):
super().__post_init__()
self.out_channels = self.out_channels or self.in_channels
self.hidden_size = self.num_attention_heads * self.attention_head_dim
self.num_channels_latents = self.in_channels


@dataclass
class CosmosVideoConfig(DiTConfig):
arch_config: DiTArchConfig = field(default_factory=CosmosArchConfig)
prefix: str = "Cosmos"
4 changes: 2 additions & 2 deletions fastvideo/configs/models/encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from fastvideo.configs.models.encoders.clip import (CLIPTextConfig,
CLIPVisionConfig)
from fastvideo.configs.models.encoders.llama import LlamaConfig
from fastvideo.configs.models.encoders.t5 import T5Config
from fastvideo.configs.models.encoders.t5 import T5Config, T5LargeConfig

__all__ = [
"EncoderConfig", "TextEncoderConfig", "ImageEncoderConfig",
"BaseEncoderOutput", "CLIPTextConfig", "CLIPVisionConfig", "LlamaConfig",
"T5Config"
"T5Config", "T5LargeConfig"
]
22 changes: 22 additions & 0 deletions fastvideo/configs/models/encoders/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,30 @@ def __post_init__(self):
}


@dataclass
class T5LargeArchConfig(T5ArchConfig):
"""T5 Large architecture config with parameters for your specific model."""
d_model: int = 1024
d_kv: int = 128
d_ff: int = 65536
num_layers: int = 24
num_decoder_layers: int | None = 24
num_heads: int = 128
decoder_start_token_id: int = 0
n_positions: int = 512
task_specific_params: dict | None = None


@dataclass
class T5Config(TextEncoderConfig):
arch_config: TextEncoderArchConfig = field(default_factory=T5ArchConfig)

prefix: str = "t5"


@dataclass
class T5LargeConfig(TextEncoderConfig):
"""T5 Large configuration for your specific model."""
arch_config: TextEncoderArchConfig = field(default_factory=T5LargeArchConfig)

prefix: str = "t5"
2 changes: 2 additions & 0 deletions fastvideo/configs/models/vaes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from fastvideo.configs.models.vaes.hunyuanvae import HunyuanVAEConfig
from fastvideo.configs.models.vaes.stepvideovae import StepVideoVAEConfig
from fastvideo.configs.models.vaes.wanvae import WanVAEConfig
from fastvideo.configs.models.vaes.cosmosvae import CosmosVAEConfig

__all__ = [
"HunyuanVAEConfig",
"WanVAEConfig",
"StepVideoVAEConfig",
"CosmosVAEConfig",
]
88 changes: 88 additions & 0 deletions fastvideo/configs/models/vaes/cosmosvae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field

import torch

from fastvideo.configs.models.vaes.base import VAEArchConfig, VAEConfig


@dataclass
class CosmosVAEArchConfig(VAEArchConfig):
_class_name: str = "AutoencoderKLWan"
_diffusers_version: str = "0.34.0.dev0"
_name_or_path: str = ""
Comment on lines +11 to +13
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These fields should be pop/removed in the loader so can be removed. You can refer to how wan's vae config is defined

base_dim: int = 96
z_dim: int = 16
dim_mult: tuple[int, ...] = (1, 2, 4, 4)
num_res_blocks: int = 2
attn_scales: tuple[float, ...] = ()
temperal_downsample: tuple[bool, ...] = (False, True, True)
dropout: float = 0.0
decoder_base_dim: int | None = None
is_residual: bool = False
in_channels: int = 3
out_channels: int = 3
patch_size: int | None = None
scale_factor_temporal: int = 4
scale_factor_spatial: int = 8
clip_output: bool = True
latents_mean: tuple[float, ...] = (
-0.7571,
-0.7089,
-0.9113,
0.1075,
-0.1745,
0.9653,
-0.1517,
1.5508,
0.4134,
-0.0715,
0.5517,
-0.3632,
-0.1922,
-0.9497,
0.2503,
-0.2921,
)
latents_std: tuple[float, ...] = (
2.8184,
1.4541,
2.3275,
2.6558,
1.2196,
1.7708,
2.6052,
2.0743,
3.2687,
2.1526,
2.8652,
1.5579,
1.6382,
1.1253,
2.8251,
1.9160,
)
temporal_compression_ratio = 4
spatial_compression_ratio = 8

def __post_init__(self):
self.scaling_factor: torch.Tensor = 1.0 / torch.tensor(
self.latents_std).view(1, self.z_dim, 1, 1, 1)
self.shift_factor: torch.Tensor = torch.tensor(self.latents_mean).view(
1, self.z_dim, 1, 1, 1)
self.temporal_compression_ratio = self.scale_factor_temporal
self.spatial_compression_ratio = self.scale_factor_spatial


@dataclass
class CosmosVAEConfig(VAEConfig):
arch_config: CosmosVAEArchConfig = field(default_factory=CosmosVAEArchConfig)
use_feature_cache: bool = True

use_tiling: bool = False
use_temporal_tiling: bool = False
use_parallel_tiling: bool = False

def __post_init__(self):
self.blend_num_frames = (self.tile_sample_min_num_frames -
self.tile_sample_stride_num_frames) * 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Newline char

4 changes: 3 additions & 1 deletion fastvideo/configs/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from fastvideo.configs.pipelines.wan import (WanI2V480PConfig, WanI2V720PConfig,
WanT2V480PConfig, WanT2V720PConfig)

from fastvideo.configs.pipelines.cosmos import CosmosConfig

__all__ = [
"HunyuanConfig", "FastHunyuanConfig", "PipelineConfig",
"SlidingTileAttnConfig", "WanT2V480PConfig", "WanI2V480PConfig",
"WanT2V720PConfig", "WanI2V720PConfig", "StepVideoT2VConfig",
"get_pipeline_config_cls_from_name"
"CosmosConfig", "get_pipeline_config_cls_from_name"
]
Loading