Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions examples/inference/basic/basic_ltx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@

# from fastvideo.configs.sample import SamplingParam

OUTPUT_PATH = "video_samples_ltx"
def main():
# FastVideo will automatically use the optimal default arguments for the
# model.
# If a local path is provided, FastVideo will make a best effort
# attempt to identify the optimal arguments.
generator = VideoGenerator.from_pretrained(
"Lightricks/LTX-Video",
# # FastVideo will automatically handle distributed setup
num_gpus=1,
use_fsdp_inference=False,
# dit_cpu_offload=True,
# vae_cpu_offload=False,
# text_encoder_cpu_offload=True,
# # Set pin_cpu_memory to false if CPU RAM is limited and there're no frequent CPU-GPU transfer
# pin_cpu_memory=True,
# # image_encoder_cpu_offload=False,
)

prompt = "A cute little penguin takes out a book and starts reading it"
image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png"

video = generator.generate_video(prompt, image_path=image_path, output_path=OUTPUT_PATH, save_video=True, height=512, width=768, num_frames=20)

if __name__ == "__main__":
main()
5 changes: 4 additions & 1 deletion fastvideo/configs/models/dits/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from fastvideo.configs.models.dits.hunyuanvideo import HunyuanVideoConfig
from fastvideo.configs.models.dits.ltxvideo import LTXVideoConfig
from fastvideo.configs.models.dits.stepvideo import StepVideoConfig
from fastvideo.configs.models.dits.wanvideo import WanVideoConfig

__all__ = ["HunyuanVideoConfig", "WanVideoConfig", "StepVideoConfig"]
__all__ = [
"HunyuanVideoConfig", "WanVideoConfig", "StepVideoConfig", "LTXVideoConfig"
]
87 changes: 87 additions & 0 deletions fastvideo/configs/models/dits/ltxvideo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field

from fastvideo.configs.models.dits.base import DiTArchConfig, DiTConfig


def is_blocks(n: str, m) -> bool:
return "blocks" in n and str.isdigit(n.split(".")[-1])


@dataclass
class LTXVideoArchConfig(DiTArchConfig):
_fsdp_shard_conditions: list = field(default_factory=lambda: [is_blocks])

# Parameter name mappings for loading pretrained weights from HuggingFace/Diffusers
# Maps from source model parameter names to FastVideo LTX implementation names
param_names_mapping: dict = field(
default_factory=lambda: {
# todo: double check all of this
r"^transformer_blocks\.(\d+)\.norm1\.weight$":
r"transformer_blocks.\1.norm1.weight",
r"^transformer_blocks\.(\d+)\.norm2\.weight$":
r"transformer_blocks.\1.norm2.weight",

# FeedForward network mappings (double check)
r"^transformer_blocks\.(\d+)\.ff\.net\.0\.weight$":
r"transformer_blocks.\1.ff.net.0.weight",
r"^transformer_blocks\.(\d+)\.ff\.net\.0\.bias$":
r"transformer_blocks.\1.ff.net.0.bias",
r"^transformer_blocks\.(\d+)\.ff\.net\.3\.weight$":
r"transformer_blocks.\1.ff.net.3.weight",
r"^transformer_blocks\.(\d+)\.ff\.net\.3\.bias$":
r"transformer_blocks.\1.ff.net.3.bias",

# Scale-shift table for adaptive layer norm
r"^transformer_blocks\.(\d+)\.scale_shift_table$":
r"transformer_blocks.\1.scale_shift_table",

# Time embedding mappings
r"^time_embed\.emb\.timestep_embedder\.linear_1\.(weight|bias)$":
r"time_embed.emb.mlp.fc_in.\1",
r"^time_embed\.emb\.timestep_embedder\.linear_2\.(weight|bias)$":
r"time_embed.emb.mlp.fc_out.\1",

# Caption projection mappings
r"^caption_projection\.linear_1\.(weight|bias)$":
r"caption_projection.fc_in.\1",
r"^caption_projection\.linear_2\.(weight|bias)$":
r"caption_projection.fc_out.\1",

# Output normalization (FP32LayerNorm)
r"^norm_out\.weight$": r"norm_out.weight",

# Global scale-shift table
r"^scale_shift_table$": r"scale_shift_table",
})

num_attention_heads: int = 32
attention_head_dim: int = 64
in_channels: int = 128
out_channels: int | None = 128
num_layers: int = 28
dropout: float = 0.0
patch_size: int = 1
patch_size_t: int = 1
norm_elementwise_affine: bool = False
norm_eps: float = 1e-6
activation_fn: str = "gelu-approximate"
attention_bias: bool = True
attention_out_bias: bool = True
caption_channels: int | list[int] | tuple[int, ...] | None = 4096
cross_attention_dim: int = 2048
qk_norm: str = "rms_norm_across_heads"
attention_type: str | None = "torch"
use_additional_conditions: bool | None = False
exclude_lora_layers: list[str] = field(default_factory=lambda: [])

def __post_init__(self):
self.hidden_size = self.num_attention_heads * self.attention_head_dim
self.out_channels = self.in_channels if self.out_channels is None else self.out_channels
self.num_channels_latents = self.out_channels


@dataclass
class LTXVideoConfig(DiTConfig):
arch_config: DiTArchConfig = field(default_factory=LTXVideoArchConfig)
prefix: str = "LTX"
2 changes: 2 additions & 0 deletions fastvideo/configs/models/vaes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from fastvideo.configs.models.vaes.hunyuanvae import HunyuanVAEConfig
from fastvideo.configs.models.vaes.ltxvae import LTXVAEConfig
from fastvideo.configs.models.vaes.stepvideovae import StepVideoVAEConfig
from fastvideo.configs.models.vaes.wanvae import WanVAEConfig

__all__ = [
"HunyuanVAEConfig",
"WanVAEConfig",
"StepVideoVAEConfig",
"LTXVAEConfig",
]
53 changes: 53 additions & 0 deletions fastvideo/configs/models/vaes/ltxvae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field

import torch

from fastvideo.configs.models.vaes.base import VAEArchConfig, VAEConfig


@dataclass
class LTXVAEArchConfig(VAEArchConfig):
block_out_channels: tuple[int, ...] = (128, 256, 512, 512)
decoder_causal: bool = False
encoder_causal: bool = True
in_channels: int = 3
latent_channels: int = 128
layers_per_block: tuple[int, ...] = (4, 3, 3, 3, 4)
out_channels: int = 3
patch_size: int = 4
patch_size_t: int = 1
resnet_norm_eps: float = 1e-06
scaling_factor: float = 1.0
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, False)

# Additional fields that might be inherited from base class
z_dim: int = 128 # Using latent_channels as z_dim
is_residual: bool = False
clip_output: bool = True

def __post_init__(self):
# Calculate compression ratios based on patch sizes and downsampling
self.temporal_compression_ratio = self.patch_size_t
# Spatial compression is usually patch_size * product of spatial downsampling
self.spatial_compression_ratio = self.patch_size * (2**(
len(self.block_out_channels) - 1))

if isinstance(self.scaling_factor, int | float):
self.scaling_factor_tensor: torch.Tensor = torch.tensor(
self.scaling_factor)


@dataclass
class LTXVAEConfig(VAEConfig):
arch_config: LTXVAEArchConfig = field(default_factory=LTXVAEArchConfig)
use_feature_cache: bool = True
use_tiling: bool = False
use_temporal_tiling: bool = False
use_parallel_tiling: bool = False

def __post_init__(self):
if hasattr(self, 'tile_sample_min_num_frames') and hasattr(
self, 'tile_sample_stride_num_frames'):
self.blend_num_frames = (self.tile_sample_min_num_frames -
self.tile_sample_stride_num_frames) * 2
4 changes: 3 additions & 1 deletion fastvideo/configs/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from fastvideo.configs.pipelines.base import (PipelineConfig,
SlidingTileAttnConfig)
from fastvideo.configs.pipelines.hunyuan import FastHunyuanConfig, HunyuanConfig
from fastvideo.configs.pipelines.ltx import LTXConfig
from fastvideo.configs.pipelines.registry import (
get_pipeline_config_cls_from_name)
from fastvideo.configs.pipelines.stepvideo import StepVideoT2VConfig
Expand All @@ -12,5 +13,6 @@
"HunyuanConfig", "FastHunyuanConfig", "PipelineConfig",
"SlidingTileAttnConfig", "WanT2V480PConfig", "WanI2V480PConfig",
"WanT2V720PConfig", "WanI2V720PConfig", "StepVideoT2VConfig",
"SelfForcingWanT2V480PConfig", "get_pipeline_config_cls_from_name"
"SelfForcingWanT2V480PConfig", "LTXConfig",
"get_pipeline_config_cls_from_name"
]
23 changes: 23 additions & 0 deletions fastvideo/configs/pipelines/ltx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field

from fastvideo.configs.models import DiTConfig, VAEConfig
from fastvideo.configs.models.dits import LTXVideoConfig
from fastvideo.configs.models.vaes import LTXVAEConfig
from fastvideo.configs.pipelines.base import PipelineConfig


@dataclass
class LTXConfig(PipelineConfig):
"""Base configuration for LTX pipeline architecture."""

# DiT
dit_config: DiTConfig = field(default_factory=LTXVideoConfig)
# VAE
vae_config: VAEConfig = field(default_factory=LTXVAEConfig)
vae_tiling: bool = False
vae_sp: bool = False

# Precision for each component
precision: str = "bf16"
vae_precision: str = "bf16"
11 changes: 9 additions & 2 deletions fastvideo/configs/pipelines/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fastvideo.configs.pipelines.base import PipelineConfig
from fastvideo.configs.pipelines.hunyuan import FastHunyuanConfig, HunyuanConfig
from fastvideo.configs.pipelines.stepvideo import StepVideoT2VConfig
from fastvideo.configs.pipelines.ltx import LTXConfig

# isort: off
from fastvideo.configs.pipelines.wan import (
Expand Down Expand Up @@ -39,6 +40,7 @@
"Wan-AI/Wan2.2-TI2V-5B-Diffusers": Wan2_2_TI2V_5B_Config,
"Wan-AI/Wan2.2-T2V-A14B-Diffusers": Wan2_2_T2V_A14B_Config,
"Wan-AI/Wan2.2-I2V-A14B-Diffusers": Wan2_2_I2V_A14B_Config,
"Lightricks/LTX-Video": LTXConfig,
# Add other specific weight variants
}

Expand All @@ -50,6 +52,7 @@
"wandmdpipeline": lambda id: "wandmdpipeline" in id.lower(),
"wancausaldmdpipeline": lambda id: "wancausaldmdpipeline" in id.lower(),
"stepvideo": lambda id: "stepvideo" in id.lower(),
"ltx": lambda id: "ltx" in id.lower(),
# Add other pipeline architecture detectors
}

Expand All @@ -62,7 +65,8 @@
"wanimagetovideo": WanI2V480PConfig,
"wandmdpipeline": FastWan2_1_T2V_480P_Config,
"wancausaldmdpipeline": SelfForcingWanT2V480PConfig,
"stepvideo": StepVideoT2VConfig
"stepvideo": StepVideoT2VConfig,
"ltx": LTXConfig,
# Other fallbacks by architecture
}

Expand Down Expand Up @@ -100,17 +104,20 @@ def get_pipeline_config_cls_from_name(
pipeline_config_cls: type[PipelineConfig] | None = None

# First try exact match for specific weights
print(pipeline_name_or_path)
if pipeline_name_or_path in PIPE_NAME_TO_CONFIG:
pipeline_config_cls = PIPE_NAME_TO_CONFIG[pipeline_name_or_path]

print(f" exact {pipeline_config_cls}")
# Try partial matches (for local paths that might include the weight ID)
for registered_id, config_class in PIPE_NAME_TO_CONFIG.items():
if registered_id in pipeline_name_or_path:
pipeline_config_cls = config_class
print(f" partial {pipeline_config_cls}")
break

# If no match, try to use the fallback config
if pipeline_config_cls is None:
print(f" trying fallback {pipeline_config_cls}")
if os.path.exists(pipeline_name_or_path):
config = verify_model_config_and_directory(pipeline_name_or_path)
else:
Expand Down
2 changes: 2 additions & 0 deletions fastvideo/entrypoints/video_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def from_fastvideo_args(cls,
# Initialize distributed environment if needed
# initialize_distributed_and_parallelism(fastvideo_args)

print(f"fastvideo_args {fastvideo_args}")
executor_class = Executor.get_class(fastvideo_args)
print(f"executor_class {executor_class}")
return cls(
fastvideo_args=fastvideo_args,
executor_class=executor_class,
Expand Down
Loading