diff --git a/examples/inference/lora/wan_multi_lora_inference.py b/examples/inference/lora/wan_multi_lora_inference.py new file mode 100644 index 000000000..f17d59eee --- /dev/null +++ b/examples/inference/lora/wan_multi_lora_inference.py @@ -0,0 +1,32 @@ +from fastvideo import VideoGenerator + +OUTPUT_PATH = "./multi_lora" + + +def main(): + # Create a generator for WanVideo2.1 I2V + generator = VideoGenerator.from_pretrained( + "Wan-AI/Wan2.1-I2V-14B-480P", + num_gpus=1, + ) + + # Load three LoRA adapters into the pipeline + generator.set_lora_adapter("lora1", "path/to/first_lora") + generator.set_lora_adapter("lora2", "path/to/second_lora") + generator.set_lora_adapter("lora3", "path/to/third_lora") + + # The last call activates lora3. Generate a video with it + prompt = "An astronaut explores a strange new world, cinematic scene" + generator.generate_video(prompt, output_path=OUTPUT_PATH, save_video=True) + + # Switch to lora1 and generate another video + generator.set_lora_adapter("lora1") + generator.generate_video(prompt, output_path=OUTPUT_PATH, save_video=True) + + # Switch to lora2 and generate one more video + generator.set_lora_adapter("lora2") + generator.generate_video(prompt, output_path=OUTPUT_PATH, save_video=True) + + +if __name__ == "__main__": + main() diff --git a/examples/inference/optimizations/optimized_wan_i2v_example.py b/examples/inference/optimizations/optimized_wan_i2v_example.py new file mode 100644 index 000000000..f029f431f --- /dev/null +++ b/examples/inference/optimizations/optimized_wan_i2v_example.py @@ -0,0 +1,36 @@ +from fastvideo import VideoGenerator +from fastvideo.v1.configs.sample import SamplingParam + + +OUTPUT_PATH = "./optimized_output" + + +def main(): + """Run WanVideo2.1 I2V pipeline with all optimizations enabled.""" + generator = VideoGenerator.from_pretrained( + "Wan-AI/Wan2.1-I2V-14B-480P", + num_gpus=1, + skip_layer_guidance=0.2, + use_normalized_attention=True, + nag_scale=1.5, + nag_tau=2.5, + nag_alpha=0.125, + use_dcm=True, + use_taylor_seer=True, + taylor_seer_order=2, + ) + + sampling = SamplingParam.from_pretrained("Wan-AI/Wan2.1-I2V-14B-480P") + + prompt = "A lone explorer crosses a vast alien desert under twin moons" + generator.generate_video( + prompt, + sampling_param=sampling, + output_path=OUTPUT_PATH, + save_video=True, + ) + + +if __name__ == "__main__": + main() + diff --git a/fastvideo/v1/configs/pipelines/base.py b/fastvideo/v1/configs/pipelines/base.py index 7ec4027b2..1b35b069a 100644 --- a/fastvideo/v1/configs/pipelines/base.py +++ b/fastvideo/v1/configs/pipelines/base.py @@ -89,6 +89,16 @@ class PipelineConfig: STA_mode: STA_Mode = STA_Mode.STA_INFERENCE skip_time_steps: int = 15 + # Additional guidance/optimization parameters + skip_layer_guidance: float | None = None # fraction of denoise steps without CFG + use_normalized_attention: bool = False + nag_scale: float = 1.5 + nag_tau: float = 2.5 + nag_alpha: float = 0.125 + use_dcm: bool = False + use_taylor_seer: bool = False + taylor_seer_order: int = 2 + # Compilation # enable_torch_compile: bool = False @@ -206,6 +216,63 @@ def add_cli_args(parser: FlexibleArgumentParser, "Bool for applying scheduler scale in set_timesteps, used in stepvideo", ) + parser.add_argument( + f"--{prefix_with_dot}skip-layer-guidance", + type=float, + dest=f"{prefix_with_dot.replace('-', '_')}skip_layer_guidance", + default=PipelineConfig.skip_layer_guidance, + help="Fraction of steps to disable CFG for SkipLayerGuidance", + ) + parser.add_argument( + f"--{prefix_with_dot}use-normalized-attention", + action=StoreBoolean, + dest=f"{prefix_with_dot.replace('-', '_')}use_normalized_attention", + default=PipelineConfig.use_normalized_attention, + help="Enable Normalized Attention Guidance", + ) + parser.add_argument( + f"--{prefix_with_dot}nag-scale", + type=float, + dest=f"{prefix_with_dot.replace('-', '_')}nag_scale", + default=PipelineConfig.nag_scale, + help="Scale for Normalized Attention Guidance", + ) + parser.add_argument( + f"--{prefix_with_dot}nag-tau", + type=float, + dest=f"{prefix_with_dot.replace('-', '_')}nag_tau", + default=PipelineConfig.nag_tau, + help="Tau parameter for Normalized Attention Guidance", + ) + parser.add_argument( + f"--{prefix_with_dot}nag-alpha", + type=float, + dest=f"{prefix_with_dot.replace('-', '_')}nag_alpha", + default=PipelineConfig.nag_alpha, + help="Alpha parameter for Normalized Attention Guidance", + ) + parser.add_argument( + f"--{prefix_with_dot}use-dcm", + action=StoreBoolean, + dest=f"{prefix_with_dot.replace('-', '_')}use_dcm", + default=PipelineConfig.use_dcm, + help="Enable Dynamic Convolution Module", + ) + parser.add_argument( + f"--{prefix_with_dot}use-taylor-seer", + action=StoreBoolean, + dest=f"{prefix_with_dot.replace('-', '_')}use_taylor_seer", + default=PipelineConfig.use_taylor_seer, + help="Enable TaylorSeer optimization", + ) + parser.add_argument( + f"--{prefix_with_dot}taylor-seer-order", + type=int, + dest=f"{prefix_with_dot.replace('-', '_')}taylor_seer_order", + default=PipelineConfig.taylor_seer_order, + help="Derivative order for TaylorSeer optimization", + ) + # Add VAE configuration arguments from fastvideo.v1.configs.models.vaes.base import VAEConfig VAEConfig.add_cli_args(parser, prefix=f"{prefix_with_dot}vae-config") diff --git a/fastvideo/v1/pipelines/stages/denoising.py b/fastvideo/v1/pipelines/stages/denoising.py index 46040644e..e5f0d9b46 100644 --- a/fastvideo/v1/pipelines/stages/denoising.py +++ b/fastvideo/v1/pipelines/stages/denoising.py @@ -45,6 +45,99 @@ logger = init_logger(__name__) +def apply_normalized_attention_guidance( + pos: torch.Tensor, + neg: torch.Tensor | None = None, + nag_scale: float = 1.5, + nag_tau: float = 2.5, + nag_alpha: float = 0.125, +) -> torch.Tensor: + """Apply Normalized Attention Guidance (NAG) to noise predictions. + + This implementation follows the formula from the official NAG repository + and operates on the positive and negative noise predictions. + """ + + if neg is None: + flat = pos.flatten(2) + mean = flat.mean(dim=-1, keepdim=True) + var = flat.var(dim=-1, unbiased=False, keepdim=True) + normalized = (flat - mean) / (var + 1e-6).sqrt() + return normalized.view_as(pos) + + pos_flat = pos.flatten(2) + neg_flat = neg.flatten(2) + + guidance = pos_flat * nag_scale - neg_flat * (nag_scale - 1) + norm_pos = pos_flat.norm(p=2, dim=-1, keepdim=True) + norm_guidance = guidance.norm(p=2, dim=-1, keepdim=True) + scale = norm_guidance / (norm_pos + 1e-7) + guidance = guidance * torch.minimum(scale, scale.new_ones(1) * nag_tau) / ( + scale + 1e-7 + ) + + out = guidance * nag_alpha + pos_flat * (1 - nag_alpha) + return out.view_as(pos) + + +_dcm_modules: dict[torch.device, tuple[torch.nn.Conv3d, torch.nn.Conv3d, torch.nn.Conv3d]] = {} + + +def apply_dcm(tensor: torch.Tensor) -> torch.Tensor: + """Apply Dynamic Convolution Module (DCM).""" + global _dcm_modules + conv_offset, conv_weight, conv_gate = _dcm_modules.get(tensor.device, (None, None, None)) + if conv_offset is None: + channels = tensor.size(1) + conv_offset = torch.nn.Conv3d(channels, channels, kernel_size=3, padding=1, bias=False).to( + tensor.device, tensor.dtype + ) + conv_weight = torch.nn.Conv3d(channels, channels, kernel_size=3, padding=1, bias=False).to( + tensor.device, tensor.dtype + ) + conv_gate = torch.nn.Conv3d(channels, channels, kernel_size=3, padding=1).to( + tensor.device, tensor.dtype + ) + _dcm_modules[tensor.device] = (conv_offset, conv_weight, conv_gate) + + offset = conv_offset(tensor) + out = conv_weight(tensor + offset) + gate = torch.sigmoid(conv_gate(tensor)) + return tensor + gate * out + + +_taylor_cache: dict[torch.device, dict[str, Any]] = {} + + +def apply_taylor_seer(tensor: torch.Tensor, step: int, order: int = 2) -> torch.Tensor: + """Apply TaylorSeer optimization using a simple derivative cache.""" + cache = _taylor_cache.setdefault(tensor.device, { + "prev": None, + "prev_diff": None, + "prev_step": None, + }) + + if cache["prev"] is None: + cache["prev"] = tensor.detach() + cache["prev_step"] = step + return tensor + + dt = step - cache["prev_step"] + if dt == 0: + return tensor + + diff = (tensor - cache["prev"]) / dt + result = cache["prev"] + diff * dt + if order >= 2 and cache["prev_diff"] is not None: + second = (diff - cache["prev_diff"]) / dt + result = result + 0.5 * second * dt * dt + + cache["prev"] = tensor.detach() + cache["prev_diff"] = diff + cache["prev_step"] = step + return result + + class DenoisingStage(PipelineStage): """ Stage for running the denoising loop in diffusion pipelines. @@ -83,6 +176,9 @@ def forward( Returns: The batch with denoised latents. """ + # Reset caches for optional optimizations + _taylor_cache.clear() + # Prepare extra step kwargs for scheduler extra_step_kwargs = self.prepare_extra_func_kwargs( self.scheduler.step, @@ -264,8 +360,31 @@ def forward( **neg_cond_kwargs, ) noise_pred_text = noise_pred - noise_pred = noise_pred_uncond + batch.guidance_scale * ( - noise_pred_text - noise_pred_uncond) + if fastvideo_args.pipeline_config.skip_layer_guidance and ( + i / len(timesteps) + < fastvideo_args.pipeline_config.skip_layer_guidance + ): + noise_pred = noise_pred_text + else: + noise_pred = noise_pred_uncond + batch.guidance_scale * ( + noise_pred_text - noise_pred_uncond) + + if fastvideo_args.pipeline_config.use_normalized_attention: + noise_pred = apply_normalized_attention_guidance( + noise_pred_text, + noise_pred_uncond, + nag_scale=fastvideo_args.pipeline_config.nag_scale * batch.guidance_scale, + nag_tau=fastvideo_args.pipeline_config.nag_tau, + nag_alpha=fastvideo_args.pipeline_config.nag_alpha, + ) + if fastvideo_args.pipeline_config.use_dcm: + noise_pred = apply_dcm(noise_pred) + if fastvideo_args.pipeline_config.use_taylor_seer: + noise_pred = apply_taylor_seer( + noise_pred, + i, + order=fastvideo_args.pipeline_config.taylor_seer_order, + ) # Apply guidance rescale if needed if batch.guidance_rescale > 0.0: @@ -276,6 +395,23 @@ def forward( guidance_rescale=batch.guidance_rescale, ) + if not batch.do_classifier_free_guidance: + if fastvideo_args.pipeline_config.use_normalized_attention: + noise_pred = apply_normalized_attention_guidance( + noise_pred, + nag_scale=fastvideo_args.pipeline_config.nag_scale, + nag_tau=fastvideo_args.pipeline_config.nag_tau, + nag_alpha=fastvideo_args.pipeline_config.nag_alpha, + ) + if fastvideo_args.pipeline_config.use_dcm: + noise_pred = apply_dcm(noise_pred) + if fastvideo_args.pipeline_config.use_taylor_seer: + noise_pred = apply_taylor_seer( + noise_pred, + i, + order=fastvideo_args.pipeline_config.taylor_seer_order, + ) + # Compute the previous noisy sample latents = self.scheduler.step(noise_pred, t,