From bc0c1bde34df597a21476f838513db7ba4a58eb2 Mon Sep 17 00:00:00 2001 From: Aarti Lalwani Date: Thu, 4 Sep 2025 19:12:56 -0700 Subject: [PATCH 1/6] initial changes for dit and vae without full optimizations --- fastvideo/configs/models/dits/__init__.py | 4 +- fastvideo/configs/models/dits/ltxvideo.py | 88 ++ fastvideo/configs/models/vaes/__init__.py | 3 + fastvideo/configs/models/vaes/ltxvae.py | 48 + fastvideo/models/dits/ltxvideo.py | 934 ++++++++++++++++++ fastvideo/models/loader/fsdp_load.py | 4 + fastvideo/models/registry.py | 6 +- .../pipelines/basic/ltxvideo/__init__.py | 0 .../pipelines/basic/ltxvideo/ltx_pipeline.py | 74 ++ fastvideo/tests/transformers/test_ltxvideo.py | 426 ++++++++ fastvideo/tests/vaes/test_ltx_vae.py | 199 ++++ 11 files changed, 1782 insertions(+), 4 deletions(-) create mode 100644 fastvideo/configs/models/dits/ltxvideo.py create mode 100644 fastvideo/configs/models/vaes/ltxvae.py create mode 100644 fastvideo/models/dits/ltxvideo.py create mode 100644 fastvideo/pipelines/basic/ltxvideo/__init__.py create mode 100644 fastvideo/pipelines/basic/ltxvideo/ltx_pipeline.py create mode 100644 fastvideo/tests/transformers/test_ltxvideo.py create mode 100644 fastvideo/tests/vaes/test_ltx_vae.py diff --git a/fastvideo/configs/models/dits/__init__.py b/fastvideo/configs/models/dits/__init__.py index 72271a525..0f6c7694a 100644 --- a/fastvideo/configs/models/dits/__init__.py +++ b/fastvideo/configs/models/dits/__init__.py @@ -1,5 +1,7 @@ from fastvideo.configs.models.dits.hunyuanvideo import HunyuanVideoConfig from fastvideo.configs.models.dits.stepvideo import StepVideoConfig from fastvideo.configs.models.dits.wanvideo import WanVideoConfig +from fastvideo.configs.models.dits.ltxvideo import LTXVideoConfig -__all__ = ["HunyuanVideoConfig", "WanVideoConfig", "StepVideoConfig"] + +__all__ = ["HunyuanVideoConfig", "WanVideoConfig", "StepVideoConfig","LTXVideoConfig"] diff --git a/fastvideo/configs/models/dits/ltxvideo.py b/fastvideo/configs/models/dits/ltxvideo.py new file mode 100644 index 000000000..4719fc118 --- /dev/null +++ b/fastvideo/configs/models/dits/ltxvideo.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field +from fastvideo.configs.models.dits.base import DiTArchConfig, DiTConfig + + +@dataclass +class LTXVideoArchConfig(DiTArchConfig): + fsdp_shard_conditions: list = field( + default_factory=lambda: + [lambda n, m: "transformer_blocks" in n and n.split(".")[-1].isdigit()]) + + # Parameter name mappings for loading pretrained weights from HuggingFace/Diffusers + # Maps from source model parameter names to FastVideo LTX implementation names + param_names_mapping: dict = field( + default_factory=lambda: { + # todo: double check all of this + r"^transformer_blocks\.(\d+)\.norm1\.weight$": + r"transformer_blocks.\1.norm1.weight", + r"^transformer_blocks\.(\d+)\.norm2\.weight$": + r"transformer_blocks.\1.norm2.weight", + + + # FeedForward network mappings (double check) + r"^transformer_blocks\.(\d+)\.ff\.net\.0\.weight$": + r"transformer_blocks.\1.ff.net.0.weight", + r"^transformer_blocks\.(\d+)\.ff\.net\.0\.bias$": + r"transformer_blocks.\1.ff.net.0.bias", + r"^transformer_blocks\.(\d+)\.ff\.net\.3\.weight$": + r"transformer_blocks.\1.ff.net.3.weight", + r"^transformer_blocks\.(\d+)\.ff\.net\.3\.bias$": + r"transformer_blocks.\1.ff.net.3.bias", + + + # Scale-shift table for adaptive layer norm + r"^transformer_blocks\.(\d+)\.scale_shift_table$": + r"transformer_blocks.\1.scale_shift_table", + + # Time embedding mappings + r"^time_embed\.emb\.timestep_embedder\.linear_1\.(weight|bias)$": + r"time_embed.emb.mlp.fc_in.\1", + r"^time_embed\.emb\.timestep_embedder\.linear_2\.(weight|bias)$": + r"time_embed.emb.mlp.fc_out.\1", + + # Caption projection mappings + r"^caption_projection\.linear_1\.(weight|bias)$": + r"caption_projection.fc_in.\1", + r"^caption_projection\.linear_2\.(weight|bias)$": + r"caption_projection.fc_out.\1", + + # Output normalization (FP32LayerNorm) + r"^norm_out\.weight$": + r"norm_out.weight", + + # Global scale-shift table + r"^scale_shift_table$": + r"scale_shift_table", + }) + + num_attention_heads: int = 32 + attention_head_dim: int = 64 + in_channels: int = 128 + out_channels: int | None = 128 + num_layers: int = 28 + dropout: float = 0.0 + patch_size: int = 1 + patch_size_t: int = 1 + norm_elementwise_affine: bool = False + norm_eps: float = 1e-6 + activation_fn: str = "gelu-approximate" + attention_bias: bool = True + attention_out_bias: bool = True + caption_channels: int | list[int] | tuple[int, ...] | None = 4096 + cross_attention_dim: int = 2048 + qk_norm: str = "rms_norm_across_heads" + attention_type: str | None = "torch" + use_additional_conditions: bool | None = False + exclude_lora_layers: list[str] = field(default_factory=lambda: []) + + def __post_init__(self): + self.hidden_size = self.num_attention_heads * self.attention_head_dim + self.out_channels = self.in_channels if self.out_channels is None else self.out_channels + self.num_channels_latents = self.out_channels + + +@dataclass +class LTXVideoConfig(DiTConfig): + arch_config: DiTArchConfig = field(default_factory=LTXVideoArchConfig) + prefix: str = "LTXVideo" \ No newline at end of file diff --git a/fastvideo/configs/models/vaes/__init__.py b/fastvideo/configs/models/vaes/__init__.py index 700c8de1b..bb5f4d464 100644 --- a/fastvideo/configs/models/vaes/__init__.py +++ b/fastvideo/configs/models/vaes/__init__.py @@ -1,9 +1,12 @@ from fastvideo.configs.models.vaes.hunyuanvae import HunyuanVAEConfig from fastvideo.configs.models.vaes.stepvideovae import StepVideoVAEConfig from fastvideo.configs.models.vaes.wanvae import WanVAEConfig +from fastvideo.configs.models.vaes.ltxvae import LTXVAEConfig + __all__ = [ "HunyuanVAEConfig", "WanVAEConfig", "StepVideoVAEConfig", + "LTXVAEConfig", ] diff --git a/fastvideo/configs/models/vaes/ltxvae.py b/fastvideo/configs/models/vaes/ltxvae.py new file mode 100644 index 000000000..18b1e0887 --- /dev/null +++ b/fastvideo/configs/models/vaes/ltxvae.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field +import torch +from fastvideo.configs.models.vaes.base import VAEArchConfig, VAEConfig + + +@dataclass +class LTXVAEArchConfig(VAEArchConfig): + block_out_channels: tuple[int, ...] = (128, 256, 512, 512) + decoder_causal: bool = False + encoder_causal: bool = True + in_channels: int = 3 + latent_channels: int = 128 + layers_per_block: tuple[int, ...] = (4, 3, 3, 3, 4) + out_channels: int = 3 + patch_size: int = 4 + patch_size_t: int = 1 + resnet_norm_eps: float = 1e-06 + scaling_factor: float = 1.0 + spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, False) + + # Additional fields that might be inherited from base class + z_dim: int = 128 # Using latent_channels as z_dim + is_residual: bool = False + clip_output: bool = True + + def __post_init__(self): + # Calculate compression ratios based on patch sizes and downsampling + self.temporal_compression_ratio = self.patch_size_t + # Spatial compression is usually patch_size * product of spatial downsampling + self.spatial_compression_ratio = self.patch_size * (2 ** (len(self.block_out_channels) - 1)) + + if isinstance(self.scaling_factor, (int, float)): + self.scaling_factor_tensor: torch.Tensor = torch.tensor(self.scaling_factor) + + +@dataclass +class LTXVAEConfig(VAEConfig): + arch_config: LTXVAEArchConfig = field(default_factory=LTXVAEArchConfig) + use_feature_cache: bool = True + use_tiling: bool = False + use_temporal_tiling: bool = False + use_parallel_tiling: bool = False + + def __post_init__(self): + if hasattr(self, 'tile_sample_min_num_frames') and hasattr(self, 'tile_sample_stride_num_frames'): + self.blend_num_frames = (self.tile_sample_min_num_frames - + self.tile_sample_stride_num_frames) * 2 \ No newline at end of file diff --git a/fastvideo/models/dits/ltxvideo.py b/fastvideo/models/dits/ltxvideo.py new file mode 100644 index 000000000..786a1a615 --- /dev/null +++ b/fastvideo/models/dits/ltxvideo.py @@ -0,0 +1,934 @@ +# Copyright 2025 The Genmo team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# FastVideo optimized imports +from fastvideo.layers.layernorm import RMSNorm, FP32LayerNorm, LayerNormScaleShift, ScaleResidualLayerNormScaleShift +from fastvideo.layers.layernorm import ScaleResidual +from fastvideo.layers.activation import NewGELU +from fastvideo.layers.vocab_parallel_embedding import VocabParallelEmbedding, UnquantizedEmbeddingMethod +from fastvideo.configs.models.dits import LTXVideoConfig +from fastvideo.layers.visual_embedding import TimestepEmbedder +from fastvideo.layers.linear import ReplicatedLinear, ColumnParallelLinear, RowParallelLinear, QKVParallelLinear +from fastvideo.platforms import AttentionBackendEnum, current_platform + +from diffusers.models.attention import FeedForward + + +#from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.normalization import AdaLayerNormSingle + +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin +from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from diffusers.utils.torch_utils import maybe_allow_in_graph +# from ..attention import FeedForward +from fastvideo.attention import DistributedAttention, LocalAttention +#from diffusers.attention_processor import Attention +from fastvideo.models.dits.base import CachableDiT + +from diffusers.models.cache_utils import CacheMixin +from diffusers.models.embeddings import PixArtAlphaTextProjection +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +# from ..normalization import AdaLayerNormSingle + +from fastvideo.layers.linear import ReplicatedLinear + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + + +class LTXVideoAttentionProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: LocalAttention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + # TODO: Optimize with fused QKV projection for better performance + # Current implementation uses separate projections + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # # Fused QKV projection example + # from fastvideo.layers.linear import QKVParallelLinear + # self.qkv_proj = QKVParallelLinear( + # hidden_size=hidden_size, + # head_size=attention_head_dim, + # total_num_heads=num_attention_heads, + # bias=True + # ) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + # TODO: Consider local attention patterns for long sequences + # TODO: Add distributed attention support for multi-GPU setups + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + # TODO: Add output dimension parallelism here + #example + # # Output dimension parallelism + # from fastvideo.layers.linear import ColumnParallelLinear + # self.q_proj = ColumnParallelLinear( + # input_size=hidden_size, + # output_size=head_size * num_heads, + # bias=bias, + # gather_output=False + # ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + + + +class FastVideoLTXRotaryPosEmbed(nn.Module): + """FastVideo optimized rotary position embedding for LTX model.""" + + def __init__( + self, + dim: int, + base_num_frames: int = 20, + base_height: int = 2048, + base_width: int = 2048, + patch_size: int = 1, + patch_size_t: int = 1, + theta: float = 10000.0, + ) -> None: + super().__init__() + + self.dim = dim + self.base_num_frames = base_num_frames + self.base_height = base_height + self.base_width = base_width + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.theta = theta + + def _prepare_video_coords( + self, + batch_size: int, + num_frames: int, + height: int, + width: int, + rope_interpolation_scale: Tuple[torch.Tensor, float, float], + device: torch.device, + ) -> torch.Tensor: + # Add defaults based on base dimensions if None + num_frames = num_frames or self.base_num_frames + height = height or (self.base_height // self.patch_size) + width = width or (self.base_width // self.patch_size) + print(f"num_frames {num_frames} height{height} width {width}") + # Always compute rope in fp32 + grid_h = torch.arange(height, dtype=torch.float32, device=device) + grid_w = torch.arange(width, dtype=torch.float32, device=device) + grid_f = torch.arange(num_frames, dtype=torch.float32, device=device) + grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij") + grid = torch.stack(grid, dim=0) + grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + + if rope_interpolation_scale is not None: + grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames + grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height + grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width + + grid = grid.flatten(2, 4).transpose(1, 2) + + return grid + + def forward( + self, + hidden_states: torch.Tensor, + num_frames: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None, + video_coords: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = hidden_states.size(0) + + if video_coords is None: + grid = self._prepare_video_coords( + batch_size, + num_frames, + height, + width, + rope_interpolation_scale=rope_interpolation_scale, + device=hidden_states.device, + ) + else: + grid = torch.stack( + [ + video_coords[:, 0] / self.base_num_frames, + video_coords[:, 1] / self.base_height, + video_coords[:, 2] / self.base_width, + ], + dim=-1, + ) + + start = 1.0 + end = self.theta + freqs = self.theta ** torch.linspace( + math.log(start, self.theta), + math.log(end, self.theta), + self.dim // 6, + device=hidden_states.device, + dtype=torch.float32, + ) + freqs = freqs * math.pi / 2.0 + freqs = freqs * (grid.unsqueeze(-1) * 2 - 1) + freqs = freqs.transpose(-1, -2).flatten(2) + + cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) + sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) + + if self.dim % 6 != 0: + cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % 6]) + sin_padding = torch.zeros_like(cos_freqs[:, :, : self.dim % 6]) + cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) + sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) + + return cos_freqs, sin_freqs + +# class LTXVideoLocalAttention(nn.Module): +# def __init__( +# self, +# query_dim: int, +# heads: int, +# dim_head: int, +# kv_heads: Optional[int] = None, +# bias: bool = True, +# cross_attention_dim: Optional[int] = None, +# out_bias: bool = True, +# qk_norm: Optional[str] = None, +# eps: float = 1e-6, +# ): +# super().__init__() + +# self.inner_dim = heads * dim_head +# self.cross_attention_dim = cross_attention_dim +# self.heads = heads +# self.dim_head = dim_head + +# # Handle self-attention vs cross-attention projections +# if cross_attention_dim is None: # Self-attention case +# # Use QKVParallelLinear for fused Q, K, V projections +# self.qkv_proj = QKVParallelLinear( +# hidden_size=query_dim, +# head_size=dim_head, +# total_num_heads=heads, +# total_num_kv_heads=kv_heads or heads, +# bias=bias +# ) +# # No separate to_q, to_k, to_v for self-attention +# self.to_q = None +# self.to_k = None +# self.to_v = None +# else: # Cross-attention case +# # Keep separate projections for cross-attention +# self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) +# self.to_k = nn.Linear(cross_attention_dim, self.inner_dim, bias=bias) +# self.to_v = nn.Linear(cross_attention_dim, self.inner_dim, bias=bias) +# self.qkv_proj = None + +# # Output projection (same for both cases) +# self.to_out = nn.ModuleList([ +# nn.Linear(self.inner_dim, query_dim, bias=out_bias), +# nn.Dropout(0.0) +# ]) + +# # Norms (same for both cases) +# if qk_norm == "rms_norm_across_heads": +# norm_eps = 1e-5 +# norm_elementwise_affine = True +# self.norm_q = torch.nn.RMSNorm(self.inner_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) +# self.norm_k = torch.nn.RMSNorm(self.inner_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) +# else: +# self.norm_q = None +# self.norm_k = None + +# def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None): +# if encoder_hidden_states is None: +# encoder_hidden_states = hidden_states + +# batch_size, seq_len, _ = hidden_states.shape + +# # Handle Q, K, V projection based on attention type +# if self.qkv_proj is not None: # Self-attention case +# # Use fused QKV projection +# qkv_output, _ = self.qkv_proj(hidden_states) +# # Split the output into Q, K, V +# q_size = self.heads * self.dim_head +# k_size = self.heads * self.dim_head # Assuming same size for simplicity +# v_size = self.heads * self.dim_head + +# query = qkv_output[..., :q_size] +# key = qkv_output[..., q_size:q_size + k_size] +# value = qkv_output[..., q_size + k_size:q_size + k_size + v_size] +# else: # Cross-attention case +# # Use separate projections +# query = self.to_q(hidden_states) +# key = self.to_k(encoder_hidden_states) +# value = self.to_v(encoder_hidden_states) + +# # Rest of the forward method remains the same +# if hasattr(self, 'norm_q') and self.norm_q is not None: +# query = self.norm_q(query) +# key = self.norm_k(key) + +# if image_rotary_emb is not None: +# query = apply_rotary_emb(query, image_rotary_emb) +# key = apply_rotary_emb(key, image_rotary_emb) + +# # Reshape and attention computation +# query = query.view(batch_size, -1, self.heads, self.dim_head).transpose(1, 2) +# key = key.view(batch_size, -1, self.heads, self.dim_head).transpose(1, 2) +# value = value.view(batch_size, -1, self.heads, self.dim_head).transpose(1, 2) + +# hidden_states = F.scaled_dot_product_attention( +# query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False +# ) +# hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + +# hidden_states = self.to_out[0](hidden_states) +# hidden_states = self.to_out[1](hidden_states) + +# return hidden_states + +class LTXVideoLocalAttention(nn.Module): + """Wrapper for LocalAttention""" + + def __init__( + self, + query_dim: int, + heads: int, + dim_head: int, + kv_heads: Optional[int] = None, + bias: bool = True, + cross_attention_dim: Optional[int] = None, + out_bias: bool = True, + qk_norm: Optional[str] = None, + eps: float = 1e-6, + ): + super().__init__() + + self.inner_dim = heads * dim_head + self.cross_attention_dim = cross_attention_dim + self.heads = heads + self.dim_head = dim_head + + #Projections + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim or query_dim, self.inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim or query_dim, self.inner_dim, bias=bias) + # self.to_out = nn.ModuleList([ + # nn.Linear(self.inner_dim, query_dim, bias=out_bias), + # nn.Dropout(0.0) + # ]) + self.to_out = nn.ModuleList([ + RowParallelLinear( + input_size=self.inner_dim, + output_size=query_dim, + bias=out_bias, + input_is_parallel=True, # Input comes from parallel attention computation + reduce_results=True # All-reduce to gather full result + ), + nn.Dropout(0.0) + ]) + + #self.to_out = + # self.out_proj = RowParallelLinear( + # input_size=head_size * num_heads, + # output_size=hidden_size, + # bias=bias, + # input_is_parallel=True + # ) + + + + # #QKVParallelLinear for self-attention: + # if cross_attention_dim is None: # Self-attention + # self.qkv_proj = QKVParallelLinear( + # hidden_size=query_dim, + # head_size=dim_head, + # total_num_heads=heads, + # bias=bias + # ) + # else: # Cross-attention - keep separate for different input dims + # self.to_q = ColumnParallelLinear(query_dim, self.inner_dim, bias=bias, gather_output=False) + # self.to_k = ColumnParallelLinear(cross_attention_dim, self.inner_dim, bias=bias, gather_output=False) + # self.to_v = ColumnParallelLinear(cross_attention_dim, self.inner_dim, bias=bias, gather_output=False) + + + # Replace output projection: + # self.to_out[0] = nn.Linear(self.inner_dim, query_dim, bias=out_bias) + # self.to_out = nn.ModuleList([ + # RowParallelLinear(self.inner_dim, query_dim, bias=out_bias, input_is_parallel=True), + # nn.Dropout(0.0) + # ]) + + + # self.to_out = nn.ModuleList([ + # nn.Linear(self.inner_dim, query_dim, bias=out_bias), + # nn.Dropout(0.0) + # ]) + + + # Norms + if qk_norm == "rms_norm_across_heads": + norm_eps = 1e-5 + norm_elementwise_affine = True + self.norm_q = torch.nn.RMSNorm(self.inner_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + self.norm_k = torch.nn.RMSNorm(self.inner_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + + else: + self.norm_q = None + self.norm_k = None + + + self.attention = LocalAttention( + num_heads=heads, + head_size=dim_head, + num_kv_heads=kv_heads or heads, + softmax_scale=None, + causal=False, + supported_attention_backends=(AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA), + ) + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + batch_size, seq_len, _ = hidden_states.shape + + query = self.to_q(hidden_states) + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + if hasattr(self, 'norm_q'): + query = self.norm_q(query) + key = self.norm_k(key) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + # Use your original working approach temporarily + query = query.view(batch_size, -1, self.heads, self.dim_head).transpose(1, 2) + key = key.view(batch_size, -1, self.heads, self.dim_head).transpose(1, 2) + value = value.view(batch_size, -1, self.heads, self.dim_head).transpose(1, 2) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + + hidden_states, _ = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states) # Dropout + return hidden_states + +class GELU(nn.Module): + def __init__(self, dim_in, dim_out, approximate="tanh", bias=True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.approximate = approximate + + def forward(self, x): + x = self.proj(x) + if self.approximate == "tanh": + x = F.gelu(x, approximate="tanh") + else: + x = F.gelu(x) + return x + + +class FastVideoFeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "gelu-approximate", + bias: bool = True, + ): + super().__init__() + inner_dim = int(dim * mult) # 8192 + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu-approximate": + self.net = nn.ModuleList([ + GELU(dim, inner_dim, approximate="tanh", bias=bias), + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out, bias=bias) + ]) + + # TODO: fix below + # if activation_fn == "gelu-approximate": + # # Replace the Linear layers with parallel versions: + # self.net = nn.ModuleList([ + # # First layer: Column parallel (split output features) + # ColumnParallelLinear(dim, inner_dim, bias=bias, gather_output=False), + # GELU(dim, inner_dim, approximate="tanh", bias=bias), + # nn.Dropout(dropout), + # # Second layer: Row parallel (split input features, gather output) + # RowParallelLinear(inner_dim, dim_out, bias=bias, input_is_parallel=True) + # ]) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + +@maybe_allow_in_graph +class FastVideoLTXTransformerBlock(nn.Module): + r""" + FastVideo Transformer block for LTX model. + + TODO: describe the parts I changed for FastVideo + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: int, + qk_norm: str = "rms_norm_across_heads", + activation_fn: str = "gelu-approximate", + attention_bias: bool = True, + attention_out_bias: bool = True, + eps: float = 1e-6, + elementwise_affine: bool = False, + ): + super().__init__() + + # Use FastVideo RMSNorm + self.norm1 = RMSNorm(dim, eps=eps, has_weight=elementwise_affine) + + # Self-attention using the wrapper + self.attn1 = LTXVideoLocalAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + kv_heads=num_attention_heads, # LTX doesn't use GQA + bias=attention_bias, + cross_attention_dim=None, # Self-attention + out_bias=attention_out_bias, + qk_norm=qk_norm, + eps=1e-5, + ) + + self.norm2 = RMSNorm(dim, eps=eps, has_weight=elementwise_affine) + + + # Cross-attention using wrapper + self.attn2 = LTXVideoLocalAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + kv_heads=num_attention_heads, + bias=attention_bias, + cross_attention_dim=cross_attention_dim, + out_bias=attention_out_bias, + qk_norm=qk_norm, + eps=1e-5, + ) + + self.ff = FastVideoFeedForward(dim, activation_fn=activation_fn) + + + # Scale-shift table for adaptive layer norm + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size = hidden_states.size(0) + + # First self-attention block with scale/shift modulation + norm_hidden_states = self.norm1(hidden_states) + + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + + # Self-attention - the wrapper handles all projections and normalization + attn_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, # None for self-attention + image_rotary_emb=image_rotary_emb, + ) + + # Gated residual connection + hidden_states = hidden_states + attn_hidden_states * gate_msa + + # Cross-attention block + attn_hidden_states = self.attn2( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + image_rotary_emb=None, # No rotary embeddings for cross-attention + ) + hidden_states = hidden_states + attn_hidden_states + + # Feed-forward block with scale/shift modulation + norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp + ff_output = self.ff(norm_hidden_states) + + # Gated residual connection + hidden_states = hidden_states + ff_output * gate_mlp + + return hidden_states + +@maybe_allow_in_graph +class LTXVideoTransformer3DModel(CachableDiT): + r""" + FastVideo optimized Transformer model for video-like data used in LTX. + + Key optimizations: + - RMSNorm for normalization layers + - QuickGELU activation functions + - Fused scale/shift operations where possible + - Prepared for distributed attention and dimension parallelism + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["norm"] + _repeated_blocks = ["FastVideoLTXTransformerBlock"] + + def __init__( + self, + config: Optional[LTXVideoConfig] = None, + hf_config: Optional[Dict] = None, + **kwargs + ) -> None: + super().__init__(config=config, hf_config=hf_config) + + # Handle both config object and kwargs + if config is not None: + # Extract parameters from config + if hasattr(config, 'arch_config'): + # FastVideo style config + arch_config = config.arch_config + in_channels = getattr(arch_config, 'in_channels', 128) + out_channels = getattr(arch_config, 'out_channels', 128) + patch_size = getattr(arch_config, 'patch_size', 1) + patch_size_t = getattr(arch_config, 'patch_size_t', 1) + num_attention_heads = getattr(arch_config, 'num_attention_heads', 32) + attention_head_dim = getattr(arch_config, 'attention_head_dim', 64) + cross_attention_dim = getattr(arch_config, 'cross_attention_dim', 2048) + num_layers = getattr(arch_config, 'num_layers', 28) + activation_fn = getattr(arch_config, 'activation_fn', 'gelu-approximate') + qk_norm = getattr(arch_config, 'qk_norm', 'rms_norm_across_heads') + norm_elementwise_affine = getattr(arch_config, 'norm_elementwise_affine', False) + norm_eps = getattr(arch_config, 'norm_eps', 1e-6) + caption_channels = getattr(arch_config, 'caption_channels', 4096) + attention_bias = getattr(arch_config, 'attention_bias', True) + attention_out_bias = getattr(arch_config, 'attention_out_bias', True) + else: + # Try to get from hf_config if provided + if hf_config: + in_channels = hf_config.get('in_channels', 128) + out_channels = hf_config.get('out_channels', 128) + patch_size = hf_config.get('patch_size', 1) + patch_size_t = hf_config.get('patch_size_t', 1) + num_attention_heads = hf_config.get('num_attention_heads', 32) + attention_head_dim = hf_config.get('attention_head_dim', 64) + cross_attention_dim = hf_config.get('cross_attention_dim', 2048) + num_layers = hf_config.get('num_layers', 28) + activation_fn = hf_config.get('activation_fn', 'gelu-approximate') + qk_norm = hf_config.get('qk_norm', 'rms_norm_across_heads') + norm_elementwise_affine = hf_config.get('norm_elementwise_affine', False) + norm_eps = hf_config.get('norm_eps', 1e-6) + caption_channels = hf_config.get('caption_channels', 4096) + attention_bias = hf_config.get('attention_bias', True) + attention_out_bias = hf_config.get('attention_out_bias', True) + else: + # Default values + raise ValueError("Either config or hf_config must be provided") + else: + # Use kwargs + in_channels = kwargs.get('in_channels', 128) + out_channels = kwargs.get('out_channels', 128) + # anythig else from kwargs?? + + out_channels = out_channels or in_channels + inner_dim = num_attention_heads * attention_head_dim + + # TODO: Add input dimension parallelism for distributed training + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + + # Use AdaLayerNormSingle for time embedding (keep original for compatibility) + self.time_embed = AdaLayerNormSingle(inner_dim, use_additional_conditions=False) + + + # Caption projection + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + + # FastVideo optimized rotary position embedding + self.rope = FastVideoLTXRotaryPosEmbed( + dim=inner_dim, + base_num_frames=20, + base_height=2048, + base_width=2048, + patch_size=patch_size, + patch_size_t=patch_size_t, + theta=10000.0, + ) + + + # FastVideo optimized transformer blocks + self.transformer_blocks = nn.ModuleList( + [ + FastVideoLTXTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + qk_norm=qk_norm, + activation_fn=activation_fn, + attention_bias=attention_bias, + attention_out_bias=attention_out_bias, + eps=norm_eps, + elementwise_affine=norm_elementwise_affine, + ) + for _ in range(num_layers) + ] + ) + + # Using FastVideo FP32LayerNorm for output normalization + self.norm_out = FP32LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False) + + # # TODO: Add output dimension parallelism for distributed training + # example: + # # Output dimension parallelism + # from fastvideo.layers.linear import ColumnParallelLinear + # self.q_proj = ColumnParallelLinear( + # input_size=hidden_size, + # output_size=head_size * num_heads, + # bias=bias, + # gather_output=False + # ) + self.proj_out = nn.Linear(inner_dim, out_channels) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_attention_mask: torch.Tensor, + num_frames: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + rope_interpolation_scale: Optional[Union[Tuple[float, float, float], torch.Tensor]] = None, + video_coords: Optional[torch.Tensor] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> torch.Tensor: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale, video_coords) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + batch_size = hidden_states.size(0) + + # TODO: Add input dimension parallelism here + # example + # # Input dimension parallelism + # from fastvideo.layers.linear import RowParallelLinear + # self.out_proj = RowParallelLinear( + # input_size=head_size * num_heads, + # output_size=hidden_size, + # bias=bias, + # input_is_parallel=True + # ) + hidden_states = self.proj_in(hidden_states) + + temb, embedded_timestep = self.time_embed( + timestep.flatten(), + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + + temb = temb.view(batch_size, -1, temb.size(-1)) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) + + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) + + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + encoder_attention_mask, + ) + else: + hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + encoder_attention_mask=encoder_attention_mask, + ) + + scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + + # FP32 normalization for better numerical stability + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + + # TODO: Add output dimension parallelism here + # # Output dimension parallelism + # from fastvideo.layers.linear import ColumnParallelLinear + # self.q_proj = ColumnParallelLinear( + # input_size=hidden_size, + # output_size=head_size * num_heads, + # bias=bias, + # gather_output=False + # ) + + output = self.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + +def apply_rotary_emb(x, freqs): + """Apply rotary embeddings to input tensors.""" + cos, sin = freqs + x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return out +# when i use newgelu, passes and same output difference as this version, just the intermediate ones difer a lot but its not relative big diff +# INFO 08-27 17:43:12 [test_ltxvideo.py:331] Max Diff: 0.03125 +# INFO 08-27 17:43:12 [test_ltxvideo.py:332] Mean Diff: 0.0026702880859375 + +# #copy pasted from stepvideo, move it to a shared file and use this instead? +# class AdaLayerNormSingle(nn.Module): +# r""" +# Norm layer adaptive layer norm single (adaLN-single). + +# As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). + +# Parameters: +# embedding_dim (`int`): The size of each embedding vector. +# use_additional_conditions (`bool`): To use additional conditions for normalization or not. +# """ + +# def __init__(self, embedding_dim: int, time_step_rescale=1000): +# super().__init__() + +# self.emb = TimestepEmbedder(embedding_dim) + +# self.silu = nn.SiLU() +# self.linear = ReplicatedLinear(embedding_dim, +# 6 * embedding_dim, +# bias=True) + +# self.time_step_rescale = time_step_rescale ## timestep usually in [0, 1], we rescale it to [0,1000] for stability + +# def forward( +# self, +# timestep: torch.Tensor, +# added_cond_kwargs: dict[str, torch.Tensor] | None = None, +# ) -> tuple[torch.Tensor, torch.Tensor]: +# embedded_timestep = self.emb(timestep * self.time_step_rescale) + +# out, _ = self.linear(self.silu(embedded_timestep)) + +# return out, embedded_timestep diff --git a/fastvideo/models/loader/fsdp_load.py b/fastvideo/models/loader/fsdp_load.py index 117b1a728..59f3b4db0 100644 --- a/fastvideo/models/loader/fsdp_load.py +++ b/fastvideo/models/loader/fsdp_load.py @@ -245,9 +245,13 @@ def load_model_from_full_model_state_dict( NotImplementedError: If got FSDP with more than 1D. """ meta_sd = model.state_dict() + # TODO: remove (for debugging) + print(list(meta_sd.keys())) sharded_sd = {} custom_param_sd, reverse_param_names_mapping = hf_to_custom_state_dict( full_sd_iterator, param_names_mapping) # type: ignore + # TODO: remove (for debugging) + print(list(custom_param_sd.keys())) for target_param_name, full_tensor in custom_param_sd.items(): meta_sharded_param = meta_sd.get(target_param_name) if meta_sharded_param is None: diff --git a/fastvideo/models/registry.py b/fastvideo/models/registry.py index 68be9bc13..e899dcb37 100644 --- a/fastvideo/models/registry.py +++ b/fastvideo/models/registry.py @@ -25,14 +25,13 @@ "HunyuanVideoTransformer3DModel": ("dits", "hunyuanvideo", "HunyuanVideoTransformer3DModel"), "WanTransformer3DModel": ("dits", "wanvideo", "WanTransformer3DModel"), - "CausalWanTransformer3DModel": ("dits", "causal_wanvideo", "CausalWanTransformer3DModel"), "StepVideoModel": ("dits", "stepvideo", "StepVideoModel") } _IMAGE_TO_VIDEO_DIT_MODELS = { # "HunyuanVideoTransformer3DModel": ("dits", "hunyuanvideo", "HunyuanVideoDiT"), "WanTransformer3DModel": ("dits", "wanvideo", "WanTransformer3DModel"), - "CausalWanTransformer3DModel": ("dits", "causal_wanvideo", "CausalWanTransformer3DModel"), + "LTXVideoTransformer3DModel": ("dits", "ltxvideo", "LTXVideoTransformer3DModel"), } _TEXT_ENCODER_MODELS = { @@ -52,7 +51,8 @@ "AutoencoderKLHunyuanVideo": ("vaes", "hunyuanvae", "AutoencoderKLHunyuanVideo"), "AutoencoderKLWan": ("vaes", "wanvae", "AutoencoderKLWan"), - "AutoencoderKLStepvideo": ("vaes", "stepvideovae", "AutoencoderKLStepvideo") + "AutoencoderKLStepvideo": ("vaes", "stepvideovae", "AutoencoderKLStepvideo"), + "AutoencoderKLLTXVideo":( "vaes", "ltxvae", "AutoencoderKLLTXVideo"), } _SCHEDULERS = { diff --git a/fastvideo/pipelines/basic/ltxvideo/__init__.py b/fastvideo/pipelines/basic/ltxvideo/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fastvideo/pipelines/basic/ltxvideo/ltx_pipeline.py b/fastvideo/pipelines/basic/ltxvideo/ltx_pipeline.py new file mode 100644 index 000000000..58e42add8 --- /dev/null +++ b/fastvideo/pipelines/basic/ltxvideo/ltx_pipeline.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +LTX video diffusion pipeline implementation. + +This module contains an implementation of the LTX video diffusion pipeline +using the modular pipeline architecture. +""" + +from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.logger import init_logger + +from fastvideo.pipelines import ComposedPipelineBase +from fastvideo.pipelines.stages import (ConditioningStage, DecodingStage, + DenoisingStage, InputValidationStage, + LatentPreparationStage,CLIPImageEncodingStage, + TextEncodingStage, + TimestepPreparationStage) + +logger = init_logger(__name__) + + +class LTXPipeline(ComposedPipelineBase): + """ + LTX video diffusion pipeline with LoRA support. + """ + + _required_config_modules = [ + "text_encoder", "tokenizer", "vae", "transformer", "scheduler" + ] + + @property + def required_config_modules(self) -> List[str]: + return self._required_config_modules + + def initialize_pipeline(self, fastvideo_args: FastVideoArgs): + """Initialize pipeline-specific components.""" + pass + + def create_pipeline_stages(self, fastvideo_args: FastVideoArgs) -> None: + """Set up pipeline stages with proper dependency injection.""" + + self.add_stage(stage_name="input_validation_stage", + stage=InputValidationStage()) + + self.add_stage(stage_name="prompt_encoding_stage", + stage=TextEncodingStage( + text_encoders=[self.get_module("text_encoder")], + tokenizers=[self.get_module("tokenizer")], + )) + + self.add_stage(stage_name="conditioning_stage", + stage=ConditioningStage()) + + self.add_stage(stage_name="timestep_preparation_stage", + stage=TimestepPreparationStage( + scheduler=self.get_module("scheduler"))) + + self.add_stage(stage_name="latent_preparation_stage", + stage=LatentPreparationStage( + scheduler=self.get_module("scheduler"), + transformer=self.get_module("transformer", None))) + + self.add_stage(stage_name="denoising_stage", + stage=DenoisingStage( + transformer=self.get_module("transformer"), + scheduler=self.get_module("scheduler"), + pipeline=self)) + + self.add_stage(stage_name="decoding_stage", + stage=DecodingStage(vae=self.get_module("vae"), + pipeline=self)) + + +EntryClass = LTXPipeline diff --git a/fastvideo/tests/transformers/test_ltxvideo.py b/fastvideo/tests/transformers/test_ltxvideo.py new file mode 100644 index 000000000..03ffa0eda --- /dev/null +++ b/fastvideo/tests/transformers/test_ltxvideo.py @@ -0,0 +1,426 @@ +import os + +import numpy as np +import pytest +import torch +from diffusers import LTXVideoTransformer3DModel + +from fastvideo.configs.pipelines import PipelineConfig +from fastvideo.forward_context import set_forward_context +from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.logger import init_logger +from fastvideo.models.loader.component_loader import TransformerLoader +from fastvideo.utils import maybe_download_model +from fastvideo.configs.models.dits import LTXVideoConfig +from huggingface_hub import snapshot_download + +from fastvideo.pipelines.pipeline_batch_info import ForwardBatch + + +logger = init_logger(__name__) + +os.environ["MASTER_ADDR"] = "localhost" +os.environ["MASTER_PORT"] = "29503" + +# BASE_MODEL_PATH = "Lightricks/LTX-Video" +# MODEL_PATH = maybe_download_model(BASE_MODEL_PATH, +# local_dir=os.path.join( +# 'data', BASE_MODEL_PATH)) +# TRANSFORMER_PATH = os.path.join(MODEL_PATH, "transformer") + + +snapshot_download( + "Lightricks/LTX-Video", + local_dir="data/Lightricks/LTX-Video", + allow_patterns=[ + "vae/*.json", + "vae/*.safetensors", # Explicitly allow safetensors in vae + "transformer/*.json", + "transformer/*.safetensors", # Explicitly allow safetensors in transformer + "tokenizer/*", + "scheduler/*", + "*.json", + "README.md" + ] +) +#BASE_MODEL_PATH = "Lightricks/LTX-Video" +TRANSFORMER_PATH = "data/Lightricks/LTX-Video/transformer" +#TRANSFORMER_PATH = os.path.join(BASE_MODEL_PATH, "transformer") +print(f"TRANSFORMER_PATH {TRANSFORMER_PATH}") + + +def add_debug_hooks(model, model_name): + hooks = [] + for i, block in enumerate(model.transformer_blocks): + def make_hook(block_idx, name): + def hook(module, input, output): + if isinstance(output, tuple): + tensor = output[0] + else: + tensor = output + print(f"{name} Block {block_idx}: max={tensor.max():.6f}") + return hook + + hook = block.register_forward_hook(make_hook(i, model_name)) + hooks.append(hook) + return hooks + + + +@pytest.mark.usefixtures("distributed_setup") +def test_ltx_transformer(): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + precision = torch.bfloat16 + precision_str = "bf16" + args = FastVideoArgs( + model_path=TRANSFORMER_PATH, + dit_cpu_offload=True, + pipeline_config=PipelineConfig( + dit_config=LTXVideoConfig(), + dit_precision=precision_str + ) + ) + args.device = device + + loader = TransformerLoader() + model2 = loader.load(TRANSFORMER_PATH, args).to(dtype=precision) + + model1 = LTXVideoTransformer3DModel.from_pretrained( + TRANSFORMER_PATH, device=device, + torch_dtype=precision).to(device, dtype=precision).requires_grad_(False) + + total_params = sum(p.numel() for p in model1.parameters()) + # Calculate weight sum for model1 (converting to float64 to avoid overflow) + weight_sum_model1 = sum( + p.to(torch.float64).sum().item() for p in model1.parameters()) + # Also calculate mean for more stable comparison + weight_mean_model1 = weight_sum_model1 / total_params + logger.info("Model 1 weight sum: %s", weight_sum_model1) + logger.info("Model 1 weight mean: %s", weight_mean_model1) + + # Calculate weight sum for model2 (converting to float64 to avoid overflow) + total_params_model2 = sum(p.numel() for p in model2.parameters()) + weight_sum_model2 = sum( + p.to(torch.float64).sum().item() for p in model2.parameters()) + # Also calculate mean for more stable comparison + weight_mean_model2 = weight_sum_model2 / total_params_model2 + logger.info("Model 2 weight sum: %s", weight_sum_model2) + logger.info("Model 2 weight mean: %s", weight_mean_model2) + + weight_sum_diff = abs(weight_sum_model1 - weight_sum_model2) + logger.info("Weight sum difference: %s", weight_sum_diff) + weight_mean_diff = abs(weight_mean_model1 - weight_mean_model2) + logger.info("Weight mean difference: %s", weight_mean_diff) + + # Set both models to eval mode + model1 = model1.eval() + model2 = model2.eval() + + # Create identical inputs for both models + batch_size = 1 + seq_len = 128 + + # Video latents [B, C, T, H, W] + # LTX uses 128 latent channels + hidden_states = torch.randn(batch_size, + 128, # LTX latent channels + 5, # temporal dimension (17 frames -> 5 latent frames) + 64, # height (256 -> 64 in latent space) + 64, # width (256 -> 64 in latent space) + device=device, + dtype=precision) + + # Text embeddings [B, L, D] + # LTX uses 4096 dimensional embeddings + encoder_hidden_states = torch.randn(batch_size, + seq_len, + 4096, # LTX embedding dimension + device=device, + dtype=precision) + + # Timestep + timestep = torch.tensor([500], device=device, dtype=torch.long) + + # Create attention mask for LTX + encoder_attention_mask = torch.ones(batch_size, seq_len, device=device, dtype=precision) + + forward_batch = ForwardBatch( + data_type="dummy", + ) + + + # TODO: clean this up + hidden_states_5d = torch.randn(batch_size, 128, 5, 64, 64, device=device, dtype=precision) + # Prepare reshaped version for model1 (original LTX) + hidden_states_3d = hidden_states_5d.permute(0, 2, 3, 4, 1).reshape(batch_size, -1, 128) + + # Add hooks to both models + hooks1 = add_debug_hooks(model1, "Model1") + hooks2 = add_debug_hooks(model2, "Model2") + + + with torch.amp.autocast('cuda', dtype=precision): + # LTX transformer expects different arguments + output1 = model1( + hidden_states=hidden_states_3d, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + encoder_attention_mask=encoder_attention_mask, + num_frames=5, + height=64, + width=64, + return_dict=False, + )[0] + + with set_forward_context( + current_timestep=0, + attn_metadata=None, + forward_batch=forward_batch, + ): + output2 = model2( + hidden_states=hidden_states_3d, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + encoder_attention_mask=encoder_attention_mask, + num_frames=5, # Same actual dimensions + height=64, + width=64, + return_dict=False, + + )[0] + + + + # Remove hooks when done + for hook in hooks1 + hooks2: + hook.remove() + + + def compare_models_layer_by_layer(model1, model2, hidden_states, encoder_hidden_states, timestep, encoder_attention_mask, num_frames, height, width): + """Compare outputs of each transformer block between two models.""" + + batch_size = hidden_states.size(0) + + with torch.no_grad(): + # Initial processing comparison + h1 = model1.proj_in(hidden_states) + h2 = model2.proj_in(hidden_states) + print(f"After proj_in: {(h1 - h2).abs().max():.6f}") + + # Time embedding comparison + temb1, embedded_timestep1 = model1.time_embed(timestep.flatten(), batch_size=batch_size, hidden_dtype=h1.dtype) + temb2, embedded_timestep2 = model2.time_embed(timestep.flatten(), batch_size=batch_size, hidden_dtype=h2.dtype) + print(f"Time embed diff: {(temb1 - temb2).abs().max():.6f}") + print(f"Embedded timestep diff: {(embedded_timestep1 - embedded_timestep2).abs().max():.6f}") + + temb1 = temb1.view(batch_size, -1, temb1.size(-1)) + temb2 = temb2.view(batch_size, -1, temb2.size(-1)) + + # Rotary embedding comparison + image_rotary_emb1 = model1.rope(hidden_states, num_frames, height, width) + image_rotary_emb2 = model2.rope(hidden_states, num_frames, height, width) + print(f"Rotary cos diff: {(image_rotary_emb1[0] - image_rotary_emb2[0]).abs().max():.6f}") + print(f"Rotary sin diff: {(image_rotary_emb1[1] - image_rotary_emb2[1]).abs().max():.6f}") + + # Caption projection comparison + enc_h1 = model1.caption_projection(encoder_hidden_states) + enc_h2 = model2.caption_projection(encoder_hidden_states) + enc_h1 = enc_h1.view(batch_size, -1, h1.size(-1)) + enc_h2 = enc_h2.view(batch_size, -1, h2.size(-1)) + print(f"Caption projection diff: {(enc_h1 - enc_h2).abs().max():.6f}") + + print("\n" + "="*80) + layer_diffs = [] + + for i, (block1, block2) in enumerate(zip(model1.transformer_blocks, model2.transformer_blocks)): + print(f"\nBlock {i}:") + print("-"*40) + + # Compare scale_shift_table + if i == 0: # Only check first block to avoid spam + table_diff = (block1.scale_shift_table - block2.scale_shift_table).abs().max() + print(f" Scale-shift table diff: {table_diff:.6f}") + + # Compute ada values before block + num_ada_params = block1.scale_shift_table.shape[0] + ada_values1 = block1.scale_shift_table[None, None] + temb1.reshape(batch_size, temb1.size(1), num_ada_params, -1) + ada_values2 = block2.scale_shift_table[None, None] + temb2.reshape(batch_size, temb2.size(1), num_ada_params, -1) + print(f" Ada values diff: {(ada_values1 - ada_values2).abs().max():.6f}") + + # Extract individual ada components + shift_msa1, scale_msa1, gate_msa1, shift_mlp1, scale_mlp1, gate_mlp1 = ada_values1.unbind(dim=2) + shift_msa2, scale_msa2, gate_msa2, shift_mlp2, scale_mlp2, gate_mlp2 = ada_values2.unbind(dim=2) + + print(f" MSA scale diff: {(scale_msa1 - scale_msa2).abs().max():.6f}") + print(f" MSA shift diff: {(shift_msa1 - shift_msa2).abs().max():.6f}") + print(f" MSA gate diff: {(gate_msa1 - gate_msa2).abs().max():.6f}") + + # Pre-block hidden states + print(f" Hidden states before block: {(h1 - h2).abs().max():.6f}") + + # Norm1 + norm_h1 = block1.norm1(h1) + norm_h2 = block2.norm1(h2) + print(f" After norm1: {(norm_h1 - norm_h2).abs().max():.6f}") + + # Apply ada modulation to norm + norm_h1_ada = norm_h1 * (1 + scale_msa1) + shift_msa1 + norm_h2_ada = norm_h2 * (1 + scale_msa2) + shift_msa2 + print(f" After ada modulation: {(norm_h1_ada - norm_h2_ada).abs().max():.6f}") + + # Self-attention + attn_out1 = block1.attn1(norm_h1_ada, encoder_hidden_states=None, image_rotary_emb=image_rotary_emb1) + attn_out2 = block2.attn1(norm_h2_ada, encoder_hidden_states=None, image_rotary_emb=image_rotary_emb2) + print(f" After self-attention: {(attn_out1 - attn_out2).abs().max():.6f}") + + # After gated residual + h1_after_sa = h1 + attn_out1 * gate_msa1 + h2_after_sa = h2 + attn_out2 * gate_msa2 + print(f" After self-attn residual: {(h1_after_sa - h2_after_sa).abs().max():.6f}") + + # Cross-attention + cross_attn1 = block1.attn2(h1_after_sa, encoder_hidden_states=enc_h1, attention_mask=encoder_attention_mask) + cross_attn2 = block2.attn2(h2_after_sa, encoder_hidden_states=enc_h2, attention_mask=encoder_attention_mask) + print(f" After cross-attention: {(cross_attn1 - cross_attn2).abs().max():.6f}") + + # After cross-attention residual + h1_after_ca = h1_after_sa + cross_attn1 + h2_after_ca = h2_after_sa + cross_attn2 + print(f" After cross-attn residual: {(h1_after_ca - h2_after_ca).abs().max():.6f}") + + # Norm2 and FF + norm2_h1 = block1.norm2(h1_after_ca) * (1 + scale_mlp1) + shift_mlp1 + norm2_h2 = block2.norm2(h2_after_ca) * (1 + scale_mlp2) + shift_mlp2 + print(f" After norm2 + ada: {(norm2_h1 - norm2_h2).abs().max():.6f}") + + ff_out1 = block1.ff(norm2_h1) + ff_out2 = block2.ff(norm2_h2) + print(f" After feedforward: {(ff_out1 - ff_out2).abs().max():.6f}") + + # Final output + h1 = h1_after_ca + ff_out1 * gate_mlp1 + h2 = h2_after_ca + ff_out2 * gate_mlp2 + print(f" Final block output: {(h1 - h2).abs().max():.6f}") + + layer_diffs.append((h1 - h2).abs().max().item()) + print("\n" + "="*80) + print(f"Maximum difference across all blocks: {max(layer_diffs):.6f}") + print(f"Average per-layer difference: {sum(layer_diffs)/len(layer_diffs):.6f}") + + return layer_diffs + layer_diffs = compare_models_layer_by_layer( + model1, model2, hidden_states_3d, encoder_hidden_states, + timestep, encoder_attention_mask, 5, 64, 64 + ) + + print("Block 0 scale_shift_table comparison:") + print(f"Model1: {model1.transformer_blocks[0].scale_shift_table}") + print(f"Model2: {model2.transformer_blocks[0].scale_shift_table}") + # Check if outputs have the same shape + assert output1.shape == output2.shape, f"Output shapes don't match: {output1.shape} vs {output2.shape}" + assert output1.dtype == output2.dtype, f"Output dtype don't match: {output1.dtype} vs {output2.dtype}" + + # Check if outputs are similar (allowing for small numerical differences) + max_diff = torch.max(torch.abs(output1 - output2)) + mean_diff = torch.mean(torch.abs(output1 - output2)) + logger.info(f"Output1 shape: {output1.shape}, dtype {output1.dtype}, range: [{output1.min().item():.4f}, {output1.max().item():.4f}]") + logger.info(f"Output2 shape: {output2.shape}, dtype {output2.dtype}, range: [{output2.min().item():.4f}, {output2.max().item():.4f}]") + + logger.info("Max Diff: %s", max_diff.item()) + logger.info("Mean Diff: %s", mean_diff.item()) + assert max_diff < 1e-1, f"Maximum difference between outputs: {max_diff.item()}" + # mean diff + assert mean_diff < 1e-2, f"Mean difference between outputs: {mean_diff.item()}" + + +# TODO: try testing this +# @pytest.mark.usefixtures("distributed_setup") +# def test_ltx_transformer_rope_interpolation(): +# """Test LTX transformer with rope interpolation for different resolutions.""" +# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +# precision = torch.bfloat16 +# precision_str = "bf16" +# args = FastVideoArgs( +# model_path=TRANSFORMER_PATH, +# dit_cpu_offload=False, +# pipeline_config=PipelineConfig( +# dit_config=LTXVideoConfig(), +# dit_precision=precision_str +# ) +# ) +# args.device = device + +# loader = TransformerLoader() +# model2 = loader.load(TRANSFORMER_PATH, args).to(dtype=precision) + +# model1 = LTXVideoTransformer3DModel.from_pretrained( +# TRANSFORMER_PATH, device=device, +# torch_dtype=precision).to(device, dtype=precision).requires_grad_(False) + +# model1 = model1.eval() +# model2 = model2.eval() + +# batch_size = 1 +# seq_len = 128 + +# # Test with different resolutions (rope interpolation) +# # Higher resolution than training +# hidden_states = torch.randn(batch_size, +# 128, +# 5, +# 96, # larger height +# 96, # larger width +# device=device, +# dtype=precision) + +# encoder_hidden_states = torch.randn(batch_size, +# seq_len, +# 4096, +# device=device, +# dtype=precision) + +# timestep = torch.tensor([500], device=device, dtype=torch.long) +# encoder_attention_mask = torch.ones(batch_size, seq_len, device=device, dtype=precision) + +# # Test with rope interpolation scale +# rope_interpolation_scale = (1.0, 1.5, 1.5) # temporal, height, width scales + +# forward_batch = ForwardBatch( +# data_type="dummy", +# ) + +# with torch.amp.autocast('cuda', dtype=precision): +# output1 = model1( +# hidden_states=hidden_states, +# encoder_hidden_states=encoder_hidden_states, +# timestep=timestep, +# encoder_attention_mask=encoder_attention_mask, +# num_frames=5, +# height=96, +# width=96, +# rope_interpolation_scale=rope_interpolation_scale, +# return_dict=False, +# )[0] + +# with set_forward_context( +# current_timestep=0, +# attn_metadata=None, +# forward_batch=forward_batch, +# ): +# output2 = model2( +# hidden_states=hidden_states, +# encoder_hidden_states=encoder_hidden_states, +# timestep=timestep, +# encoder_attention_mask=encoder_attention_mask, +# num_frames=5, +# height=96, +# width=96, +# rope_interpolation_scale=rope_interpolation_scale, +# ) + +# # Check outputs +# assert output1.shape == output2.shape, f"Output shapes don't match: {output1.shape} vs {output2.shape}" + +# max_diff = torch.max(torch.abs(output1 - output2)) +# logger.info("Max diff with rope interpolation: %s", max_diff.item()) +# assert max_diff < 1e-1, f"Outputs differ with rope interpolation: {max_diff.item()}" \ No newline at end of file diff --git a/fastvideo/tests/vaes/test_ltx_vae.py b/fastvideo/tests/vaes/test_ltx_vae.py new file mode 100644 index 000000000..079b0feaf --- /dev/null +++ b/fastvideo/tests/vaes/test_ltx_vae.py @@ -0,0 +1,199 @@ +import os + +import numpy as np +import pytest +import torch +from diffusers import AutoencoderKLLTXVideo + +from fastvideo.configs.pipelines import PipelineConfig +from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.logger import init_logger +from fastvideo.models.loader.component_loader import VAELoader +from fastvideo.configs.models.vaes import LTXVAEConfig +from fastvideo.utils import maybe_download_model +from huggingface_hub import snapshot_download + + +logger = init_logger(__name__) + +os.environ["MASTER_ADDR"] = "localhost" +os.environ["MASTER_PORT"] = "29503" + +BASE_MODEL_PATH = "Lightricks/LTX-Video" + + +snapshot_download( + "Lightricks/LTX-Video", + local_dir="data/Lightricks/LTX-Video", + allow_patterns=[ + "vae/*.json", + "vae/*.safetensors", # Explicitly allow safetensors in vae + "transformer/*.json", + "transformer/*.safetensors", # Explicitly allow safetensors in transformer + "tokenizer/*", + "scheduler/*", + "*.json", + "README.md" + ] +) + +MODEL_PATH = "data/Lightricks/LTX-Video" +VAE_PATH = os.path.join(MODEL_PATH, "vae") +print(f"VAE_PATH {VAE_PATH}") + + +@pytest.mark.usefixtures("distributed_setup") +def test_ltx_vae(): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + precision = torch.bfloat16 + precision_str = "bf16" + args = FastVideoArgs( + model_path=VAE_PATH, + pipeline_config=PipelineConfig( + vae_config=LTXVAEConfig(), + vae_precision=precision_str + ) + ) + args.device = device + args.vae_cpu_offload = False + + loader = VAELoader() + #model2 is the one i implemented + model2 = loader.load(VAE_PATH, args) + + model1 = AutoencoderKLLTXVideo.from_pretrained( + VAE_PATH, torch_dtype=precision).to(device).eval() + + # Create identical inputs for both models + batch_size = 1 + + model1_decoder_keys = [k for k in model1.state_dict().keys() if 'decoder' in k] + model2_decoder_keys = [k for k in model2.state_dict().keys() if 'decoder' in k] + + logger.info(f"Model1 decoder keys sample: {model1_decoder_keys[:5]}") + logger.info(f"Model2 decoder keys sample: {model2_decoder_keys[:5]}") + + # Check if keys match + missing_in_model2 = set(model1_decoder_keys) - set(model2_decoder_keys) + extra_in_model2 = set(model2_decoder_keys) - set(model1_decoder_keys) + + if missing_in_model2: + logger.warning(f"Keys in model1 but not model2: {list(missing_in_model2)[:5]}") + if extra_in_model2: + logger.warning(f"Keys in model2 but not model1: {list(extra_in_model2)[:5]}") + + # Video input [B, C, T, H, W] + input_tensor = torch.randn(batch_size, + 3, + 17, # 17 frames + 256, # Height + 256, # Width + device=device, + dtype=precision) + + # Disable gradients for inference + with torch.no_grad(): + # Test encoding + logger.info("Testing encoding...") + latent1_dist = model1.encode(input_tensor).latent_dist + latent1 = latent1_dist.mean + + logger.info("FastVideo encoding...") + latent2_dist = model2.encode(input_tensor).latent_dist + latent2 = latent2_dist.mean + + # Check if latents have the same shape + assert latent1.shape == latent2.shape, f"Latent shapes don't match: {latent1.shape} vs {latent2.shape}" + + # Check if latents are similar + max_diff_encode = torch.max(torch.abs(latent1 - latent2)) + mean_diff_encode = torch.mean(torch.abs(latent1 - latent2)) + logger.info("Maximum difference between encoded latents: %s", + max_diff_encode.item()) + logger.info("Mean difference between encoded latents: %s", + mean_diff_encode.item()) + assert max_diff_encode < 1e-5, f"Encoded latents differ significantly: max diff = {max_diff_encode.item()}" + + # Test decoding + logger.info("Testing decoding...") + + # For LTX, we need to use the mode of the distribution + latent1_tensor = latent1_dist.mode() + latent2_tensor = latent2_dist.mode() + + latent_diff = torch.max(torch.abs(latent1_tensor - latent2_tensor)) + logger.info(f"Latent difference before decoding: {latent_diff.item()}") + + if hasattr(model1.config, 'scaling_factor'): + latent1_tensor = latent1_tensor * model1.config.scaling_factor + print(f"model1.config.scaling_factor{model1.config.scaling_factor}") + + if hasattr(model2.config.arch_config, 'scaling_factor'): + latent2_tensor = latent2_tensor * model2.config.arch_config.scaling_factor + print(f"model2.config.arch_config.scaling_factor{model2.config.arch_config.scaling_factor}") + output1 = model1.decode(latent1_tensor).sample + output2 = model2.decode(latent2_tensor).sample + logger.info(f"Output1 shape: {output1.shape}, range: [{output1.min().item():.4f}, {output1.max().item():.4f}]") + logger.info(f"Output2 shape: {output2.shape}, range: [{output2.min().item():.4f}, {output2.max().item():.4f}]") + + + # Check if outputs have the same shape + assert output1.shape == output2.shape, f"Output shapes don't match: {output1.shape} vs {output2.shape}" + + # Check if outputs are similar + max_diff_decode = torch.max(torch.abs(output1 - output2)) + mean_diff_decode = torch.mean(torch.abs(output1 - output2)) + logger.info("Maximum difference between decoded outputs: %s", + max_diff_decode.item()) + logger.info("Mean difference between decoded outputs: %s", + mean_diff_decode.item()) + assert max_diff_decode < 1e-5, f"Decoded outputs differ significantly: max diff = {max_diff_decode.item()}" + + +# #TODO: test this +# @pytest.mark.usefixtures("distributed_setup") +# def test_ltx_vae_tiling(): +# """Test LTX VAE with tiling enabled for large videos.""" +# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +# precision = torch.bfloat16 +# precision_str = "bf16" +# args = FastVideoArgs( +# model_path=VAE_PATH, +# pipeline_config=PipelineConfig( +# vae_config=LTXVAEConfig(), +# vae_precision=precision_str +# ) +# ) +# args.device = device +# args.vae_cpu_offload = False + +# loader = VAELoader() +# model2 = loader.load(VAE_PATH, args) + +# model1 = AutoencoderKLLTXVideo.from_pretrained( +# VAE_PATH, torch_dtype=precision).to(device).eval() + +# # Enable tiling for both models +# model1.enable_tiling() +# if hasattr(model2, 'enable_tiling'): +# model2.enable_tiling() + +# # Create larger input that requires tiling +# batch_size = 1 +# input_tensor = torch.randn(batch_size, +# 3, +# 17, # frames +# 768, # height (larger than default tile size) +# 768, # width (larger than default tile size) +# device=device, +# dtype=precision) + +# # Test with tiling +# with torch.no_grad(): +# logger.info("Testing tiled encoding...") +# latent1 = model1.encode(input_tensor).latent_dist.mean +# latent2 = model2.encode(input_tensor).mean + +# max_diff = torch.max(torch.abs(latent1 - latent2)) +# logger.info("Max difference with tiling: %s", max_diff.item()) +# assert max_diff < 1e-4, f"Tiled encoding differs: max diff = {max_diff.item()}" \ No newline at end of file From f659ed957a8f2f92b0ebe0204588828564969ecf Mon Sep 17 00:00:00 2001 From: Aarti Lalwani Date: Mon, 8 Sep 2025 15:59:55 -0700 Subject: [PATCH 2/6] remove unnecessary func to fix build error --- fastvideo/pipelines/basic/ltxvideo/ltx_pipeline.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/fastvideo/pipelines/basic/ltxvideo/ltx_pipeline.py b/fastvideo/pipelines/basic/ltxvideo/ltx_pipeline.py index 58e42add8..e98d62a4d 100644 --- a/fastvideo/pipelines/basic/ltxvideo/ltx_pipeline.py +++ b/fastvideo/pipelines/basic/ltxvideo/ltx_pipeline.py @@ -28,10 +28,6 @@ class LTXPipeline(ComposedPipelineBase): "text_encoder", "tokenizer", "vae", "transformer", "scheduler" ] - @property - def required_config_modules(self) -> List[str]: - return self._required_config_modules - def initialize_pipeline(self, fastvideo_args: FastVideoArgs): """Initialize pipeline-specific components.""" pass From 34cef9e0c6d8caee3409a7ac0aca54547d0a7c8b Mon Sep 17 00:00:00 2001 From: Aarti Lalwani Date: Mon, 8 Sep 2025 20:14:48 -0700 Subject: [PATCH 3/6] fix build --- fastvideo/configs/models/vaes/ltxvae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastvideo/configs/models/vaes/ltxvae.py b/fastvideo/configs/models/vaes/ltxvae.py index 18b1e0887..160ad6774 100644 --- a/fastvideo/configs/models/vaes/ltxvae.py +++ b/fastvideo/configs/models/vaes/ltxvae.py @@ -30,7 +30,7 @@ def __post_init__(self): # Spatial compression is usually patch_size * product of spatial downsampling self.spatial_compression_ratio = self.patch_size * (2 ** (len(self.block_out_channels) - 1)) - if isinstance(self.scaling_factor, (int, float)): + if isinstance(self.scaling_factor, int | float): self.scaling_factor_tensor: torch.Tensor = torch.tensor(self.scaling_factor) From bae8bf214966e5a830d520895a148767e96903bf Mon Sep 17 00:00:00 2001 From: Aarti Lalwani Date: Tue, 9 Sep 2025 13:58:14 -0700 Subject: [PATCH 4/6] pre commit formatting --- fastvideo/configs/models/dits/__init__.py | 7 ++--- fastvideo/configs/models/dits/ltxvideo.py | 27 +++++++++---------- fastvideo/configs/models/vaes/__init__.py | 3 +-- fastvideo/configs/models/vaes/ltxvae.py | 21 +++++++++------ .../pipelines/basic/ltxvideo/ltx_pipeline.py | 3 +-- 5 files changed, 31 insertions(+), 30 deletions(-) diff --git a/fastvideo/configs/models/dits/__init__.py b/fastvideo/configs/models/dits/__init__.py index 0f6c7694a..73594a816 100644 --- a/fastvideo/configs/models/dits/__init__.py +++ b/fastvideo/configs/models/dits/__init__.py @@ -1,7 +1,8 @@ from fastvideo.configs.models.dits.hunyuanvideo import HunyuanVideoConfig +from fastvideo.configs.models.dits.ltxvideo import LTXVideoConfig from fastvideo.configs.models.dits.stepvideo import StepVideoConfig from fastvideo.configs.models.dits.wanvideo import WanVideoConfig -from fastvideo.configs.models.dits.ltxvideo import LTXVideoConfig - -__all__ = ["HunyuanVideoConfig", "WanVideoConfig", "StepVideoConfig","LTXVideoConfig"] +__all__ = [ + "HunyuanVideoConfig", "WanVideoConfig", "StepVideoConfig", "LTXVideoConfig" +] diff --git a/fastvideo/configs/models/dits/ltxvideo.py b/fastvideo/configs/models/dits/ltxvideo.py index 4719fc118..de16893c4 100644 --- a/fastvideo/configs/models/dits/ltxvideo.py +++ b/fastvideo/configs/models/dits/ltxvideo.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass, field + from fastvideo.configs.models.dits.base import DiTArchConfig, DiTConfig @@ -8,7 +9,7 @@ class LTXVideoArchConfig(DiTArchConfig): fsdp_shard_conditions: list = field( default_factory=lambda: [lambda n, m: "transformer_blocks" in n and n.split(".")[-1].isdigit()]) - + # Parameter name mappings for loading pretrained weights from HuggingFace/Diffusers # Maps from source model parameter names to FastVideo LTX implementation names param_names_mapping: dict = field( @@ -18,8 +19,7 @@ class LTXVideoArchConfig(DiTArchConfig): r"transformer_blocks.\1.norm1.weight", r"^transformer_blocks\.(\d+)\.norm2\.weight$": r"transformer_blocks.\1.norm2.weight", - - + # FeedForward network mappings (double check) r"^transformer_blocks\.(\d+)\.ff\.net\.0\.weight$": r"transformer_blocks.\1.ff.net.0.weight", @@ -30,32 +30,29 @@ class LTXVideoArchConfig(DiTArchConfig): r"^transformer_blocks\.(\d+)\.ff\.net\.3\.bias$": r"transformer_blocks.\1.ff.net.3.bias", - # Scale-shift table for adaptive layer norm r"^transformer_blocks\.(\d+)\.scale_shift_table$": r"transformer_blocks.\1.scale_shift_table", - + # Time embedding mappings r"^time_embed\.emb\.timestep_embedder\.linear_1\.(weight|bias)$": r"time_embed.emb.mlp.fc_in.\1", r"^time_embed\.emb\.timestep_embedder\.linear_2\.(weight|bias)$": r"time_embed.emb.mlp.fc_out.\1", - + # Caption projection mappings r"^caption_projection\.linear_1\.(weight|bias)$": r"caption_projection.fc_in.\1", r"^caption_projection\.linear_2\.(weight|bias)$": r"caption_projection.fc_out.\1", - + # Output normalization (FP32LayerNorm) - r"^norm_out\.weight$": - r"norm_out.weight", - + r"^norm_out\.weight$": r"norm_out.weight", + # Global scale-shift table - r"^scale_shift_table$": - r"scale_shift_table", + r"^scale_shift_table$": r"scale_shift_table", }) - + num_attention_heads: int = 32 attention_head_dim: int = 64 in_channels: int = 128 @@ -75,7 +72,7 @@ class LTXVideoArchConfig(DiTArchConfig): attention_type: str | None = "torch" use_additional_conditions: bool | None = False exclude_lora_layers: list[str] = field(default_factory=lambda: []) - + def __post_init__(self): self.hidden_size = self.num_attention_heads * self.attention_head_dim self.out_channels = self.in_channels if self.out_channels is None else self.out_channels @@ -85,4 +82,4 @@ def __post_init__(self): @dataclass class LTXVideoConfig(DiTConfig): arch_config: DiTArchConfig = field(default_factory=LTXVideoArchConfig) - prefix: str = "LTXVideo" \ No newline at end of file + prefix: str = "LTXVideo" diff --git a/fastvideo/configs/models/vaes/__init__.py b/fastvideo/configs/models/vaes/__init__.py index bb5f4d464..7c3ab085a 100644 --- a/fastvideo/configs/models/vaes/__init__.py +++ b/fastvideo/configs/models/vaes/__init__.py @@ -1,8 +1,7 @@ from fastvideo.configs.models.vaes.hunyuanvae import HunyuanVAEConfig +from fastvideo.configs.models.vaes.ltxvae import LTXVAEConfig from fastvideo.configs.models.vaes.stepvideovae import StepVideoVAEConfig from fastvideo.configs.models.vaes.wanvae import WanVAEConfig -from fastvideo.configs.models.vaes.ltxvae import LTXVAEConfig - __all__ = [ "HunyuanVAEConfig", diff --git a/fastvideo/configs/models/vaes/ltxvae.py b/fastvideo/configs/models/vaes/ltxvae.py index 160ad6774..a41ea3a85 100644 --- a/fastvideo/configs/models/vaes/ltxvae.py +++ b/fastvideo/configs/models/vaes/ltxvae.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass, field + import torch + from fastvideo.configs.models.vaes.base import VAEArchConfig, VAEConfig @@ -18,20 +20,22 @@ class LTXVAEArchConfig(VAEArchConfig): resnet_norm_eps: float = 1e-06 scaling_factor: float = 1.0 spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, False) - + # Additional fields that might be inherited from base class z_dim: int = 128 # Using latent_channels as z_dim is_residual: bool = False clip_output: bool = True - + def __post_init__(self): # Calculate compression ratios based on patch sizes and downsampling self.temporal_compression_ratio = self.patch_size_t # Spatial compression is usually patch_size * product of spatial downsampling - self.spatial_compression_ratio = self.patch_size * (2 ** (len(self.block_out_channels) - 1)) - + self.spatial_compression_ratio = self.patch_size * (2**( + len(self.block_out_channels) - 1)) + if isinstance(self.scaling_factor, int | float): - self.scaling_factor_tensor: torch.Tensor = torch.tensor(self.scaling_factor) + self.scaling_factor_tensor: torch.Tensor = torch.tensor( + self.scaling_factor) @dataclass @@ -41,8 +45,9 @@ class LTXVAEConfig(VAEConfig): use_tiling: bool = False use_temporal_tiling: bool = False use_parallel_tiling: bool = False - + def __post_init__(self): - if hasattr(self, 'tile_sample_min_num_frames') and hasattr(self, 'tile_sample_stride_num_frames'): + if hasattr(self, 'tile_sample_min_num_frames') and hasattr( + self, 'tile_sample_stride_num_frames'): self.blend_num_frames = (self.tile_sample_min_num_frames - - self.tile_sample_stride_num_frames) * 2 \ No newline at end of file + self.tile_sample_stride_num_frames) * 2 diff --git a/fastvideo/pipelines/basic/ltxvideo/ltx_pipeline.py b/fastvideo/pipelines/basic/ltxvideo/ltx_pipeline.py index e98d62a4d..df7b643eb 100644 --- a/fastvideo/pipelines/basic/ltxvideo/ltx_pipeline.py +++ b/fastvideo/pipelines/basic/ltxvideo/ltx_pipeline.py @@ -8,11 +8,10 @@ from fastvideo.fastvideo_args import FastVideoArgs from fastvideo.logger import init_logger - from fastvideo.pipelines import ComposedPipelineBase from fastvideo.pipelines.stages import (ConditioningStage, DecodingStage, DenoisingStage, InputValidationStage, - LatentPreparationStage,CLIPImageEncodingStage, + LatentPreparationStage, TextEncodingStage, TimestepPreparationStage) From b0455606f4c7b7fe2ff6ad60a0c1d1e794ce6465 Mon Sep 17 00:00:00 2001 From: Aarti Lalwani Date: Thu, 11 Sep 2025 18:45:15 -0700 Subject: [PATCH 5/6] add the vae code --- fastvideo/models/vaes/ltxvae.py | 1434 +++++++++++++++++++++++++++++++ 1 file changed, 1434 insertions(+) create mode 100644 fastvideo/models/vaes/ltxvae.py diff --git a/fastvideo/models/vaes/ltxvae.py b/fastvideo/models/vaes/ltxvae.py new file mode 100644 index 000000000..bd0f2d814 --- /dev/null +++ b/fastvideo/models/vaes/ltxvae.py @@ -0,0 +1,1434 @@ +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from fastvideo.configs.models.vaes import LTXVAEConfig + + +from fastvideo.layers.layernorm import RMSNorm, FP32LayerNorm, LayerNormScaleShift +from fastvideo.layers.activation import get_act_fn +from fastvideo.layers.layernorm import ScaleResidual +from fastvideo.layers.activation import QuickGELU +from fastvideo.layers.vocab_parallel_embedding import VocabParallelEmbedding, UnquantizedEmbeddingMethod + +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.loaders import FromOriginalModelMixin +from diffusers.configuration_utils import ConfigMixin + +from diffusers.utils.accelerate_utils import apply_forward_hook +from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.modeling_utils import ModelMixin +from fastvideo.models.vaes.common import ParallelTiledVAE + +from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution + + +class FastVideoLTXCausalConv3d(nn.Module): + """FastVideo 3D causal convolution for LTX VAE.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]] = 3, + stride: Union[int, Tuple[int, int, int]] = 1, + dilation: Union[int, Tuple[int, int, int]] = 1, + groups: int = 1, + padding_mode: str = "zeros", + is_causal: bool = True, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.is_causal = is_causal + self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size) + + dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1) + stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + height_pad = self.kernel_size[1] // 2 + width_pad = self.kernel_size[2] // 2 + padding = (0, height_pad, width_pad) + + # TODO: optimize convolution? + # TODO: Add parallelism across output channels + self.conv = nn.Conv3d( + in_channels, + out_channels, + self.kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + padding=padding, + padding_mode=padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + time_kernel_size = self.kernel_size[0] + + if self.is_causal: + pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, time_kernel_size - 1, 1, 1)) + hidden_states = torch.concatenate([pad_left, hidden_states], dim=2) + else: + pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1)) + pad_right = hidden_states[:, :, -1:, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1)) + hidden_states = torch.concatenate([pad_left, hidden_states, pad_right], dim=2) + + hidden_states = self.conv(hidden_states) + return hidden_states + + +class FastVideoLTXResnetBlock3d(nn.Module): + r""" + FastVideo 3D ResNet block for LTX VAE. + + TODO: fix optimizations + - QuickGELU/FastSiLU activation + - TOOD: add more parallelism + """ + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + eps: float = 1e-6, + elementwise_affine: bool = False, + non_linearity: str = "swish", + is_causal: bool = True, + inject_noise: bool = False, + timestep_conditioning: bool = False, + ) -> None: + super().__init__() + + out_channels = out_channels or in_channels + + # Use FastVideo activation functions + if non_linearity == "swish": + self.nonlinearity = get_act_fn("silu") + elif non_linearity == "gelu": + self.nonlinearity = get_act_fn("quick_gelu") + else: + + self.nonlinearity = get_act_fn(non_linearity) + + # Use FastVideo RMSNorm + self.norm1 = RMSNorm(in_channels, eps=1e-8, has_weight=elementwise_affine) + + self.conv1 = FastVideoLTXCausalConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal + ) + + self.norm2 = RMSNorm(out_channels, eps=1e-8, has_weight=elementwise_affine) + + self.dropout = nn.Dropout(dropout) + self.conv2 = FastVideoLTXCausalConv3d( + in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal + ) + + self.norm3 = None + self.conv_shortcut = None + if in_channels != out_channels: + # Use FP32LayerNorm for better numerical stability in shortcut path + self.norm3 = FP32LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True) + self.conv_shortcut = FastVideoLTXCausalConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal + ) + + self.per_channel_scale1 = None + self.per_channel_scale2 = None + if inject_noise: + self.per_channel_scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1)) + self.per_channel_scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1)) + + self.scale_shift_table = None + if timestep_conditioning: + self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5) + + # # Use ScaleResidual for residual connection + # self.residual = ScaleResidual() + + def forward( + self, inputs: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None + ) -> torch.Tensor: + hidden_states = inputs + + hidden_states = self.norm1(hidden_states.movedim(1, -1)).movedim(-1, 1) + + if self.scale_shift_table is not None: + temb = temb.unflatten(1, (4, -1)) + self.scale_shift_table[None, ..., None, None, None] + shift_1, scale_1, shift_2, scale_2 = temb.unbind(dim=1) + hidden_states = hidden_states * (1 + scale_1) + shift_1 + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + if self.per_channel_scale1 is not None: + spatial_shape = hidden_states.shape[-2:] + spatial_noise = torch.randn( + spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype + )[None] + hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, ...] + + hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1) + + if self.scale_shift_table is not None: + hidden_states = hidden_states * (1 + scale_2) + shift_2 + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.per_channel_scale2 is not None: + spatial_shape = hidden_states.shape[-2:] + spatial_noise = torch.randn( + spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype + )[None] + hidden_states = hidden_states + (spatial_noise * self.per_channel_scale2)[None, :, None, ...] + + if self.norm3 is not None: + inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1) + + if self.conv_shortcut is not None: + inputs = self.conv_shortcut(inputs) + + hidden_states = hidden_states + inputs + return hidden_states + + +class FastVideoLTXDownsampler3d(nn.Module): + """FastVideo 3D downsampler for LTX VAE.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + stride: Union[int, Tuple[int, int, int]] = 1, + is_causal: bool = True, + padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + self.group_size = (in_channels * stride[0] * stride[1] * stride[2]) // out_channels + + out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2]) + + # TODO: Add parallelism ? + self.conv = FastVideoLTXCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + is_causal=is_causal, + padding_mode=padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2) + + residual = ( + hidden_states.unflatten(4, (-1, self.stride[2])) + .unflatten(3, (-1, self.stride[1])) + .unflatten(2, (-1, self.stride[0])) + ) + residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4) + residual = residual.unflatten(1, (-1, self.group_size)) + residual = residual.mean(dim=2) + + hidden_states = self.conv(hidden_states) + hidden_states = ( + hidden_states.unflatten(4, (-1, self.stride[2])) + .unflatten(3, (-1, self.stride[1])) + .unflatten(2, (-1, self.stride[0])) + ) + hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4) + hidden_states = hidden_states + residual + + return hidden_states + + +class FastVideoLTXUpsampler3d(nn.Module): + """FastVideo 3D upsampler for LTX VAE.""" + + def __init__( + self, + in_channels: int, + stride: Union[int, Tuple[int, int, int]] = 1, + is_causal: bool = True, + residual: bool = False, + upscale_factor: int = 1, + padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + self.residual = residual + self.upscale_factor = upscale_factor + + out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor + + # TODO: Add parallelism ? + self.conv = FastVideoLTXCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + is_causal=is_causal, + padding_mode=padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + if self.residual: + residual = hidden_states.reshape( + batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width + ) + residual = residual.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3) + repeats = (self.stride[0] * self.stride[1] * self.stride[2]) // self.upscale_factor + residual = residual.repeat(1, repeats, 1, 1, 1) + residual = residual[:, :, self.stride[0] - 1 :] + + hidden_states = self.conv(hidden_states) + hidden_states = hidden_states.reshape( + batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width + ) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3) + hidden_states = hidden_states[:, :, self.stride[0] - 1 :] + + if self.residual: + hidden_states = hidden_states + residual + + return hidden_states + + +class FastVideoLTXDownBlock3D(nn.Module): + r""" + FastVideo down block for LTX VAE. + + TODO: add more optimizations: + + - add parallelism + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + spatio_temporal_scale: bool = True, + is_causal: bool = True, + ): + super().__init__() + + out_channels = out_channels or in_channels + + resnets = [] + for _ in range(num_layers): + resnets.append( + FastVideoLTXResnetBlock3d( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.downsamplers = None + if spatio_temporal_scale: + self.downsamplers = nn.ModuleList( + [ + FastVideoLTXCausalConv3d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=(2, 2, 2), + is_causal=is_causal, + ) + ] + ) + + self.conv_out = None + if in_channels != out_channels: + self.conv_out = FastVideoLTXResnetBlock3d( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + else: + hidden_states = resnet(hidden_states, temb, generator) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + if self.conv_out is not None: + hidden_states = self.conv_out(hidden_states, temb, generator) + + return hidden_states + + +class FastVideoLTX095DownBlock3D(nn.Module): + r""" + FastVideo down block for LTX VAE. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + spatio_temporal_scale: bool = True, + is_causal: bool = True, + downsample_type: str = "conv", + ): + super().__init__() + + out_channels = out_channels or in_channels + + resnets = [] + for _ in range(num_layers): + resnets.append( + FastVideoLTXResnetBlock3d( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.downsamplers = None + if spatio_temporal_scale: + self.downsamplers = nn.ModuleList() + + if downsample_type == "conv": + self.downsamplers.append( + FastVideoLTXCausalConv3d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=(2, 2, 2), + is_causal=is_causal, + ) + ) + elif downsample_type == "spatial": + self.downsamplers.append( + FastVideoLTXDownsampler3d( + in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), is_causal=is_causal + ) + ) + elif downsample_type == "temporal": + self.downsamplers.append( + FastVideoLTXDownsampler3d( + in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), is_causal=is_causal + ) + ) + elif downsample_type == "spatiotemporal": + self.downsamplers.append( + FastVideoLTXDownsampler3d( + in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), is_causal=is_causal + ) + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + else: + hidden_states = resnet(hidden_states, temb, generator) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class FastVideoLTXMidBlock3d(nn.Module): + r""" + FastVideo middle block for LTX VAE. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + is_causal: bool = True, + inject_noise: bool = False, + timestep_conditioning: bool = False, + ) -> None: + super().__init__() + + self.time_embedder = None + if timestep_conditioning: + # TODO: VocabParallelEmbedding? + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0) + + resnets = [] + for _ in range(num_layers): + resnets.append( + FastVideoLTXResnetBlock3d( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1) + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + else: + hidden_states = resnet(hidden_states, temb, generator) + + return hidden_states + + +class FastVideoLTXUpBlock3d(nn.Module): + r""" + FastVideo up block for LTX VAE. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + spatio_temporal_scale: bool = True, + is_causal: bool = True, + inject_noise: bool = False, + timestep_conditioning: bool = False, + upsample_residual: bool = False, + upscale_factor: int = 1, + ): + super().__init__() + + out_channels = out_channels or in_channels + + self.time_embedder = None + if timestep_conditioning: + # TODO: VocabParallelEmbedding? + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0) + + self.conv_in = None + if in_channels != out_channels: + self.conv_in = FastVideoLTXResnetBlock3d( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + ) + + self.upsamplers = None + if spatio_temporal_scale: + self.upsamplers = nn.ModuleList( + [ + FastVideoLTXUpsampler3d( + out_channels * upscale_factor, + stride=(2, 2, 2), + is_causal=is_causal, + residual=upsample_residual, + upscale_factor=upscale_factor, + ) + ] + ) + + resnets = [] + for _ in range(num_layers): + resnets.append( + FastVideoLTXResnetBlock3d( + in_channels=out_channels, + out_channels=out_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + if self.conv_in is not None: + hidden_states = self.conv_in(hidden_states, temb, generator) + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + else: + hidden_states = resnet(hidden_states, temb, generator) + + return hidden_states + + +class FastVideoLTXEncoder3d(nn.Module): + r""" + FastVideo encoder for LTX VAE. + + TODO: describe optimization: + - add parallelism + - replace more fastvideo normalization and activations? + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 128, + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + down_block_types: Tuple[str, ...] = ( + "FastVideoLTXDownBlock3D", + "FastVideoLTXDownBlock3D", + "FastVideoLTXDownBlock3D", + "FastVideoLTXDownBlock3D", + ), + spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), + layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), + downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"), + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + is_causal: bool = True, + ): + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.in_channels = in_channels * patch_size**2 + + output_channel = block_out_channels[0] + + # TODO: Add input dimension parallelism ? + self.conv_in = FastVideoLTXCausalConv3d( + in_channels=self.in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + is_causal=is_causal, + ) + + # down blocks + is_ltx_095 = down_block_types[-1] == "FastVideoLTX095DownBlock3D" + num_block_out_channels = len(block_out_channels) - (1 if is_ltx_095 else 0) + self.down_blocks = nn.ModuleList([]) + for i in range(num_block_out_channels): + input_channel = output_channel + if not is_ltx_095: + output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i] + else: + output_channel = block_out_channels[i + 1] + + if down_block_types[i] == "FastVideoLTXDownBlock3D": + down_block = FastVideoLTXDownBlock3D( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + is_causal=is_causal, + ) + elif down_block_types[i] == "FastVideoLTX095DownBlock3D": + down_block = FastVideoLTX095DownBlock3D( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + is_causal=is_causal, + downsample_type=downsample_type[i], + ) + else: + raise ValueError(f"Unknown down block type: {down_block_types[i]}") + + self.down_blocks.append(down_block) + + # mid block + self.mid_block = FastVideoLTXMidBlock3d( + in_channels=output_channel, + num_layers=layers_per_block[-1], + resnet_eps=resnet_norm_eps, + is_causal=is_causal, + ) + + # out + self.norm_out = RMSNorm(output_channel, eps=1e-8, has_weight=False) + + + self.conv_act = nn.SiLU() + # TODO: Add output dimension parallelism + self.conv_out = FastVideoLTXCausalConv3d( + in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal + ) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + + p = self.patch_size + p_t = self.patch_size_t + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p + + hidden_states = hidden_states.reshape( + batch_size, num_channels, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p + ) + # Thanks for driving me insane with the weird patching order :( + hidden_states = hidden_states.permute(0, 1, 3, 7, 5, 2, 4, 6).flatten(1, 4) + hidden_states = self.conv_in(hidden_states) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for down_block in self.down_blocks: + hidden_states = self._gradient_checkpointing_func(down_block, hidden_states) + + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states) + else: + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states) + + hidden_states = self.mid_block(hidden_states) + + hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + last_channel = hidden_states[:, -1:] + last_channel = last_channel.repeat(1, hidden_states.size(1) - 2, 1, 1, 1) + hidden_states = torch.cat([hidden_states, last_channel], dim=1) + + return hidden_states + + +class FastVideoLTXDecoder3d(nn.Module): + r""" + FastVideo decoder for LTX VAE. + + TODO: update optimizations: + - add parallelism + """ + + def __init__( + self, + in_channels: int = 128, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), + layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + is_causal: bool = False, + inject_noise: Tuple[bool, ...] = (False, False, False, False), + timestep_conditioning: bool = False, + upsample_residual: Tuple[bool, ...] = (False, False, False, False), + upsample_factor: Tuple[bool, ...] = (1, 1, 1, 1), + ) -> None: + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.out_channels = out_channels * patch_size**2 + + block_out_channels = tuple(reversed(block_out_channels)) + spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling)) + layers_per_block = tuple(reversed(layers_per_block)) + inject_noise = tuple(reversed(inject_noise)) + upsample_residual = tuple(reversed(upsample_residual)) + upsample_factor = tuple(reversed(upsample_factor)) + output_channel = block_out_channels[0] + + # TODO: Add input dimension parallelism + self.conv_in = FastVideoLTXCausalConv3d( + in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal + ) + + self.mid_block = FastVideoLTXMidBlock3d( + in_channels=output_channel, + num_layers=layers_per_block[0], + resnet_eps=resnet_norm_eps, + is_causal=is_causal, + inject_noise=inject_noise[0], + timestep_conditioning=timestep_conditioning, + ) + + # up blocks + num_block_out_channels = len(block_out_channels) + self.up_blocks = nn.ModuleList([]) + for i in range(num_block_out_channels): + input_channel = output_channel // upsample_factor[i] + output_channel = block_out_channels[i] // upsample_factor[i] + + up_block = FastVideoLTXUpBlock3d( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i + 1], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + is_causal=is_causal, + inject_noise=inject_noise[i + 1], + timestep_conditioning=timestep_conditioning, + upsample_residual=upsample_residual[i], + upscale_factor=upsample_factor[i], + ) + + self.up_blocks.append(up_block) + + # out + self.norm_out = RMSNorm(output_channel, eps=1e-8, has_weight=False) + self.conv_act = nn.SiLU() + # TODO: Add output dimension parallelism + self.conv_out = FastVideoLTXCausalConv3d( + in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal + ) + + # timestep embedding + self.time_embedder = None + self.scale_shift_table = None + self.timestep_scale_multiplier = None + if timestep_conditioning: + self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32)) + # TODO: VocabParallelEmbedding ? + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0) + self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + + if self.timestep_scale_multiplier is not None: + temb = temb * self.timestep_scale_multiplier + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb) + + for up_block in self.up_blocks: + hidden_states = self._gradient_checkpointing_func(up_block, hidden_states, temb) + else: + hidden_states = self.mid_block(hidden_states, temb) + + for up_block in self.up_blocks: + hidden_states = up_block(hidden_states, temb) + + hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1).unflatten(1, (2, -1)) + temb = temb + self.scale_shift_table[None, ..., None, None, None] + shift, scale = temb.unbind(dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + p = self.patch_size + p_t = self.patch_size_t + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.reshape(batch_size, -1, p_t, p, p, num_frames, height, width) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 4, 7, 3).flatten(6, 7).flatten(4, 5).flatten(2, 3) + + return hidden_states + +class AutoencoderKLLTXVideo(nn.Module): + r""" + FastVideo VAE model with KL loss for encoding images into latents and decoding latent representations into images. + + TODO: add more optimizations + - activation functions + - add parallelism + """ + + _supports_gradient_checkpointing = True + + + def __init__( + self, + config: LTXVAEConfig, + ) -> None: + nn.Module.__init__(self) + ParallelTiledVAE.__init__(self, config) + + # Extract parameters from config + # First check if config has arch_config (FastVideo style) or direct attributes (Diffusers style) + if hasattr(config, 'arch_config'): + # FastVideo style config + arch_config = config.arch_config + in_channels = getattr(arch_config, 'in_channels', 3) + out_channels = getattr(arch_config, 'out_channels', 3) + latent_channels = getattr(arch_config, 'latent_channels', 128) + block_out_channels = getattr(arch_config, 'block_out_channels', (128, 256, 512, 512)) + layers_per_block = getattr(arch_config, 'layers_per_block', (4, 3, 3, 3, 4)) + spatio_temporal_scaling = getattr(arch_config, 'spatio_temporal_scaling', (True, True, True, False)) + patch_size = getattr(arch_config, 'patch_size', 4) + patch_size_t = getattr(arch_config, 'patch_size_t', 1) + resnet_norm_eps = getattr(arch_config, 'resnet_norm_eps', 1e-6) + scaling_factor = getattr(arch_config, 'scaling_factor', 1.0) + encoder_causal = getattr(arch_config, 'encoder_causal', True) + decoder_causal = getattr(arch_config, 'decoder_causal', False) + spatial_compression_ratio = getattr(arch_config, 'spatial_compression_ratio', None) + temporal_compression_ratio = getattr(arch_config, 'temporal_compression_ratio', None) + + # Decoder specific parameters + decoder_block_out_channels = getattr(arch_config, 'decoder_block_out_channels', block_out_channels) + decoder_layers_per_block = getattr(arch_config, 'decoder_layers_per_block', layers_per_block) + decoder_spatio_temporal_scaling = getattr(arch_config, 'decoder_spatio_temporal_scaling', spatio_temporal_scaling) + decoder_inject_noise = getattr(arch_config, 'decoder_inject_noise', (False, False, False, False, False)) + downsample_type = getattr(arch_config, 'downsample_type', ("conv", "conv", "conv", "conv")) + upsample_residual = getattr(arch_config, 'upsample_residual', (False, False, False, False)) + upsample_factor = getattr(arch_config, 'upsample_factor', (1, 1, 1, 1)) + timestep_conditioning = getattr(arch_config, 'timestep_conditioning', False) + else: + # Diffusers style config - attributes directly on config + in_channels = getattr(config, 'in_channels', 3) + out_channels = getattr(config, 'out_channels', 3) + latent_channels = getattr(config, 'latent_channels', 128) + block_out_channels = getattr(config, 'block_out_channels', (128, 256, 512, 512)) + layers_per_block = getattr(config, 'layers_per_block', (4, 3, 3, 3, 4)) + spatio_temporal_scaling = getattr(config, 'spatio_temporal_scaling', (True, True, True, False)) + patch_size = getattr(config, 'patch_size', 4) + patch_size_t = getattr(config, 'patch_size_t', 1) + resnet_norm_eps = getattr(config, 'resnet_norm_eps', 1e-6) + scaling_factor = getattr(config, 'scaling_factor', 1.0) + encoder_causal = getattr(config, 'encoder_causal', True) + decoder_causal = getattr(config, 'decoder_causal', False) + spatial_compression_ratio = getattr(config, 'spatial_compression_ratio', None) + temporal_compression_ratio = getattr(config, 'temporal_compression_ratio', None) + + # Decoder specific parameters + decoder_block_out_channels = getattr(config, 'decoder_block_out_channels', block_out_channels) + decoder_layers_per_block = getattr(config, 'decoder_layers_per_block', layers_per_block) + decoder_spatio_temporal_scaling = getattr(config, 'decoder_spatio_temporal_scaling', spatio_temporal_scaling) + decoder_inject_noise = getattr(config, 'decoder_inject_noise', (False, False, False, False, False)) + downsample_type = getattr(config, 'downsample_type', ("conv", "conv", "conv", "conv")) + upsample_residual = getattr(config, 'upsample_residual', (False, False, False, False)) + upsample_factor = getattr(config, 'upsample_factor', (1, 1, 1, 1)) + timestep_conditioning = getattr(config, 'timestep_conditioning', False) + + # Handle down_block_types - use FastVideo versions + down_block_types = getattr(config, 'down_block_types', None) + if down_block_types is None or all('LTXVideoDownBlock3D' in t for t in down_block_types): + down_block_types = ( + "FastVideoLTXDownBlock3D", + "FastVideoLTXDownBlock3D", + "FastVideoLTXDownBlock3D", + "FastVideoLTXDownBlock3D", + ) + + # Create encoder and decoder with extracted parameters + self.encoder = FastVideoLTXEncoder3d( + in_channels=in_channels, + out_channels=latent_channels, + block_out_channels=block_out_channels, + down_block_types=down_block_types, + spatio_temporal_scaling=spatio_temporal_scaling, + layers_per_block=layers_per_block, + downsample_type=downsample_type, + patch_size=patch_size, + patch_size_t=patch_size_t, + resnet_norm_eps=resnet_norm_eps, + is_causal=encoder_causal, + ) + + self.decoder = FastVideoLTXDecoder3d( + in_channels=latent_channels, + out_channels=out_channels, + block_out_channels=decoder_block_out_channels, + spatio_temporal_scaling=decoder_spatio_temporal_scaling, + layers_per_block=decoder_layers_per_block, + patch_size=patch_size, + patch_size_t=patch_size_t, + resnet_norm_eps=resnet_norm_eps, + is_causal=decoder_causal, + timestep_conditioning=timestep_conditioning, + inject_noise=decoder_inject_noise, + upsample_residual=upsample_residual, + upsample_factor=upsample_factor, + ) + + # Register buffers for latent normalization + latents_mean = torch.zeros((latent_channels,), requires_grad=False) + latents_std = torch.ones((latent_channels,), requires_grad=False) + self.register_buffer("latents_mean", latents_mean, persistent=True) + self.register_buffer("latents_std", latents_std, persistent=True) + + # Set compression ratios + self.spatial_compression_ratio = ( + patch_size * 2 ** sum(spatio_temporal_scaling) + if spatial_compression_ratio is None + else spatial_compression_ratio + ) + self.temporal_compression_ratio = ( + patch_size_t * 2 ** sum(spatio_temporal_scaling) + if temporal_compression_ratio is None + else temporal_compression_ratio + ) + + # Memory optimization settings + self.use_slicing = False + self.use_tiling = False + self.use_framewise_encoding = False + self.use_framewise_decoding = False + + # Batch sizes for memory management + self.num_sample_frames_batch_size = 16 + self.num_latent_frames_batch_size = 2 + + # Tiling parameters + self.tile_sample_min_height = 512 + self.tile_sample_min_width = 512 + self.tile_sample_min_num_frames = 16 + self.tile_sample_stride_height = 448 + self.tile_sample_stride_width = 448 + self.tile_sample_stride_num_frames = 8 + + # Store config for compatibility + self.config = config + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_min_num_frames: Optional[int] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, + tile_sample_stride_num_frames: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = x.shape + + if self.use_framewise_encoding and num_frames > self.tile_sample_min_num_frames: + return self._temporal_tiled_encode(x) + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + enc = self.encoder(x) + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode( + self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + + if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames: + return self._temporal_tiled_decode(z, temb, return_dict=return_dict) + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, temb, return_dict=return_dict) + + dec = self.decoder(z, temb) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode( + self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + """ + Decode a batch of images. + """ + if self.use_slicing and z.shape[0] > 1: + if temb is not None: + decoded_slices = [ + self._decode(z_slice, t_slice).sample for z_slice, t_slice in (z.split(1), temb.split(1)) + ] + else: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z, temb).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) + for x in range(blend_extent): + b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder.""" + batch_size, num_channels, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + time = self.encoder( + x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + ) + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode( + self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + """ + + batch_size, num_channels, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb) + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def _temporal_tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = x.shape + latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1 + + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio + blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames + + row = [] + for i in range(0, num_frames, self.tile_sample_stride_num_frames): + tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :] + if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width): + tile = self.tiled_encode(tile) + else: + tile = self.encoder(tile) + if i > 0: + tile = tile[:, :, 1:, :, :] + row.append(tile) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :]) + else: + result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :]) + + enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames] + return enc + + def _temporal_tiled_decode( + self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio + blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames + + row = [] + for i in range(0, num_frames, tile_latent_stride_num_frames): + tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :] + if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height): + decoded = self.tiled_decode(tile, temb, return_dict=True).sample + else: + decoded = self.decoder(tile, temb) + if i > 0: + decoded = decoded[:, :, :-1, :, :] + row.append(decoded) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + tile = tile[:, :, : self.tile_sample_stride_num_frames, :, :] + result_row.append(tile) + else: + result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :]) + + dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + temb: Optional[torch.Tensor] = None, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[torch.Tensor, torch.Tensor]: + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, temb) + if not return_dict: + return (dec.sample,) + return dec + From d02725f6f88ab18c0300b9082386be151f026612 Mon Sep 17 00:00:00 2001 From: Aarti Lalwani Date: Tue, 30 Sep 2025 14:53:57 -0700 Subject: [PATCH 6/6] initial pipeline logic --- examples/inference/basic/basic_ltx.py | 29 +++++++++++++++++++ fastvideo/configs/models/dits/ltxvideo.py | 10 ++++--- fastvideo/configs/pipelines/__init__.py | 4 ++- fastvideo/configs/pipelines/ltx.py | 23 +++++++++++++++ fastvideo/configs/pipelines/registry.py | 11 +++++-- fastvideo/entrypoints/video_generator.py | 2 ++ fastvideo/models/dits/ltxvideo.py | 6 +++- .../pipelines/basic/ltxvideo/ltx_pipeline.py | 2 +- fastvideo/pipelines/pipeline_registry.py | 1 + 9 files changed, 79 insertions(+), 9 deletions(-) create mode 100644 examples/inference/basic/basic_ltx.py create mode 100644 fastvideo/configs/pipelines/ltx.py diff --git a/examples/inference/basic/basic_ltx.py b/examples/inference/basic/basic_ltx.py new file mode 100644 index 000000000..84c93b30d --- /dev/null +++ b/examples/inference/basic/basic_ltx.py @@ -0,0 +1,29 @@ + +# from fastvideo.configs.sample import SamplingParam + +OUTPUT_PATH = "video_samples_ltx" +def main(): + # FastVideo will automatically use the optimal default arguments for the + # model. + # If a local path is provided, FastVideo will make a best effort + # attempt to identify the optimal arguments. + generator = VideoGenerator.from_pretrained( + "Lightricks/LTX-Video", + # # FastVideo will automatically handle distributed setup + num_gpus=1, + use_fsdp_inference=False, + # dit_cpu_offload=True, + # vae_cpu_offload=False, + # text_encoder_cpu_offload=True, + # # Set pin_cpu_memory to false if CPU RAM is limited and there're no frequent CPU-GPU transfer + # pin_cpu_memory=True, + # # image_encoder_cpu_offload=False, + ) + + prompt = "A cute little penguin takes out a book and starts reading it" + image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png" + + video = generator.generate_video(prompt, image_path=image_path, output_path=OUTPUT_PATH, save_video=True, height=512, width=768, num_frames=20) + +if __name__ == "__main__": + main() diff --git a/fastvideo/configs/models/dits/ltxvideo.py b/fastvideo/configs/models/dits/ltxvideo.py index de16893c4..b2cf70011 100644 --- a/fastvideo/configs/models/dits/ltxvideo.py +++ b/fastvideo/configs/models/dits/ltxvideo.py @@ -4,11 +4,13 @@ from fastvideo.configs.models.dits.base import DiTArchConfig, DiTConfig +def is_blocks(n: str, m) -> bool: + return "blocks" in n and str.isdigit(n.split(".")[-1]) + + @dataclass class LTXVideoArchConfig(DiTArchConfig): - fsdp_shard_conditions: list = field( - default_factory=lambda: - [lambda n, m: "transformer_blocks" in n and n.split(".")[-1].isdigit()]) + _fsdp_shard_conditions: list = field(default_factory=lambda: [is_blocks]) # Parameter name mappings for loading pretrained weights from HuggingFace/Diffusers # Maps from source model parameter names to FastVideo LTX implementation names @@ -82,4 +84,4 @@ def __post_init__(self): @dataclass class LTXVideoConfig(DiTConfig): arch_config: DiTArchConfig = field(default_factory=LTXVideoArchConfig) - prefix: str = "LTXVideo" + prefix: str = "LTX" diff --git a/fastvideo/configs/pipelines/__init__.py b/fastvideo/configs/pipelines/__init__.py index 6ff503848..ad028d2e7 100644 --- a/fastvideo/configs/pipelines/__init__.py +++ b/fastvideo/configs/pipelines/__init__.py @@ -1,6 +1,7 @@ from fastvideo.configs.pipelines.base import (PipelineConfig, SlidingTileAttnConfig) from fastvideo.configs.pipelines.hunyuan import FastHunyuanConfig, HunyuanConfig +from fastvideo.configs.pipelines.ltx import LTXConfig from fastvideo.configs.pipelines.registry import ( get_pipeline_config_cls_from_name) from fastvideo.configs.pipelines.stepvideo import StepVideoT2VConfig @@ -12,5 +13,6 @@ "HunyuanConfig", "FastHunyuanConfig", "PipelineConfig", "SlidingTileAttnConfig", "WanT2V480PConfig", "WanI2V480PConfig", "WanT2V720PConfig", "WanI2V720PConfig", "StepVideoT2VConfig", - "SelfForcingWanT2V480PConfig", "get_pipeline_config_cls_from_name" + "SelfForcingWanT2V480PConfig", "LTXConfig", + "get_pipeline_config_cls_from_name" ] diff --git a/fastvideo/configs/pipelines/ltx.py b/fastvideo/configs/pipelines/ltx.py new file mode 100644 index 000000000..5c9c2a720 --- /dev/null +++ b/fastvideo/configs/pipelines/ltx.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from fastvideo.configs.models import DiTConfig, VAEConfig +from fastvideo.configs.models.dits import LTXVideoConfig +from fastvideo.configs.models.vaes import LTXVAEConfig +from fastvideo.configs.pipelines.base import PipelineConfig + + +@dataclass +class LTXConfig(PipelineConfig): + """Base configuration for LTX pipeline architecture.""" + + # DiT + dit_config: DiTConfig = field(default_factory=LTXVideoConfig) + # VAE + vae_config: VAEConfig = field(default_factory=LTXVAEConfig) + vae_tiling: bool = False + vae_sp: bool = False + + # Precision for each component + precision: str = "bf16" + vae_precision: str = "bf16" \ No newline at end of file diff --git a/fastvideo/configs/pipelines/registry.py b/fastvideo/configs/pipelines/registry.py index 62a4eefe5..e31867690 100644 --- a/fastvideo/configs/pipelines/registry.py +++ b/fastvideo/configs/pipelines/registry.py @@ -7,6 +7,7 @@ from fastvideo.configs.pipelines.base import PipelineConfig from fastvideo.configs.pipelines.hunyuan import FastHunyuanConfig, HunyuanConfig from fastvideo.configs.pipelines.stepvideo import StepVideoT2VConfig +from fastvideo.configs.pipelines.ltx import LTXConfig # isort: off from fastvideo.configs.pipelines.wan import ( @@ -39,6 +40,7 @@ "Wan-AI/Wan2.2-TI2V-5B-Diffusers": Wan2_2_TI2V_5B_Config, "Wan-AI/Wan2.2-T2V-A14B-Diffusers": Wan2_2_T2V_A14B_Config, "Wan-AI/Wan2.2-I2V-A14B-Diffusers": Wan2_2_I2V_A14B_Config, + "Lightricks/LTX-Video": LTXConfig, # Add other specific weight variants } @@ -50,6 +52,7 @@ "wandmdpipeline": lambda id: "wandmdpipeline" in id.lower(), "wancausaldmdpipeline": lambda id: "wancausaldmdpipeline" in id.lower(), "stepvideo": lambda id: "stepvideo" in id.lower(), + "ltx": lambda id: "ltx" in id.lower(), # Add other pipeline architecture detectors } @@ -62,7 +65,8 @@ "wanimagetovideo": WanI2V480PConfig, "wandmdpipeline": FastWan2_1_T2V_480P_Config, "wancausaldmdpipeline": SelfForcingWanT2V480PConfig, - "stepvideo": StepVideoT2VConfig + "stepvideo": StepVideoT2VConfig, + "ltx": LTXConfig, # Other fallbacks by architecture } @@ -100,17 +104,20 @@ def get_pipeline_config_cls_from_name( pipeline_config_cls: type[PipelineConfig] | None = None # First try exact match for specific weights + print(pipeline_name_or_path) if pipeline_name_or_path in PIPE_NAME_TO_CONFIG: pipeline_config_cls = PIPE_NAME_TO_CONFIG[pipeline_name_or_path] - + print(f" exact {pipeline_config_cls}") # Try partial matches (for local paths that might include the weight ID) for registered_id, config_class in PIPE_NAME_TO_CONFIG.items(): if registered_id in pipeline_name_or_path: pipeline_config_cls = config_class + print(f" partial {pipeline_config_cls}") break # If no match, try to use the fallback config if pipeline_config_cls is None: + print(f" trying fallback {pipeline_config_cls}") if os.path.exists(pipeline_name_or_path): config = verify_model_config_and_directory(pipeline_name_or_path) else: diff --git a/fastvideo/entrypoints/video_generator.py b/fastvideo/entrypoints/video_generator.py index 7d97b21bc..a4f6fc618 100644 --- a/fastvideo/entrypoints/video_generator.py +++ b/fastvideo/entrypoints/video_generator.py @@ -90,7 +90,9 @@ def from_fastvideo_args(cls, # Initialize distributed environment if needed # initialize_distributed_and_parallelism(fastvideo_args) + print(f"fastvideo_args {fastvideo_args}") executor_class = Executor.get_class(fastvideo_args) + print(f"executor_class {executor_class}") return cls( fastvideo_args=fastvideo_args, executor_class=executor_class, diff --git a/fastvideo/models/dits/ltxvideo.py b/fastvideo/models/dits/ltxvideo.py index 786a1a615..ca051738d 100644 --- a/fastvideo/models/dits/ltxvideo.py +++ b/fastvideo/models/dits/ltxvideo.py @@ -660,9 +660,13 @@ class LTXVideoTransformer3DModel(CachableDiT): _skip_layerwise_casting_patterns = ["norm"] _repeated_blocks = ["FastVideoLTXTransformerBlock"] + _fsdp_shard_conditions = LTXVideoConfig()._fsdp_shard_conditions + param_names_mapping = LTXVideoConfig().param_names_mapping + + def __init__( self, - config: Optional[LTXVideoConfig] = None, + config: LTXVideoConfig, hf_config: Optional[Dict] = None, **kwargs ) -> None: diff --git a/fastvideo/pipelines/basic/ltxvideo/ltx_pipeline.py b/fastvideo/pipelines/basic/ltxvideo/ltx_pipeline.py index df7b643eb..b9f30b0d7 100644 --- a/fastvideo/pipelines/basic/ltxvideo/ltx_pipeline.py +++ b/fastvideo/pipelines/basic/ltxvideo/ltx_pipeline.py @@ -20,7 +20,7 @@ class LTXPipeline(ComposedPipelineBase): """ - LTX video diffusion pipeline with LoRA support. + LTX video diffusion pipeline """ _required_config_modules = [ diff --git a/fastvideo/pipelines/pipeline_registry.py b/fastvideo/pipelines/pipeline_registry.py index 9955c3729..6cdcada1d 100644 --- a/fastvideo/pipelines/pipeline_registry.py +++ b/fastvideo/pipelines/pipeline_registry.py @@ -24,6 +24,7 @@ "WanCausalDMDPipeline": "wan", "StepVideoPipeline": "stepvideo", "HunyuanVideoPipeline": "hunyuan", + "LTXPipeline": "ltx", } _PREPROCESS_WORKLOAD_TYPE_TO_PIPELINE_NAME: dict[WorkloadType, str] = {