Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions examples/inference/basic/basic_wan2_2_Fun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from fastvideo import VideoGenerator

# from fastvideo.configs.sample import SamplingParam

OUTPUT_PATH = "video_samples_wan2_1_Fun"
OUTPUT_NAME = "wan2.1_test"
def main():
# FastVideo will automatically use the optimal default arguments for the
# model.
# If a local path is provided, FastVideo will make a best effort
# attempt to identify the optimal arguments.
generator = VideoGenerator.from_pretrained(
"IRMChen/Wan2.1-Fun-1.3B-Control-Diffusers",
# "alibaba-pai/Wan2.2-Fun-A14B-Control",
# FastVideo will automatically handle distributed setup
num_gpus=1,
use_fsdp_inference=True,
dit_cpu_offload=True, # DiT need to be offloaded for MoE
vae_cpu_offload=False,
text_encoder_cpu_offload=True,
# Set pin_cpu_memory to false if CPU RAM is limited and there're no frequent CPU-GPU transfer
pin_cpu_memory=True,
# image_encoder_cpu_offload=False,
)

prompt = "一位年轻女性穿着一件粉色的连衣裙,裙子上有白色的装饰和粉色的纽扣。她的头发是紫色的,头上戴着一个红色的大蝴蝶结,显得非常可爱和精致。她还戴着一个红色的领结,整体造型充满了少女感和活力。她的表情温柔,双手轻轻交叉放在身前,姿态优雅。背景是简单的灰色,没有任何多余的装饰,使得人物更加突出。她的妆容清淡自然,突显了她的清新气质。整体画面给人一种甜美、梦幻的感觉,仿佛置身于童话世界中。"
negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
# prompt = "A young woman with beautiful, clear eyes and blonde hair stands in the forest, wearing a white dress and a crown. Her expression is serene, reminiscent of a movie star, with fair and youthful skin. Her brown long hair flows in the wind. The video quality is very high, with a clear view. High quality, masterpiece, best quality, high resolution, ultra-fine, fantastical."
# negative_prompt = "Twisted body, limb deformities, text captions, comic, static, ugly, error, messy code."
image_path = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/wan_fun/asset_Wan2_2/v1.0/8.png"
control_video_path = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/wan_fun/asset_Wan2_2/v1.0/pose.mp4"

video = generator.generate_video(prompt, negative_prompt=negative_prompt, image_path=image_path, video_path=control_video_path, output_path=OUTPUT_PATH, output_video_name=OUTPUT_NAME, save_video=True)

if __name__ == "__main__":
main()
8 changes: 4 additions & 4 deletions fastvideo/configs/models/encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
EncoderConfig,
ImageEncoderConfig,
TextEncoderConfig)
from fastvideo.configs.models.encoders.clip import (CLIPTextConfig,
CLIPVisionConfig)
from fastvideo.configs.models.encoders.clip import (
CLIPTextConfig, CLIPVisionConfig, WAN2_1ControlCLIPVisionConfig)
from fastvideo.configs.models.encoders.llama import LlamaConfig
from fastvideo.configs.models.encoders.t5 import T5Config

__all__ = [
"EncoderConfig", "TextEncoderConfig", "ImageEncoderConfig",
"BaseEncoderOutput", "CLIPTextConfig", "CLIPVisionConfig", "LlamaConfig",
"T5Config"
"BaseEncoderOutput", "CLIPTextConfig", "CLIPVisionConfig",
"WAN2_1ControlCLIPVisionConfig", "LlamaConfig", "T5Config"
]
12 changes: 12 additions & 0 deletions fastvideo/configs/models/encoders/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ class CLIPTextConfig(TextEncoderConfig):

num_hidden_layers_override: int | None = None
require_post_norm: bool | None = None
enable_scale: bool = True
is_causal: bool = True
prefix: str = "clip"


Expand All @@ -87,4 +89,14 @@ class CLIPVisionConfig(ImageEncoderConfig):

num_hidden_layers_override: int | None = None
require_post_norm: bool | None = None
enable_scale: bool = True
is_causal: bool = True
prefix: str = "clip"


@dataclass
class WAN2_1ControlCLIPVisionConfig(CLIPVisionConfig):
num_hidden_layers_override: int | None = 31
require_post_norm: bool | None = False
enable_scale: bool = False
is_causal: bool = False
3 changes: 2 additions & 1 deletion fastvideo/configs/pipelines/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
FastWan2_1_T2V_480P_Config, FastWan2_2_TI2V_5B_Config,
Wan2_2_I2V_A14B_Config, Wan2_2_T2V_A14B_Config, Wan2_2_TI2V_5B_Config,
WanI2V480PConfig, WanI2V720PConfig, WanT2V480PConfig, WanT2V720PConfig,
SelfForcingWanT2V480PConfig)
SelfForcingWanT2V480PConfig, WANV2VConfig)
# isort: on
from fastvideo.logger import init_logger
from fastvideo.utils import (maybe_download_model_index,
Expand All @@ -27,6 +27,7 @@
"hunyuanvideo-community/HunyuanVideo": HunyuanConfig,
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers": WanT2V480PConfig,
"weizhou03/Wan2.1-Fun-1.3B-InP-Diffusers": WanI2V480PConfig,
"IRMChen/Wan2.1-Fun-1.3B-Control-Diffusers": WANV2VConfig,
"Wan-AI/Wan2.1-I2V-14B-480P-Diffusers": WanI2V480PConfig,
"Wan-AI/Wan2.1-I2V-14B-720P-Diffusers": WanI2V720PConfig,
"Wan-AI/Wan2.1-T2V-14B-Diffusers": WanT2V720PConfig,
Expand Down
13 changes: 12 additions & 1 deletion fastvideo/configs/pipelines/wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from fastvideo.configs.models import DiTConfig, EncoderConfig, VAEConfig
from fastvideo.configs.models.dits import WanVideoConfig
from fastvideo.configs.models.encoders import (BaseEncoderOutput,
CLIPVisionConfig, T5Config)
CLIPVisionConfig, T5Config,
WAN2_1ControlCLIPVisionConfig)
from fastvideo.configs.models.vaes import WanVAEConfig
from fastvideo.configs.pipelines.base import PipelineConfig

Expand Down Expand Up @@ -97,6 +98,16 @@ class WanI2V720PConfig(WanI2V480PConfig):
flow_shift: float | None = 5.0


@dataclass
class WANV2VConfig(WanI2V480PConfig):
"""Configuration for WAN2.1 1.3B Control pipeline."""

image_encoder_config: EncoderConfig = field(
default_factory=WAN2_1ControlCLIPVisionConfig)
# CLIP encoder precision
image_encoder_precision: str = 'bf16'


@dataclass
class FastWan2_1_T2V_480P_Config(WanT2V480PConfig):
"""Base configuration for FastWan T2V 1.3B 480P pipeline architecture with DMD"""
Expand Down
9 changes: 9 additions & 0 deletions fastvideo/configs/sample/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ class SamplingParam:
# Image inputs
image_path: str | None = None

# Video inputs
video_path: str | None = None

# Text inputs
prompt: str | list[str] | None = None
negative_prompt: str = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
Expand Down Expand Up @@ -200,6 +203,12 @@ def add_cli_args(parser: Any) -> Any:
default=SamplingParam.image_path,
help="Path to input image for image-to-video generation",
)
parser.add_argument(
"--video_path",
type=str,
default=SamplingParam.video_path,
help="Path to input video for video-to-video generation",
)
parser.add_argument(
"--moba-config-path",
type=str,
Expand Down
3 changes: 3 additions & 0 deletions fastvideo/configs/sample/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
WanI2V_14B_720P_SamplingParam,
WanT2V_1_3B_SamplingParam,
WanT2V_14B_SamplingParam,
Wan2_1_Fun_1_3B_Control_SamplingParam,
SelfForcingWanT2V480PConfig,
)
# isort: on
Expand All @@ -39,6 +40,8 @@
"Wan-AI/Wan2.1-I2V-14B-720P-Diffusers": WanI2V_14B_720P_SamplingParam,
"weizhou03/Wan2.1-Fun-1.3B-InP-Diffusers":
Wan2_1_Fun_1_3B_InP_SamplingParam,
"IRMChen/Wan2.1-Fun-1.3B-Control-Diffusers":
Wan2_1_Fun_1_3B_Control_SamplingParam,

# Wan2.2
"Wan-AI/Wan2.2-TI2V-5B-Diffusers": Wan2_2_TI2V_5B_SamplingParam,
Expand Down
17 changes: 17 additions & 0 deletions fastvideo/configs/sample/wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,17 @@ class Wan2_1_Fun_1_3B_InP_SamplingParam(SamplingParam):
num_inference_steps: int = 50


@dataclass
class Wan2_1_Fun_1_3B_Control_SamplingParam(SamplingParam):
fps: int = 16
num_frames: int = 49
height: int = 832
width: int = 480
guidance_scale: float = 6.0
teacache_params: WanTeaCacheParams = field(
default_factory=lambda: WanTeaCacheParams(teacache_thresh=0.1, ))


# =============================================
# ============= Wan2.2 TI2V Models =============
# =============================================
Expand Down Expand Up @@ -162,6 +173,12 @@ class Wan2_2_I2V_A14B_SamplingParam(Wan2_2_Base_SamplingParam):
# can be overridden during sampling


@dataclass
class Wan2_2_Fun_A14B_Control_SamplingParam(
Wan2_1_Fun_1_3B_Control_SamplingParam):
num_frames: int = 81


# =============================================
# ============= Causal Self-Forcing =============
# =============================================
Expand Down
1 change: 0 additions & 1 deletion fastvideo/entrypoints/video_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,6 @@ def _generate_single_video(
eta=0.0,
n_tokens=n_tokens,
VSA_sparsity=fastvideo_args.VSA_sparsity,
extra={},
)

# Use prompt[:100] for video name
Expand Down
4 changes: 2 additions & 2 deletions fastvideo/models/encoders/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(
"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
self.scale = self.head_dim**-0.5
self.scale = self.head_dim**-0.5 if config.enable_scale else None
self.dropout = config.attention_dropout

self.qkv_proj = QKVParallelLinear(
Expand All @@ -166,7 +166,7 @@ def __init__(
self.head_dim,
self.num_heads_per_partition,
softmax_scale=self.scale,
causal=True,
causal=config.is_causal,
supported_attention_backends=config._supported_attention_backends)

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
Expand Down
Loading