diff --git a/examples/inference/basic/basic_wan2_2_Fun.py b/examples/inference/basic/basic_wan2_2_Fun.py new file mode 100644 index 000000000..ca779941b --- /dev/null +++ b/examples/inference/basic/basic_wan2_2_Fun.py @@ -0,0 +1,36 @@ +from fastvideo import VideoGenerator + +# from fastvideo.configs.sample import SamplingParam + +OUTPUT_PATH = "video_samples_wan2_1_Fun" +OUTPUT_NAME = "wan2.1_test" +def main(): + # FastVideo will automatically use the optimal default arguments for the + # model. + # If a local path is provided, FastVideo will make a best effort + # attempt to identify the optimal arguments. + generator = VideoGenerator.from_pretrained( + "IRMChen/Wan2.1-Fun-1.3B-Control-Diffusers", + # "alibaba-pai/Wan2.2-Fun-A14B-Control", + # FastVideo will automatically handle distributed setup + num_gpus=1, + use_fsdp_inference=True, + dit_cpu_offload=True, # DiT need to be offloaded for MoE + vae_cpu_offload=False, + text_encoder_cpu_offload=True, + # Set pin_cpu_memory to false if CPU RAM is limited and there're no frequent CPU-GPU transfer + pin_cpu_memory=True, + # image_encoder_cpu_offload=False, + ) + + prompt = "一位年轻女性穿着一件粉色的连衣裙,裙子上有白色的装饰和粉色的纽扣。她的头发是紫色的,头上戴着一个红色的大蝴蝶结,显得非常可爱和精致。她还戴着一个红色的领结,整体造型充满了少女感和活力。她的表情温柔,双手轻轻交叉放在身前,姿态优雅。背景是简单的灰色,没有任何多余的装饰,使得人物更加突出。她的妆容清淡自然,突显了她的清新气质。整体画面给人一种甜美、梦幻的感觉,仿佛置身于童话世界中。" + negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" + # prompt = "A young woman with beautiful, clear eyes and blonde hair stands in the forest, wearing a white dress and a crown. Her expression is serene, reminiscent of a movie star, with fair and youthful skin. Her brown long hair flows in the wind. The video quality is very high, with a clear view. High quality, masterpiece, best quality, high resolution, ultra-fine, fantastical." + # negative_prompt = "Twisted body, limb deformities, text captions, comic, static, ugly, error, messy code." + image_path = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/wan_fun/asset_Wan2_2/v1.0/8.png" + control_video_path = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/wan_fun/asset_Wan2_2/v1.0/pose.mp4" + + video = generator.generate_video(prompt, negative_prompt=negative_prompt, image_path=image_path, video_path=control_video_path, output_path=OUTPUT_PATH, output_video_name=OUTPUT_NAME, save_video=True) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/fastvideo/configs/models/encoders/__init__.py b/fastvideo/configs/models/encoders/__init__.py index f783a2106..e56dd3a3c 100644 --- a/fastvideo/configs/models/encoders/__init__.py +++ b/fastvideo/configs/models/encoders/__init__.py @@ -2,13 +2,13 @@ EncoderConfig, ImageEncoderConfig, TextEncoderConfig) -from fastvideo.configs.models.encoders.clip import (CLIPTextConfig, - CLIPVisionConfig) +from fastvideo.configs.models.encoders.clip import ( + CLIPTextConfig, CLIPVisionConfig, WAN2_1ControlCLIPVisionConfig) from fastvideo.configs.models.encoders.llama import LlamaConfig from fastvideo.configs.models.encoders.t5 import T5Config __all__ = [ "EncoderConfig", "TextEncoderConfig", "ImageEncoderConfig", - "BaseEncoderOutput", "CLIPTextConfig", "CLIPVisionConfig", "LlamaConfig", - "T5Config" + "BaseEncoderOutput", "CLIPTextConfig", "CLIPVisionConfig", + "WAN2_1ControlCLIPVisionConfig", "LlamaConfig", "T5Config" ] diff --git a/fastvideo/configs/models/encoders/clip.py b/fastvideo/configs/models/encoders/clip.py index a7d313a86..d233872f1 100644 --- a/fastvideo/configs/models/encoders/clip.py +++ b/fastvideo/configs/models/encoders/clip.py @@ -77,6 +77,8 @@ class CLIPTextConfig(TextEncoderConfig): num_hidden_layers_override: int | None = None require_post_norm: bool | None = None + enable_scale: bool = True + is_causal: bool = True prefix: str = "clip" @@ -87,4 +89,14 @@ class CLIPVisionConfig(ImageEncoderConfig): num_hidden_layers_override: int | None = None require_post_norm: bool | None = None + enable_scale: bool = True + is_causal: bool = True prefix: str = "clip" + + +@dataclass +class WAN2_1ControlCLIPVisionConfig(CLIPVisionConfig): + num_hidden_layers_override: int | None = 31 + require_post_norm: bool | None = False + enable_scale: bool = False + is_causal: bool = False diff --git a/fastvideo/configs/pipelines/registry.py b/fastvideo/configs/pipelines/registry.py index 62a4eefe5..8803d0765 100644 --- a/fastvideo/configs/pipelines/registry.py +++ b/fastvideo/configs/pipelines/registry.py @@ -13,7 +13,7 @@ FastWan2_1_T2V_480P_Config, FastWan2_2_TI2V_5B_Config, Wan2_2_I2V_A14B_Config, Wan2_2_T2V_A14B_Config, Wan2_2_TI2V_5B_Config, WanI2V480PConfig, WanI2V720PConfig, WanT2V480PConfig, WanT2V720PConfig, - SelfForcingWanT2V480PConfig) + SelfForcingWanT2V480PConfig, WANV2VConfig) # isort: on from fastvideo.logger import init_logger from fastvideo.utils import (maybe_download_model_index, @@ -27,6 +27,7 @@ "hunyuanvideo-community/HunyuanVideo": HunyuanConfig, "Wan-AI/Wan2.1-T2V-1.3B-Diffusers": WanT2V480PConfig, "weizhou03/Wan2.1-Fun-1.3B-InP-Diffusers": WanI2V480PConfig, + "IRMChen/Wan2.1-Fun-1.3B-Control-Diffusers": WANV2VConfig, "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers": WanI2V480PConfig, "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers": WanI2V720PConfig, "Wan-AI/Wan2.1-T2V-14B-Diffusers": WanT2V720PConfig, diff --git a/fastvideo/configs/pipelines/wan.py b/fastvideo/configs/pipelines/wan.py index 4a45c4df7..6102282c2 100644 --- a/fastvideo/configs/pipelines/wan.py +++ b/fastvideo/configs/pipelines/wan.py @@ -7,7 +7,8 @@ from fastvideo.configs.models import DiTConfig, EncoderConfig, VAEConfig from fastvideo.configs.models.dits import WanVideoConfig from fastvideo.configs.models.encoders import (BaseEncoderOutput, - CLIPVisionConfig, T5Config) + CLIPVisionConfig, T5Config, + WAN2_1ControlCLIPVisionConfig) from fastvideo.configs.models.vaes import WanVAEConfig from fastvideo.configs.pipelines.base import PipelineConfig @@ -97,6 +98,16 @@ class WanI2V720PConfig(WanI2V480PConfig): flow_shift: float | None = 5.0 +@dataclass +class WANV2VConfig(WanI2V480PConfig): + """Configuration for WAN2.1 1.3B Control pipeline.""" + + image_encoder_config: EncoderConfig = field( + default_factory=WAN2_1ControlCLIPVisionConfig) + # CLIP encoder precision + image_encoder_precision: str = 'bf16' + + @dataclass class FastWan2_1_T2V_480P_Config(WanT2V480PConfig): """Base configuration for FastWan T2V 1.3B 480P pipeline architecture with DMD""" diff --git a/fastvideo/configs/sample/base.py b/fastvideo/configs/sample/base.py index c4108613b..873777561 100644 --- a/fastvideo/configs/sample/base.py +++ b/fastvideo/configs/sample/base.py @@ -18,6 +18,9 @@ class SamplingParam: # Image inputs image_path: str | None = None + # Video inputs + video_path: str | None = None + # Text inputs prompt: str | list[str] | None = None negative_prompt: str = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" @@ -200,6 +203,12 @@ def add_cli_args(parser: Any) -> Any: default=SamplingParam.image_path, help="Path to input image for image-to-video generation", ) + parser.add_argument( + "--video_path", + type=str, + default=SamplingParam.video_path, + help="Path to input video for video-to-video generation", + ) parser.add_argument( "--moba-config-path", type=str, diff --git a/fastvideo/configs/sample/registry.py b/fastvideo/configs/sample/registry.py index 7bd4c4f77..1a551cafd 100644 --- a/fastvideo/configs/sample/registry.py +++ b/fastvideo/configs/sample/registry.py @@ -18,6 +18,7 @@ WanI2V_14B_720P_SamplingParam, WanT2V_1_3B_SamplingParam, WanT2V_14B_SamplingParam, + Wan2_1_Fun_1_3B_Control_SamplingParam, SelfForcingWanT2V480PConfig, ) # isort: on @@ -39,6 +40,8 @@ "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers": WanI2V_14B_720P_SamplingParam, "weizhou03/Wan2.1-Fun-1.3B-InP-Diffusers": Wan2_1_Fun_1_3B_InP_SamplingParam, + "IRMChen/Wan2.1-Fun-1.3B-Control-Diffusers": + Wan2_1_Fun_1_3B_Control_SamplingParam, # Wan2.2 "Wan-AI/Wan2.2-TI2V-5B-Diffusers": Wan2_2_TI2V_5B_SamplingParam, diff --git a/fastvideo/configs/sample/wan.py b/fastvideo/configs/sample/wan.py index d779cf2c0..8beab07f9 100644 --- a/fastvideo/configs/sample/wan.py +++ b/fastvideo/configs/sample/wan.py @@ -122,6 +122,17 @@ class Wan2_1_Fun_1_3B_InP_SamplingParam(SamplingParam): num_inference_steps: int = 50 +@dataclass +class Wan2_1_Fun_1_3B_Control_SamplingParam(SamplingParam): + fps: int = 16 + num_frames: int = 49 + height: int = 832 + width: int = 480 + guidance_scale: float = 6.0 + teacache_params: WanTeaCacheParams = field( + default_factory=lambda: WanTeaCacheParams(teacache_thresh=0.1, )) + + # ============================================= # ============= Wan2.2 TI2V Models ============= # ============================================= @@ -162,6 +173,12 @@ class Wan2_2_I2V_A14B_SamplingParam(Wan2_2_Base_SamplingParam): # can be overridden during sampling +@dataclass +class Wan2_2_Fun_A14B_Control_SamplingParam( + Wan2_1_Fun_1_3B_Control_SamplingParam): + num_frames: int = 81 + + # ============================================= # ============= Causal Self-Forcing ============= # ============================================= diff --git a/fastvideo/entrypoints/video_generator.py b/fastvideo/entrypoints/video_generator.py index 7d97b21bc..2606ecb56 100644 --- a/fastvideo/entrypoints/video_generator.py +++ b/fastvideo/entrypoints/video_generator.py @@ -298,7 +298,6 @@ def _generate_single_video( eta=0.0, n_tokens=n_tokens, VSA_sparsity=fastvideo_args.VSA_sparsity, - extra={}, ) # Use prompt[:100] for video name diff --git a/fastvideo/models/encoders/clip.py b/fastvideo/models/encoders/clip.py index 63c3ede47..940571b36 100644 --- a/fastvideo/models/encoders/clip.py +++ b/fastvideo/models/encoders/clip.py @@ -140,7 +140,7 @@ def __init__( "embed_dim must be divisible by num_heads " f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads}).") - self.scale = self.head_dim**-0.5 + self.scale = self.head_dim**-0.5 if config.enable_scale else None self.dropout = config.attention_dropout self.qkv_proj = QKVParallelLinear( @@ -166,7 +166,7 @@ def __init__( self.head_dim, self.num_heads_per_partition, softmax_scale=self.scale, - causal=True, + causal=config.is_causal, supported_attention_backends=config._supported_attention_backends) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): diff --git a/fastvideo/models/vision_utils.py b/fastvideo/models/vision_utils.py index d2a5e131a..f961ae3c5 100644 --- a/fastvideo/models/vision_utils.py +++ b/fastvideo/models/vision_utils.py @@ -3,6 +3,7 @@ import os import tempfile from collections.abc import Callable +from typing import Any from urllib.parse import unquote, urlparse import imageio @@ -11,6 +12,8 @@ import PIL.ImageOps import requests import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF from packaging import version if version.parse(version.parse( @@ -131,12 +134,88 @@ def load_image( return image +def _load_gif(gif_path: str) -> tuple[list[PIL.Image.Image], float | None]: + """ + Load frames from a GIF file. + + Args: + gif_path: Path to the GIF file + + Returns: + Tuple of (list of PIL images, original FPS or None) + """ + pil_images = [] + original_fps = None + + with PIL.Image.open(gif_path) as gif: + # Extract FPS from GIF metadata + if hasattr(gif, 'info') and 'duration' in gif.info: + duration_ms = gif.info['duration'] + if duration_ms > 0: + original_fps = 1000.0 / duration_ms + + # Extract all frames + try: + while True: + pil_images.append(gif.copy()) + gif.seek(gif.tell() + 1) + except EOFError: + # End of GIF reached + pass + + return pil_images, original_fps + + +def _load_video_with_ffmpeg( + video_path: str) -> tuple[list[PIL.Image.Image], float | None]: + """ + Load frames from a video file using ffmpeg. + + Args: + video_path: Path to the video file + + Returns: + Tuple of (list of PIL images, original FPS or None) + + Raises: + AttributeError: If ffmpeg is not installed + """ + # Verify ffmpeg is available + try: + imageio.plugins.ffmpeg.get_exe() + except AttributeError as e: + raise AttributeError( + "Unable to find an ffmpeg installation on your machine. " + "Please install via `pip install imageio-ffmpeg`") from e + + pil_images = [] + original_fps = None + + with imageio.get_reader(video_path) as reader: + # Try to extract FPS metadata + metadata = reader.get_meta_data() + original_fps = metadata.get('fps') + + # Fallback: try format-specific metadata + if original_fps is None: + source_size = metadata.get('source_size', {}) + if isinstance(source_size, dict): + original_fps = source_size.get('fps') + + # Extract all frames + for frame in reader: + pil_images.append(PIL.Image.fromarray(frame)) + + return pil_images, original_fps + + # adapted from diffusers.utils import load_video def load_video( video: str, convert_method: Callable[[list[PIL.Image.Image]], list[PIL.Image.Image]] | None = None, -) -> list[PIL.Image.Image]: + return_fps: bool = False, +) -> tuple[list[PIL.Image.Image], float | Any] | list[PIL.Image.Image]: """ Loads `video` to a list of PIL Image. Args: @@ -145,9 +224,12 @@ def load_video( convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*): A conversion method to apply to the video after loading it. When set to `None` the images will be converted to "RGB". + return_fps (`bool`, *optional*, defaults to `False`): + Whether to return the FPS of the video. If `True`, returns a tuple of (images, fps). + If `False`, returns only the list of images. Returns: - `List[PIL.Image.Image]`: - The video as a list of PIL images. + `List[PIL.Image.Image]` or `Tuple[List[PIL.Image.Image], float | None]`: + The video as a list of PIL images. If `return_fps` is True, also returns the original FPS. """ is_url = video.startswith("http://") or video.startswith("https://") is_file = os.path.isfile(video) @@ -175,39 +257,27 @@ def load_video( video_data = response.iter_content(chunk_size=8192) for chunk in video_data: temp_file.write(chunk) - - video = video_path - - pil_images = [] - if video.endswith(".gif"): - gif = PIL.Image.open(video) - try: - while True: - pil_images.append(gif.copy()) - gif.seek(gif.tell() + 1) - except EOFError: - pass - + was_tempfile_created = True else: - try: - imageio.plugins.ffmpeg.get_exe() - except AttributeError: - raise AttributeError( - "`Unable to find an ffmpeg installation on your machine. Please install via `pip install imageio-ffmpeg" - ) from None + video_path = video - with imageio.get_reader(video) as reader: - # Read all frames - for frame in reader: - pil_images.append(PIL.Image.fromarray(frame)) + pil_images = [] + original_fps = None - if was_tempfile_created: - os.remove(video_path) + try: + if video_path.endswith(".gif"): + pil_images, original_fps = _load_gif(video_path) + else: + pil_images, original_fps = _load_video_with_ffmpeg(video_path) + finally: + # Clean up temporary file if it was created + if was_tempfile_created and os.path.exists(video_path): + os.remove(video_path) if convert_method is not None: pil_images = convert_method(pil_images) - return pil_images + return pil_images, original_fps if return_fps else pil_images def get_default_height_width( @@ -297,3 +367,50 @@ def resize( else: raise ValueError(f"resize_mode {resize_mode} is not supported") return image + + +def create_default_image(width: int = 512, height: int = 512, color: tuple[int, int, int] = (0, 0, 0)) -> PIL.Image.Image: + """ + Create a default black PIL image. + + Args: + width: Image width in pixels + height: Image height in pixels + color: RGB color tuple + + Returns: + PIL.Image.Image: A new PIL image with specified dimensions and color + """ + return PIL.Image.new("RGB", (width, height), color=color) + + +def preprocess_reference_image_for_clip(image: PIL.Image.Image, device: torch.device) -> PIL.Image.Image: + """ + Preprocess reference image to match CLIP encoder requirements. + + Applies normalization, resizing to 224x224, and denormalization to ensure + the image is in the correct format for CLIP processing. + + Args: + image: Input PIL image + device: Target device for tensor operations + + Returns: + Preprocessed PIL image ready for CLIP encoder + """ + # Convert PIL to tensor and normalize to [-1, 1] range + image_tensor = TF.to_tensor(image).sub_(0.5).div_(0.5).to(device) + + # Resize to CLIP's expected input size (224x224) using bicubic interpolation + resized_tensor = F.interpolate( + image_tensor.unsqueeze(0), + size=(224, 224), + mode='bicubic', + align_corners=False + ).squeeze(0) + + # Denormalize back to [0, 1] range + denormalized_tensor = resized_tensor.mul_(0.5).add_(0.5) + + return TF.to_pil_image(denormalized_tensor) + diff --git a/fastvideo/pipelines/basic/wan/wan_v2v_pipeline.py b/fastvideo/pipelines/basic/wan/wan_v2v_pipeline.py new file mode 100644 index 000000000..0e0862b8e --- /dev/null +++ b/fastvideo/pipelines/basic/wan/wan_v2v_pipeline.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Wan video-to-video diffusion pipeline implementation. + +This module contains an implementation of the Wan video-to-video diffusion pipeline +using the modular pipeline architecture. +""" + +from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.logger import init_logger +from fastvideo.pipelines.composed_pipeline_base import ComposedPipelineBase +from fastvideo.pipelines.lora_pipeline import LoRAPipeline + +# isort: off +from fastvideo.pipelines.stages import ( + RefImageEncodingStage, ConditioningStage, DecodingStage, DenoisingStage, + VideoVAEEncodingStage, InputValidationStage, LatentPreparationStage, + TextEncodingStage, TimestepPreparationStage) +# isort: on +from fastvideo.models.schedulers.scheduling_flow_unipc_multistep import ( + FlowUniPCMultistepScheduler) + +logger = init_logger(__name__) + + +class WanVideoToVideoPipeline(LoRAPipeline, ComposedPipelineBase): + + _required_config_modules = [ + "text_encoder", "tokenizer", "vae", "transformer", "scheduler", \ + "image_encoder", "image_processor" + ] + + def initialize_pipeline(self, fastvideo_args: FastVideoArgs): + self.modules["scheduler"] = FlowUniPCMultistepScheduler( + shift=fastvideo_args.pipeline_config.flow_shift) + + def create_pipeline_stages(self, fastvideo_args: FastVideoArgs): + """Set up pipeline stages with proper dependency injection.""" + + self.add_stage(stage_name="input_validation_stage", + stage=InputValidationStage()) + + self.add_stage(stage_name="prompt_encoding_stage", + stage=TextEncodingStage( + text_encoders=[self.get_module("text_encoder")], + tokenizers=[self.get_module("tokenizer")], + )) + + if (self.get_module("image_encoder") is not None + and self.get_module("image_processor") is not None): + self.add_stage( + stage_name="ref_image_encoding_stage", + stage=RefImageEncodingStage( + image_encoder=self.get_module("image_encoder"), + image_processor=self.get_module("image_processor"), + )) + + self.add_stage(stage_name="conditioning_stage", + stage=ConditioningStage()) + + self.add_stage(stage_name="timestep_preparation_stage", + stage=TimestepPreparationStage( + scheduler=self.get_module("scheduler"))) + + self.add_stage(stage_name="latent_preparation_stage", + stage=LatentPreparationStage( + scheduler=self.get_module("scheduler"), + transformer=self.get_module("transformer"))) + + self.add_stage(stage_name="video_latent_preparation_stage", + stage=VideoVAEEncodingStage(vae=self.get_module("vae"))) + + self.add_stage(stage_name="denoising_stage", + stage=DenoisingStage( + transformer=self.get_module("transformer"), + transformer_2=self.get_module("transformer_2"), + scheduler=self.get_module("scheduler"))) + + self.add_stage(stage_name="decoding_stage", + stage=DecodingStage(vae=self.get_module("vae"))) + + +EntryClass = WanVideoToVideoPipeline diff --git a/fastvideo/pipelines/pipeline_batch_info.py b/fastvideo/pipelines/pipeline_batch_info.py index 7369d9b45..adb0f8840 100644 --- a/fastvideo/pipelines/pipeline_batch_info.py +++ b/fastvideo/pipelines/pipeline_batch_info.py @@ -86,6 +86,11 @@ class ForwardBatch: prompt_path: str | None = None output_path: str = "outputs/" output_video_name: str | None = None + + # Video inputs + video_path: str | None = None + video_latent: torch.Tensor | None = None + # Primary encoder embeddings prompt_embeds: list[torch.Tensor] = field(default_factory=list) negative_prompt_embeds: list[torch.Tensor] | None = None diff --git a/fastvideo/pipelines/pipeline_registry.py b/fastvideo/pipelines/pipeline_registry.py index 9955c3729..2f56d06c2 100644 --- a/fastvideo/pipelines/pipeline_registry.py +++ b/fastvideo/pipelines/pipeline_registry.py @@ -21,6 +21,7 @@ "WanPipeline": "wan", "WanDMDPipeline": "wan", "WanImageToVideoPipeline": "wan", + "WanVideoToVideoPipeline": "wan", "WanCausalDMDPipeline": "wan", "StepVideoPipeline": "stepvideo", "HunyuanVideoPipeline": "hunyuan", diff --git a/fastvideo/pipelines/stages/__init__.py b/fastvideo/pipelines/stages/__init__.py index 9566b426f..2880db3f2 100644 --- a/fastvideo/pipelines/stages/__init__.py +++ b/fastvideo/pipelines/stages/__init__.py @@ -14,7 +14,9 @@ DmdDenoisingStage) from fastvideo.pipelines.stages.encoding import EncodingStage from fastvideo.pipelines.stages.image_encoding import (ImageEncodingStage, - ImageVAEEncodingStage) + RefImageEncodingStage, + ImageVAEEncodingStage, + VideoVAEEncodingStage) from fastvideo.pipelines.stages.input_validation import InputValidationStage from fastvideo.pipelines.stages.latent_preparation import LatentPreparationStage from fastvideo.pipelines.stages.stepvideo_encoding import ( @@ -35,7 +37,9 @@ "EncodingStage", "DecodingStage", "ImageEncodingStage", + "RefImageEncodingStage", "ImageVAEEncodingStage", + "VideoVAEEncodingStage", "TextEncodingStage", "StepvideoPromptEncodingStage", ] diff --git a/fastvideo/pipelines/stages/denoising.py b/fastvideo/pipelines/stages/denoising.py index ed9a7d06e..e0510d6d9 100644 --- a/fastvideo/pipelines/stages/denoising.py +++ b/fastvideo/pipelines/stages/denoising.py @@ -283,9 +283,15 @@ def forward( current_guidance_scale = batch.guidance_scale_2 assert current_model is not None, "current_model is None" - # Expand latents for I2V + # Expand latents for V2V/I2V latent_model_input = latents.to(target_dtype) - if batch.image_latent is not None: + if batch.video_latent is not None: + latent_model_input = torch.cat([ + latent_model_input, batch.video_latent, + torch.zeros_like(latents) + ], + dim=1).to(target_dtype) + elif batch.image_latent is not None: assert not fastvideo_args.pipeline_config.ti2v_task, "image latents should not be provided for TI2V task" latent_model_input = torch.cat( [latent_model_input, batch.image_latent], diff --git a/fastvideo/pipelines/stages/image_encoding.py b/fastvideo/pipelines/stages/image_encoding.py index 4e817e4ae..3c6ff9b7a 100644 --- a/fastvideo/pipelines/stages/image_encoding.py +++ b/fastvideo/pipelines/stages/image_encoding.py @@ -1,8 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 """ -Image encoding stages for I2V diffusion pipelines. +Image and video encoding stages for diffusion pipelines. -This module contains implementations of image encoding stages for diffusion pipelines. +This module contains implementations of encoding stages for diffusion pipelines: +- ImageEncodingStage: Encodes images using image encoders (e.g., CLIP) +- RefImageEncodingStage: Encodes reference image for Wan2.1 control pipeline +- ImageVAEEncodingStage: Encodes images to latent space using VAE for I2V generation +- VideoVAEEncodingStage: Encodes videos to latent space using VAE for V2V and control tasks """ import PIL @@ -14,7 +18,9 @@ from fastvideo.logger import init_logger from fastvideo.models.vaes.common import ParallelTiledVAE from fastvideo.models.vision_utils import (get_default_height_width, normalize, - numpy_to_pt, pil_to_numpy, resize) + numpy_to_pt, pil_to_numpy, resize, + create_default_image, + preprocess_reference_image_for_clip) from fastvideo.pipelines.pipeline_batch_info import ForwardBatch from fastvideo.pipelines.stages.base import PipelineStage from fastvideo.pipelines.stages.validators import StageValidators as V @@ -94,12 +100,60 @@ def verify_output(self, batch: ForwardBatch, return result +class RefImageEncodingStage(ImageEncodingStage): + """ + Stage for encoding reference image prompts into embeddings for Wan2.1 Control models. + + This stage extends ImageEncodingStage with specialized preprocessing for reference images. + """ + + @torch.no_grad() + def forward( + self, + batch: ForwardBatch, + fastvideo_args: FastVideoArgs, + ) -> ForwardBatch: + """ + Encode the prompt into image encoder hidden states. + + Args: + batch: The current batch information. + fastvideo_args: The inference arguments. + + Returns: + The batch with encoded prompt embeddings. + """ + self.image_encoder = self.image_encoder.to(get_local_torch_device()) + + image = batch.pil_image + if image is None: + image = create_default_image() + # Preprocess reference image for CLIP encoder + image_tensor = preprocess_reference_image_for_clip( + image, get_local_torch_device()) + + image_inputs = self.image_processor(images=image_tensor, + return_tensors="pt").to( + get_local_torch_device()) + with set_forward_context(current_timestep=0, attn_metadata=None): + outputs = self.image_encoder(**image_inputs) + image_embeds = outputs.last_hidden_state + batch.image_embeds.append(image_embeds) + + if batch.pil_image is None: + batch.image_embeds = [ + torch.zeros_like(x) for x in batch.image_embeds + ] + + return batch + + class ImageVAEEncodingStage(PipelineStage): """ - Stage for encoding pixel representations into latent space. - - This stage handles the encoding of pixel representations into the final - input format (e.g., latents). + Stage for encoding image pixel representations into latent space. + + This stage handles the encoding of image pixel representations into the final + input format (e.g., latents) for image-to-video generation. """ def __init__(self, vae: ParallelTiledVAE) -> None: @@ -144,9 +198,9 @@ def forward( self.vae = self.vae.to(get_local_torch_device()) + # Process single image for I2V latent_height = height // self.vae.spatial_compression_ratio latent_width = width // self.vae.spatial_compression_ratio - image = batch.pil_image image = self.preprocess( image, @@ -296,3 +350,179 @@ def verify_output(self, batch: ForwardBatch, result.add_check("image_latent", batch.image_latent, [V.is_tensor, V.with_dims(5)]) return result + + +class VideoVAEEncodingStage(ImageVAEEncodingStage): + """ + Stage for encoding video pixel representations into latent space. + + This stage handles the encoding of video pixel representations for video-to-video generation and control. + Inherits from ImageVAEEncodingStage to reuse common functionality. + """ + + def forward( + self, + batch: ForwardBatch, + fastvideo_args: FastVideoArgs, + ) -> ForwardBatch: + """ + Encode video pixel representations into latent space. + + Args: + batch: The current batch information. + fastvideo_args: The inference arguments. + + Returns: + The batch with encoded outputs. + """ + assert batch.video_latent is not None, "Video latent input is required for VideoVAEEncodingStage" + + if fastvideo_args.mode == ExecutionMode.INFERENCE: + assert batch.height is not None and isinstance(batch.height, int) + assert batch.width is not None and isinstance(batch.width, int) + assert batch.num_frames is not None and isinstance( + batch.num_frames, int) + height = batch.height + width = batch.width + num_frames = batch.num_frames + elif fastvideo_args.mode == ExecutionMode.PREPROCESS: + assert batch.height is not None and isinstance(batch.height, list) + assert batch.width is not None and isinstance(batch.width, list) + assert batch.num_frames is not None and isinstance( + batch.num_frames, list) + num_frames = batch.num_frames[0] + height = batch.height[0] + width = batch.width[0] + + self.vae = self.vae.to(get_local_torch_device()) + + # Prepare video tensor from control video + video_condition = self._prepare_control_video_tensor( + batch.video_latent, num_frames, height, + width).to(get_local_torch_device(), dtype=torch.float32) + + # Setup VAE precision + vae_dtype = PRECISION_TO_TYPE[ + fastvideo_args.pipeline_config.vae_precision] + vae_autocast_enabled = ( + vae_dtype != torch.float32) and not fastvideo_args.disable_autocast + + # Encode control video + with torch.autocast(device_type="cuda", + dtype=vae_dtype, + enabled=vae_autocast_enabled): + if fastvideo_args.pipeline_config.vae_tiling: + self.vae.enable_tiling() + if not vae_autocast_enabled: + video_condition = video_condition.to(vae_dtype) + encoder_output = self.vae.encode(video_condition) + + generator = batch.generator + if generator is None: + raise ValueError("Generator must be provided") + latent_condition = self.retrieve_latents(encoder_output, generator) + + if (hasattr(self.vae, "shift_factor") + and self.vae.shift_factor is not None): + if isinstance(self.vae.shift_factor, torch.Tensor): + latent_condition -= self.vae.shift_factor.to( + latent_condition.device, latent_condition.dtype) + else: + latent_condition -= self.vae.shift_factor + + if isinstance(self.vae.scaling_factor, torch.Tensor): + latent_condition = latent_condition * self.vae.scaling_factor.to( + latent_condition.device, latent_condition.dtype) + else: + latent_condition = latent_condition * self.vae.scaling_factor + + batch.video_latent = latent_condition + + # Offload models if needed + if hasattr(self, 'maybe_free_model_hooks'): + self.maybe_free_model_hooks() + + self.vae.to("cpu") + + return batch + + def _prepare_control_video_tensor(self, control_video, num_frames: int, + height: int, width: int) -> torch.Tensor: + """ + Prepare video tensor from control video input. + """ + if isinstance(control_video, list): + processed_frames = [] + for i, frame in enumerate(control_video): + if i >= num_frames: + break + processed_frame = self.preprocess( + frame, + vae_scale_factor=self.vae.spatial_compression_ratio, + height=height, + width=width).to(get_local_torch_device(), + dtype=torch.float32) + processed_frames.append(processed_frame) + + if processed_frames: + video_tensor = torch.cat( + [f.unsqueeze(2) for f in processed_frames], dim=2) + else: + video_tensor = torch.zeros(1, + 3, + 0, + height, + width, + device=get_local_torch_device(), + dtype=torch.float32) + elif isinstance(control_video, torch.Tensor): + # Handle tensor input [batch, channels, frames, height, width] + video_tensor = control_video.to(get_local_torch_device(), + dtype=torch.float32) + + if video_tensor.shape[2] > num_frames: + video_tensor = video_tensor[:, :, :num_frames] + else: + raise ValueError( + f"Unsupported control_video type: {type(control_video)}. " + "Expected list of PIL Images or torch.Tensor.") + + # Pad with zeros if we have fewer frames than required + current_frames = video_tensor.shape[2] + if current_frames < num_frames: + padding_frames = num_frames - current_frames + zero_padding = torch.zeros(video_tensor.shape[0], + video_tensor.shape[1], + padding_frames, + height, + width, + device=video_tensor.device, + dtype=video_tensor.dtype) + video_tensor = torch.cat([video_tensor, zero_padding], dim=2) + + return video_tensor + + def verify_input(self, batch: ForwardBatch, + fastvideo_args: FastVideoArgs) -> VerificationResult: + """Verify video encoding stage inputs.""" + result = VerificationResult() + result.add_check("video_latent", batch.video_latent, V.not_none) + result.add_check("generator", batch.generator, + V.generator_or_list_generators) + if fastvideo_args.mode == ExecutionMode.PREPROCESS: + result.add_check("height", batch.height, V.list_not_empty) + result.add_check("width", batch.width, V.list_not_empty) + result.add_check("num_frames", batch.num_frames, V.list_not_empty) + else: + result.add_check("height", batch.height, V.positive_int) + result.add_check("width", batch.width, V.positive_int) + result.add_check("num_frames", batch.num_frames, V.positive_int) + return result + + def verify_output(self, batch: ForwardBatch, + fastvideo_args: FastVideoArgs) -> VerificationResult: + """Verify video encoding stage outputs.""" + result = VerificationResult() + result.add_check("video_latent", batch.video_latent, + [V.is_tensor, V.with_dims(5)]) + return result diff --git a/fastvideo/pipelines/stages/input_validation.py b/fastvideo/pipelines/stages/input_validation.py index ff37bb13c..5b06e968e 100644 --- a/fastvideo/pipelines/stages/input_validation.py +++ b/fastvideo/pipelines/stages/input_validation.py @@ -9,7 +9,7 @@ from fastvideo.fastvideo_args import FastVideoArgs from fastvideo.logger import init_logger -from fastvideo.models.vision_utils import load_image, load_video +from fastvideo.models.vision_utils import load_image, load_video, pil_to_numpy, numpy_to_pt, normalize, resize from fastvideo.pipelines.pipeline_batch_info import ForwardBatch from fastvideo.pipelines.stages.base import PipelineStage from fastvideo.pipelines.stages.validators import (StageValidators, @@ -135,6 +135,55 @@ def forward( batch.width = ow batch.pil_image = img + # for v2v, get control video from video path + if batch.video_path is not None: + pil_images, original_fps = load_video(batch.video_path, + return_fps=True) + logger.info("Loaded video with %s frames, original FPS: %s", + len(pil_images), original_fps) + + # Get target parameters from batch + target_fps = batch.fps + target_num_frames = batch.num_frames + target_height = batch.height + target_width = batch.width + + if target_fps is not None and original_fps is not None: + frame_skip = max(1, int(original_fps // target_fps)) + if frame_skip > 1: + pil_images = pil_images[::frame_skip] + effective_fps = original_fps / frame_skip + logger.info( + "Resampled video from %.1f fps to %.1f fps (skip=%s)", + original_fps, effective_fps, frame_skip) + + # Limit to target number of frames + if target_num_frames is not None and len( + pil_images) > target_num_frames: + pil_images = pil_images[:target_num_frames] + logger.info("Limited video to %s frames (from %s total)", + target_num_frames, len(pil_images)) + + # Resize each PIL image to target dimensions + resized_images = [] + for pil_img in pil_images: + resized_img = resize(pil_img, + target_height, + target_width, + resize_mode="default", + resample="lanczos") + resized_images.append(resized_img) + + # Convert PIL images to numpy array + video_numpy = pil_to_numpy(resized_images) + video_numpy = normalize(video_numpy) + video_tensor = numpy_to_pt(video_numpy) + + # Rearrange to [C, T, H, W] and add batch dimension -> [B, C, T, H, W] + input_video = video_tensor.permute(1, 0, 2, 3).unsqueeze(0) + + batch.video_latent = input_video + return batch def verify_input(self, batch: ForwardBatch, diff --git a/fastvideo/utils.py b/fastvideo/utils.py index 50937aa24..39bf651b7 100644 --- a/fastvideo/utils.py +++ b/fastvideo/utils.py @@ -26,7 +26,7 @@ import imageio import numpy as np import torch -import torchvision +import torchvision.utils as make_grid import yaml from diffusers.loaders.lora_base import ( _best_guess_weight_name) # watch out for potetential removal from diffusers @@ -898,7 +898,7 @@ def save_decoded_latents_as_video(decoded_latents: list[torch.Tensor], videos = rearrange(decoded_latents, "b c t h w -> t b c h w") frames = [] for x in videos: - x = torchvision.utils.make_grid(x, nrow=6) + x = make_grid(x, nrow=6) x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) frames.append((x * 255).numpy().astype(np.uint8))