Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions fastvideo/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from fastvideo.forward_context import ForwardContext, get_forward_context
from fastvideo.platforms import AttentionBackendEnum
from fastvideo.utils import get_compute_dtype

from fastvideo.layers.rotary_embedding import _apply_rotary_emb

class DistributedAttention(nn.Module):
"""Distributed attention layer.
Expand Down Expand Up @@ -63,6 +63,7 @@ def forward(
replicated_q: torch.Tensor | None = None,
replicated_k: torch.Tensor | None = None,
replicated_v: torch.Tensor | None = None,
freqs_cis: tuple[torch.Tensor, torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Forward pass for distributed attention.

Expand Down Expand Up @@ -96,6 +97,11 @@ def forward(
qkv = sequence_model_parallel_all_to_all_4D(qkv,
scatter_dim=2,
gather_dim=1)
if freqs_cis is not None:
cos, sin = freqs_cis
# apply to q and k
qkv[:batch_size*2] = _apply_rotary_emb(qkv[:batch_size*2], cos, sin, is_neox_style=False)

# Apply backend-specific preprocess_qkv
qkv = self.attn_impl.preprocess_qkv(qkv, ctx_attn_metadata)

Expand Down Expand Up @@ -145,6 +151,7 @@ def forward(
replicated_k: torch.Tensor | None = None,
replicated_v: torch.Tensor | None = None,
gate_compress: torch.Tensor | None = None,
freqs_cis: tuple[torch.Tensor, torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Forward pass for distributed attention.

Expand All @@ -170,7 +177,7 @@ def forward(

forward_context: ForwardContext = get_forward_context()
ctx_attn_metadata = forward_context.attn_metadata

batch_size, seq_len, num_heads, head_dim = q.shape
# Stack QKV
qkvg = torch.cat([q, k, v, gate_compress],
dim=0) # [3, seq_len, num_heads, head_dim]
Expand All @@ -180,9 +187,14 @@ def forward(
scatter_dim=2,
gather_dim=1)

qkvg = self.attn_impl.preprocess_qkv(qkvg, ctx_attn_metadata)
if freqs_cis is not None:
cos, sin = freqs_cis
qkvg[:batch_size*2] = _apply_rotary_emb(qkvg[:batch_size*2], cos, sin, is_neox_style=False)

qkvg = self.attn_impl.preprocess_qkv(qkvg, ctx_attn_metadata)

q, k, v, gate_compress = qkvg.chunk(4, dim=0)

output = self.attn_impl.forward(
q, k, v, gate_compress, ctx_attn_metadata) # type: ignore[call-arg]

Expand Down Expand Up @@ -242,6 +254,7 @@ def forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
freqs_cis: tuple[torch.Tensor, torch.Tensor] | None = None,
) -> torch.Tensor:
"""
Apply local attention between query, key and value tensors.
Expand All @@ -260,6 +273,10 @@ def forward(

forward_context: ForwardContext = get_forward_context()
ctx_attn_metadata = forward_context.attn_metadata
if freqs_cis is not None:
cos, sin = freqs_cis
q = _apply_rotary_emb(q, cos, sin, is_neox_style=False)
k = _apply_rotary_emb(k, cos, sin, is_neox_style=False)

output = self.attn_impl.forward(q, k, v, ctx_attn_metadata)
return output
18 changes: 17 additions & 1 deletion fastvideo/distributed/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import torch
import torch.distributed

from fastvideo.distributed.parallel_state import get_sp_group, get_tp_group
from fastvideo.distributed.parallel_state import (get_sp_group,
get_sp_parallel_rank,
get_sp_world_size,
get_tp_group)


def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
Expand All @@ -30,3 +33,16 @@ def sequence_model_parallel_all_gather(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
return get_sp_group().all_gather(input_, dim)

def sequence_model_parallel_shard(input_: torch.Tensor,
dim: int = 1) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for this function is incorrect. It describes an all-gather operation, but the function performs sharding. Please update the docstring to accurately reflect the function's behavior.

Suggested change
"""All-gather the input tensor across model parallel group."""
"""Shard the input tensor across model parallel group."""

sp_rank = get_sp_parallel_rank()
sp_world_size = get_sp_world_size()
assert input_.shape[dim] % sp_world_size == 0, "input tensor dim={dim} must be divisible by sp_world_size"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The string in the assert statement is not an f-string, so {dim} will be treated as a literal string instead of being replaced by the value of the dim variable. This can make debugging more difficult.

Suggested change
assert input_.shape[dim] % sp_world_size == 0, "input tensor dim={dim} must be divisible by sp_world_size"
assert input_.shape[dim] % sp_world_size == 0, f"input tensor dim={dim} must be divisible by sp_world_size"

elements_per_rank = input_.shape[dim] // sp_world_size
# sharding dim
input_ = input_.movedim(dim, 0)
input_ = input_[sp_rank*elements_per_rank:(sp_rank+1)*elements_per_rank]
input_ = input_.movedim(0, dim)
return input_
13 changes: 9 additions & 4 deletions fastvideo/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def get_rotary_pos_embed(
theta_rescale_factor=1.0,
interpolation_factor=1.0,
shard_dim: int = 0,
do_sp_sharding: bool = False,
dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Expand All @@ -383,7 +384,7 @@ def get_rotary_pos_embed(
theta_rescale_factor: Rescale factor for theta. Defaults to 1.0
interpolation_factor: Factor to scale positions. Defaults to 1.0
shard_dim: Which dimension to shard for sequence parallelism. Defaults to 0.

do_sp_sharding: Whether to shard the positional embeddings for sequence parallelism. Defaults to False.
Returns:
Tuple of (cos, sin) tensors for rotary embeddings
"""
Expand All @@ -399,9 +400,13 @@ def get_rotary_pos_embed(
) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"

# Get SP info
sp_group = get_sp_group()
sp_rank = sp_group.rank_in_group
sp_world_size = sp_group.world_size
if do_sp_sharding:
sp_group = get_sp_group()
sp_rank = sp_group.rank_in_group
sp_world_size = sp_group.world_size
else:
sp_rank = 0
sp_world_size = 1

freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list,
Expand Down
37 changes: 17 additions & 20 deletions fastvideo/models/dits/wanvideo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
LocalAttention)
from fastvideo.configs.models.dits import WanVideoConfig
from fastvideo.configs.sample.wan import WanTeaCacheParams
from fastvideo.distributed.parallel_state import get_sp_world_size
from fastvideo.distributed.communication_op import sequence_model_parallel_all_gather, sequence_model_parallel_shard
from fastvideo.distributed.parallel_state import get_sp_parallel_rank, get_sp_world_size
from fastvideo.forward_context import get_forward_context
from fastvideo.layers.layernorm import (FP32LayerNorm, LayerNormScaleShift,
RMSNorm, ScaleResidual,
Expand All @@ -21,8 +22,7 @@
# from torch.nn import RMSNorm
# TODO: RMSNorm ....
from fastvideo.layers.mlp import MLP
from fastvideo.layers.rotary_embedding import (_apply_rotary_emb,
get_rotary_pos_embed)
from fastvideo.layers.rotary_embedding import get_rotary_pos_embed
from fastvideo.layers.visual_embedding import (ModulateProjection, PatchEmbed,
TimestepEmbedder)
from fastvideo.logger import init_logger
Expand Down Expand Up @@ -328,13 +328,7 @@ def forward(
key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
value = value.squeeze(1).unflatten(2, (self.num_attention_heads, -1))

# Apply rotary embeddings
cos, sin = freqs_cis
query, key = _apply_rotary_emb(query, cos, sin,
is_neox_style=False), _apply_rotary_emb(
key, cos, sin, is_neox_style=False)

attn_output, _ = self.attn1(query, key, value)
attn_output, _ = self.attn1(query, key, value, freqs_cis=freqs_cis)
attn_output = attn_output.flatten(2)
attn_output, _ = self.to_out(attn_output)
attn_output = attn_output.squeeze(1)
Expand Down Expand Up @@ -476,15 +470,11 @@ def forward(
gate_compress = gate_compress.squeeze(1).unflatten(
2, (self.num_attention_heads, -1))

# Apply rotary embeddings
cos, sin = freqs_cis
query, key = _apply_rotary_emb(query, cos, sin,
is_neox_style=False), _apply_rotary_emb(
key, cos, sin, is_neox_style=False)

attn_output, _ = self.attn1(query,
key,
value,
freqs_cis=freqs_cis,
gate_compress=gate_compress)
attn_output = attn_output.flatten(2)
attn_output, _ = self.to_out(attn_output)
Expand Down Expand Up @@ -622,20 +612,23 @@ def forward(self,
d = self.hidden_size // self.num_attention_heads
rope_dim_list = [d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)]
freqs_cos, freqs_sin = get_rotary_pos_embed(
(post_patch_num_frames * get_sp_world_size(), post_patch_height,
(post_patch_num_frames, post_patch_height,
post_patch_width),
self.hidden_size,
self.num_attention_heads,
rope_dim_list,
dtype=torch.float32 if current_platform.is_mps() else torch.float64,
rope_theta=10000)
freqs_cos = freqs_cos.to(hidden_states.device)
freqs_sin = freqs_sin.to(hidden_states.device)
freqs_cis = (freqs_cos.float(),
freqs_sin.float()) if freqs_cos is not None else None
freqs_cis = (freqs_cos.to(hidden_states.device).float(),
freqs_sin.to(hidden_states.device).float())

hidden_states = self.patch_embedding(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
# hidden_states = sequence_model_parallel_shard(hidden_states, dim=1)
sp_rank = get_sp_parallel_rank()
sp_world_size = get_sp_world_size()
elements_per_rank = hidden_states.shape[1] // sp_world_size
hidden_states = hidden_states[:, sp_rank*elements_per_rank:(sp_rank+1)*elements_per_rank]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The sharding logic is implemented manually here, while there's a sequence_model_parallel_shard function available (and commented out) that encapsulates this logic. Using the existing abstraction would make the code cleaner and more maintainable.

Suggested change
# hidden_states = sequence_model_parallel_shard(hidden_states, dim=1)
sp_rank = get_sp_parallel_rank()
sp_world_size = get_sp_world_size()
elements_per_rank = hidden_states.shape[1] // sp_world_size
hidden_states = hidden_states[:, sp_rank*elements_per_rank:(sp_rank+1)*elements_per_rank]
hidden_states = sequence_model_parallel_shard(hidden_states, dim=1)


temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
timestep, encoder_hidden_states, encoder_hidden_states_image)
Expand Down Expand Up @@ -681,6 +674,10 @@ def forward(self,
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2,
dim=1)
hidden_states = self.norm_out(hidden_states, shift, scale)
# hidden_states = sequence_model_parallel_all_gather(hidden_states, dim=1)
output_tensor = [torch.empty_like(hidden_states) for _ in range(sp_world_size)]
hidden_states = torch.distributed.all_gather(output_tensor, hidden_states)
hidden_states = torch.cat(output_tensor, dim=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The all_gather operation is implemented manually. For consistency and to leverage the existing abstraction, please use the sequence_model_parallel_all_gather function which is already imported and commented out.

Suggested change
# hidden_states = sequence_model_parallel_all_gather(hidden_states, dim=1)
output_tensor = [torch.empty_like(hidden_states) for _ in range(sp_world_size)]
hidden_states = torch.distributed.all_gather(output_tensor, hidden_states)
hidden_states = torch.cat(output_tensor, dim=1)
hidden_states = sequence_model_parallel_all_gather(hidden_states, dim=1)

hidden_states = self.proj_out(hidden_states)

hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames,
Expand Down
1 change: 1 addition & 0 deletions fastvideo/pipelines/pipeline_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
_PIPELINE_NAME_TO_ARCHITECTURE_NAME: dict[str, str] = {
"WanPipeline": "wan",
"WanImageToVideoPipeline": "wan",
"WanDMDPipeline": "wan",
"StepVideoPipeline": "stepvideo",
"HunyuanVideoPipeline": "hunyuan",
}
Expand Down
51 changes: 4 additions & 47 deletions fastvideo/pipelines/stages/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@

from fastvideo.attention import get_attn_backend
from fastvideo.configs.pipelines.base import STA_Mode
from fastvideo.distributed import (get_local_torch_device, get_sp_parallel_rank,
get_sp_world_size, get_world_group)
from fastvideo.distributed.communication_op import (
sequence_model_parallel_all_gather)
from fastvideo.distributed import (get_local_torch_device, get_world_group)
from fastvideo.fastvideo_args import FastVideoArgs
from fastvideo.forward_context import set_forward_context
from fastvideo.logger import init_logger
Expand Down Expand Up @@ -113,22 +110,7 @@ def forward(
autocast_enabled = (target_dtype != torch.float32
) and not fastvideo_args.disable_autocast

# Handle sequence parallelism if enabled
sp_world_size, rank_in_sp_group = get_sp_world_size(
), get_sp_parallel_rank()
sp_group = sp_world_size > 1
if sp_group:
latents = rearrange(batch.latents,
"b c (n t) h w -> b c n t h w",
n=sp_world_size).contiguous()
latents = latents[:, :, rank_in_sp_group, :, :, :]
batch.latents = latents
if batch.image_latent is not None:
image_latent = rearrange(batch.image_latent,
"b c (n t) h w -> b c n t h w",
n=sp_world_size).contiguous()
image_latent = image_latent[:, :, rank_in_sp_group, :, :, :]
batch.image_latent = image_latent

# Get timesteps and calculate warmup steps
timesteps = batch.timesteps
# TODO(will): remove this once we add input/output validation for stages
Expand Down Expand Up @@ -311,9 +293,7 @@ def forward(
and progress_bar is not None):
progress_bar.update()

# Gather results if using sequence parallelism
if sp_group:
latents = sequence_model_parallel_all_gather(latents, dim=2)


# Update batch with final latents
batch.latents = latents
Expand Down Expand Up @@ -664,22 +644,6 @@ def forward(
dtype=torch.long,
device=get_local_torch_device())

# Handle sequence parallelism if enabled
sp_world_size, rank_in_sp_group = get_sp_world_size(
), get_sp_parallel_rank()
sp_group = sp_world_size > 1
if sp_group:
latents = rearrange(latents,
"b (n t) c h w -> b n t c h w",
n=sp_world_size).contiguous()
latents = latents[:, rank_in_sp_group, :, :, :, :]
if batch.image_latent is not None:
image_latent = rearrange(batch.image_latent,
"b c (n t) h w -> b c n t h w",
n=sp_world_size).contiguous()

image_latent = image_latent[:, :, rank_in_sp_group, :, :, :]
batch.image_latent = image_latent

# Run denoising loop
with self.progress_bar(total=len(timesteps)) as progress_bar:
Expand Down Expand Up @@ -776,11 +740,6 @@ def forward(
noise = torch.randn(video_raw_latent_shape,
device=self.device,
dtype=pred_video.dtype)
if sp_group:
noise = rearrange(noise,
"b (n t) c h w -> b n t c h w",
n=sp_world_size).contiguous()
noise = noise[:, rank_in_sp_group, :, :, :, :]
latents = self.scheduler.add_noise(
pred_video.flatten(0, 1), noise.flatten(0, 1),
next_timestep).unflatten(0, pred_video.shape[:2])
Expand All @@ -794,9 +753,7 @@ def forward(
and progress_bar is not None):
progress_bar.update()

# Gather results if using sequence parallelism
if sp_group:
latents = sequence_model_parallel_all_gather(latents, dim=1)

latents = latents.permute(0, 2, 1, 3, 4)
# Update batch with final latents
batch.latents = latents
Expand Down
2 changes: 1 addition & 1 deletion scripts/inference/v1_inference_wan_dmd.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash

num_gpus=1
num_gpus=2
export FASTVIDEO_ATTENTION_BACKEND=VIDEO_SPARSE_ATTN
export MODEL_BASE=FastVideo/FastWan2.1-T2V-1.3B-Diffusers
# export MODEL_BASE=hunyuanvideo-community/HunyuanVideo
Expand Down