diff --git a/fastvideo/configs/models/dits/__init__.py b/fastvideo/configs/models/dits/__init__.py index 72271a525..0abc716e2 100644 --- a/fastvideo/configs/models/dits/__init__.py +++ b/fastvideo/configs/models/dits/__init__.py @@ -1,5 +1,9 @@ +from fastvideo.configs.models.dits.cosmos import CosmosVideoConfig from fastvideo.configs.models.dits.hunyuanvideo import HunyuanVideoConfig from fastvideo.configs.models.dits.stepvideo import StepVideoConfig from fastvideo.configs.models.dits.wanvideo import WanVideoConfig -__all__ = ["HunyuanVideoConfig", "WanVideoConfig", "StepVideoConfig"] +__all__ = [ + "HunyuanVideoConfig", "WanVideoConfig", "StepVideoConfig", + "CosmosVideoConfig" +] diff --git a/fastvideo/configs/models/dits/cosmos.py b/fastvideo/configs/models/dits/cosmos.py new file mode 100644 index 000000000..b76e67ed9 --- /dev/null +++ b/fastvideo/configs/models/dits/cosmos.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from fastvideo.configs.models.dits.base import DiTArchConfig, DiTConfig + + +def is_transformer_blocks(n: str, m) -> bool: + return "transformer_blocks" in n and str.isdigit(n.split(".")[-1]) + + +@dataclass +class CosmosArchConfig(DiTArchConfig): + _fsdp_shard_conditions: list = field( + default_factory=lambda: [is_transformer_blocks]) + + param_names_mapping: dict = field( + default_factory=lambda: { + r"^patch_embed\.(.*)$": r"patch_embed.\1", + r"^time_embed\.time_proj\.(.*)$": r"time_embed.time_proj.\1", + r"^time_embed\.t_embedder\.(.*)$": r"time_embed.t_embedder.\1", + r"^time_embed\.norm\.(.*)$": r"time_embed.norm.\1", + r"^transformer_blocks\.(\d+)\.attn1\.to_q\.(.*)$": + r"transformer_blocks.\1.attn1.to_q.\2", + r"^transformer_blocks\.(\d+)\.attn1\.to_k\.(.*)$": + r"transformer_blocks.\1.attn1.to_k.\2", + r"^transformer_blocks\.(\d+)\.attn1\.to_v\.(.*)$": + r"transformer_blocks.\1.attn1.to_v.\2", + r"^transformer_blocks\.(\d+)\.attn1\.to_out\.0\.(.*)$": + r"transformer_blocks.\1.attn1.to_out.\2", + r"^transformer_blocks\.(\d+)\.attn1\.norm_q\.(.*)$": + r"transformer_blocks.\1.attn1.norm_q.\2", + r"^transformer_blocks\.(\d+)\.attn1\.norm_k\.(.*)$": + r"transformer_blocks.\1.attn1.norm_k.\2", + r"^transformer_blocks\.(\d+)\.attn2\.to_q\.(.*)$": + r"transformer_blocks.\1.attn2.to_q.\2", + r"^transformer_blocks\.(\d+)\.attn2\.to_k\.(.*)$": + r"transformer_blocks.\1.attn2.to_k.\2", + r"^transformer_blocks\.(\d+)\.attn2\.to_v\.(.*)$": + r"transformer_blocks.\1.attn2.to_v.\2", + r"^transformer_blocks\.(\d+)\.attn2\.to_out\.0\.(.*)$": + r"transformer_blocks.\1.attn2.to_out.\2", + r"^transformer_blocks\.(\d+)\.attn2\.norm_q\.(.*)$": + r"transformer_blocks.\1.attn2.norm_q.\2", + r"^transformer_blocks\.(\d+)\.attn2\.norm_k\.(.*)$": + r"transformer_blocks.\1.attn2.norm_k.\2", + r"^transformer_blocks\.(\d+)\.ff\.net\.0\.proj\.(.*)$": + r"transformer_blocks.\1.ff.fc_in.\2", + r"^transformer_blocks\.(\d+)\.ff\.net\.2\.(.*)$": + r"transformer_blocks.\1.ff.fc_out.\2", + r"^norm_out\.(.*)$": r"norm_out.\1", + r"^proj_out\.(.*)$": r"proj_out.\1", + }) + + lora_param_names_mapping: dict = field( + default_factory=lambda: { + r"^transformer_blocks\.(\d+)\.attn1\.to_q\.(.*)$": + r"transformer_blocks.\1.attn1.to_q.\2", + r"^transformer_blocks\.(\d+)\.attn1\.to_k\.(.*)$": + r"transformer_blocks.\1.attn1.to_k.\2", + r"^transformer_blocks\.(\d+)\.attn1\.to_v\.(.*)$": + r"transformer_blocks.\1.attn1.to_v.\2", + r"^transformer_blocks\.(\d+)\.attn1\.to_out\.(.*)$": + r"transformer_blocks.\1.attn1.to_out.\2", + r"^transformer_blocks\.(\d+)\.attn2\.to_q\.(.*)$": + r"transformer_blocks.\1.attn2.to_q.\2", + r"^transformer_blocks\.(\d+)\.attn2\.to_k\.(.*)$": + r"transformer_blocks.\1.attn2.to_k.\2", + r"^transformer_blocks\.(\d+)\.attn2\.to_v\.(.*)$": + r"transformer_blocks.\1.attn2.to_v.\2", + r"^transformer_blocks\.(\d+)\.attn2\.to_out\.(.*)$": + r"transformer_blocks.\1.attn2.to_out.\2", + r"^transformer_blocks\.(\d+)\.ff\.(.*)$": + r"transformer_blocks.\1.ff.\2", + }) + + # Cosmos-specific config parameters based on transformer_cosmos.py + in_channels: int = 16 + out_channels: int = 16 + num_attention_heads: int = 16 + attention_head_dim: int = 128 + num_layers: int = 28 + mlp_ratio: float = 4.0 + text_embed_dim: int = 1024 + adaln_lora_dim: int = 256 + max_size: tuple[int, int, int] = (128, 240, 240) + patch_size: tuple[int, int, int] = (1, 2, 2) + rope_scale: tuple[float, float, float] = (1.0, 3.0, 3.0) + concat_padding_mask: bool = True + extra_pos_embed_type: str | None = None + qk_norm: str = "rms_norm" + eps: float = 1e-6 + exclude_lora_layers: list[str] = field(default_factory=lambda: ["embedder"]) + + def __post_init__(self): + super().__post_init__() + self.out_channels = self.out_channels or self.in_channels + self.hidden_size = self.num_attention_heads * self.attention_head_dim + self.num_channels_latents = self.in_channels + + +@dataclass +class CosmosVideoConfig(DiTConfig): + arch_config: DiTArchConfig = field(default_factory=CosmosArchConfig) + prefix: str = "Cosmos" diff --git a/fastvideo/configs/models/encoders/__init__.py b/fastvideo/configs/models/encoders/__init__.py index f783a2106..e2aa77787 100644 --- a/fastvideo/configs/models/encoders/__init__.py +++ b/fastvideo/configs/models/encoders/__init__.py @@ -5,10 +5,10 @@ from fastvideo.configs.models.encoders.clip import (CLIPTextConfig, CLIPVisionConfig) from fastvideo.configs.models.encoders.llama import LlamaConfig -from fastvideo.configs.models.encoders.t5 import T5Config +from fastvideo.configs.models.encoders.t5 import T5Config, T5LargeConfig __all__ = [ "EncoderConfig", "TextEncoderConfig", "ImageEncoderConfig", "BaseEncoderOutput", "CLIPTextConfig", "CLIPVisionConfig", "LlamaConfig", - "T5Config" + "T5Config", "T5LargeConfig" ] diff --git a/fastvideo/configs/models/encoders/t5.py b/fastvideo/configs/models/encoders/t5.py index 70649551b..c1de3609c 100644 --- a/fastvideo/configs/models/encoders/t5.py +++ b/fastvideo/configs/models/encoders/t5.py @@ -70,8 +70,31 @@ def __post_init__(self): } +@dataclass +class T5LargeArchConfig(T5ArchConfig): + """T5 Large architecture config with parameters for your specific model.""" + d_model: int = 1024 + d_kv: int = 128 + d_ff: int = 65536 + num_layers: int = 24 + num_decoder_layers: int | None = 24 + num_heads: int = 128 + decoder_start_token_id: int = 0 + n_positions: int = 512 + task_specific_params: dict | None = None + + @dataclass class T5Config(TextEncoderConfig): arch_config: TextEncoderArchConfig = field(default_factory=T5ArchConfig) prefix: str = "t5" + + +@dataclass +class T5LargeConfig(TextEncoderConfig): + """T5 Large configuration for your specific model.""" + arch_config: TextEncoderArchConfig = field( + default_factory=T5LargeArchConfig) + + prefix: str = "t5" diff --git a/fastvideo/configs/models/vaes/__init__.py b/fastvideo/configs/models/vaes/__init__.py index 700c8de1b..12bf5c609 100644 --- a/fastvideo/configs/models/vaes/__init__.py +++ b/fastvideo/configs/models/vaes/__init__.py @@ -1,3 +1,4 @@ +from fastvideo.configs.models.vaes.cosmosvae import CosmosVAEConfig from fastvideo.configs.models.vaes.hunyuanvae import HunyuanVAEConfig from fastvideo.configs.models.vaes.stepvideovae import StepVideoVAEConfig from fastvideo.configs.models.vaes.wanvae import WanVAEConfig @@ -6,4 +7,5 @@ "HunyuanVAEConfig", "WanVAEConfig", "StepVideoVAEConfig", + "CosmosVAEConfig", ] diff --git a/fastvideo/configs/models/vaes/cosmosvae.py b/fastvideo/configs/models/vaes/cosmosvae.py new file mode 100644 index 000000000..4680986f3 --- /dev/null +++ b/fastvideo/configs/models/vaes/cosmosvae.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +import torch + +from fastvideo.configs.models.vaes.base import VAEArchConfig, VAEConfig + + +@dataclass +class CosmosVAEArchConfig(VAEArchConfig): + _name_or_path: str = "" + base_dim: int = 96 + z_dim: int = 16 + dim_mult: tuple[int, ...] = (1, 2, 4, 4) + num_res_blocks: int = 2 + attn_scales: tuple[float, ...] = () + temperal_downsample: tuple[bool, ...] = (False, True, True) + dropout: float = 0.0 + decoder_base_dim: int | None = None + is_residual: bool = False + in_channels: int = 3 + out_channels: int = 3 + patch_size: int | None = None + scale_factor_temporal: int = 4 + scale_factor_spatial: int = 8 + clip_output: bool = True + latents_mean: tuple[float, ...] = ( + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ) + latents_std: tuple[float, ...] = ( + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ) + temporal_compression_ratio = 4 + spatial_compression_ratio = 8 + + def __post_init__(self): + self.scaling_factor: torch.Tensor = 1.0 / torch.tensor( + self.latents_std).view(1, self.z_dim, 1, 1, 1) + self.shift_factor: torch.Tensor = torch.tensor(self.latents_mean).view( + 1, self.z_dim, 1, 1, 1) + self.temporal_compression_ratio = self.scale_factor_temporal + self.spatial_compression_ratio = self.scale_factor_spatial + + +@dataclass +class CosmosVAEConfig(VAEConfig): + arch_config: CosmosVAEArchConfig = field( + default_factory=CosmosVAEArchConfig) + use_feature_cache: bool = True + + use_tiling: bool = False + use_temporal_tiling: bool = False + use_parallel_tiling: bool = False + + def __post_init__(self): + self.blend_num_frames = (self.tile_sample_min_num_frames - + self.tile_sample_stride_num_frames) * 2 diff --git a/fastvideo/configs/pipelines/__init__.py b/fastvideo/configs/pipelines/__init__.py index 82721919a..639ef3aa0 100644 --- a/fastvideo/configs/pipelines/__init__.py +++ b/fastvideo/configs/pipelines/__init__.py @@ -1,5 +1,6 @@ from fastvideo.configs.pipelines.base import (PipelineConfig, SlidingTileAttnConfig) +from fastvideo.configs.pipelines.cosmos import CosmosConfig from fastvideo.configs.pipelines.hunyuan import FastHunyuanConfig, HunyuanConfig from fastvideo.configs.pipelines.registry import ( get_pipeline_config_cls_from_name) @@ -11,5 +12,5 @@ "HunyuanConfig", "FastHunyuanConfig", "PipelineConfig", "SlidingTileAttnConfig", "WanT2V480PConfig", "WanI2V480PConfig", "WanT2V720PConfig", "WanI2V720PConfig", "StepVideoT2VConfig", - "get_pipeline_config_cls_from_name" + "CosmosConfig", "get_pipeline_config_cls_from_name" ] diff --git a/fastvideo/configs/pipelines/cosmos.py b/fastvideo/configs/pipelines/cosmos.py new file mode 100644 index 000000000..3ca78fe0f --- /dev/null +++ b/fastvideo/configs/pipelines/cosmos.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +from collections.abc import Callable +from dataclasses import dataclass, field + +import torch + +from fastvideo.configs.models import DiTConfig, EncoderConfig, VAEConfig +from fastvideo.configs.models.dits import CosmosVideoConfig +from fastvideo.configs.models.encoders import BaseEncoderOutput, T5LargeConfig +from fastvideo.configs.models.vaes import CosmosVAEConfig +from fastvideo.configs.pipelines.base import PipelineConfig + + +def t5_large_postprocess_text(outputs: BaseEncoderOutput) -> torch.Tensor: + """Postprocess T5 Large text encoder outputs for Cosmos pipeline. + + Return raw last_hidden_state without truncation/padding. + """ + hidden_state = outputs.last_hidden_state + + if hidden_state is None: + raise ValueError("T5 Large outputs missing last_hidden_state") + + nan_count = torch.isnan(hidden_state).sum() + if nan_count > 0: + hidden_state = hidden_state.masked_fill(torch.isnan(hidden_state), 0.0) + + return hidden_state + + +@dataclass +class CosmosConfig(PipelineConfig): + """Configuration for Cosmos2 Video2World pipeline matching diffusers.""" + + dit_config: DiTConfig = field(default_factory=CosmosVideoConfig) + + vae_config: VAEConfig = field(default_factory=CosmosVAEConfig) + + text_encoder_configs: tuple[EncoderConfig, ...] = field( + default_factory=lambda: (T5LargeConfig(), )) + postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.Tensor], + ...] = field(default_factory=lambda: + (t5_large_postprocess_text, )) + + dit_precision: str = "bf16" + vae_precision: str = "fp16" + text_encoder_precisions: tuple[str, ...] = field( + default_factory=lambda: ("bf16", )) + + conditioning_strategy: str = "frame_replace" + min_num_conditional_frames: int = 1 + max_num_conditional_frames: int = 2 + sigma_conditional: float = 0.0001 + sigma_data: float = 1.0 + state_ch: int = 16 + state_t: int = 24 + text_encoder_class: str = "T5" + + embedded_cfg_scale: int = 6 + flow_shift: float = 1.0 + + def __post_init__(self): + self.vae_config.load_encoder = True + self.vae_config.load_decoder = True + + self._vae_latent_dim = 16 diff --git a/fastvideo/configs/pipelines/registry.py b/fastvideo/configs/pipelines/registry.py index cd8fb9938..979692628 100644 --- a/fastvideo/configs/pipelines/registry.py +++ b/fastvideo/configs/pipelines/registry.py @@ -5,6 +5,7 @@ from collections.abc import Callable from fastvideo.configs.pipelines.base import PipelineConfig +from fastvideo.configs.pipelines.cosmos import CosmosConfig from fastvideo.configs.pipelines.hunyuan import FastHunyuanConfig, HunyuanConfig from fastvideo.configs.pipelines.stepvideo import StepVideoT2VConfig from fastvideo.configs.pipelines.wan import (FastWan2_1_T2V_480P_Config, @@ -34,6 +35,7 @@ "Wan-AI/Wan2.2-TI2V-5B-Diffusers": WanT2V720PConfig, "Wan-AI/Wan2.2-T2V-A14B-Diffusers": WanT2V480PConfig, "Wan-AI/Wan2.2-I2V-A14B-Diffusers": WanI2V480PConfig, + "nvidia/Cosmos-Predict2-2B-Video2World": CosmosConfig, # Add other specific weight variants } @@ -44,6 +46,7 @@ "wanimagetovideo": lambda id: "wanimagetovideo" in id.lower(), "wandmdpipeline": lambda id: "wandmdpipeline" in id.lower(), "stepvideo": lambda id: "stepvideo" in id.lower(), + "cosmos": lambda id: "cosmos" in id.lower(), # Add other pipeline architecture detectors } diff --git a/fastvideo/configs/sample/cosmos.py b/fastvideo/configs/sample/cosmos.py new file mode 100644 index 000000000..32886151e --- /dev/null +++ b/fastvideo/configs/sample/cosmos.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass + +from fastvideo.configs.sample.base import SamplingParam + + +@dataclass +class Cosmos_Predict2_2B_Video2World_SamplingParam(SamplingParam): + # Video parameters + height: int = 704 + width: int = 1280 + num_frames: int = 93 + fps: int = 16 + + # Denoising stage + guidance_scale: float = 7.0 + negative_prompt: str = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality." + num_inference_steps: int = 35 diff --git a/fastvideo/image_processor.py b/fastvideo/image_processor.py new file mode 100644 index 000000000..3483631f5 --- /dev/null +++ b/fastvideo/image_processor.py @@ -0,0 +1,195 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Minimal image processing utilities for FastVideo. +This module provides lightweight image preprocessing without external dependencies beyond PyTorch/NumPy/PIL. +""" + +import numpy as np +import PIL.Image +import torch + + +class ImageProcessor: + """ + Minimal image processor for video frame preprocessing. + + This is a lightweight alternative to diffusers.VideoProcessor that handles: + - PIL image to tensor conversion + - Resizing to specified dimensions + - Normalization to [-1, 1] range + + Args: + vae_scale_factor: The VAE scale factor used to ensure dimensions are multiples of this value. + """ + + def __init__(self, vae_scale_factor: int = 8) -> None: + self.vae_scale_factor = vae_scale_factor + + def preprocess( + self, + image: PIL.Image.Image | np.ndarray | torch.Tensor, + height: int | None = None, + width: int | None = None, + ) -> torch.Tensor: + """ + Preprocess an image to a normalized torch tensor. + + Args: + image: Input image (PIL Image, NumPy array, or torch tensor) + height: Target height. If None, uses image's original height. + width: Target width. If None, uses image's original width. + + Returns: + torch.Tensor: Normalized tensor of shape (1, 3, height, width) or (1, 1, height, width) for grayscale, + with values in range [-1, 1]. + """ + # Handle different input types + if isinstance(image, PIL.Image.Image): + return self._preprocess_pil(image, height, width) + elif isinstance(image, np.ndarray): + return self._preprocess_numpy(image, height, width) + elif isinstance(image, torch.Tensor): + return self._preprocess_tensor(image, height, width) + else: + raise ValueError( + f"Unsupported image type: {type(image)}. " + "Supported types: PIL.Image.Image, np.ndarray, torch.Tensor") + + def _preprocess_pil( + self, + image: PIL.Image.Image, + height: int | None = None, + width: int | None = None, + ) -> torch.Tensor: + """Preprocess a PIL image.""" + if height is None: + height = image.height + if width is None: + width = image.width + + height = height - (height % self.vae_scale_factor) + width = width - (width % self.vae_scale_factor) + + image = image.resize((width, height), + resample=PIL.Image.Resampling.LANCZOS) + + image_np = np.array(image, dtype=np.float32) / 255.0 + + if image_np.ndim == 2: # Grayscale + image_np = np.expand_dims(image_np, axis=-1) + + return self._normalize_to_tensor(image_np) + + def _preprocess_numpy( + self, + image: np.ndarray, + height: int | None = None, + width: int | None = None, + ) -> torch.Tensor: + """Preprocess a numpy array.""" + # Determine target dimensions if not provided + if image.ndim == 3: + img_height, img_width = image.shape[:2] + elif image.ndim == 2: + img_height, img_width = image.shape + else: + raise ValueError(f"Expected 2D or 3D array, got {image.ndim}D") + + if height is None: + height = img_height + if width is None: + width = img_width + + height = height - (height % self.vae_scale_factor) + width = width - (width % self.vae_scale_factor) + + if image.dtype == np.uint8: + pil_image = PIL.Image.fromarray(image) + else: + # Assume normalized [0, 1] or similar + if image.max() <= 1.0: + image_uint8 = (image * 255).astype(np.uint8) + else: + image_uint8 = image.astype(np.uint8) + pil_image = PIL.Image.fromarray(image_uint8) + + pil_image = pil_image.resize((width, height), + resample=PIL.Image.Resampling.LANCZOS) + image_np = np.array(pil_image, dtype=np.float32) / 255.0 + + # Ensure 3D shape + if image_np.ndim == 2: + image_np = np.expand_dims(image_np, axis=-1) + + return self._normalize_to_tensor(image_np) + + def _preprocess_tensor( + self, + image: torch.Tensor, + height: int | None = None, + width: int | None = None, + ) -> torch.Tensor: + """Preprocess a torch tensor.""" + # Determine target dimensions + if image.ndim == 3: # (H, W, C) or (C, H, W) + if image.shape[0] in (1, 3, 4): # Likely (C, H, W) + img_height, img_width = image.shape[1], image.shape[2] + else: # Likely (H, W, C) + img_height, img_width = image.shape[0], image.shape[1] + elif image.ndim == 2: # (H, W) + img_height, img_width = image.shape + else: + raise ValueError(f"Expected 2D or 3D tensor, got {image.ndim}D") + + if height is None: + height = img_height + if width is None: + width = img_width + + height = height - (height % self.vae_scale_factor) + width = width - (width % self.vae_scale_factor) + + if image.ndim == 2: + image = image.unsqueeze(0).unsqueeze(0) # (1, 1, H, W) + elif image.ndim == 3: + if image.shape[0] in (1, 3, 4): # (C, H, W) + image = image.unsqueeze(0) # (1, C, H, W) + else: # (H, W, C) - need to rearrange + image = image.permute(2, 0, 1).unsqueeze(0) # (1, C, H, W) + + image = torch.nn.functional.interpolate(image, + size=(height, width), + mode="bilinear", + align_corners=False) + + if image.max() > 1.0: # Assume [0, 255] range + image = image / 255.0 + + image = 2.0 * image - 1.0 + + return image + + def _normalize_to_tensor(self, image_np: np.ndarray) -> torch.Tensor: + """ + Convert normalized numpy array [0, 1] to torch tensor [-1, 1]. + + Args: + image_np: NumPy array with shape (H, W) or (H, W, C) with values in [0, 1] + + Returns: + torch.Tensor: Shape (1, C, H, W) or (1, 1, H, W) with values in [-1, 1] + """ + # Convert to tensor + if image_np.ndim == 2: # (H, W) - grayscale + tensor = torch.from_numpy(image_np).unsqueeze(0).unsqueeze( + 0) # (1, 1, H, W) + elif image_np.ndim == 3: # (H, W, C) + tensor = torch.from_numpy(image_np).permute(2, 0, 1).unsqueeze( + 0) # (1, C, H, W) + else: + raise ValueError(f"Expected 2D or 3D array, got {image_np.ndim}D") + + # Normalize to [-1, 1] + tensor = 2.0 * tensor - 1.0 + + return tensor diff --git a/fastvideo/layers/layernorm.py b/fastvideo/layers/layernorm.py index fd81db6d6..ad234d552 100644 --- a/fastvideo/layers/layernorm.py +++ b/fastvideo/layers/layernorm.py @@ -39,6 +39,22 @@ def __init__( if self.has_weight: self.weight = nn.Parameter(self.weight) + def forward_diffusers(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Forward method that matches Diffusers RMSNorm implementation exactly.""" + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + + self.variance_epsilon) + + if self.has_weight and self.weight is not None: + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + hidden_states = hidden_states * self.weight + else: + hidden_states = hidden_states.to(input_dtype) + + return hidden_states + # if we do fully_shard(model.layer_norm), and we call layer_form.forward_native(input) instead of layer_norm(input), # we need to call model.layer_norm.register_fsdp_forward_method(model, "forward_native") to make sure fsdp2 hooks are triggered # for mixed precision and cpu offloading diff --git a/fastvideo/layers/rotary_embedding.py b/fastvideo/layers/rotary_embedding.py index 6abe90609..fb7f2a2a4 100644 --- a/fastvideo/layers/rotary_embedding.py +++ b/fastvideo/layers/rotary_embedding.py @@ -44,6 +44,59 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: return x.flatten(-2) +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> torch.Tensor: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + # Match Diffusers broadcasting (sequence_dim=2 case) + cos = cos[None, None, :, :] + sin = sin[None, None, :, :] + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, + 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio, OmniGen, CogView4 and Cosmos + x_real, x_imag = x.reshape(*x.shape[:-1], 2, + -1).unbind(-2) # [B, S, H, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError( + f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2." + ) + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + # used for lumina + x_rotated = torch.view_as_complex(x.float().reshape( + *x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) + + def _apply_rotary_emb( x: torch.Tensor, cos: torch.Tensor, diff --git a/fastvideo/layers/visual_embedding.py b/fastvideo/layers/visual_embedding.py index 954cd925c..e03373be4 100644 --- a/fastvideo/layers/visual_embedding.py +++ b/fastvideo/layers/visual_embedding.py @@ -173,3 +173,79 @@ def unpatchify(x, t, h, w, patch_size, channels) -> torch.Tensor: imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) return imgs + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> torch.Tensor: + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class Timesteps(nn.Module): + + def __init__(self, + num_channels: int, + flip_sin_to_cos: bool, + downscale_freq_shift: float, + scale: int = 1): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb \ No newline at end of file diff --git a/fastvideo/models/dits/cosmos.py b/fastvideo/models/dits/cosmos.py new file mode 100644 index 000000000..fbbdbced0 --- /dev/null +++ b/fastvideo/models/dits/cosmos.py @@ -0,0 +1,726 @@ +# SPDX-License-Identifier: Apache-2.0 + +import math +from typing import Any + +import numpy as np +import torch +import torch.nn as nn + +from fastvideo.attention import DistributedAttention, LocalAttention +from fastvideo.configs.models.dits.cosmos import CosmosVideoConfig +from fastvideo.forward_context import get_forward_context +from fastvideo.layers.layernorm import RMSNorm +from fastvideo.layers.linear import ReplicatedLinear +from fastvideo.layers.mlp import MLP +from fastvideo.layers.rotary_embedding import apply_rotary_emb +from fastvideo.layers.visual_embedding import Timesteps +from fastvideo.models.dits.base import BaseDiT +from fastvideo.platforms import AttentionBackendEnum + + +class CosmosPatchEmbed(nn.Module): + + def __init__(self, + in_channels: int, + out_channels: int, + patch_size: tuple[int, int, int], + bias: bool = True) -> None: + super().__init__() + self.patch_size = patch_size + + self.proj = nn.Linear(in_channels * patch_size[0] * patch_size[1] * + patch_size[2], + out_channels, + bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + hidden_states = hidden_states.reshape(batch_size, num_channels, + num_frames // p_t, p_t, + height // p_h, p_h, width // p_w, + p_w) + hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, + 7).flatten(4, 7) + hidden_states = self.proj(hidden_states) + return hidden_states + + +class CosmosTimestepEmbedding(nn.Module): + + def __init__(self, in_features: int, out_features: int) -> None: + super().__init__() + self.linear_1 = nn.Linear(in_features, out_features, bias=False) + self.activation = nn.SiLU() + self.linear_2 = nn.Linear(out_features, 3 * out_features, bias=False) + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + emb = self.linear_1(timesteps) + emb = self.activation(emb) + emb = self.linear_2(emb) + return emb + + +class CosmosEmbedding(nn.Module): + + def __init__(self, embedding_dim: int, condition_dim: int) -> None: + super().__init__() + + self.time_proj = Timesteps(embedding_dim, + flip_sin_to_cos=True, + downscale_freq_shift=0.0) + self.t_embedder = CosmosTimestepEmbedding(embedding_dim, condition_dim) + self.norm = RMSNorm(embedding_dim, eps=1e-6) + + def forward(self, hidden_states: torch.Tensor, + timestep: torch.LongTensor) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep).type_as(hidden_states) + temb = self.t_embedder(timesteps_proj) + embedded_timestep = self.norm(timesteps_proj) + return temb, embedded_timestep + + +class CosmosAdaLayerNorm(nn.Module): + + def __init__(self, in_features: int, hidden_features: int) -> None: + super().__init__() + self.embedding_dim = in_features + + self.activation = nn.SiLU() + self.norm = nn.LayerNorm(in_features, + elementwise_affine=False, + eps=1e-6) + self.linear_1 = nn.Linear(in_features, hidden_features, bias=False) + self.linear_2 = nn.Linear(hidden_features, 2 * in_features, bias=False) + + def forward(self, + hidden_states: torch.Tensor, + embedded_timestep: torch.Tensor, + temb: torch.Tensor | None = None) -> torch.Tensor: + embedded_timestep = self.activation(embedded_timestep) + embedded_timestep = self.linear_1(embedded_timestep) + embedded_timestep = self.linear_2(embedded_timestep) + + if temb is not None: + embedded_timestep = embedded_timestep + temb[..., :2 * + self.embedding_dim] + + shift, scale = embedded_timestep.chunk(2, dim=-1) + with torch.autocast(device_type="cuda", enabled=False): + hidden_states = self.norm(hidden_states) + + if embedded_timestep.ndim == 2: + shift, scale = (x.unsqueeze(1) for x in (shift, scale)) + + hidden_states = hidden_states * (1 + scale) + shift + return hidden_states + + +class CosmosAdaLayerNormZero(nn.Module): + + def __init__(self, + in_features: int, + hidden_features: int | None = None) -> None: + super().__init__() + + self.norm = nn.LayerNorm(in_features, + elementwise_affine=False, + eps=1e-6) + self.activation = nn.SiLU() + + if hidden_features is None: + self.linear_1 = nn.Identity() + else: + self.linear_1 = nn.Linear(in_features, hidden_features, bias=False) + + self.linear_2 = nn.Linear(hidden_features, 3 * in_features, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + embedded_timestep: torch.Tensor, + temb: torch.Tensor | None = None, + ) -> torch.Tensor: + embedded_timestep = self.activation(embedded_timestep) + embedded_timestep = self.linear_1(embedded_timestep) + embedded_timestep = self.linear_2(embedded_timestep) + + if temb is not None: + embedded_timestep = embedded_timestep + temb + + shift, scale, gate = embedded_timestep.chunk(3, dim=-1) + + with torch.autocast(device_type="cuda", enabled=False): + hidden_states = self.norm(hidden_states) + + if embedded_timestep.ndim == 2: + shift, scale, gate = (x.unsqueeze(1) for x in (shift, scale, gate)) + + hidden_states = hidden_states * (1 + scale) + shift + return hidden_states, gate + + +class CosmosSelfAttention(nn.Module): + + def __init__(self, + dim: int, + num_heads: int, + qk_norm=True, + eps=1e-6, + prefix: str = "") -> None: + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qk_norm = qk_norm + self.eps = eps + + # layers - use standard PyTorch layers when using torch backend + self.to_q = nn.Linear(dim, dim, bias=False) + self.to_k = nn.Linear(dim, dim, bias=False) + self.to_v = nn.Linear(dim, dim, bias=False) + self.to_out = nn.Linear(dim, dim, bias=False) + self.dropout = nn.Dropout(0.0) + + self.norm_q = RMSNorm(self.head_dim, + eps=eps) if qk_norm else nn.Identity() + self.norm_k = RMSNorm(self.head_dim, + eps=eps) if qk_norm else nn.Identity() + + def forward(self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: torch.Tensor | None = None) -> torch.Tensor: + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + # Get QKV + query = self.to_q(hidden_states) + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + # Reshape for multi-head attention + query = query.unflatten(2, (self.num_heads, -1)).transpose(1, 2) + key = key.unflatten(2, (self.num_heads, -1)).transpose(1, 2) + value = value.unflatten(2, (self.num_heads, -1)).transpose(1, 2) + + # Apply normalization + if self.norm_q is not None: + query = self.norm_q(query) + if self.norm_k is not None: + key = self.norm_k(key) + + # Apply RoPE if provided + if image_rotary_emb is not None: + query = apply_rotary_emb(query, + image_rotary_emb, + use_real=True, + use_real_unbind_dim=-2) + key = apply_rotary_emb(key, + image_rotary_emb, + use_real=True, + use_real_unbind_dim=-2) + + # Prepare for GQA (Grouped Query Attention) + if torch.onnx.is_in_onnx_export(): + query_idx = torch.tensor(query.size(3), device=query.device) + key_idx = torch.tensor(key.size(3), device=key.device) + value_idx = torch.tensor(value.size(3), device=value.device) + else: + query_idx = query.size(3) + key_idx = key.size(3) + value_idx = value.size(3) + key = key.repeat_interleave(query_idx // key_idx, dim=3) + value = value.repeat_interleave(query_idx // value_idx, dim=3) + + # Attention computation + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + attn_output = attn_output.transpose(1, 2).flatten(2, 3).type_as(query) + + # Output projection + attn_output = self.to_out(attn_output) + attn_output = self.dropout(attn_output) + + return attn_output + + +class CosmosCrossAttention(nn.Module): + + def __init__(self, + dim: int, + cross_attention_dim: int, + num_heads: int, + qk_norm=True, + eps=1e-6, + prefix: str = "") -> None: + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.cross_attention_dim = cross_attention_dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qk_norm = qk_norm + self.eps = eps + + self.to_q = nn.Linear(dim, dim, bias=False) + self.to_k = nn.Linear(cross_attention_dim, dim, bias=False) + self.to_v = nn.Linear(cross_attention_dim, dim, bias=False) + self.to_out = nn.Linear(dim, dim, bias=False) + self.dropout = nn.Dropout(0.0) + + self.norm_q = RMSNorm(self.head_dim, + eps=eps) if qk_norm else nn.Identity() + self.norm_k = RMSNorm(self.head_dim, + eps=eps) if qk_norm else nn.Identity() + + def forward(self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None) -> torch.Tensor: + + # Get QKV + query = self.to_q(hidden_states) + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + # Reshape for multi-head attention + query = query.unflatten(2, (self.num_heads, -1)).transpose(1, 2) + key = key.unflatten(2, (self.num_heads, -1)).transpose(1, 2) + value = value.unflatten(2, (self.num_heads, -1)).transpose(1, 2) + + # Apply normalization + if self.norm_q is not None: + query = self.norm_q(query) + if self.norm_k is not None: + key = self.norm_k(key) + + # Prepare for GQA (Grouped Query Attention) + if torch.onnx.is_in_onnx_export(): + query_idx = torch.tensor(query.size(3), device=query.device) + key_idx = torch.tensor(key.size(3), device=key.device) + value_idx = torch.tensor(value.size(3), device=value.device) + else: + query_idx = query.size(3) + key_idx = key.size(3) + value_idx = value.size(3) + key = key.repeat_interleave(query_idx // key_idx, dim=3) + value = value.repeat_interleave(query_idx // value_idx, dim=3) + + # Attention computation + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + attn_output = attn_output.transpose(1, 2).flatten(2, 3).type_as(query) + + # Output projection + attn_output = self.to_out(attn_output) + attn_output = self.dropout(attn_output) + + return attn_output + + +class CosmosTransformerBlock(nn.Module): + + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: int, + mlp_ratio: float = 4.0, + adaln_lora_dim: int = 256, + qk_norm: str = "rms_norm", + out_bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.norm1 = CosmosAdaLayerNormZero(in_features=hidden_size, + hidden_features=adaln_lora_dim) + self.attn1 = CosmosSelfAttention( + dim=hidden_size, + num_heads=num_attention_heads, + qk_norm=(qk_norm == "rms_norm"), + eps=1e-5, + prefix=f"{prefix}.attn1") + + self.norm2 = CosmosAdaLayerNormZero(in_features=hidden_size, + hidden_features=adaln_lora_dim) + self.attn2 = CosmosCrossAttention( + dim=hidden_size, + cross_attention_dim=cross_attention_dim, + num_heads=num_attention_heads, + qk_norm=(qk_norm == "rms_norm"), + eps=1e-5, + prefix=f"{prefix}.attn2") + + self.norm3 = CosmosAdaLayerNormZero(in_features=hidden_size, + hidden_features=adaln_lora_dim) + self.ff = MLP(hidden_size, + int(hidden_size * mlp_ratio), + act_type="gelu", + bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + embedded_timestep: torch.Tensor, + temb: torch.Tensor | None = None, + image_rotary_emb: torch.Tensor | None = None, + extra_pos_emb: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + if extra_pos_emb is not None: + hidden_states = hidden_states + extra_pos_emb + + norm_hidden_states, gate = self.norm1(hidden_states, embedded_timestep, + temb) + + attn_output = self.attn1(norm_hidden_states, + image_rotary_emb=image_rotary_emb) + hidden_states = hidden_states + gate * attn_output + + norm_hidden_states, gate = self.norm2(hidden_states, embedded_timestep, + temb) + attn_output = self.attn2(norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask) + + hidden_states = hidden_states + gate * attn_output + + norm_hidden_states, gate = self.norm3(hidden_states, embedded_timestep, + temb) + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + gate * ff_output + + return hidden_states + + +class CosmosRotaryPosEmbed(nn.Module): + + def __init__( + self, + hidden_size: int, + max_size: tuple[int, int, int] = (128, 240, 240), + patch_size: tuple[int, int, int] = (1, 2, 2), + base_fps: int = 24, + rope_scale: tuple[float, float, float] = (2.0, 1.0, 1.0), + ) -> None: + super().__init__() + + self.max_size = [ + size // patch + for size, patch in zip(max_size, patch_size, strict=False) + ] + self.patch_size = patch_size + self.base_fps = base_fps + + self.dim_h = hidden_size // 6 * 2 + self.dim_w = hidden_size // 6 * 2 + self.dim_t = hidden_size - self.dim_h - self.dim_w + + self.h_ntk_factor = rope_scale[1]**(self.dim_h / (self.dim_h - 2)) + self.w_ntk_factor = rope_scale[2]**(self.dim_w / (self.dim_w - 2)) + self.t_ntk_factor = rope_scale[0]**(self.dim_t / (self.dim_t - 2)) + + + def forward(self, + hidden_states: torch.Tensor, + fps: int | None = None) -> tuple[torch.Tensor, torch.Tensor]: + fps = 16 + batch_size, num_channels, num_frames, height, width = hidden_states.shape + pe_size = [ + num_frames // self.patch_size[0], height // self.patch_size[1], + width // self.patch_size[2] + ] + device = hidden_states.device + + h_theta = 10000.0 * self.h_ntk_factor + w_theta = 10000.0 * self.w_ntk_factor + t_theta = 10000.0 * self.t_ntk_factor + + seq = torch.arange(max(self.max_size), + device=device, + dtype=torch.float32) + dim_h_range = ( + torch.arange(0, self.dim_h, 2, device=device, + dtype=torch.float32)[:(self.dim_h // 2)] / self.dim_h) + dim_w_range = ( + torch.arange(0, self.dim_w, 2, device=device, + dtype=torch.float32)[:(self.dim_w // 2)] / self.dim_w) + dim_t_range = ( + torch.arange(0, self.dim_t, 2, device=device, + dtype=torch.float32)[:(self.dim_t // 2)] / self.dim_t) + + h_spatial_freqs = 1.0 / (h_theta**dim_h_range) + w_spatial_freqs = 1.0 / (w_theta**dim_w_range) + temporal_freqs = 1.0 / (t_theta**dim_t_range) + + emb_h = torch.outer(seq[:pe_size[1]], + h_spatial_freqs)[None, :, None, :].repeat( + pe_size[0], 1, pe_size[2], 1) + emb_w = torch.outer(seq[:pe_size[2]], + w_spatial_freqs)[None, None, :, :].repeat( + pe_size[0], pe_size[1], 1, 1) + + if fps is None: + emb_t = torch.outer(seq[:pe_size[0]], temporal_freqs) + else: + temporal_scale = seq[:pe_size[0]] / fps * self.base_fps + emb_t = torch.outer(temporal_scale, + temporal_freqs) + + emb_t = emb_t[:, None, None, :].repeat(1, pe_size[1], pe_size[2], 1) + freqs = torch.cat([emb_t, emb_h, emb_w] * 2, dim=-1).flatten(0, + 2).float() + cos = torch.cos(freqs) + sin = torch.sin(freqs) + return cos, sin + + +class CosmosLearnablePositionalEmbed(nn.Module): + + def __init__( + self, + hidden_size: int, + max_size: tuple[int, int, int], + patch_size: tuple[int, int, int], + eps: float = 1e-6, + ) -> None: + super().__init__() + + self.max_size = [ + size // patch + for size, patch in zip(max_size, patch_size, strict=False) + ] + self.patch_size = patch_size + self.eps = eps + + self.pos_emb_t = nn.Parameter(torch.zeros(self.max_size[0], + hidden_size)) + self.pos_emb_h = nn.Parameter(torch.zeros(self.max_size[1], + hidden_size)) + self.pos_emb_w = nn.Parameter(torch.zeros(self.max_size[2], + hidden_size)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + pe_size = [ + num_frames // self.patch_size[0], height // self.patch_size[1], + width // self.patch_size[2] + ] + + emb_t = self.pos_emb_t[:pe_size[0]][None, :, None, None, :].repeat( + batch_size, 1, pe_size[1], pe_size[2], 1) + emb_h = self.pos_emb_h[:pe_size[1]][None, None, :, None, :].repeat( + batch_size, pe_size[0], 1, pe_size[2], 1) + emb_w = self.pos_emb_w[:pe_size[2]][None, None, None, :, :].repeat( + batch_size, pe_size[0], pe_size[1], 1, 1) + emb = emb_t + emb_h + emb_w + emb = emb.flatten(1, 3) + + norm = torch.linalg.vector_norm(emb, + dim=-1, + keepdim=True, + dtype=torch.float32) + norm = torch.add(self.eps, + norm, + alpha=np.sqrt(norm.numel() / emb.numel())) + return (emb / norm).type_as(hidden_states) + + +class CosmosTransformer3DModel(BaseDiT): + _fsdp_shard_conditions = CosmosVideoConfig()._fsdp_shard_conditions + _compile_conditions = CosmosVideoConfig()._compile_conditions + # _supported_attention_backends = CosmosVideoConfig()._supported_attention_backends + param_names_mapping = CosmosVideoConfig().param_names_mapping + lora_param_names_mapping = CosmosVideoConfig().lora_param_names_mapping + + def __init__(self, config: CosmosVideoConfig, hf_config: dict[str, Any]) -> None: + super().__init__(config=config, hf_config=hf_config) + + inner_dim = config.num_attention_heads * config.attention_head_dim + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.in_channels = config.in_channels + self.out_channels = config.out_channels + self.num_channels_latents = config.num_channels_latents + self.patch_size = config.patch_size + self.max_size = config.max_size + self.rope_scale = config.rope_scale + self.concat_padding_mask = config.concat_padding_mask + self.extra_pos_embed_type = config.extra_pos_embed_type + + # 1. Patch Embedding + patch_embed_in_channels = config.in_channels + 1 if config.concat_padding_mask else config.in_channels + self.patch_embed = CosmosPatchEmbed(patch_embed_in_channels, + inner_dim, + config.patch_size, + bias=False) + + # 2. Positional Embedding + self.rope = CosmosRotaryPosEmbed(hidden_size=config.attention_head_dim, + max_size=config.max_size, + patch_size=config.patch_size, + rope_scale=config.rope_scale) + + self.learnable_pos_embed = None + if config.extra_pos_embed_type == "learnable": + self.learnable_pos_embed = CosmosLearnablePositionalEmbed( + hidden_size=inner_dim, + max_size=config.max_size, + patch_size=config.patch_size, + ) + + # 3. Time Embedding + self.time_embed = CosmosEmbedding(inner_dim, inner_dim) + + # 4. Transformer Blocks + self.transformer_blocks = nn.ModuleList([ + CosmosTransformerBlock( + num_attention_heads=config.num_attention_heads, + attention_head_dim=config.attention_head_dim, + cross_attention_dim=config.text_embed_dim, + mlp_ratio=config.mlp_ratio, + adaln_lora_dim=config.adaln_lora_dim, + qk_norm=config.qk_norm, + out_bias=False, + prefix=f"{config.prefix}.transformer_blocks.{i}", + ) for i in range(config.num_layers) + ]) + + # 5. Output norm & projection + self.norm_out = CosmosAdaLayerNorm(inner_dim, config.adaln_lora_dim) + self.proj_out = nn.Linear(inner_dim, + config.out_channels * + math.prod(config.patch_size), + bias=False) + + self.gradient_checkpointing = False + + # For TeaCache + self.previous_e0_even = None + self.previous_e0_odd = None + self.previous_residual_even = None + self.previous_residual_odd = None + self.is_even = True + self.should_calc_even = True + self.should_calc_odd = True + self.accumulated_rel_l1_distance_even = 0 + self.accumulated_rel_l1_distance_odd = 0 + self.cnt = 0 + self.__post_init__() + + def forward(self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor | list[torch.Tensor], + attention_mask: torch.Tensor | None = None, + fps: int | None = None, + condition_mask: torch.Tensor | None = None, + padding_mask: torch.Tensor | None = None, + **kwargs) -> torch.Tensor: + forward_batch = get_forward_context().forward_batch + enable_teacache = forward_batch is not None and forward_batch.enable_teacache + + orig_dtype = hidden_states.dtype + if not isinstance(encoder_hidden_states, torch.Tensor): + encoder_hidden_states = encoder_hidden_states[0] + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + # 1. Concatenate padding mask if needed & prepare attention mask + if condition_mask is not None: + hidden_states = torch.cat([hidden_states, condition_mask], dim=1) + + if self.concat_padding_mask: + from torchvision import transforms + padding_mask = transforms.functional.resize( + padding_mask, list(hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST + ) + hidden_states = torch.cat( + [hidden_states, padding_mask.unsqueeze(2).repeat(batch_size, 1, num_frames, 1, 1)], dim=1 + ) + + if attention_mask is not None: + attention_mask = attention_mask.unsqueeze(1).unsqueeze( + 1) # [B, 1, 1, S] + + # 2. Generate positional embeddings + image_rotary_emb = self.rope(hidden_states, fps=fps) + extra_pos_emb = self.learnable_pos_embed( + hidden_states) if self.extra_pos_embed_type == "learnable" else None + + # 3. Patchify input + p_t, p_h, p_w = self.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + hidden_states = self.patch_embed(hidden_states) + hidden_states = hidden_states.flatten( + 1, 3) # [B, T, H, W, C] -> [B, THW, C] codespell:ignore + + # 4. Timestep embeddings + if timestep.ndim == 1: + temb, embedded_timestep = self.time_embed(hidden_states, timestep) + elif timestep.ndim == 5: + assert timestep.shape == (batch_size, 1, num_frames, 1, 1), ( + f"Expected timestep to have shape [B, 1, T, 1, 1], but got {timestep.shape}" + ) + timestep = timestep.flatten() + temb, embedded_timestep = self.time_embed(hidden_states, timestep) + # We can do this because num_frames == post_patch_num_frames, as p_t is 1 + temb, embedded_timestep = ( + x.view(batch_size, post_patch_num_frames, 1, 1, + -1).expand(-1, -1, post_patch_height, post_patch_width, + -1).flatten(1, 3) + for x in (temb, embedded_timestep) + ) # [BT, C] -> [B, T, 1, 1, C] -> [B, T, H, W, C] -> [B, THW, C] codespell:ignore + else: + raise ValueError(f"Unsupported timestep shape: {timestep.shape}") + + # 6. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + for i, block in enumerate(self.transformer_blocks): + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + embedded_timestep, + temb, + image_rotary_emb, + extra_pos_emb, + attention_mask, + ) + else: + for i, block in enumerate(self.transformer_blocks): + hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + embedded_timestep=embedded_timestep, + temb=temb, + image_rotary_emb=image_rotary_emb, + extra_pos_emb=extra_pos_emb, + attention_mask=attention_mask, + ) + + # 7. Output norm & projection & unpatchify + hidden_states = self.norm_out(hidden_states, embedded_timestep, temb) + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.unflatten(2, (p_h, p_w, p_t, -1)) + hidden_states = hidden_states.unflatten( + 1, (post_patch_num_frames, post_patch_height, post_patch_width)) + # NOTE: The permutation order here is not the inverse operation of what happens when patching as usually expected. + # It might be a source of confusion to the reader, but this is correct + hidden_states = hidden_states.permute(0, 7, 1, 6, 2, 4, 3, 5) + hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + return hidden_states \ No newline at end of file diff --git a/fastvideo/models/encoders/t5.py b/fastvideo/models/encoders/t5.py index 5671b96ce..f4bcdbbbb 100644 --- a/fastvideo/models/encoders/t5.py +++ b/fastvideo/models/encoders/t5.py @@ -181,7 +181,8 @@ def __init__(self, self.qkv_proj = QKVParallelLinear( self.d_model, - self.d_model // self.total_num_heads, + #self.d_model // self.total_num_heads, + self.key_value_proj_dim, self.total_num_heads, self.total_num_kv_heads, bias=False, @@ -199,7 +200,8 @@ def __init__(self, padding_size=self.relative_attention_num_buckets, quant_config=quant_config) self.o = RowParallelLinear( - self.d_model, + #self.d_model, + self.total_num_heads * self.key_value_proj_dim, self.d_model, bias=False, quant_config=quant_config, @@ -298,10 +300,12 @@ def forward( ) -> torch.Tensor: bs, seq_len, _ = hidden_states.shape num_seqs = bs - n, c = self.n_heads, self.d_model // self.total_num_heads + #n, c = self.n_heads, self.d_model // self.total_num_heads + n, c = self.n_heads, self.key_value_proj_dim qkv, _ = self.qkv_proj(hidden_states) # Projection of 'own' hidden state (self-attention). No GQA here. - q, k, v = qkv.split(self.inner_dim, dim=-1) + #q, k, v = qkv.split(self.inner_dim, dim=-1) + q, k, v = qkv.split(self.qkv_proj.output_sizes, dim=-1) q = q.reshape(bs, seq_len, n, c) k = k.reshape(bs, seq_len, n, c) v = v.reshape(bs, seq_len, n, c) diff --git a/fastvideo/models/registry.py b/fastvideo/models/registry.py index 5c585fa16..fda28fe03 100644 --- a/fastvideo/models/registry.py +++ b/fastvideo/models/registry.py @@ -25,7 +25,8 @@ "HunyuanVideoTransformer3DModel": ("dits", "hunyuanvideo", "HunyuanVideoTransformer3DModel"), "WanTransformer3DModel": ("dits", "wanvideo", "WanTransformer3DModel"), - "StepVideoModel": ("dits", "stepvideo", "StepVideoModel") + "StepVideoModel": ("dits", "stepvideo", "StepVideoModel"), + "CosmosTransformer3DModel": ("dits", "cosmos", "CosmosTransformer3DModel") } _IMAGE_TO_VIDEO_DIT_MODELS = { @@ -37,6 +38,7 @@ "CLIPTextModel": ("encoders", "clip", "CLIPTextModel"), "LlamaModel": ("encoders", "llama", "LlamaModel"), "UMT5EncoderModel": ("encoders", "t5", "UMT5EncoderModel"), + "T5EncoderModel": ("encoders", "t5", "T5EncoderModel"), "STEP1TextEncoder": ("encoders", "stepllm", "STEP1TextEncoder"), "BertModel": ("encoders", "clip", "CLIPTextModel"), } @@ -234,7 +236,6 @@ def register_model( def _raise_for_unsupported(self, architectures: list[str]) -> NoReturn: all_supported_archs = self.get_supported_archs() - if any(arch in all_supported_archs for arch in architectures): raise ValueError( f"Model architectures {architectures} failed " diff --git a/fastvideo/models/schedulers/scheduling_flow_match_euler_discrete.py b/fastvideo/models/schedulers/scheduling_flow_match_euler_discrete.py index 73ce8c954..45c69d004 100644 --- a/fastvideo/models/schedulers/scheduling_flow_match_euler_discrete.py +++ b/fastvideo/models/schedulers/scheduling_flow_match_euler_discrete.py @@ -88,6 +88,14 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin, The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear". stochastic_sampling (`bool`, defaults to False): Whether to use stochastic sampling. + final_sigmas_type (`str`, defaults to "sigma_min"): + The type of final sigmas to use. Either "sigma_min" or "zero". + sigma_max (`float`, *optional*): + The maximum sigma value for the noise schedule. + sigma_min (`float`, *optional*): + The minimum sigma value for the noise schedule. + sigma_data (`float`, *optional*): + The sigma data value for scaling. """ _compatibles: list[Any] = [] @@ -110,6 +118,10 @@ def __init__( use_beta_sigmas: bool | None = False, time_shift_type: str = "exponential", stochastic_sampling: bool = False, + final_sigmas_type: str = "sigma_min", + sigma_max: float | None = None, + sigma_min: float | None = None, + sigma_data: float | None = None, ): if sum([ self.config.use_beta_sigmas, self.config.use_exponential_sigmas, @@ -336,9 +348,9 @@ def set_timesteps( sigmas_array: np.ndarray if sigmas is None: if timesteps_array is None: - timesteps_array = np.linspace(self._sigma_to_t(self.sigma_max), - self._sigma_to_t(self.sigma_min), - num_inference_steps) + t_max = self._sigma_to_t(self.sigma_max) + t_min = self._sigma_to_t(self.sigma_min) + timesteps_array = np.linspace(t_max, t_min, num_inference_steps) sigmas_array = timesteps_array / self.config.num_train_timesteps else: sigmas_array = np.array(sigmas).astype(np.float32) @@ -403,9 +415,7 @@ def set_timesteps( [sigmas_tensor, torch.ones(1, device=sigmas_tensor.device)]) else: - sigmas_tensor = torch.cat( - [sigmas_tensor, - torch.zeros(1, device=sigmas_tensor.device)]) + sigmas_tensor = torch.cat([sigmas_tensor, torch.zeros(1, device=sigmas_tensor.device)]) self.timesteps = timesteps_tensor self.sigmas = sigmas_tensor @@ -505,7 +515,9 @@ def step( next_sigma = lower_sigmas[..., None] dt = current_sigma - next_sigma else: - assert self.step_index is not None, "step_index should not be None" + if self.step_index is None: + self._init_step_index(timestep) + sigma_idx = self.step_index sigma = self.sigmas[sigma_idx] sigma_next = self.sigmas[sigma_idx + 1] @@ -522,7 +534,6 @@ def step( prev_sample = sample + dt * model_output # upon completion increase step index by one - assert self._step_index is not None, "_step_index should not be None" self._step_index += 1 if per_token_timesteps is None: # Cast sample back to model compatible dtype @@ -558,7 +569,7 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, min_inv_rho = sigma_min**(1 / rho) max_inv_rho = sigma_max**(1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho - return sigmas + return torch.from_numpy(sigmas).to(dtype=in_sigmas.dtype, device=in_sigmas.device) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential def _convert_to_exponential(self, in_sigmas: torch.Tensor, @@ -583,7 +594,7 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, sigmas = np.exp( np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)) - return sigmas + return torch.from_numpy(sigmas).to(dtype=in_sigmas.dtype, device=in_sigmas.device) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta def _convert_to_beta(self, @@ -614,7 +625,7 @@ def _convert_to_beta(self, for timestep in 1 - np.linspace(0, 1, num_inference_steps) ] ]) - return sigmas + return torch.from_numpy(sigmas).to(dtype=in_sigmas.dtype, device=in_sigmas.device) def _time_shift_exponential( self, mu: float, sigma: float, diff --git a/fastvideo/pipelines/basic/cosmos/__init__.py b/fastvideo/pipelines/basic/cosmos/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fastvideo/pipelines/basic/cosmos/cosmos_pipeline.py b/fastvideo/pipelines/basic/cosmos/cosmos_pipeline.py new file mode 100644 index 000000000..f3b7c9cd7 --- /dev/null +++ b/fastvideo/pipelines/basic/cosmos/cosmos_pipeline.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Cosmos video diffusion pipeline implementation. + +This module contains an implementation of the Cosmos video diffusion pipeline +using the modular pipeline architecture. +""" + +from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.logger import init_logger +from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler) +from fastvideo.pipelines.composed_pipeline_base import ComposedPipelineBase +from fastvideo.pipelines.stages import (ConditioningStage, CosmosDenoisingStage, + CosmosLatentPreparationStage, + DecodingStage, InputValidationStage, + TextEncodingStage, + TimestepPreparationStage) + +logger = init_logger(__name__) + + +class Cosmos2VideoToWorldPipeline(ComposedPipelineBase): + + _required_config_modules = [ + "text_encoder", "tokenizer", "vae", "transformer", "scheduler", + "safety_checker" + ] + + def initialize_pipeline(self, fastvideo_args: FastVideoArgs): + self.modules["scheduler"] = FlowMatchEulerDiscreteScheduler( + shift=fastvideo_args.pipeline_config.flow_shift, + use_karras_sigmas=True) + + sigma_max = 80.0 + sigma_min = 0.002 + sigma_data = 1.0 + final_sigmas_type = "sigma_min" + + if self.modules["scheduler"] is not None: + scheduler = self.modules["scheduler"] + scheduler.config.sigma_max = sigma_max + scheduler.config.sigma_min = sigma_min + scheduler.config.sigma_data = sigma_data + scheduler.config.final_sigmas_type = final_sigmas_type + scheduler.sigma_max = sigma_max + scheduler.sigma_min = sigma_min + scheduler.sigma_data = sigma_data + + def create_pipeline_stages(self, fastvideo_args: FastVideoArgs): + """Set up pipeline stages with proper dependency injection.""" + + self.add_stage(stage_name="input_validation_stage", + stage=InputValidationStage()) + + self.add_stage(stage_name="prompt_encoding_stage", + stage=TextEncodingStage( + text_encoders=[self.get_module("text_encoder")], + tokenizers=[self.get_module("tokenizer")], + )) + + self.add_stage(stage_name="conditioning_stage", + stage=ConditioningStage()) + + self.add_stage(stage_name="timestep_preparation_stage", + stage=TimestepPreparationStage( + scheduler=self.get_module("scheduler"))) + + self.add_stage(stage_name="latent_preparation_stage", + stage=CosmosLatentPreparationStage( + scheduler=self.get_module("scheduler"), + transformer=self.get_module("transformer"), + vae=self.get_module("vae"))) + + self.add_stage(stage_name="denoising_stage", + stage=CosmosDenoisingStage( + transformer=self.get_module("transformer"), + scheduler=self.get_module("scheduler"))) + + self.add_stage(stage_name="decoding_stage", + stage=DecodingStage(vae=self.get_module("vae"))) + + +EntryClass = Cosmos2VideoToWorldPipeline diff --git a/fastvideo/pipelines/pipeline_registry.py b/fastvideo/pipelines/pipeline_registry.py index d80fd40b5..5ce030b58 100644 --- a/fastvideo/pipelines/pipeline_registry.py +++ b/fastvideo/pipelines/pipeline_registry.py @@ -23,6 +23,7 @@ "WanImageToVideoPipeline": "wan", "StepVideoPipeline": "stepvideo", "HunyuanVideoPipeline": "hunyuan", + "Cosmos2VideoToWorldPipeline": "cosmos" } _PREPROCESS_WORKLOAD_TYPE_TO_PIPELINE_NAME: dict[WorkloadType, str] = { diff --git a/fastvideo/pipelines/stages/__init__.py b/fastvideo/pipelines/stages/__init__.py index 1852edc49..efd45e838 100644 --- a/fastvideo/pipelines/stages/__init__.py +++ b/fastvideo/pipelines/stages/__init__.py @@ -9,13 +9,15 @@ from fastvideo.pipelines.stages.base import PipelineStage from fastvideo.pipelines.stages.conditioning import ConditioningStage from fastvideo.pipelines.stages.decoding import DecodingStage -from fastvideo.pipelines.stages.denoising import (DenoisingStage, +from fastvideo.pipelines.stages.denoising import (CosmosDenoisingStage, + DenoisingStage, DmdDenoisingStage) from fastvideo.pipelines.stages.encoding import EncodingStage from fastvideo.pipelines.stages.image_encoding import (ImageEncodingStage, ImageVAEEncodingStage) from fastvideo.pipelines.stages.input_validation import InputValidationStage -from fastvideo.pipelines.stages.latent_preparation import LatentPreparationStage +from fastvideo.pipelines.stages.latent_preparation import ( + CosmosLatentPreparationStage, LatentPreparationStage) from fastvideo.pipelines.stages.stepvideo_encoding import ( StepvideoPromptEncodingStage) from fastvideo.pipelines.stages.text_encoding import TextEncodingStage @@ -27,9 +29,11 @@ "InputValidationStage", "TimestepPreparationStage", "LatentPreparationStage", + "CosmosLatentPreparationStage", "ConditioningStage", "DenoisingStage", "DmdDenoisingStage", + "CosmosDenoisingStage", "EncodingStage", "DecodingStage", "ImageEncodingStage", diff --git a/fastvideo/pipelines/stages/decoding.py b/fastvideo/pipelines/stages/decoding.py index 8f226b5c4..3abc319d5 100644 --- a/fastvideo/pipelines/stages/decoding.py +++ b/fastvideo/pipelines/stages/decoding.py @@ -92,20 +92,21 @@ def forward( vae_autocast_enabled = (vae_dtype != torch.float32 ) and not fastvideo_args.disable_autocast - if isinstance(self.vae.scaling_factor, torch.Tensor): - latents = latents / self.vae.scaling_factor.to( - latents.device, latents.dtype) - else: - latents = latents / self.vae.scaling_factor - - # Apply shifting if needed - if (hasattr(self.vae, "shift_factor") - and self.vae.shift_factor is not None): - if isinstance(self.vae.shift_factor, torch.Tensor): - latents += self.vae.shift_factor.to(latents.device, - latents.dtype) + if hasattr(self.vae, 'scaling_factor'): + if isinstance(self.vae.scaling_factor, torch.Tensor): + latents = latents / self.vae.scaling_factor.to( + latents.device, latents.dtype) else: - latents += self.vae.shift_factor + latents = latents / self.vae.scaling_factor + + # Apply shifting if needed + if (hasattr(self.vae, "shift_factor") + and self.vae.shift_factor is not None): + if isinstance(self.vae.shift_factor, torch.Tensor): + latents += self.vae.shift_factor.to( + latents.device, latents.dtype) + else: + latents += self.vae.shift_factor # Decode latents with torch.autocast(device_type="cuda", @@ -117,6 +118,7 @@ def forward( # self.vae.enable_parallel() if not vae_autocast_enabled: latents = latents.to(vae_dtype) + image = self.vae.decode(latents) # Normalize image to [0, 1] range diff --git a/fastvideo/pipelines/stages/denoising.py b/fastvideo/pipelines/stages/denoising.py index 8a2d9f0b6..71c77376a 100644 --- a/fastvideo/pipelines/stages/denoising.py +++ b/fastvideo/pipelines/stages/denoising.py @@ -293,16 +293,13 @@ def forward( **pos_cond_kwargs, ) - # Apply guidance if batch.do_classifier_free_guidance: batch.is_cfg_negative = True with set_forward_context( current_timestep=i, attn_metadata=attn_metadata, forward_batch=batch, - # fastvideo_args=fastvideo_args ): - # Run transformer noise_pred_uncond = current_model( latent_model_input, neg_prompt_embeds, @@ -311,6 +308,7 @@ def forward( **image_kwargs, **neg_cond_kwargs, ) + noise_pred_text = noise_pred noise_pred = noise_pred_uncond + current_guidance_scale * ( noise_pred_text - noise_pred_uncond) @@ -329,6 +327,7 @@ def forward( latents, **extra_step_kwargs, return_dict=False)[0] + # Update progress bar if i == len(timesteps) - 1 or ( (i + 1) > num_warmup_steps and @@ -600,6 +599,289 @@ def verify_output(self, batch: ForwardBatch, return result +class CosmosDenoisingStage(DenoisingStage): + """ + Denoising stage for Cosmos models using FlowMatchEulerDiscreteScheduler. + """ + + def __init__(self, transformer, scheduler, pipeline=None) -> None: + super().__init__(transformer, scheduler, pipeline) + + def forward( + self, + batch: ForwardBatch, + fastvideo_args: FastVideoArgs, + ) -> ForwardBatch: + pipeline = self.pipeline() if self.pipeline else None + if not fastvideo_args.model_loaded["transformer"]: + loader = TransformerLoader() + self.transformer = loader.load( + fastvideo_args.model_paths["transformer"], fastvideo_args) + if pipeline: + pipeline.add_module("transformer", self.transformer) + fastvideo_args.model_loaded["transformer"] = True + + extra_step_kwargs = self.prepare_extra_func_kwargs( + self.scheduler.step, + { + "generator": batch.generator, + "eta": batch.eta + }, + ) + + if hasattr(self.transformer, 'module'): + transformer_dtype = next(self.transformer.module.parameters()).dtype + else: + transformer_dtype = next(self.transformer.parameters()).dtype + target_dtype = transformer_dtype + autocast_enabled = (target_dtype != torch.float32 + ) and not fastvideo_args.disable_autocast + + latents = batch.latents + num_inference_steps = batch.num_inference_steps + guidance_scale = batch.guidance_scale + + sigma_max = 80.0 + sigma_min = 0.002 + sigma_data = 1.0 + final_sigmas_type = "sigma_min" + + if self.scheduler is not None: + self.scheduler.register_to_config( + sigma_max=sigma_max, + sigma_min=sigma_min, + sigma_data=sigma_data, + final_sigmas_type=final_sigmas_type, + ) + + self.scheduler.set_timesteps(num_inference_steps, device=latents.device) + timesteps = self.scheduler.timesteps + + if (hasattr(self.scheduler.config, 'final_sigmas_type') + and self.scheduler.config.final_sigmas_type == "sigma_min" + and len(self.scheduler.sigmas) > 1): + self.scheduler.sigmas[-1] = self.scheduler.sigmas[-2] + + conditioning_latents = getattr(batch, 'conditioning_latents', None) + unconditioning_latents = conditioning_latents + + # Sampling loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if hasattr(self, 'interrupt') and self.interrupt: + continue + + current_sigma = self.scheduler.sigmas[i] + current_t = current_sigma / (current_sigma + 1) + c_in = 1 - current_t + c_skip = 1 - current_t + c_out = -current_t + + timestep = current_t.view(1, 1, 1, 1, + 1).expand(latents.size(0), -1, + latents.size(2), -1, + -1) # [B, 1, T, 1, 1] + + with torch.autocast(device_type="cuda", + dtype=target_dtype, + enabled=autocast_enabled): + + # Conditional forward pass + cond_latent = latents * c_in + + if hasattr( + batch, 'cond_indicator' + ) and batch.cond_indicator is not None and conditioning_latents is not None: + cond_latent = batch.cond_indicator * conditioning_latents + ( + 1 - batch.cond_indicator) * cond_latent + else: + logger.warning( + "Step %s: Missing conditioning data - cond_indicator: %s, conditioning_latents: %s", + i, hasattr(batch, 'cond_indicator'), + conditioning_latents is not None) + + cond_latent = cond_latent.to(target_dtype) + + # Apply conditional timestep processing + cond_timestep = timestep + if hasattr(batch, 'cond_indicator' + ) and batch.cond_indicator is not None: + sigma_conditioning = 0.0001 + t_conditioning = sigma_conditioning / ( + sigma_conditioning + 1) + cond_timestep = batch.cond_indicator * t_conditioning + ( + 1 - batch.cond_indicator) * timestep + cond_timestep = cond_timestep.to(target_dtype) + + with set_forward_context( + current_timestep=i, + attn_metadata=None, + forward_batch=batch, + ): + # Use conditioning masks from CosmosLatentPreparationStage + condition_mask = batch.cond_mask.to( + target_dtype) if hasattr(batch, + 'cond_mask') else None + padding_mask = torch.zeros(1, + 1, + batch.height, + batch.width, + device=cond_latent.device, + dtype=target_dtype) + + # Fallback if masks not available + if condition_mask is None: + batch_size, num_channels, num_frames, height, width = cond_latent.shape + condition_mask = torch.zeros( + batch_size, + 1, + num_frames, + height, + width, + device=cond_latent.device, + dtype=target_dtype) + + noise_pred = self.transformer( + hidden_states=cond_latent, + timestep=cond_timestep.to(target_dtype), + encoder_hidden_states=batch.prompt_embeds[0].to( + target_dtype), + fps=24, # TODO: get fps from batch or config + condition_mask=condition_mask, + padding_mask=padding_mask, + return_dict=False, + )[0] + + cond_pred = (c_skip * latents + + c_out * noise_pred.float()).to(target_dtype) + + if hasattr( + batch, 'cond_indicator' + ) and batch.cond_indicator is not None and conditioning_latents is not None: + cond_pred = batch.cond_indicator * conditioning_latents + ( + 1 - batch.cond_indicator) * cond_pred + + if batch.do_classifier_free_guidance and batch.negative_prompt_embeds is not None: + uncond_latent = latents * c_in + + if hasattr( + batch, 'uncond_indicator' + ) and batch.uncond_indicator is not None and unconditioning_latents is not None: + uncond_latent = batch.uncond_indicator * unconditioning_latents + ( + 1 - batch.uncond_indicator) * uncond_latent + + with set_forward_context( + current_timestep=i, + attn_metadata=None, + forward_batch=batch, + ): + # Use uncond_mask for unconditional pass if available + uncond_condition_mask = batch.uncond_mask.to( + target_dtype + ) if hasattr( + batch, 'uncond_mask' + ) and batch.uncond_mask is not None else condition_mask + + # Apply same conditional timestep processing for unconditional pass + uncond_timestep = timestep + if hasattr(batch, 'uncond_indicator' + ) and batch.uncond_indicator is not None: + sigma_conditioning = 0.0001 # Same as Diffusers default + t_conditioning = sigma_conditioning / ( + sigma_conditioning + 1) + uncond_timestep = batch.uncond_indicator * t_conditioning + ( + 1 - batch.uncond_indicator) * timestep + uncond_timestep = uncond_timestep.to( + target_dtype) + + noise_pred_uncond = self.transformer( + hidden_states=uncond_latent.to(target_dtype), + timestep=uncond_timestep.to(target_dtype), + encoder_hidden_states=batch. + negative_prompt_embeds[0].to(target_dtype), + fps=24, # TODO: get fps from batch or config + condition_mask=uncond_condition_mask, + padding_mask=padding_mask, + return_dict=False, + )[0] + + uncond_pred = ( + c_skip * latents + + c_out * noise_pred_uncond.float()).to(target_dtype) + + # Apply conditional indicator masking for unconditional prediction like diffusers + if hasattr( + batch, 'uncond_indicator' + ) and batch.uncond_indicator is not None and unconditioning_latents is not None: + uncond_pred = batch.uncond_indicator * unconditioning_latents + ( + 1 - batch.uncond_indicator) * uncond_pred + + guidance_diff = cond_pred - uncond_pred + final_pred = cond_pred + guidance_scale * guidance_diff + else: + final_pred = cond_pred + + # Convert to noise for scheduler step + if current_sigma > 1e-8: + noise_for_scheduler = (latents - final_pred) / current_sigma + else: + logger.warning( + "Step %s: current_sigma too small (%s), using final_pred directly", + i, current_sigma) + noise_for_scheduler = final_pred + + # Debug: Check for NaN values before scheduler step + if torch.isnan(noise_for_scheduler).sum() > 0: + logger.error( + "Step %s: NaN detected in noise_for_scheduler, sum: %s", + i, + noise_for_scheduler.float().sum().item()) + logger.error( + "Step %s: latents sum: %s, final_pred sum: %s, current_sigma: %s", + i, + latents.float().sum().item(), + final_pred.float().sum().item(), current_sigma) + + latents = self.scheduler.step(noise_for_scheduler, + t, + latents, + **extra_step_kwargs, + return_dict=False)[0] + + progress_bar.update() + + # Update batch with final latents + batch.latents = latents + + return batch + + def verify_input(self, batch: ForwardBatch, + fastvideo_args: FastVideoArgs) -> VerificationResult: + """Verify Cosmos denoising stage inputs.""" + result = VerificationResult() + result.add_check("latents", batch.latents, + [V.is_tensor, V.with_dims(5)]) + result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty) + result.add_check("num_inference_steps", batch.num_inference_steps, + V.positive_int) + result.add_check("guidance_scale", batch.guidance_scale, + V.positive_float) + result.add_check("do_classifier_free_guidance", + batch.do_classifier_free_guidance, V.bool_value) + result.add_check( + "negative_prompt_embeds", batch.negative_prompt_embeds, lambda x: + not batch.do_classifier_free_guidance or V.list_not_empty(x)) + return result + + def verify_output(self, batch: ForwardBatch, + fastvideo_args: FastVideoArgs) -> VerificationResult: + """Verify Cosmos denoising stage outputs.""" + result = VerificationResult() + result.add_check("latents", batch.latents, + [V.is_tensor, V.with_dims(5)]) + return result + + class DmdDenoisingStage(DenoisingStage): """ Denoising stage for DMD. diff --git a/fastvideo/pipelines/stages/input_validation.py b/fastvideo/pipelines/stages/input_validation.py index 2f8cb4811..b2f677689 100644 --- a/fastvideo/pipelines/stages/input_validation.py +++ b/fastvideo/pipelines/stages/input_validation.py @@ -38,7 +38,7 @@ def _generate_seeds(self, batch: ForwardBatch, batch.seeds = seeds # Peiyuan: using GPU seed will cause A100 and H100 to generate different results... batch.generator = [ - torch.Generator("cpu").manual_seed(seed) for seed in seeds + torch.Generator(device="cpu").manual_seed(seed) for seed in seeds ] def forward( diff --git a/fastvideo/pipelines/stages/latent_preparation.py b/fastvideo/pipelines/stages/latent_preparation.py index ea23a7daa..0cf15d7d8 100644 --- a/fastvideo/pipelines/stages/latent_preparation.py +++ b/fastvideo/pipelines/stages/latent_preparation.py @@ -3,10 +3,14 @@ Latent preparation stage for diffusion pipelines. """ +from typing import Any + +import torch from diffusers.utils.torch_utils import randn_tensor from fastvideo.distributed import get_local_torch_device from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.image_processor import ImageProcessor from fastvideo.logger import init_logger from fastvideo.pipelines.pipeline_batch_info import ForwardBatch from fastvideo.pipelines.stages.base import PipelineStage @@ -108,6 +112,238 @@ def forward( return batch + +class CosmosLatentPreparationStage(PipelineStage): + """ + Cosmos-specific latent preparation stage that properly handles the tensor shapes + and conditioning masks required by the Cosmos transformer. + + This stage replicates the logic from diffusers' Cosmos2VideoToWorldPipeline.prepare_latents() + """ + + def __init__(self, scheduler, transformer, vae=None) -> None: + super().__init__() + self.scheduler = scheduler + self.transformer = transformer + self.vae = vae + + def forward( + self, + batch: ForwardBatch, + fastvideo_args: FastVideoArgs, + ) -> ForwardBatch: + # Determine batch size + if isinstance(batch.prompt, list): + batch_size = len(batch.prompt) + elif batch.prompt is not None: + batch_size = 1 + else: + batch_size = batch.prompt_embeds[0].shape[0] + + # Adjust batch size for number of videos per prompt + batch_size *= batch.num_videos_per_prompt + + # Get required parameters + # Force float32 for latent preparation + dtype = torch.float32 + device = get_local_torch_device() + generator = batch.generator + latents = batch.latents + num_frames = batch.num_frames + height = batch.height + width = batch.width + + if height is None or width is None: + raise ValueError("Height and width must be provided") + + vae_scale_factor_spatial = 8 + vae_scale_factor_temporal = 4 + + latent_height = height // 8 + latent_width = width // vae_scale_factor_spatial + + num_latent_frames = (num_frames - 1) // vae_scale_factor_temporal + 1 + + # Cosmos transformer expects in_channels - 1 for the latent channels + num_channels_latents = self.transformer.config.in_channels - 1 + + shape = (batch_size, num_channels_latents, num_latent_frames, + latent_height, latent_width) + + init_latents = None + conditioning_latents = None + + video = None + + if hasattr(batch, 'video') and batch.video is not None: + video = batch.video + elif hasattr(batch, 'pil_image') and batch.pil_image is not None: + vae_scale_factor_spatial = 8 + image_processor = ImageProcessor( + vae_scale_factor=vae_scale_factor_spatial) + + processed_image = image_processor.preprocess( + batch.pil_image, height, width) + + # Add time dimension + video = processed_image.unsqueeze(2) + + video = video.to(device=device, dtype=torch.bfloat16) + elif hasattr( + batch, + 'preprocessed_image') and batch.preprocessed_image is not None: + # Convert preprocessed image to video format + if isinstance(batch.preprocessed_image, torch.Tensor): + if batch.preprocessed_image.dim( + ) == 4: # [B, C, H, W] -> [B, C, T, H, W] + video = batch.preprocessed_image.unsqueeze(2) + elif batch.preprocessed_image.dim( + ) == 5: # Already [B, C, T, H, W] + video = batch.preprocessed_image + else: + logger.info( + "CosmosLatentPreparationStage - No video input sources found") + + if video is not None: + num_cond_frames = video.size(2) + + if num_cond_frames >= num_frames: + # Take the last `num_frames` frames for conditioning + num_cond_latent_frames = (num_frames - + 1) // vae_scale_factor_temporal + 1 + video = video[:, :, -num_frames:] + else: + num_cond_latent_frames = (num_cond_frames - + 1) // vae_scale_factor_temporal + 1 + num_padding_frames = num_frames - num_cond_frames + last_frame = video[:, :, -1:] + padding = last_frame.repeat(1, 1, num_padding_frames, 1, 1) + video = torch.cat([video, padding], dim=2) + + if self.vae is not None: + # Move VAE to correct device before encoding + self.vae = self.vae.to(device) + self.vae = self.vae.to(dtype=video.dtype) + + def retrieve_latents( + encoder_output: Any, + generator: Any | None = None) -> torch.Tensor: + if hasattr(encoder_output, "latent_dist"): + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + elif hasattr(encoder_output, "sample"): + return encoder_output.sample(generator) + elif isinstance(encoder_output, torch.Tensor): + return encoder_output + else: + attrs = [ + attr for attr in dir(encoder_output) + if not attr.startswith('_') + ] + raise AttributeError( + f"Could not access latents of provided encoder_output. Available attributes: {attrs}" + ) + + if isinstance(generator, list): + init_latents = [ + retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), + generator=torch.Generator( + device="cpu").manual_seed(100)) + for i in range(batch_size) + ] + else: + init_latents = [ + retrieve_latents( + self.vae.encode(vid.unsqueeze(0)), + torch.Generator(device="cpu").manual_seed(100)) + for vid in video + ] + + init_latents = torch.cat(init_latents, dim=0).to(dtype) + + # Apply latent normalization + if hasattr(self.vae.config, 'latents_mean') and hasattr( + self.vae.config, 'latents_std'): + latents_mean = torch.tensor( + self.vae.config.latents_mean).view( + 1, self.vae.config.z_dim, 1, 1, + 1).to(device, dtype) + latents_std = torch.tensor( + self.vae.config.latents_std).view( + 1, self.vae.config.z_dim, 1, 1, + 1).to(device, dtype) + init_latents = (init_latents - latents_mean + ) / latents_std * self.scheduler.sigma_data + + conditioning_latents = init_latents + + # Offload VAE to CPU after encoding to save memory + self.vae.to("cpu") + else: + num_cond_latent_frames = 0 + + # Generate or use provided latents + if latents is None: + # Use float32 for randn_tensor + latents = randn_tensor( + shape, + generator=torch.Generator(device="cpu").manual_seed(200), + device=device, + dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + latents = latents * self.scheduler.sigma_max + + padding_shape = (batch_size, 1, num_latent_frames, latent_height, + latent_width) + ones_padding = latents.new_ones(padding_shape) + zeros_padding = latents.new_zeros(padding_shape) + + cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1) + cond_indicator[:, :, :num_cond_latent_frames] = 1.0 + cond_mask = cond_indicator * ones_padding + ( + 1 - cond_indicator) * zeros_padding + + uncond_indicator = None + uncond_mask = None + if batch.do_classifier_free_guidance: + uncond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1) + uncond_indicator[:, :, :num_cond_latent_frames] = 1.0 + uncond_mask = uncond_indicator * ones_padding + ( + 1 - uncond_indicator) * zeros_padding + + batch.latents = latents + batch.raw_latent_shape = latents.shape + + batch.conditioning_latents = conditioning_latents + batch.cond_indicator = cond_indicator + batch.uncond_indicator = uncond_indicator + batch.cond_mask = cond_mask + batch.uncond_mask = uncond_mask + + return batch + + def verify_input(self, batch: ForwardBatch, + fastvideo_args: FastVideoArgs) -> VerificationResult: + """Verify Cosmos latent preparation stage inputs.""" + result = VerificationResult() + result.add_check( + "prompt_or_embeds", None, lambda _: V.string_or_list_strings( + batch.prompt) or V.list_not_empty(batch.prompt_embeds)) + result.add_check("prompt_embeds", batch.prompt_embeds, + V.list_of_tensors) + result.add_check("num_videos_per_prompt", batch.num_videos_per_prompt, + V.positive_int) + result.add_check("generator", batch.generator, + V.generator_or_list_generators) + result.add_check("num_frames", batch.num_frames, V.positive_int) + result.add_check("height", batch.height, V.positive_int) + result.add_check("width", batch.width, V.positive_int) + result.add_check("latents", batch.latents, V.none_or_tensor) + return result + def adjust_video_length(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> int: """ @@ -130,25 +366,6 @@ def adjust_video_length(self, batch: ForwardBatch, latent_num_frames = video_length // 17 * 3 return int(latent_num_frames) - def verify_input(self, batch: ForwardBatch, - fastvideo_args: FastVideoArgs) -> VerificationResult: - """Verify latent preparation stage inputs.""" - result = VerificationResult() - result.add_check( - "prompt_or_embeds", None, lambda _: V.string_or_list_strings( - batch.prompt) or V.list_not_empty(batch.prompt_embeds)) - result.add_check("prompt_embeds", batch.prompt_embeds, - V.list_of_tensors) - result.add_check("num_videos_per_prompt", batch.num_videos_per_prompt, - V.positive_int) - result.add_check("generator", batch.generator, - V.generator_or_list_generators) - result.add_check("num_frames", batch.num_frames, V.positive_int) - result.add_check("height", batch.height, V.positive_int) - result.add_check("width", batch.width, V.positive_int) - result.add_check("latents", batch.latents, V.none_or_tensor) - return result - def verify_output(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult: """Verify latent preparation stage outputs.""" diff --git a/fastvideo/pipelines/stages/text_encoding.py b/fastvideo/pipelines/stages/text_encoding.py index 5a25236da..67dcd4652 100644 --- a/fastvideo/pipelines/stages/text_encoding.py +++ b/fastvideo/pipelines/stages/text_encoding.py @@ -85,6 +85,11 @@ def forward( output_hidden_states=True, ) prompt_embeds = postprocess_func(outputs) + + lengths = attention_mask.sum(dim=1).cpu() + for i, length in enumerate(lengths): + prompt_embeds[i, length:] = 0 + batch.prompt_embeds.append(prompt_embeds) if batch.prompt_attention_mask is not None: batch.prompt_attention_mask.append(attention_mask) @@ -106,6 +111,10 @@ def forward( ) negative_prompt_embeds = postprocess_func(negative_outputs) + lengths = negative_attention_mask.sum(dim=1).cpu() + for i, length in enumerate(lengths): + negative_prompt_embeds[i, length:] = 0 + assert batch.negative_prompt_embeds is not None batch.negative_prompt_embeds.append(negative_prompt_embeds) if batch.negative_attention_mask is not None: diff --git a/fastvideo/pipelines/stages/utils.py b/fastvideo/pipelines/stages/utils.py new file mode 100644 index 000000000..c7c272ab8 --- /dev/null +++ b/fastvideo/pipelines/stages/utils.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Utility functions for pipeline stages. +""" + +import inspect +from typing import Any + +import torch + + +def retrieve_timesteps( + scheduler: Any, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs: Any, +) -> tuple[Any, int]: + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + if timesteps is None: + raise ValueError("scheduler.timesteps is None after set_timesteps") + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + if timesteps is None: + raise ValueError("scheduler.timesteps is None after set_timesteps") + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + if timesteps is None: + raise ValueError("scheduler.timesteps is None after set_timesteps") + num_inference_steps = len(timesteps) + return timesteps, num_inference_steps diff --git a/fastvideo/tests/encoders/test_t5_encoder.py b/fastvideo/tests/encoders/test_t5_encoder.py index fa75d3691..0e93b53bb 100644 --- a/fastvideo/tests/encoders/test_t5_encoder.py +++ b/fastvideo/tests/encoders/test_t5_encoder.py @@ -6,7 +6,7 @@ import torch from torch.distributed.tensor import DTensor from torch.testing import assert_close -from transformers import AutoConfig, AutoTokenizer, UMT5EncoderModel +from transformers import AutoConfig, AutoTokenizer, UMT5EncoderModel, T5EncoderModel from fastvideo.configs.pipelines import PipelineConfig from fastvideo.forward_context import set_forward_context @@ -14,14 +14,15 @@ from fastvideo.models.loader.component_loader import TextEncoderLoader from fastvideo.utils import maybe_download_model, PRECISION_TO_TYPE from fastvideo.fastvideo_args import FastVideoArgs -from fastvideo.configs.models.encoders import T5Config +from fastvideo.configs.models.encoders import T5Config, T5LargeConfig logger = init_logger(__name__) os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29503" -BASE_MODEL_PATH = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" +#BASE_MODEL_PATH = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" +BASE_MODEL_PATH = "nvidia/Cosmos-Predict2-2B-Video2World" MODEL_PATH = maybe_download_model(BASE_MODEL_PATH, local_dir=os.path.join( 'data', BASE_MODEL_PATH)) @@ -118,6 +119,123 @@ def test_t5_encoder(): last_hidden_state1 = outputs1[tokens.attention_mask == 1] last_hidden_state2 = outputs2[tokens.attention_mask == 1] + assert last_hidden_state1.shape == last_hidden_state2.shape, \ + f"Hidden state shapes don't match: {last_hidden_state1.shape} vs {last_hidden_state2.shape}" + + max_diff_hidden = torch.max( + torch.abs(last_hidden_state1 - last_hidden_state2)) + mean_diff_hidden = torch.mean( + torch.abs(last_hidden_state1 - last_hidden_state2)) + + logger.info("Maximum difference in last hidden states: %s", + max_diff_hidden.item()) + logger.info("Mean difference in last hidden states: %s", + mean_diff_hidden.item()) + logger.info("Max memory allocated: %s GB", torch.cuda.max_memory_allocated() / 1024**3) + # Check if outputs are similar (allowing for small numerical differences) + assert mean_diff_hidden < 1e-4, \ + f"Hidden states differ significantly: mean diff = {mean_diff_hidden.item()}" + assert max_diff_hidden < 1e-4, \ + f"Hidden states differ significantly: max diff = {max_diff_hidden.item()}" + + +@pytest.mark.usefixtures("distributed_setup") +def test_t5_large_encoder(): + # Initialize the two model implementations + hf_config = AutoConfig.from_pretrained(TEXT_ENCODER_PATH) + print(hf_config) + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + precision_str = "fp32" + precision = PRECISION_TO_TYPE[precision_str] + model1 = T5EncoderModel.from_pretrained(TEXT_ENCODER_PATH).to( + precision).to(device).eval() + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH) + + args = FastVideoArgs(model_path=TEXT_ENCODER_PATH, + pipeline_config=PipelineConfig(text_encoder_configs=(T5LargeConfig(),), + text_encoder_precisions=(precision_str,)), + pin_cpu_memory=False) + loader = TextEncoderLoader() + model2 = loader.load(TEXT_ENCODER_PATH, args) + model2 = model2.to(precision) + model2.eval() + + # Sanity check weights between the two models + logger.info("Comparing model weights for sanity check...") + params1 = dict(model1.named_parameters()) + params2 = dict(model2.named_parameters()) + + # Check number of parameters + logger.info("Model1 has %s parameters", len(params1)) + logger.info("Model2 has %s parameters", len(params2)) + + # # Print parameter names for comparison + # logger.info("Model1 parameters:") + # for name in sorted(params1.keys()): + # logger.info(" %s: %s", name, params1[name].shape) + + # logger.info("Model2 parameters:") + # for name in sorted(params2.keys()): + # logger.info(" %s: %s", name, params2[name].shape) + + weight_diffs = [] + # check if embed_tokens are the same + # weights = ["encoder.block.{}.layer.0.layer_norm.weight", "encoder.block.{}.layer.0.SelfAttention.relative_attention_bias.weight", \ + # "encoder.block.{}.layer.0.SelfAttention.o.weight", "encoder.block.{}.layer.1.DenseReluDense.wi_0.weight", "encoder.block.{}.layer.1.DenseReluDense.wi_1.weight",\ + # "encoder.block.{}.layer.1.DenseReluDense.wo.weight", \ + # "encoder.block.{}.layer.1.layer_norm.weight", "encoder.final_layer_norm.weight"] + + # for idx in range(hf_config.num_hidden_layers): + # for w in weights: + # name1 = w.format(idx) + # name2 = w.format(idx) + # p1 = params1[name1] + # p2 = params2[name2] + # p2 = (p2.to_local() if isinstance(p2, DTensor) else p2).to(p1) + # assert_close(p1, p2, atol=1e-4, rtol=1e-4) + + + # Test with some sample prompts + prompts = [ + "Once upon a time", "The quick brown fox jumps over", + "In a galaxy far, far away" + ] + + logger.info("Testing T5 Large encoder with sample prompts") + + with torch.no_grad(): + for prompt in prompts: + logger.info("Testing prompt: %s", prompt) + + # Tokenize the prompt + tokens = tokenizer(prompt, + padding="max_length", + max_length=512, + truncation=True, + return_tensors="pt").to(device) + + # Get outputs from HuggingFace implementation + # filter out padding input_ids + # tokens.input_ids = tokens.input_ids[tokens.attention_mask==1] + # tokens.attention_mask = tokens.attention_mask[tokens.attention_mask==1] + outputs1 = model1(input_ids=tokens.input_ids, + attention_mask=tokens.attention_mask, + output_hidden_states=True).last_hidden_state + print("--------------------------------") + logger.info("Testing model2 with T5LargeConfig") + + # Get outputs from our implementation + with set_forward_context(current_timestep=0, attn_metadata=None): + outputs2 = model2( + input_ids=tokens.input_ids, + attention_mask=tokens.attention_mask, + ).last_hidden_state + + # Compare last hidden states + last_hidden_state1 = outputs1[tokens.attention_mask == 1] + last_hidden_state2 = outputs2[tokens.attention_mask == 1] + assert last_hidden_state1.shape == last_hidden_state2.shape, \ f"Hidden state shapes don't match: {last_hidden_state1.shape} vs {last_hidden_state2.shape}" diff --git a/fastvideo/tests/transformers/test_cosmos.py b/fastvideo/tests/transformers/test_cosmos.py new file mode 100644 index 000000000..2b6ca5dbb --- /dev/null +++ b/fastvideo/tests/transformers/test_cosmos.py @@ -0,0 +1,257 @@ +# SPDX-License-Identifier: Apache-2.0 +import os + +import numpy as np +import pytest +import torch +from diffusers.models.transformers.transformer_cosmos import CosmosTransformer3DModel + +from fastvideo.configs.pipelines import PipelineConfig +from fastvideo.forward_context import set_forward_context +from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.logger import init_logger +from fastvideo.models.loader.component_loader import TransformerLoader +from fastvideo.utils import maybe_download_model +from fastvideo.configs.models.dits import CosmosVideoConfig +from fastvideo.pipelines.pipeline_batch_info import ForwardBatch + + +logger = init_logger(__name__) + +os.environ["MASTER_ADDR"] = "localhost" +os.environ["MASTER_PORT"] = "29504" + +# BASE_MODEL_PATH = "nvidia/Cosmos-Predict2-2B-Text2Image" +BASE_MODEL_PATH = "nvidia/Cosmos-Predict2-2B-Video2World" +MODEL_PATH = maybe_download_model(BASE_MODEL_PATH, + local_dir=os.path.join( + 'data', BASE_MODEL_PATH)) +TRANSFORMER_PATH = os.path.join(MODEL_PATH, "transformer") + + +@pytest.mark.usefixtures("distributed_setup") +def test_cosmos2_transformer(): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + precision = torch.bfloat16 + precision_str = "bf16" + args = FastVideoArgs(model_path=TRANSFORMER_PATH, + dit_cpu_offload=False, + use_fsdp_inference=False, + pipeline_config=PipelineConfig(dit_config=CosmosVideoConfig(), dit_precision=precision_str)) + + loader = TransformerLoader() + model2 = loader.load(TRANSFORMER_PATH, args).to(device, dtype=precision) + + model1 = CosmosTransformer3DModel.from_pretrained( + TRANSFORMER_PATH, torch_dtype=precision).to(device, dtype=precision).requires_grad_(False) + + total_params = sum(p.numel() for p in model1.parameters()) + # Calculate weight sum for model1 (converting to float64 to avoid overflow) + weight_sum_model1 = sum( + p.to(torch.float64).sum().item() for p in model1.parameters()) + # Also calculate mean for more stable comparison + weight_mean_model1 = weight_sum_model1 / total_params + logger.info("Model 1 weight sum: %s", weight_sum_model1) + logger.info("Model 1 weight mean: %s", weight_mean_model1) + + # Calculate weight sum for model2 (converting to float64 to avoid overflow) + total_params_model2 = sum(p.numel() for p in model2.parameters()) + weight_sum_model2 = sum( + p.to(torch.float64).sum().item() for p in model2.parameters()) + # Also calculate mean for more stable comparison + weight_mean_model2 = weight_sum_model2 / total_params_model2 + logger.info("Model 2 weight sum: %s", weight_sum_model2) + logger.info("Model 2 weight mean: %s", weight_mean_model2) + + weight_sum_diff = abs(weight_sum_model1 - weight_sum_model2) + logger.info("Weight sum difference: %s", weight_sum_diff) + weight_mean_diff = abs(weight_mean_model1 - weight_mean_model2) + logger.info("Weight mean difference: %s", weight_mean_diff) + + # Set both models to eval mode + model1 = model1.eval() + model2 = model2.eval() + + # Create identical inputs for both models + batch_size = 1 + seq_len = 30 + + # Video latents [B, C, T, H, W] - Cosmos2 specific dimensions + hidden_states = torch.randn(batch_size, + 17, + 1, # Single frame for image generation + 32, # Height patches + 32, # Width patches + device=device, + dtype=precision) + + # Text embeddings [B, L, D] - Cosmos2 uses T5 embeddings with 1024 dim + encoder_hidden_states = torch.randn(batch_size, + seq_len, + 1024, # T5 embedding dimension + device=device, + dtype=precision) + + # Timestep + timestep = torch.tensor([500], device=device, dtype=precision) + + # padding mask + padding_mask = hidden_states.new_zeros(1, 1, 32, 32, device=device, dtype=precision) + # print(padding_mask.shape) + + forward_batch = ForwardBatch( + data_type="dummy", + ) + + with torch.autocast('cuda', dtype=precision): + output1 = model1( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + padding_mask=padding_mask, + return_dict=False, + )[0] + with set_forward_context( + current_timestep=0, + attn_metadata=None, + forward_batch=forward_batch, + ): + output2 = model2(hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + padding_mask=padding_mask) + + # Check if outputs have the same shape + assert output1.shape == output2.shape, f"Output shapes don't match: {output1.shape} vs {output2.shape}" + assert output1.dtype == output2.dtype, f"Output dtype don't match: {output1.dtype} vs {output2.dtype}" + + # Check if outputs are similar (allowing for small numerical differences) + max_diff = torch.max(torch.abs(output1 - output2)) + mean_diff = torch.mean(torch.abs(output1 - output2)) + logger.info("Max Diff: %s", max_diff.item()) + logger.info("Mean Diff: %s", mean_diff.item()) + assert max_diff < 1e-1, f"Maximum difference between outputs: {max_diff.item()}" + # mean diff + assert mean_diff < 1e-2, f"Mean difference between outputs: {mean_diff.item()}" + + +@pytest.mark.usefixtures("distributed_setup") +def test_cosmos2_transformer_video2world(): + """Test Cosmos2 Video2World variant""" + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + precision = torch.bfloat16 + precision_str = "bf16" + + # Use Video2World model path + base_model_path = "nvidia/Cosmos-Predict2-2B-Video2World" + model_path = maybe_download_model(base_model_path, + local_dir=os.path.join( + 'data', base_model_path)) + transformer_path = os.path.join(model_path, "transformer") + + # Use torch attention backend for exact diffusers matching + cosmos_config = CosmosVideoConfig() + cosmos_config.arch_config.attention_backend = "torch" + + args = FastVideoArgs(model_path=transformer_path, + dit_cpu_offload=False, + use_fsdp_inference=False, + enable_torch_compile=False, + disable_autocast=True, + pipeline_config=PipelineConfig(dit_config=cosmos_config, dit_precision=precision_str)) + + loader = TransformerLoader() + model2 = loader.load(transformer_path, args).to(device, dtype=precision) + + model1 = CosmosTransformer3DModel.from_pretrained( + transformer_path, torch_dtype=precision).to(device, dtype=precision).requires_grad_(False) + + total_params = sum(p.numel() for p in model1.parameters()) + # Calculate weight sum for model1 (converting to float64 to avoid overflow) + weight_sum_model1 = sum( + p.to(torch.float64).sum().item() for p in model1.parameters()) + # Also calculate mean for more stable comparison + weight_mean_model1 = weight_sum_model1 / total_params + logger.info("Model 1 weight sum: %s", weight_sum_model1) + logger.info("Model 1 weight mean: %s", weight_mean_model1) + + # Calculate weight sum for model2 (converting to float64 to avoid overflow) + total_params_model2 = sum(p.numel() for p in model2.parameters()) + weight_sum_model2 = sum( + p.to(torch.float64).sum().item() for p in model2.parameters()) + # Also calculate mean for more stable comparison + weight_mean_model2 = weight_sum_model2 / total_params_model2 + logger.info("Model 2 weight sum: %s", weight_sum_model2) + logger.info("Model 2 weight mean: %s", weight_mean_model2) + + weight_sum_diff = abs(weight_sum_model1 - weight_sum_model2) + logger.info("Weight sum difference: %s", weight_sum_diff) + weight_mean_diff = abs(weight_mean_model1 - weight_mean_model2) + logger.info("Weight mean difference: %s", weight_mean_diff) + + # Set both models to eval mode + model1 = model1.eval() + model2 = model2.eval() + + # Create identical inputs for both models + batch_size = 1 + seq_len = 30 + + # Video latents [B, C, T, H, W] - Video2World has additional condition channel + hidden_states = torch.randn(batch_size, + 17, # 16 + 1 for condition channel + 8, # Multiple frames for video + 32, # Height patches + 32, # Width patches + device=device, + dtype=precision) + + # Text embeddings [B, L, D] + encoder_hidden_states = torch.randn(batch_size, + seq_len, + 1024, # T5 embedding dimension + device=device, + dtype=precision) + + # Timestep + timestep = torch.tensor([500], device=device, dtype=precision) + + # padding mask + padding_mask = hidden_states.new_zeros(1, 1, 32, 32, device=device, dtype=precision) + + forward_batch = ForwardBatch( + data_type="dummy", + ) + + with torch.autocast('cuda', dtype=precision, enabled=False): + output1 = model1( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + padding_mask=padding_mask, + return_dict=False, + )[0] + with set_forward_context( + current_timestep=0, + attn_metadata=None, + forward_batch=forward_batch, + ): + output2 = model2(hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + padding_mask=padding_mask) + + # Check if outputs have the same shape + assert output1.shape == output2.shape, f"Output shapes don't match: {output1.shape} vs {output2.shape}" + assert output1.dtype == output2.dtype, f"Output dtype don't match: {output1.dtype} vs {output2.dtype}" + + # Check if outputs are similar (allowing for small numerical differences) + max_diff = torch.max(torch.abs(output1 - output2)) + mean_diff = torch.mean(torch.abs(output1 - output2)) + logger.info("Max Diff: %s", max_diff.item()) + logger.info("Mean Diff: %s", mean_diff.item()) + + # With torch attention backend, outputs should now match closely + assert max_diff < 1e-1, f"Maximum difference between outputs: {max_diff.item()}" + # mean diff + assert mean_diff < 1e-2, f"Mean difference between outputs: {mean_diff.item()}" \ No newline at end of file diff --git a/test_fastvideo_pipeline.py b/test_fastvideo_pipeline.py new file mode 100644 index 000000000..2913767fe --- /dev/null +++ b/test_fastvideo_pipeline.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +""" +Simple script to generate a video using the FastVideo generator. +""" + +import os +import sys + +from fastvideo.entrypoints.video_generator import VideoGenerator + + +def generate_video() -> bool: + """Generate a video using the FastVideo generator.""" + + # Configuration + #input_image_path = "/mnt/fast-disks/hao_lab/kevin/FastVideo/tennis.jpg" + #prompt = "A tennis ball bouncing on a racquet, the ball moves in a smooth arc as it hits the strings and rebounds with natural physics. The racquet strings vibrate slightly from the impact, and the ball continues its trajectory with realistic motion." + input_image_path = "/mnt/fast-disks/hao_lab/kevin/FastVideo/yellow-scrubber.png" + prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess." + negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality." + output_path = "/mnt/fast-disks/hao_lab/kevin/FastVideo/cosmos2_fastvideo_output.mp4" + + # Check if input image exists + if not os.path.exists(input_image_path): + print(f"Error: Input image not found: {input_image_path}") + return False + + try: + # Create video generator + print("Creating FastVideo generator...") + generator = VideoGenerator.from_pretrained( + model_path="nvidia/Cosmos-Predict2-2B-Video2World", + num_gpus=1, + ) + + print("Generator created successfully") + + # Run inference + print("Generating video...") + result = generator.generate_video(height=704, + width=1280, + prompt=prompt, + negative_prompt=negative_prompt, + num_frames=21, + image_path=input_image_path, + num_inference_steps=35, + guidance_scale=7.0, + seed=1, + save_video=True, + output_path=output_path, + fps=16) + + if result: + print("Video generation completed successfully!") + return True + else: + print("Video generation failed - no result returned") + return False + + except Exception as e: + print(f"Error during video generation: {e}") + import traceback + traceback.print_exc() + return False + + +if __name__ == "__main__": + success = generate_video() + if success: + print("✅ Video generation completed successfully") + sys.exit(0) + else: + print("❌ Video generation failed") + sys.exit(1)