From 9607cde51d6cdf53d7579180bb91ec8590a5c0df Mon Sep 17 00:00:00 2001 From: Peiyuan Date: Tue, 29 Jul 2025 20:14:35 +0000 Subject: [PATCH 01/11] stash --- fastvideo/attention/layer.py | 23 ++++++++-- fastvideo/distributed/communication_op.py | 18 +++++++- fastvideo/layers/rotary_embedding.py | 13 ++++-- fastvideo/models/dits/wanvideo.py | 37 ++++++++-------- fastvideo/pipelines/pipeline_registry.py | 1 + fastvideo/pipelines/stages/denoising.py | 51 ++--------------------- scripts/inference/v1_inference_wan_dmd.sh | 2 +- 7 files changed, 69 insertions(+), 76 deletions(-) diff --git a/fastvideo/attention/layer.py b/fastvideo/attention/layer.py index 8161d83f2..0546d2752 100644 --- a/fastvideo/attention/layer.py +++ b/fastvideo/attention/layer.py @@ -11,7 +11,7 @@ from fastvideo.forward_context import ForwardContext, get_forward_context from fastvideo.platforms import AttentionBackendEnum from fastvideo.utils import get_compute_dtype - +from fastvideo.layers.rotary_embedding import _apply_rotary_emb class DistributedAttention(nn.Module): """Distributed attention layer. @@ -63,6 +63,7 @@ def forward( replicated_q: torch.Tensor | None = None, replicated_k: torch.Tensor | None = None, replicated_v: torch.Tensor | None = None, + freqs_cis: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """Forward pass for distributed attention. @@ -96,6 +97,11 @@ def forward( qkv = sequence_model_parallel_all_to_all_4D(qkv, scatter_dim=2, gather_dim=1) + if freqs_cis is not None: + cos, sin = freqs_cis + # apply to q and k + qkv[:batch_size*2] = _apply_rotary_emb(qkv[:batch_size*2], cos, sin, is_neox_style=False) + # Apply backend-specific preprocess_qkv qkv = self.attn_impl.preprocess_qkv(qkv, ctx_attn_metadata) @@ -145,6 +151,7 @@ def forward( replicated_k: torch.Tensor | None = None, replicated_v: torch.Tensor | None = None, gate_compress: torch.Tensor | None = None, + freqs_cis: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """Forward pass for distributed attention. @@ -170,7 +177,7 @@ def forward( forward_context: ForwardContext = get_forward_context() ctx_attn_metadata = forward_context.attn_metadata - + batch_size, seq_len, num_heads, head_dim = q.shape # Stack QKV qkvg = torch.cat([q, k, v, gate_compress], dim=0) # [3, seq_len, num_heads, head_dim] @@ -180,9 +187,14 @@ def forward( scatter_dim=2, gather_dim=1) - qkvg = self.attn_impl.preprocess_qkv(qkvg, ctx_attn_metadata) + if freqs_cis is not None: + cos, sin = freqs_cis + qkvg[:batch_size*2] = _apply_rotary_emb(qkvg[:batch_size*2], cos, sin, is_neox_style=False) + qkvg = self.attn_impl.preprocess_qkv(qkvg, ctx_attn_metadata) + q, k, v, gate_compress = qkvg.chunk(4, dim=0) + output = self.attn_impl.forward( q, k, v, gate_compress, ctx_attn_metadata) # type: ignore[call-arg] @@ -242,6 +254,7 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + freqs_cis: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> torch.Tensor: """ Apply local attention between query, key and value tensors. @@ -260,6 +273,10 @@ def forward( forward_context: ForwardContext = get_forward_context() ctx_attn_metadata = forward_context.attn_metadata + if freqs_cis is not None: + cos, sin = freqs_cis + q = _apply_rotary_emb(q, cos, sin, is_neox_style=False) + k = _apply_rotary_emb(k, cos, sin, is_neox_style=False) output = self.attn_impl.forward(q, k, v, ctx_attn_metadata) return output diff --git a/fastvideo/distributed/communication_op.py b/fastvideo/distributed/communication_op.py index c1cad53c4..fe7681c9f 100644 --- a/fastvideo/distributed/communication_op.py +++ b/fastvideo/distributed/communication_op.py @@ -4,7 +4,10 @@ import torch import torch.distributed -from fastvideo.distributed.parallel_state import get_sp_group, get_tp_group +from fastvideo.distributed.parallel_state import (get_sp_group, + get_sp_parallel_rank, + get_sp_world_size, + get_tp_group) def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: @@ -30,3 +33,16 @@ def sequence_model_parallel_all_gather(input_: torch.Tensor, dim: int = -1) -> torch.Tensor: """All-gather the input tensor across model parallel group.""" return get_sp_group().all_gather(input_, dim) + +def sequence_model_parallel_shard(input_: torch.Tensor, + dim: int = 1) -> torch.Tensor: + """All-gather the input tensor across model parallel group.""" + sp_rank = get_sp_parallel_rank() + sp_world_size = get_sp_world_size() + assert input_.shape[dim] % sp_world_size == 0, "input tensor dim={dim} must be divisible by sp_world_size" + elements_per_rank = input_.shape[dim] // sp_world_size + # sharding dim + input_ = input_.movedim(dim, 0) + input_ = input_[sp_rank*elements_per_rank:(sp_rank+1)*elements_per_rank] + input_ = input_.movedim(0, dim) + return input_ diff --git a/fastvideo/layers/rotary_embedding.py b/fastvideo/layers/rotary_embedding.py index 6abe90609..5c81ace2c 100644 --- a/fastvideo/layers/rotary_embedding.py +++ b/fastvideo/layers/rotary_embedding.py @@ -369,6 +369,7 @@ def get_rotary_pos_embed( theta_rescale_factor=1.0, interpolation_factor=1.0, shard_dim: int = 0, + do_sp_sharding: bool = False, dtype: torch.dtype = torch.float32, ) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -383,7 +384,7 @@ def get_rotary_pos_embed( theta_rescale_factor: Rescale factor for theta. Defaults to 1.0 interpolation_factor: Factor to scale positions. Defaults to 1.0 shard_dim: Which dimension to shard for sequence parallelism. Defaults to 0. - + do_sp_sharding: Whether to shard the positional embeddings for sequence parallelism. Defaults to False. Returns: Tuple of (cos, sin) tensors for rotary embeddings """ @@ -399,9 +400,13 @@ def get_rotary_pos_embed( ) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" # Get SP info - sp_group = get_sp_group() - sp_rank = sp_group.rank_in_group - sp_world_size = sp_group.world_size + if do_sp_sharding: + sp_group = get_sp_group() + sp_rank = sp_group.rank_in_group + sp_world_size = sp_group.world_size + else: + sp_rank = 0 + sp_world_size = 1 freqs_cos, freqs_sin = get_nd_rotary_pos_embed( rope_dim_list, diff --git a/fastvideo/models/dits/wanvideo.py b/fastvideo/models/dits/wanvideo.py index 2ca840622..8fa79cc70 100644 --- a/fastvideo/models/dits/wanvideo.py +++ b/fastvideo/models/dits/wanvideo.py @@ -12,7 +12,8 @@ LocalAttention) from fastvideo.configs.models.dits import WanVideoConfig from fastvideo.configs.sample.wan import WanTeaCacheParams -from fastvideo.distributed.parallel_state import get_sp_world_size +from fastvideo.distributed.communication_op import sequence_model_parallel_all_gather, sequence_model_parallel_shard +from fastvideo.distributed.parallel_state import get_sp_parallel_rank, get_sp_world_size from fastvideo.forward_context import get_forward_context from fastvideo.layers.layernorm import (FP32LayerNorm, LayerNormScaleShift, RMSNorm, ScaleResidual, @@ -21,8 +22,7 @@ # from torch.nn import RMSNorm # TODO: RMSNorm .... from fastvideo.layers.mlp import MLP -from fastvideo.layers.rotary_embedding import (_apply_rotary_emb, - get_rotary_pos_embed) +from fastvideo.layers.rotary_embedding import get_rotary_pos_embed from fastvideo.layers.visual_embedding import (ModulateProjection, PatchEmbed, TimestepEmbedder) from fastvideo.logger import init_logger @@ -328,13 +328,7 @@ def forward( key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) value = value.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) - # Apply rotary embeddings - cos, sin = freqs_cis - query, key = _apply_rotary_emb(query, cos, sin, - is_neox_style=False), _apply_rotary_emb( - key, cos, sin, is_neox_style=False) - - attn_output, _ = self.attn1(query, key, value) + attn_output, _ = self.attn1(query, key, value, freqs_cis=freqs_cis) attn_output = attn_output.flatten(2) attn_output, _ = self.to_out(attn_output) attn_output = attn_output.squeeze(1) @@ -476,15 +470,11 @@ def forward( gate_compress = gate_compress.squeeze(1).unflatten( 2, (self.num_attention_heads, -1)) - # Apply rotary embeddings - cos, sin = freqs_cis - query, key = _apply_rotary_emb(query, cos, sin, - is_neox_style=False), _apply_rotary_emb( - key, cos, sin, is_neox_style=False) attn_output, _ = self.attn1(query, key, value, + freqs_cis=freqs_cis, gate_compress=gate_compress) attn_output = attn_output.flatten(2) attn_output, _ = self.to_out(attn_output) @@ -622,20 +612,23 @@ def forward(self, d = self.hidden_size // self.num_attention_heads rope_dim_list = [d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)] freqs_cos, freqs_sin = get_rotary_pos_embed( - (post_patch_num_frames * get_sp_world_size(), post_patch_height, + (post_patch_num_frames, post_patch_height, post_patch_width), self.hidden_size, self.num_attention_heads, rope_dim_list, dtype=torch.float32 if current_platform.is_mps() else torch.float64, rope_theta=10000) - freqs_cos = freqs_cos.to(hidden_states.device) - freqs_sin = freqs_sin.to(hidden_states.device) - freqs_cis = (freqs_cos.float(), - freqs_sin.float()) if freqs_cos is not None else None + freqs_cis = (freqs_cos.to(hidden_states.device).float(), + freqs_sin.to(hidden_states.device).float()) hidden_states = self.patch_embedding(hidden_states) hidden_states = hidden_states.flatten(2).transpose(1, 2) + # hidden_states = sequence_model_parallel_shard(hidden_states, dim=1) + sp_rank = get_sp_parallel_rank() + sp_world_size = get_sp_world_size() + elements_per_rank = hidden_states.shape[1] // sp_world_size + hidden_states = hidden_states[:, sp_rank*elements_per_rank:(sp_rank+1)*elements_per_rank] temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( timestep, encoder_hidden_states, encoder_hidden_states_image) @@ -681,6 +674,10 @@ def forward(self, shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states, shift, scale) + # hidden_states = sequence_model_parallel_all_gather(hidden_states, dim=1) + output_tensor = [torch.empty_like(hidden_states) for _ in range(sp_world_size)] + hidden_states = torch.distributed.all_gather(output_tensor, hidden_states) + hidden_states = torch.cat(output_tensor, dim=1) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, diff --git a/fastvideo/pipelines/pipeline_registry.py b/fastvideo/pipelines/pipeline_registry.py index 08b57d161..90b4deca0 100644 --- a/fastvideo/pipelines/pipeline_registry.py +++ b/fastvideo/pipelines/pipeline_registry.py @@ -20,6 +20,7 @@ _PIPELINE_NAME_TO_ARCHITECTURE_NAME: dict[str, str] = { "WanPipeline": "wan", "WanImageToVideoPipeline": "wan", + "WanDMDPipeline": "wan", "StepVideoPipeline": "stepvideo", "HunyuanVideoPipeline": "hunyuan", } diff --git a/fastvideo/pipelines/stages/denoising.py b/fastvideo/pipelines/stages/denoising.py index 07d627c4a..02ef8dd7f 100644 --- a/fastvideo/pipelines/stages/denoising.py +++ b/fastvideo/pipelines/stages/denoising.py @@ -15,10 +15,7 @@ from fastvideo.attention import get_attn_backend from fastvideo.configs.pipelines.base import STA_Mode -from fastvideo.distributed import (get_local_torch_device, get_sp_parallel_rank, - get_sp_world_size, get_world_group) -from fastvideo.distributed.communication_op import ( - sequence_model_parallel_all_gather) +from fastvideo.distributed import (get_local_torch_device, get_world_group) from fastvideo.fastvideo_args import FastVideoArgs from fastvideo.forward_context import set_forward_context from fastvideo.logger import init_logger @@ -113,22 +110,7 @@ def forward( autocast_enabled = (target_dtype != torch.float32 ) and not fastvideo_args.disable_autocast - # Handle sequence parallelism if enabled - sp_world_size, rank_in_sp_group = get_sp_world_size( - ), get_sp_parallel_rank() - sp_group = sp_world_size > 1 - if sp_group: - latents = rearrange(batch.latents, - "b c (n t) h w -> b c n t h w", - n=sp_world_size).contiguous() - latents = latents[:, :, rank_in_sp_group, :, :, :] - batch.latents = latents - if batch.image_latent is not None: - image_latent = rearrange(batch.image_latent, - "b c (n t) h w -> b c n t h w", - n=sp_world_size).contiguous() - image_latent = image_latent[:, :, rank_in_sp_group, :, :, :] - batch.image_latent = image_latent + # Get timesteps and calculate warmup steps timesteps = batch.timesteps # TODO(will): remove this once we add input/output validation for stages @@ -311,9 +293,7 @@ def forward( and progress_bar is not None): progress_bar.update() - # Gather results if using sequence parallelism - if sp_group: - latents = sequence_model_parallel_all_gather(latents, dim=2) + # Update batch with final latents batch.latents = latents @@ -664,22 +644,6 @@ def forward( dtype=torch.long, device=get_local_torch_device()) - # Handle sequence parallelism if enabled - sp_world_size, rank_in_sp_group = get_sp_world_size( - ), get_sp_parallel_rank() - sp_group = sp_world_size > 1 - if sp_group: - latents = rearrange(latents, - "b (n t) c h w -> b n t c h w", - n=sp_world_size).contiguous() - latents = latents[:, rank_in_sp_group, :, :, :, :] - if batch.image_latent is not None: - image_latent = rearrange(batch.image_latent, - "b c (n t) h w -> b c n t h w", - n=sp_world_size).contiguous() - - image_latent = image_latent[:, :, rank_in_sp_group, :, :, :] - batch.image_latent = image_latent # Run denoising loop with self.progress_bar(total=len(timesteps)) as progress_bar: @@ -776,11 +740,6 @@ def forward( noise = torch.randn(video_raw_latent_shape, device=self.device, dtype=pred_video.dtype) - if sp_group: - noise = rearrange(noise, - "b (n t) c h w -> b n t c h w", - n=sp_world_size).contiguous() - noise = noise[:, rank_in_sp_group, :, :, :, :] latents = self.scheduler.add_noise( pred_video.flatten(0, 1), noise.flatten(0, 1), next_timestep).unflatten(0, pred_video.shape[:2]) @@ -794,9 +753,7 @@ def forward( and progress_bar is not None): progress_bar.update() - # Gather results if using sequence parallelism - if sp_group: - latents = sequence_model_parallel_all_gather(latents, dim=1) + latents = latents.permute(0, 2, 1, 3, 4) # Update batch with final latents batch.latents = latents diff --git a/scripts/inference/v1_inference_wan_dmd.sh b/scripts/inference/v1_inference_wan_dmd.sh index 59b5048f2..d1dfa7877 100755 --- a/scripts/inference/v1_inference_wan_dmd.sh +++ b/scripts/inference/v1_inference_wan_dmd.sh @@ -1,6 +1,6 @@ #!/bin/bash -num_gpus=1 +num_gpus=2 export FASTVIDEO_ATTENTION_BACKEND=VIDEO_SPARSE_ATTN export MODEL_BASE=FastVideo/FastWan2.1-T2V-1.3B-Diffusers # export MODEL_BASE=hunyuanvideo-community/HunyuanVideo From 1c49486cf1e91eb9013cb0b8f51bc9d34d91b868 Mon Sep 17 00:00:00 2001 From: Peiyuan Date: Tue, 29 Jul 2025 23:04:15 +0000 Subject: [PATCH 02/11] update --- fastvideo/distributed/communication_op.py | 4 ++-- fastvideo/models/dits/wanvideo.py | 17 +++++------------ fastvideo/pipelines/stages/denoising.py | 5 +---- scripts/inference/v1_inference_wan.sh | 6 +++--- 4 files changed, 11 insertions(+), 21 deletions(-) diff --git a/fastvideo/distributed/communication_op.py b/fastvideo/distributed/communication_op.py index fe7681c9f..676e0bc52 100644 --- a/fastvideo/distributed/communication_op.py +++ b/fastvideo/distributed/communication_op.py @@ -36,10 +36,10 @@ def sequence_model_parallel_all_gather(input_: torch.Tensor, def sequence_model_parallel_shard(input_: torch.Tensor, dim: int = 1) -> torch.Tensor: - """All-gather the input tensor across model parallel group.""" + """Shard the input tensor across model parallel group.""" sp_rank = get_sp_parallel_rank() sp_world_size = get_sp_world_size() - assert input_.shape[dim] % sp_world_size == 0, "input tensor dim={dim} must be divisible by sp_world_size" + assert input_.shape[dim] % sp_world_size == 0, f"input tensor dim={dim} must be divisible by sp_world_size={sp_world_size}" elements_per_rank = input_.shape[dim] // sp_world_size # sharding dim input_ = input_.movedim(dim, 0) diff --git a/fastvideo/models/dits/wanvideo.py b/fastvideo/models/dits/wanvideo.py index 8fa79cc70..d1014e39d 100644 --- a/fastvideo/models/dits/wanvideo.py +++ b/fastvideo/models/dits/wanvideo.py @@ -621,14 +621,11 @@ def forward(self, rope_theta=10000) freqs_cis = (freqs_cos.to(hidden_states.device).float(), freqs_sin.to(hidden_states.device).float()) - + + hidden_states = self.patch_embedding(hidden_states) hidden_states = hidden_states.flatten(2).transpose(1, 2) - # hidden_states = sequence_model_parallel_shard(hidden_states, dim=1) - sp_rank = get_sp_parallel_rank() - sp_world_size = get_sp_world_size() - elements_per_rank = hidden_states.shape[1] // sp_world_size - hidden_states = hidden_states[:, sp_rank*elements_per_rank:(sp_rank+1)*elements_per_rank] + hidden_states = sequence_model_parallel_shard(hidden_states, dim=1) temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( timestep, encoder_hidden_states, encoder_hidden_states_image) @@ -674,19 +671,15 @@ def forward(self, shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states, shift, scale) - # hidden_states = sequence_model_parallel_all_gather(hidden_states, dim=1) - output_tensor = [torch.empty_like(hidden_states) for _ in range(sp_world_size)] - hidden_states = torch.distributed.all_gather(output_tensor, hidden_states) - hidden_states = torch.cat(output_tensor, dim=1) + hidden_states = sequence_model_parallel_all_gather(hidden_states, dim=1) hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1) hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) - + return output def maybe_cache_states(self, hidden_states: torch.Tensor, diff --git a/fastvideo/pipelines/stages/denoising.py b/fastvideo/pipelines/stages/denoising.py index 02ef8dd7f..f34055edf 100644 --- a/fastvideo/pipelines/stages/denoising.py +++ b/fastvideo/pipelines/stages/denoising.py @@ -293,8 +293,6 @@ def forward( and progress_bar is not None): progress_bar.update() - - # Update batch with final latents batch.latents = latents @@ -753,9 +751,8 @@ def forward( and progress_bar is not None): progress_bar.update() - latents = latents.permute(0, 2, 1, 3, 4) # Update batch with final latents batch.latents = latents - return batch + return batch \ No newline at end of file diff --git a/scripts/inference/v1_inference_wan.sh b/scripts/inference/v1_inference_wan.sh index 60f96843e..229abbdd1 100755 --- a/scripts/inference/v1_inference_wan.sh +++ b/scripts/inference/v1_inference_wan.sh @@ -1,6 +1,6 @@ #!/bin/bash -num_gpus=1 +num_gpus=2 export FASTVIDEO_ATTENTION_BACKEND= export MODEL_BASE=Wan-AI/Wan2.1-T2V-1.3B-Diffusers # export MODEL_BASE=hunyuanvideo-community/HunyuanVideo @@ -11,8 +11,8 @@ fastvideo generate \ --tp-size 1 \ --num-gpus $num_gpus \ --height 480 \ - --width 832 \ - --num-frames 77 \ + --width 848 \ + --num-frames 81 \ --num-inference-steps 50 \ --fps 16 \ --guidance-scale 6.0 \ From 46921af4d9161d1ac337fe8d2c6b99ec4837dda5 Mon Sep 17 00:00:00 2001 From: Peiyuan Date: Tue, 29 Jul 2025 23:26:09 +0000 Subject: [PATCH 03/11] fix seed issue in DMD stage --- fastvideo/models/vaes/common.py | 2 +- fastvideo/pipelines/stages/denoising.py | 14 +++---- fastvideo/utils.py | 51 ++++++++++++++++++++++++- 3 files changed, 58 insertions(+), 9 deletions(-) diff --git a/fastvideo/models/vaes/common.py b/fastvideo/models/vaes/common.py index 9730c8ff8..4218d8bbe 100644 --- a/fastvideo/models/vaes/common.py +++ b/fastvideo/models/vaes/common.py @@ -8,7 +8,7 @@ import numpy as np import torch import torch.distributed as dist -from diffusers.utils.torch_utils import randn_tensor +from fastvideo.utils import randn_tensor from fastvideo.configs.models import VAEConfig from fastvideo.distributed import get_sp_parallel_rank, get_sp_world_size diff --git a/fastvideo/pipelines/stages/denoising.py b/fastvideo/pipelines/stages/denoising.py index f34055edf..062c4eb83 100644 --- a/fastvideo/pipelines/stages/denoising.py +++ b/fastvideo/pipelines/stages/denoising.py @@ -28,6 +28,8 @@ from fastvideo.pipelines.stages.validators import VerificationResult from fastvideo.platforms import AttentionBackendEnum from fastvideo.utils import dict_to_3d_list +from fastvideo.utils import randn_tensor + try: from fastvideo.attention.backends.sliding_tile_attn import ( @@ -629,11 +631,7 @@ def forward( assert batch.latents is not None, "latents must be provided" latents = batch.latents # TODO(yongqi) hard code prepare latents - latents = torch.randn( - latents.permute(0, 2, 1, 3, 4).shape, - dtype=torch.bfloat16, - device="cuda", - generator=torch.Generator(device="cuda").manual_seed(42)) + latents = latents.permute(0, 2, 1, 3, 4) video_raw_latent_shape = latents.shape prompt_embeds = batch.prompt_embeds assert torch.isnan(prompt_embeds[0]).sum() == 0 @@ -735,9 +733,11 @@ def forward( if i < len(timesteps) - 1: next_timestep = timesteps[i + 1] * torch.ones( [1], dtype=torch.long, device=pred_video.device) - noise = torch.randn(video_raw_latent_shape, + noise = randn_tensor(video_raw_latent_shape, device=self.device, - dtype=pred_video.dtype) + dtype=pred_video.dtype, + generator=batch.generator) + latents = self.scheduler.add_noise( pred_video.flatten(0, 1), noise.flatten(0, 1), next_timestep).unflatten(0, pred_video.shape[:2]) diff --git a/fastvideo/utils.py b/fastvideo/utils.py index 9bbad98a4..dcc19962a 100644 --- a/fastvideo/utils.py +++ b/fastvideo/utils.py @@ -20,7 +20,7 @@ from dataclasses import dataclass, fields, is_dataclass from functools import lru_cache, partial, wraps from typing import Any, TypeVar, cast - +from typing import Tuple, List, Optional, Union import cloudpickle import filelock import torch @@ -812,3 +812,52 @@ def set_random_seed(seed: int) -> None: @lru_cache(maxsize=1) def is_vsa_available() -> bool: return importlib.util.find_spec("vsa") is not None + + +# copy from https://github.com/huggingface/diffusers/blob/v0.19.2/src/diffusers/utils/torch_utils.py#L36 +def randn_tensor( + shape: Union[Tuple, List], + generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, + device: Optional["torch.device"] = None, + dtype: Optional["torch.dtype"] = None, + layout: Optional["torch.layout"] = None, +)->torch.Tensor: + """A helper function to create random tensors on the desired `device` with the desired `dtype`. When + passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor + is always created on the CPU. + """ + # device on which tensor is created defaults to device + rand_device = device + batch_size = shape[0] + + layout = layout or torch.strided + device = device or torch.device("cpu") + + if generator is not None: + gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type + if gen_device_type != device.type and gen_device_type == "cpu": + rand_device = "cpu" + if device != "mps": + logger.info( + f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." + f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" + f" slighly speed up this function by passing a generator that was created on the {device} device." + ) + elif gen_device_type != device.type and gen_device_type == "cuda": + raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") + + # make sure generator list of length 1 is treated like a non-list + if isinstance(generator, list) and len(generator) == 1: + generator = generator[0] + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) + + return latents \ No newline at end of file From 5bad70a11b6b7648ddc6020fe809bda7d25a0eb3 Mon Sep 17 00:00:00 2001 From: Peiyuan Date: Tue, 29 Jul 2025 23:37:05 +0000 Subject: [PATCH 04/11] stash --- fastvideo/models/dits/wanvideo.py | 1 - fastvideo/training/training_pipeline.py | 17 +++++++---------- fastvideo/training/training_utils.py | 11 ----------- 3 files changed, 7 insertions(+), 22 deletions(-) diff --git a/fastvideo/models/dits/wanvideo.py b/fastvideo/models/dits/wanvideo.py index d1014e39d..b75d592fd 100644 --- a/fastvideo/models/dits/wanvideo.py +++ b/fastvideo/models/dits/wanvideo.py @@ -13,7 +13,6 @@ from fastvideo.configs.models.dits import WanVideoConfig from fastvideo.configs.sample.wan import WanTeaCacheParams from fastvideo.distributed.communication_op import sequence_model_parallel_all_gather, sequence_model_parallel_shard -from fastvideo.distributed.parallel_state import get_sp_parallel_rank, get_sp_world_size from fastvideo.forward_context import get_forward_context from fastvideo.layers.layernorm import (FP32LayerNorm, LayerNormScaleShift, RMSNorm, ScaleResidual, diff --git a/fastvideo/training/training_pipeline.py b/fastvideo/training/training_pipeline.py index ad5f1093a..521fee1f2 100644 --- a/fastvideo/training/training_pipeline.py +++ b/fastvideo/training/training_pipeline.py @@ -39,7 +39,7 @@ from fastvideo.training.training_utils import ( clip_grad_norm_while_handling_failing_dtensor_cases, compute_density_for_timestep_sampling, get_sigmas, load_checkpoint, - normalize_dit_input, save_checkpoint, shard_latents_across_sp) + normalize_dit_input, save_checkpoint) from fastvideo.utils import is_vsa_available, set_random_seed, shallow_asdict import wandb # isort: skip @@ -318,6 +318,9 @@ def _transformer_forward_and_compute_loss( assert model_pred.shape == target.shape, f"model_pred.shape: {model_pred.shape}, target.shape: {target.shape}" loss = (torch.mean((model_pred.float() - target.float())**2) / self.training_args.gradient_accumulation_steps) + + # THIS IS IMPORTANT: DIVIDE LOSS BY SP SIZE + loss = loss / self.sp_world_size loss.backward() avg_loss = loss.detach().clone() @@ -363,17 +366,11 @@ def train_one_step(self, training_batch: TrainingBatch) -> TrainingBatch: training_batch = self._prepare_dit_inputs(training_batch) # Shard latents across sp groups - training_batch.latents = shard_latents_across_sp( - training_batch.latents, - num_latent_t=self.training_args.num_latent_t) + training_batch.latents = training_batch.latents[:, :, :self.training_args.num_latent_t] # shard noisy_model_input to match - training_batch.noisy_model_input = shard_latents_across_sp( - training_batch.noisy_model_input, - num_latent_t=self.training_args.num_latent_t) + training_batch.noisy_model_input = training_batch.noisy_model_input[:, :, :self.training_args.num_latent_t] # shard noise to match latents - training_batch.noise = shard_latents_across_sp( - training_batch.noise, - num_latent_t=self.training_args.num_latent_t) + training_batch.noise = training_batch.noise[:, :, :self.training_args.num_latent_t] training_batch = self._build_attention_metadata(training_batch) training_batch = self._build_input_kwargs(training_batch) diff --git a/fastvideo/training/training_utils.py b/fastvideo/training/training_utils.py index cfdde78ef..4ebe7b165 100644 --- a/fastvideo/training/training_utils.py +++ b/fastvideo/training/training_utils.py @@ -267,17 +267,6 @@ def normalize_dit_input(model_type, latents, args=None) -> torch.Tensor: raise NotImplementedError(f"model_type {model_type} not supported") -def shard_latents_across_sp(latents: torch.Tensor, - num_latent_t: int) -> torch.Tensor: - sp_world_size = get_sp_world_size() - rank_in_sp_group = get_sp_parallel_rank() - latents = latents[:, :, :num_latent_t] - if sp_world_size > 1: - latents = rearrange(latents, - "b c (n s) h w -> b c n s h w", - n=sp_world_size).contiguous() - latents = latents[:, :, rank_in_sp_group, :, :, :] - return latents def clip_grad_norm_while_handling_failing_dtensor_cases( From 4ac0603344fcdb3e1a26151b5147a0d0ea93a76f Mon Sep 17 00:00:00 2001 From: Peiyuan Date: Wed, 30 Jul 2025 00:12:43 +0000 Subject: [PATCH 05/11] only divide for backward --- fastvideo/training/training_pipeline.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/fastvideo/training/training_pipeline.py b/fastvideo/training/training_pipeline.py index 521fee1f2..201fa6e86 100644 --- a/fastvideo/training/training_pipeline.py +++ b/fastvideo/training/training_pipeline.py @@ -319,10 +319,8 @@ def _transformer_forward_and_compute_loss( loss = (torch.mean((model_pred.float() - target.float())**2) / self.training_args.gradient_accumulation_steps) - # THIS IS IMPORTANT: DIVIDE LOSS BY SP SIZE - loss = loss / self.sp_world_size - loss.backward() + (loss / self.sp_world_size).backward() avg_loss = loss.detach().clone() # logger.info(f"rank: {self.rank}, avg_loss: {avg_loss.item()}", From 9658f735b79f5d9a7013705c8a73c7bdae162074 Mon Sep 17 00:00:00 2001 From: Peiyuan Date: Wed, 30 Jul 2025 00:33:15 +0000 Subject: [PATCH 06/11] upd --- .../training/finetune/wan_t2v_1.3B/crush_smol/finetune_t2v.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/training/finetune/wan_t2v_1.3B/crush_smol/finetune_t2v.sh b/examples/training/finetune/wan_t2v_1.3B/crush_smol/finetune_t2v.sh index 871b17510..027401fd2 100644 --- a/examples/training/finetune/wan_t2v_1.3B/crush_smol/finetune_t2v.sh +++ b/examples/training/finetune/wan_t2v_1.3B/crush_smol/finetune_t2v.sh @@ -7,8 +7,8 @@ export TOKENIZERS_PARALLELISM=false MODEL_PATH="Wan-AI/Wan2.1-T2V-1.3B-Diffusers" DATA_DIR="data/crush-smol_processed_t2v/combined_parquet_dataset/" -VALIDATION_DATASET_FILE="examples/training/finetune/wan_t2v_1_3b/crush_smol/validation.json" -NUM_GPUS=4 +VALIDATION_DATASET_FILE="examples/training/finetune/wan_t2v_1.3B/crush_smol/validation.json" +NUM_GPUS=2 # export CUDA_VISIBLE_DEVICES=4,5 From 61c17c84c84599ba176a8f15626d66486070638e Mon Sep 17 00:00:00 2001 From: Peiyuan Zhang Date: Fri, 8 Aug 2025 01:10:56 +0000 Subject: [PATCH 07/11] rm num_frame auto-ajust --- fastvideo/entrypoints/video_generator.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/fastvideo/entrypoints/video_generator.py b/fastvideo/entrypoints/video_generator.py index 2e6f63696..cb16e3bc6 100644 --- a/fastvideo/entrypoints/video_generator.py +++ b/fastvideo/entrypoints/video_generator.py @@ -232,27 +232,20 @@ def _generate_single_video( orig_latent_num_frames = sampling_param.num_frames // 17 * 3 if orig_latent_num_frames % fastvideo_args.num_gpus != 0: - # Adjust latent frames to be divisible by number of GPUs - if sampling_param.num_frames_round_down: - # Ensure we have at least 1 batch per GPU - new_latent_num_frames = max( - 1, (orig_latent_num_frames // num_gpus)) * num_gpus - else: - new_latent_num_frames = math.ceil( - orig_latent_num_frames / num_gpus) * num_gpus + if use_temporal_scaling_frames: # Convert back to number of frames, ensuring num_frames-1 is a multiple of temporal_scale_factor - new_num_frames = (new_latent_num_frames - + new_num_frames = (orig_latent_num_frames - 1) * temporal_scale_factor + 1 else: # stepvideo only # Find the least common multiple of 3 and num_gpus divisor = math.lcm(3, num_gpus) # Round up to the nearest multiple of this LCM - new_latent_num_frames = ( - (new_latent_num_frames + divisor - 1) // divisor) * divisor + orig_latent_num_frames = ( + (orig_latent_num_frames + divisor - 1) // divisor) * divisor # Convert back to actual frames using the StepVideo formula - new_num_frames = new_latent_num_frames // 3 * 17 + new_num_frames = orig_latent_num_frames // 3 * 17 logger.info( "Adjusting number of frames from %s to %s based on number of GPUs (%s)", From f3fc2df2c8c798c0f0230e4fa596c555fa738554 Mon Sep 17 00:00:00 2001 From: Peiyuan Zhang Date: Fri, 8 Aug 2025 02:33:25 +0000 Subject: [PATCH 08/11] huyuan done? --- fastvideo/models/dits/hunyuanvideo.py | 27 ++++++++------------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/fastvideo/models/dits/hunyuanvideo.py b/fastvideo/models/dits/hunyuanvideo.py index da14322ba..f49728fbf 100644 --- a/fastvideo/models/dits/hunyuanvideo.py +++ b/fastvideo/models/dits/hunyuanvideo.py @@ -23,6 +23,7 @@ from fastvideo.models.dits.base import CachableDiT from fastvideo.models.utils import modulate from fastvideo.platforms import AttentionBackendEnum +from fastvideo.distributed.communication_op import sequence_model_parallel_shard, sequence_model_parallel_all_gather class HunyuanRMSNorm(nn.Module): @@ -239,14 +240,7 @@ def forward( img_q = self.img_attn_q_norm(img_q).to(img_v) img_k = self.img_attn_k_norm(img_k).to(img_v) - # Apply rotary embeddings - cos, sin = freqs_cis - img_q, img_k = _apply_rotary_emb( - img_q, cos, sin, - is_neox_style=False), _apply_rotary_emb(img_k, - cos, - sin, - is_neox_style=False) + # Prepare text for attention using fused operation txt_attn_input = self.txt_attn_norm(txt, txt_attn_shift, txt_attn_scale) @@ -265,7 +259,7 @@ def forward( txt_k = self.txt_attn_k_norm(txt_k).to(txt_k.dtype) # Run distributed attention - img_attn, txt_attn = self.attn(img_q, img_k, img_v, txt_q, txt_k, txt_v) + img_attn, txt_attn = self.attn(img_q, img_k, img_v, txt_q, txt_k, txt_v, freqs_cis=freqs_cis) img_attn_out, _ = self.img_attn_proj( img_attn.view(batch_size, image_seq_len, -1)) # Use fused operation for residual connection, normalization, and modulation @@ -395,18 +389,11 @@ def forward( img_q, txt_q = q[:, :-txt_len], q[:, -txt_len:] img_k, txt_k = k[:, :-txt_len], k[:, -txt_len:] img_v, txt_v = v[:, :-txt_len], v[:, -txt_len:] - # Apply rotary embeddings to image parts - cos, sin = freqs_cis - img_q, img_k = _apply_rotary_emb( - img_q, cos, sin, - is_neox_style=False), _apply_rotary_emb(img_k, - cos, - sin, - is_neox_style=False) + # Run distributed attention img_attn_output, txt_attn_output = self.attn(img_q, img_k, img_v, txt_q, - txt_k, txt_v) + txt_k, txt_v, freqs_cis=freqs_cis) attn_output = torch.cat((img_attn_output, txt_attn_output), dim=1).view(batch_size, seq_len, -1) # Process MLP activation @@ -593,7 +580,7 @@ def forward(self, # Get rotary embeddings freqs_cos, freqs_sin = get_rotary_pos_embed( - (tt * get_sp_world_size(), th, tw), self.hidden_size, + (tt, th, tw), self.hidden_size, self.num_attention_heads, self.rope_dim_list, self.rope_theta) freqs_cos = freqs_cos.to(x.device) freqs_sin = freqs_sin.to(x.device) @@ -608,6 +595,7 @@ def forward(self, vec = vec + self.guidance_in(guidance) # Embed image and text img = self.img_in(img) + img = sequence_model_parallel_shard(img, dim=1) txt = self.txt_in(txt, t) txt_seq_len = txt.shape[1] img_seq_len = img.shape[1] @@ -648,6 +636,7 @@ def forward(self, self.maybe_cache_states(img, original_img) # Final layer processing + img = sequence_model_parallel_all_gather(img, dim=1) img = self.final_layer(img, vec) # Unpatchify to get original shape img = unpatchify(img, tt, th, tw, self.patch_size, self.out_channels) From bd19f3dc8b6a9129602f614dc2925bb93f43a96d Mon Sep 17 00:00:00 2001 From: Peiyuan Zhang Date: Fri, 8 Aug 2025 18:47:32 +0000 Subject: [PATCH 09/11] stash --- fastvideo/training/training_pipeline.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fastvideo/training/training_pipeline.py b/fastvideo/training/training_pipeline.py index 63c4c98bd..bd337216c 100644 --- a/fastvideo/training/training_pipeline.py +++ b/fastvideo/training/training_pipeline.py @@ -319,11 +319,11 @@ def _transformer_forward_and_compute_loss( # make sure no implicit broadcasting happens assert model_pred.shape == target.shape, f"model_pred.shape: {model_pred.shape}, target.shape: {target.shape}" - loss = (torch.mean((model_pred.float() - target.float())**2) / - self.training_args.gradient_accumulation_steps) + + loss = torch.mean((model_pred.float() - target.float())**2) + loss /= self.training_args.gradient_accumulation_steps - (loss / self.sp_world_size).backward() avg_loss = loss.detach().clone() # logger.info(f"rank: {self.rank}, avg_loss: {avg_loss.item()}", From 66249ee44809efac3623f760cd187a6fd3818f75 Mon Sep 17 00:00:00 2001 From: Peiyuan Zhang Date: Fri, 8 Aug 2025 18:53:12 +0000 Subject: [PATCH 10/11] stash --- fastvideo/training/distillation_pipeline.py | 38 ++------------------- 1 file changed, 3 insertions(+), 35 deletions(-) diff --git a/fastvideo/training/distillation_pipeline.py b/fastvideo/training/distillation_pipeline.py index 086c75288..575831c3a 100644 --- a/fastvideo/training/distillation_pipeline.py +++ b/fastvideo/training/distillation_pipeline.py @@ -206,11 +206,7 @@ def _generator_forward(self, training_batch: TrainingBatch) -> torch.Tensor: noise = torch.randn(self.video_latent_shape, device=self.device, dtype=dtype) - if self.sp_world_size > 1: - noise = rearrange(noise, - "b (n t) c h w -> b n t c h w", - n=self.sp_world_size).contiguous() - noise = noise[:, self.rank_in_sp_group, :, :, :, :] + noisy_latent = self.noise_scheduler.add_noise(latents.flatten(0, 1), noise.flatten(0, 1), timestep).unflatten( @@ -248,13 +244,6 @@ def _generator_multi_step_simulation_forward( current_noise_latents = torch.randn(self.video_latent_shape, device=self.device, dtype=dtype) - if self.sp_world_size > 1: - current_noise_latents = rearrange( - current_noise_latents, - "b (n t) c h w -> b n t c h w", - n=self.sp_world_size).contiguous() - current_noise_latents = current_noise_latents[:, self. - rank_in_sp_group, :, :, :, :] # Only run intermediate steps if target_timestep_idx > 0 max_target_idx = len(self.denoising_step_list) - 1 @@ -286,11 +275,7 @@ def _generator_multi_step_simulation_forward( noise = torch.randn(self.video_latent_shape, device=self.device, dtype=pred_clean.dtype) - if self.sp_world_size > 1: - noise = rearrange(noise, - "b (n t) c h w -> b n t c h w", - n=self.sp_world_size).contiguous() - noise = noise[:, self.rank_in_sp_group, :, :, :, :] + current_noise_latents = self.noise_scheduler.add_noise( pred_clean.flatten(0, 1), noise.flatten(0, 1), next_timestep_tensor).unflatten(0, pred_clean.shape[:2]) @@ -334,11 +319,6 @@ def _dmd_forward(self, generator_pred_video: torch.Tensor, noise = torch.randn(self.video_latent_shape, device=self.device, dtype=generator_pred_video.dtype) - if self.sp_world_size > 1: - noise = rearrange(noise, - "b (n t) c h w -> b n t c h w", - n=self.sp_world_size).contiguous() - noise = noise[:, self.rank_in_sp_group, :, :, :, :] noisy_latent = self.noise_scheduler.add_noise( generator_pred_video.flatten(0, 1), noise.flatten(0, 1), @@ -441,12 +421,6 @@ def faker_score_forward( fake_score_noise = torch.randn(self.video_latent_shape, device=self.device, dtype=generator_pred_video.dtype) - if self.sp_world_size > 1: - fake_score_noise = rearrange(fake_score_noise, - "b (n t) c h w -> b n t c h w", - n=self.sp_world_size).contiguous() - fake_score_noise = fake_score_noise[:, self. - rank_in_sp_group, :, :, :, :] noisy_generator_pred_video = self.noise_scheduler.add_noise( generator_pred_video.flatten(0, 1), fake_score_noise.flatten(0, 1), @@ -515,13 +489,7 @@ def _prepare_dit_inputs(self, training_batch.latents = training_batch.latents.permute(0, 2, 1, 3, 4) self.video_latent_shape = training_batch.latents.shape - if self.sp_world_size > 1: - training_batch.latents = rearrange( - training_batch.latents, - "b (n t) c h w -> b n t c h w", - n=self.sp_world_size).contiguous() - training_batch.latents = training_batch.latents[:, self. - rank_in_sp_group, :, :, :, :] + self.video_latent_shape_sp = training_batch.latents.shape From f3f94e2989586f625feb37e3d89131093d3a335e Mon Sep 17 00:00:00 2001 From: Peiyuan Zhang Date: Fri, 8 Aug 2025 21:00:37 +0000 Subject: [PATCH 11/11] stash --- fastvideo/training/training_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fastvideo/training/training_pipeline.py b/fastvideo/training/training_pipeline.py index bd337216c..ce471b2f3 100644 --- a/fastvideo/training/training_pipeline.py +++ b/fastvideo/training/training_pipeline.py @@ -322,6 +322,7 @@ def _transformer_forward_and_compute_loss( loss = torch.mean((model_pred.float() - target.float())**2) loss /= self.training_args.gradient_accumulation_steps + loss.backward() avg_loss = loss.detach().clone() @@ -331,7 +332,6 @@ def _transformer_forward_and_compute_loss( world_group = get_world_group() world_group.all_reduce(avg_loss, op=dist.ReduceOp.AVG) training_batch.total_loss += avg_loss.item() - return training_batch def _clip_grad_norm(self, training_batch: TrainingBatch) -> TrainingBatch: @@ -603,7 +603,7 @@ def _log_validation(self, transformer, training_args, global_step) -> None: validation_dataloader = DataLoader(validation_dataset, batch_size=None, num_workers=0) - + return transformer.eval() validation_steps = training_args.validation_sampling_steps.split(",")