Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions examples/inference/lora/wan_multi_lora_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from fastvideo import VideoGenerator

OUTPUT_PATH = "./multi_lora"


def main():
# Create a generator for WanVideo2.1 I2V
generator = VideoGenerator.from_pretrained(
"Wan-AI/Wan2.1-I2V-14B-480P",
num_gpus=1,
)

# Load three LoRA adapters into the pipeline
generator.set_lora_adapter("lora1", "path/to/first_lora")
generator.set_lora_adapter("lora2", "path/to/second_lora")
generator.set_lora_adapter("lora3", "path/to/third_lora")
Comment on lines +14 to +16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use real paths?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel examples/inference/lora/wan_lora_inference.py already includes this


# The last call activates lora3. Generate a video with it
prompt = "An astronaut explores a strange new world, cinematic scene"
generator.generate_video(prompt, output_path=OUTPUT_PATH, save_video=True)

# Switch to lora1 and generate another video
generator.set_lora_adapter("lora1")
generator.generate_video(prompt, output_path=OUTPUT_PATH, save_video=True)

# Switch to lora2 and generate one more video
generator.set_lora_adapter("lora2")
generator.generate_video(prompt, output_path=OUTPUT_PATH, save_video=True)


if __name__ == "__main__":
main()
36 changes: 36 additions & 0 deletions examples/inference/optimizations/optimized_wan_i2v_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from fastvideo import VideoGenerator
from fastvideo.v1.configs.sample import SamplingParam


OUTPUT_PATH = "./optimized_output"


def main():
"""Run WanVideo2.1 I2V pipeline with all optimizations enabled."""
generator = VideoGenerator.from_pretrained(
"Wan-AI/Wan2.1-I2V-14B-480P",
num_gpus=1,
skip_layer_guidance=0.2,
use_normalized_attention=True,
nag_scale=1.5,
nag_tau=2.5,
nag_alpha=0.125,
use_dcm=True,
use_taylor_seer=True,
taylor_seer_order=2,
)

sampling = SamplingParam.from_pretrained("Wan-AI/Wan2.1-I2V-14B-480P")

prompt = "A lone explorer crosses a vast alien desert under twin moons"
generator.generate_video(
prompt,
sampling_param=sampling,
output_path=OUTPUT_PATH,
save_video=True,
)


if __name__ == "__main__":
main()

67 changes: 67 additions & 0 deletions fastvideo/v1/configs/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,16 @@ class PipelineConfig:
STA_mode: STA_Mode = STA_Mode.STA_INFERENCE
skip_time_steps: int = 15

# Additional guidance/optimization parameters
skip_layer_guidance: float | None = None # fraction of denoise steps without CFG
use_normalized_attention: bool = False
nag_scale: float = 1.5
nag_tau: float = 2.5
nag_alpha: float = 0.125
use_dcm: bool = False
use_taylor_seer: bool = False
taylor_seer_order: int = 2

# Compilation
# enable_torch_compile: bool = False

Expand Down Expand Up @@ -206,6 +216,63 @@ def add_cli_args(parser: FlexibleArgumentParser,
"Bool for applying scheduler scale in set_timesteps, used in stepvideo",
)

parser.add_argument(
f"--{prefix_with_dot}skip-layer-guidance",
type=float,
dest=f"{prefix_with_dot.replace('-', '_')}skip_layer_guidance",
default=PipelineConfig.skip_layer_guidance,
help="Fraction of steps to disable CFG for SkipLayerGuidance",
)
parser.add_argument(
f"--{prefix_with_dot}use-normalized-attention",
action=StoreBoolean,
dest=f"{prefix_with_dot.replace('-', '_')}use_normalized_attention",
default=PipelineConfig.use_normalized_attention,
help="Enable Normalized Attention Guidance",
)
parser.add_argument(
f"--{prefix_with_dot}nag-scale",
type=float,
dest=f"{prefix_with_dot.replace('-', '_')}nag_scale",
default=PipelineConfig.nag_scale,
help="Scale for Normalized Attention Guidance",
)
parser.add_argument(
f"--{prefix_with_dot}nag-tau",
type=float,
dest=f"{prefix_with_dot.replace('-', '_')}nag_tau",
default=PipelineConfig.nag_tau,
help="Tau parameter for Normalized Attention Guidance",
)
parser.add_argument(
f"--{prefix_with_dot}nag-alpha",
type=float,
dest=f"{prefix_with_dot.replace('-', '_')}nag_alpha",
default=PipelineConfig.nag_alpha,
help="Alpha parameter for Normalized Attention Guidance",
)
parser.add_argument(
f"--{prefix_with_dot}use-dcm",
action=StoreBoolean,
dest=f"{prefix_with_dot.replace('-', '_')}use_dcm",
default=PipelineConfig.use_dcm,
help="Enable Dynamic Convolution Module",
)
parser.add_argument(
f"--{prefix_with_dot}use-taylor-seer",
action=StoreBoolean,
dest=f"{prefix_with_dot.replace('-', '_')}use_taylor_seer",
default=PipelineConfig.use_taylor_seer,
help="Enable TaylorSeer optimization",
)
parser.add_argument(
f"--{prefix_with_dot}taylor-seer-order",
type=int,
dest=f"{prefix_with_dot.replace('-', '_')}taylor_seer_order",
default=PipelineConfig.taylor_seer_order,
help="Derivative order for TaylorSeer optimization",
)

# Add VAE configuration arguments
from fastvideo.v1.configs.models.vaes.base import VAEConfig
VAEConfig.add_cli_args(parser, prefix=f"{prefix_with_dot}vae-config")
Expand Down
140 changes: 138 additions & 2 deletions fastvideo/v1/pipelines/stages/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,99 @@
logger = init_logger(__name__)


def apply_normalized_attention_guidance(
pos: torch.Tensor,
neg: torch.Tensor | None = None,
nag_scale: float = 1.5,
nag_tau: float = 2.5,
nag_alpha: float = 0.125,
) -> torch.Tensor:
"""Apply Normalized Attention Guidance (NAG) to noise predictions.

This implementation follows the formula from the official NAG repository
and operates on the positive and negative noise predictions.
"""

if neg is None:
flat = pos.flatten(2)
mean = flat.mean(dim=-1, keepdim=True)
var = flat.var(dim=-1, unbiased=False, keepdim=True)
normalized = (flat - mean) / (var + 1e-6).sqrt()
return normalized.view_as(pos)

pos_flat = pos.flatten(2)
neg_flat = neg.flatten(2)

guidance = pos_flat * nag_scale - neg_flat * (nag_scale - 1)
norm_pos = pos_flat.norm(p=2, dim=-1, keepdim=True)
norm_guidance = guidance.norm(p=2, dim=-1, keepdim=True)
scale = norm_guidance / (norm_pos + 1e-7)
guidance = guidance * torch.minimum(scale, scale.new_ones(1) * nag_tau) / (
scale + 1e-7
)

out = guidance * nag_alpha + pos_flat * (1 - nag_alpha)
return out.view_as(pos)


_dcm_modules: dict[torch.device, tuple[torch.nn.Conv3d, torch.nn.Conv3d, torch.nn.Conv3d]] = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider using torch.nn.ModuleDict instead of a regular Python dict to store the conv modules. This will properly register the modules with PyTorch and ensure they are moved to the correct device when the model is moved.

Suggested change
_dcm_modules: dict[torch.device, tuple[torch.nn.Conv3d, torch.nn.Conv3d, torch.nn.Conv3d]] = {}
from torch import nn
_dcm_modules: nn.ModuleDict[torch.device, tuple[torch.nn.Conv3d, torch.nn.Conv3d, torch.nn.Conv3d]] = nn.ModuleDict()



def apply_dcm(tensor: torch.Tensor) -> torch.Tensor:
"""Apply Dynamic Convolution Module (DCM)."""
global _dcm_modules
conv_offset, conv_weight, conv_gate = _dcm_modules.get(tensor.device, (None, None, None))
if conv_offset is None:
channels = tensor.size(1)
conv_offset = torch.nn.Conv3d(channels, channels, kernel_size=3, padding=1, bias=False).to(
tensor.device, tensor.dtype
)
conv_weight = torch.nn.Conv3d(channels, channels, kernel_size=3, padding=1, bias=False).to(
tensor.device, tensor.dtype
)
conv_gate = torch.nn.Conv3d(channels, channels, kernel_size=3, padding=1).to(
tensor.device, tensor.dtype
)
_dcm_modules[tensor.device] = (conv_offset, conv_weight, conv_gate)

offset = conv_offset(tensor)
out = conv_weight(tensor + offset)
gate = torch.sigmoid(conv_gate(tensor))
return tensor + gate * out


_taylor_cache: dict[torch.device, dict[str, Any]] = {}


def apply_taylor_seer(tensor: torch.Tensor, step: int, order: int = 2) -> torch.Tensor:
"""Apply TaylorSeer optimization using a simple derivative cache."""
cache = _taylor_cache.setdefault(tensor.device, {
"prev": None,
"prev_diff": None,
"prev_step": None,
})
Comment on lines +114 to +118
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The cache is initialized without a specified type. It's recommended to explicitly define the type for clarity and to avoid potential issues with type checking.

Suggested change
cache = _taylor_cache.setdefault(tensor.device, {
"prev": None,
"prev_diff": None,
"prev_step": None,
})
cache: dict[str, Any] = _taylor_cache.setdefault(tensor.device, {
"prev": None,
"prev_diff": None,
"prev_step": None,
})


if cache["prev"] is None:
cache["prev"] = tensor.detach()
cache["prev_step"] = step
return tensor

dt = step - cache["prev_step"]
if dt == 0:
return tensor

diff = (tensor - cache["prev"]) / dt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Division by dt can lead to instability if dt is very small. Consider adding a small constant to the denominator to prevent division by zero or very small values.

Suggested change
diff = (tensor - cache["prev"]) / dt
diff = (tensor - cache["prev"]) / (dt + 1e-8)

result = cache["prev"] + diff * dt
if order >= 2 and cache["prev_diff"] is not None:
second = (diff - cache["prev_diff"]) / dt
result = result + 0.5 * second * dt * dt

cache["prev"] = tensor.detach()
cache["prev_diff"] = diff
cache["prev_step"] = step
return result


class DenoisingStage(PipelineStage):
"""
Stage for running the denoising loop in diffusion pipelines.
Expand Down Expand Up @@ -83,6 +176,9 @@ def forward(
Returns:
The batch with denoised latents.
"""
# Reset caches for optional optimizations
_taylor_cache.clear()

# Prepare extra step kwargs for scheduler
extra_step_kwargs = self.prepare_extra_func_kwargs(
self.scheduler.step,
Expand Down Expand Up @@ -264,8 +360,31 @@ def forward(
**neg_cond_kwargs,
)
noise_pred_text = noise_pred
noise_pred = noise_pred_uncond + batch.guidance_scale * (
noise_pred_text - noise_pred_uncond)
if fastvideo_args.pipeline_config.skip_layer_guidance and (
i / len(timesteps)
< fastvideo_args.pipeline_config.skip_layer_guidance
):
noise_pred = noise_pred_text
else:
noise_pred = noise_pred_uncond + batch.guidance_scale * (
noise_pred_text - noise_pred_uncond)

if fastvideo_args.pipeline_config.use_normalized_attention:
noise_pred = apply_normalized_attention_guidance(
noise_pred_text,
noise_pred_uncond,
nag_scale=fastvideo_args.pipeline_config.nag_scale * batch.guidance_scale,
nag_tau=fastvideo_args.pipeline_config.nag_tau,
nag_alpha=fastvideo_args.pipeline_config.nag_alpha,
)
Comment on lines +372 to +379
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider moving the nag_scale multiplication into the apply_normalized_attention_guidance function to encapsulate the logic within the function itself.

                            noise_pred = apply_normalized_attention_guidance(
                                noise_pred_text,
                                noise_pred_uncond,
                                nag_scale=fastvideo_args.pipeline_config.nag_scale, # move multiplication here
                                nag_tau=fastvideo_args.pipeline_config.nag_tau,
                                nag_alpha=fastvideo_args.pipeline_config.nag_alpha,
                            )

if fastvideo_args.pipeline_config.use_dcm:
noise_pred = apply_dcm(noise_pred)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a check to ensure that apply_dcm and apply_taylor_seer are only applied if the corresponding flags in fastvideo_args.pipeline_config are set to True. This can prevent unnecessary computations when these optimizations are not enabled.

                        if fastvideo_args.pipeline_config.use_dcm:
                            noise_pred = apply_dcm(noise_pred)
                        if fastvideo_args.pipeline_config.use_taylor_seer:

if fastvideo_args.pipeline_config.use_taylor_seer:
noise_pred = apply_taylor_seer(
noise_pred,
i,
order=fastvideo_args.pipeline_config.taylor_seer_order,
)
Comment on lines +383 to +387
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The apply_taylor_seer function is called with the current step i as an argument. Ensure that the step argument in apply_taylor_seer is correctly representing the timestep for the Taylor series approximation. Mismatched timesteps could lead to incorrect derivative calculations.

                            noise_pred = apply_taylor_seer(
                                noise_pred,
                                i, # verify this is the correct timestep
                                order=fastvideo_args.pipeline_config.taylor_seer_order,
                            )


# Apply guidance rescale if needed
if batch.guidance_rescale > 0.0:
Expand All @@ -276,6 +395,23 @@ def forward(
guidance_rescale=batch.guidance_rescale,
)

if not batch.do_classifier_free_guidance:
if fastvideo_args.pipeline_config.use_normalized_attention:
noise_pred = apply_normalized_attention_guidance(
noise_pred,
nag_scale=fastvideo_args.pipeline_config.nag_scale,
nag_tau=fastvideo_args.pipeline_config.nag_tau,
nag_alpha=fastvideo_args.pipeline_config.nag_alpha,
Comment on lines +400 to +404
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider moving the nag_scale multiplication into the apply_normalized_attention_guidance function to encapsulate the logic within the function itself.

                            noise_pred = apply_normalized_attention_guidance(
                                noise_pred,
                                nag_scale=fastvideo_args.pipeline_config.nag_scale, # move multiplication here
                                nag_tau=fastvideo_args.pipeline_config.nag_tau,
                                nag_alpha=fastvideo_args.pipeline_config.nag_alpha,
                            )

)
if fastvideo_args.pipeline_config.use_dcm:
noise_pred = apply_dcm(noise_pred)
if fastvideo_args.pipeline_config.use_taylor_seer:
noise_pred = apply_taylor_seer(
noise_pred,
i,
order=fastvideo_args.pipeline_config.taylor_seer_order,
)

# Compute the previous noisy sample
latents = self.scheduler.step(noise_pred,
t,
Expand Down