Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions fastvideo/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from fastvideo.forward_context import ForwardContext, get_forward_context
from fastvideo.platforms import AttentionBackendEnum
from fastvideo.utils import get_compute_dtype

from fastvideo.layers.rotary_embedding import _apply_rotary_emb

class DistributedAttention(nn.Module):
"""Distributed attention layer.
Expand Down Expand Up @@ -64,6 +64,7 @@ def forward(
replicated_q: torch.Tensor | None = None,
replicated_k: torch.Tensor | None = None,
replicated_v: torch.Tensor | None = None,
freqs_cis: tuple[torch.Tensor, torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Forward pass for distributed attention.

Expand Down Expand Up @@ -97,6 +98,11 @@ def forward(
qkv = sequence_model_parallel_all_to_all_4D(qkv,
scatter_dim=2,
gather_dim=1)
if freqs_cis is not None:
cos, sin = freqs_cis
# apply to q and k
qkv[:batch_size*2] = _apply_rotary_emb(qkv[:batch_size*2], cos, sin, is_neox_style=False)

# Apply backend-specific preprocess_qkv
qkv = self.attn_impl.preprocess_qkv(qkv, ctx_attn_metadata)

Expand Down Expand Up @@ -147,6 +153,7 @@ def forward(
replicated_k: torch.Tensor | None = None,
replicated_v: torch.Tensor | None = None,
gate_compress: torch.Tensor | None = None,
freqs_cis: tuple[torch.Tensor, torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Forward pass for distributed attention.

Expand All @@ -172,7 +179,7 @@ def forward(

forward_context: ForwardContext = get_forward_context()
ctx_attn_metadata = forward_context.attn_metadata

batch_size, seq_len, num_heads, head_dim = q.shape
# Stack QKV
qkvg = torch.cat([q, k, v, gate_compress],
dim=0) # [3, seq_len, num_heads, head_dim]
Expand All @@ -182,9 +189,14 @@ def forward(
scatter_dim=2,
gather_dim=1)

qkvg = self.attn_impl.preprocess_qkv(qkvg, ctx_attn_metadata)
if freqs_cis is not None:
cos, sin = freqs_cis
qkvg[:batch_size*2] = _apply_rotary_emb(qkvg[:batch_size*2], cos, sin, is_neox_style=False)

qkvg = self.attn_impl.preprocess_qkv(qkvg, ctx_attn_metadata)

q, k, v, gate_compress = qkvg.chunk(4, dim=0)

output = self.attn_impl.forward(
q, k, v, gate_compress, ctx_attn_metadata) # type: ignore[call-arg]

Expand Down Expand Up @@ -244,6 +256,7 @@ def forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
freqs_cis: tuple[torch.Tensor, torch.Tensor] | None = None,
) -> torch.Tensor:
"""
Apply local attention between query, key and value tensors.
Expand All @@ -262,6 +275,10 @@ def forward(

forward_context: ForwardContext = get_forward_context()
ctx_attn_metadata = forward_context.attn_metadata
if freqs_cis is not None:
cos, sin = freqs_cis
q = _apply_rotary_emb(q, cos, sin, is_neox_style=False)
k = _apply_rotary_emb(k, cos, sin, is_neox_style=False)

output = self.attn_impl.forward(q, k, v, ctx_attn_metadata)
return output
18 changes: 17 additions & 1 deletion fastvideo/distributed/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import torch
import torch.distributed

from fastvideo.distributed.parallel_state import get_sp_group, get_tp_group
from fastvideo.distributed.parallel_state import (get_sp_group,
get_sp_parallel_rank,
get_sp_world_size,
get_tp_group)


def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
Expand All @@ -30,3 +33,16 @@ def sequence_model_parallel_all_gather(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
return get_sp_group().all_gather(input_, dim)

def sequence_model_parallel_shard(input_: torch.Tensor,
dim: int = 1) -> torch.Tensor:
"""Shard the input tensor across model parallel group."""
sp_rank = get_sp_parallel_rank()
sp_world_size = get_sp_world_size()
assert input_.shape[dim] % sp_world_size == 0, f"input tensor dim={dim} must be divisible by sp_world_size={sp_world_size}"
elements_per_rank = input_.shape[dim] // sp_world_size
# sharding dim
input_ = input_.movedim(dim, 0)
input_ = input_[sp_rank*elements_per_rank:(sp_rank+1)*elements_per_rank]
input_ = input_.movedim(0, dim)
return input_
17 changes: 5 additions & 12 deletions fastvideo/entrypoints/video_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,27 +232,20 @@ def _generate_single_video(
orig_latent_num_frames = sampling_param.num_frames // 17 * 3

if orig_latent_num_frames % fastvideo_args.num_gpus != 0:
# Adjust latent frames to be divisible by number of GPUs
if sampling_param.num_frames_round_down:
# Ensure we have at least 1 batch per GPU
new_latent_num_frames = max(
1, (orig_latent_num_frames // num_gpus)) * num_gpus
else:
new_latent_num_frames = math.ceil(
orig_latent_num_frames / num_gpus) * num_gpus


if use_temporal_scaling_frames:
# Convert back to number of frames, ensuring num_frames-1 is a multiple of temporal_scale_factor
new_num_frames = (new_latent_num_frames -
new_num_frames = (orig_latent_num_frames -
1) * temporal_scale_factor + 1
else: # stepvideo only
# Find the least common multiple of 3 and num_gpus
divisor = math.lcm(3, num_gpus)
# Round up to the nearest multiple of this LCM
new_latent_num_frames = (
(new_latent_num_frames + divisor - 1) // divisor) * divisor
orig_latent_num_frames = (
(orig_latent_num_frames + divisor - 1) // divisor) * divisor
# Convert back to actual frames using the StepVideo formula
new_num_frames = new_latent_num_frames // 3 * 17
new_num_frames = orig_latent_num_frames // 3 * 17

logger.info(
"Adjusting number of frames from %s to %s based on number of GPUs (%s)",
Expand Down
13 changes: 9 additions & 4 deletions fastvideo/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def get_rotary_pos_embed(
theta_rescale_factor=1.0,
interpolation_factor=1.0,
shard_dim: int = 0,
do_sp_sharding: bool = False,
dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Expand All @@ -383,7 +384,7 @@ def get_rotary_pos_embed(
theta_rescale_factor: Rescale factor for theta. Defaults to 1.0
interpolation_factor: Factor to scale positions. Defaults to 1.0
shard_dim: Which dimension to shard for sequence parallelism. Defaults to 0.

do_sp_sharding: Whether to shard the positional embeddings for sequence parallelism. Defaults to False.
Returns:
Tuple of (cos, sin) tensors for rotary embeddings
"""
Expand All @@ -399,9 +400,13 @@ def get_rotary_pos_embed(
) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"

# Get SP info
sp_group = get_sp_group()
sp_rank = sp_group.rank_in_group
sp_world_size = sp_group.world_size
if do_sp_sharding:
sp_group = get_sp_group()
sp_rank = sp_group.rank_in_group
sp_world_size = sp_group.world_size
else:
sp_rank = 0
sp_world_size = 1

freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list,
Expand Down
27 changes: 8 additions & 19 deletions fastvideo/models/dits/hunyuanvideo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from fastvideo.models.dits.base import CachableDiT
from fastvideo.models.utils import modulate
from fastvideo.platforms import AttentionBackendEnum
from fastvideo.distributed.communication_op import sequence_model_parallel_shard, sequence_model_parallel_all_gather


class HunyuanRMSNorm(nn.Module):
Expand Down Expand Up @@ -239,14 +240,7 @@ def forward(

img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
# Apply rotary embeddings
cos, sin = freqs_cis
img_q, img_k = _apply_rotary_emb(
img_q, cos, sin,
is_neox_style=False), _apply_rotary_emb(img_k,
cos,
sin,
is_neox_style=False)

# Prepare text for attention using fused operation
txt_attn_input = self.txt_attn_norm(txt, txt_attn_shift, txt_attn_scale)

Expand All @@ -265,7 +259,7 @@ def forward(
txt_k = self.txt_attn_k_norm(txt_k).to(txt_k.dtype)

# Run distributed attention
img_attn, txt_attn = self.attn(img_q, img_k, img_v, txt_q, txt_k, txt_v)
img_attn, txt_attn = self.attn(img_q, img_k, img_v, txt_q, txt_k, txt_v, freqs_cis=freqs_cis)
img_attn_out, _ = self.img_attn_proj(
img_attn.view(batch_size, image_seq_len, -1))
# Use fused operation for residual connection, normalization, and modulation
Expand Down Expand Up @@ -395,18 +389,11 @@ def forward(
img_q, txt_q = q[:, :-txt_len], q[:, -txt_len:]
img_k, txt_k = k[:, :-txt_len], k[:, -txt_len:]
img_v, txt_v = v[:, :-txt_len], v[:, -txt_len:]
# Apply rotary embeddings to image parts
cos, sin = freqs_cis
img_q, img_k = _apply_rotary_emb(
img_q, cos, sin,
is_neox_style=False), _apply_rotary_emb(img_k,
cos,
sin,
is_neox_style=False)


# Run distributed attention
img_attn_output, txt_attn_output = self.attn(img_q, img_k, img_v, txt_q,
txt_k, txt_v)
txt_k, txt_v, freqs_cis=freqs_cis)
attn_output = torch.cat((img_attn_output, txt_attn_output),
dim=1).view(batch_size, seq_len, -1)
# Process MLP activation
Expand Down Expand Up @@ -593,7 +580,7 @@ def forward(self,

# Get rotary embeddings
freqs_cos, freqs_sin = get_rotary_pos_embed(
(tt * get_sp_world_size(), th, tw), self.hidden_size,
(tt, th, tw), self.hidden_size,
self.num_attention_heads, self.rope_dim_list, self.rope_theta)
freqs_cos = freqs_cos.to(x.device)
freqs_sin = freqs_sin.to(x.device)
Expand All @@ -608,6 +595,7 @@ def forward(self,
vec = vec + self.guidance_in(guidance)
# Embed image and text
img = self.img_in(img)
img = sequence_model_parallel_shard(img, dim=1)
txt = self.txt_in(txt, t)
txt_seq_len = txt.shape[1]
img_seq_len = img.shape[1]
Expand Down Expand Up @@ -648,6 +636,7 @@ def forward(self,
self.maybe_cache_states(img, original_img)

# Final layer processing
img = sequence_model_parallel_all_gather(img, dim=1)
img = self.final_layer(img, vec)
# Unpatchify to get original shape
img = unpatchify(img, tt, th, tw, self.patch_size, self.out_channels)
Expand Down
35 changes: 12 additions & 23 deletions fastvideo/models/dits/wanvideo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
LocalAttention)
from fastvideo.configs.models.dits import WanVideoConfig
from fastvideo.configs.sample.wan import WanTeaCacheParams
from fastvideo.distributed.parallel_state import get_sp_world_size
from fastvideo.distributed.communication_op import sequence_model_parallel_all_gather, sequence_model_parallel_shard
from fastvideo.forward_context import get_forward_context
from fastvideo.layers.layernorm import (FP32LayerNorm, LayerNormScaleShift,
RMSNorm, ScaleResidual,
Expand All @@ -21,8 +21,7 @@
# from torch.nn import RMSNorm
# TODO: RMSNorm ....
from fastvideo.layers.mlp import MLP
from fastvideo.layers.rotary_embedding import (_apply_rotary_emb,
get_rotary_pos_embed)
from fastvideo.layers.rotary_embedding import get_rotary_pos_embed
from fastvideo.layers.visual_embedding import (ModulateProjection, PatchEmbed,
TimestepEmbedder)
from fastvideo.logger import init_logger
Expand Down Expand Up @@ -328,13 +327,7 @@ def forward(
key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
value = value.squeeze(1).unflatten(2, (self.num_attention_heads, -1))

# Apply rotary embeddings
cos, sin = freqs_cis
query, key = _apply_rotary_emb(query, cos, sin,
is_neox_style=False), _apply_rotary_emb(
key, cos, sin, is_neox_style=False)

attn_output, _ = self.attn1(query, key, value)
attn_output, _ = self.attn1(query, key, value, freqs_cis=freqs_cis)
attn_output = attn_output.flatten(2)
attn_output, _ = self.to_out(attn_output)
attn_output = attn_output.squeeze(1)
Expand Down Expand Up @@ -476,15 +469,11 @@ def forward(
gate_compress = gate_compress.squeeze(1).unflatten(
2, (self.num_attention_heads, -1))

# Apply rotary embeddings
cos, sin = freqs_cis
query, key = _apply_rotary_emb(query, cos, sin,
is_neox_style=False), _apply_rotary_emb(
key, cos, sin, is_neox_style=False)

attn_output, _ = self.attn1(query,
key,
value,
freqs_cis=freqs_cis,
gate_compress=gate_compress)
attn_output = attn_output.flatten(2)
attn_output, _ = self.to_out(attn_output)
Expand Down Expand Up @@ -622,20 +611,20 @@ def forward(self,
d = self.hidden_size // self.num_attention_heads
rope_dim_list = [d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)]
freqs_cos, freqs_sin = get_rotary_pos_embed(
(post_patch_num_frames * get_sp_world_size(), post_patch_height,
(post_patch_num_frames, post_patch_height,
post_patch_width),
self.hidden_size,
self.num_attention_heads,
rope_dim_list,
dtype=torch.float32 if current_platform.is_mps() else torch.float64,
rope_theta=10000)
freqs_cos = freqs_cos.to(hidden_states.device)
freqs_sin = freqs_sin.to(hidden_states.device)
freqs_cis = (freqs_cos.float(),
freqs_sin.float()) if freqs_cos is not None else None

freqs_cis = (freqs_cos.to(hidden_states.device).float(),
freqs_sin.to(hidden_states.device).float())


hidden_states = self.patch_embedding(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
hidden_states = sequence_model_parallel_shard(hidden_states, dim=1)

temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
timestep, encoder_hidden_states, encoder_hidden_states_image)
Expand Down Expand Up @@ -681,15 +670,15 @@ def forward(self,
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2,
dim=1)
hidden_states = self.norm_out(hidden_states, shift, scale)
hidden_states = sequence_model_parallel_all_gather(hidden_states, dim=1)
hidden_states = self.proj_out(hidden_states)

hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames,
post_patch_height,
post_patch_width, p_t, p_h, p_w,
-1)
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)

return output

def maybe_cache_states(self, hidden_states: torch.Tensor,
Expand Down
2 changes: 1 addition & 1 deletion fastvideo/models/vaes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import torch
import torch.distributed as dist
from diffusers.utils.torch_utils import randn_tensor
from fastvideo.utils import randn_tensor

from fastvideo.configs.models import VAEConfig
from fastvideo.distributed import get_sp_parallel_rank, get_sp_world_size
Expand Down
1 change: 1 addition & 0 deletions fastvideo/pipelines/pipeline_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"WanPipeline": "wan",
"WanDMDPipeline": "wan",
"WanImageToVideoPipeline": "wan",
"WanDMDPipeline": "wan",
"StepVideoPipeline": "stepvideo",
"HunyuanVideoPipeline": "hunyuan",
}
Expand Down
Loading