Skip to content

Conversation

jzhang38
Copy link
Collaborator

Before this PR, we shard latents in their temporal (number of latent frames) dimension before feed them to DiT. This requires the number of latent frames to be multiple of SP size.
After this PR, we shard tensors inside DiT (after pachification). So long the dit_sequence_length is multiple if SP size. It should work.

@jzhang38 jzhang38 marked this pull request as draft July 29, 2025 20:21
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @jzhang38, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

I've refactored the sequence parallelism (SP) implementation to enhance support for arbitrary resolutions. Previously, latent frames were sharded before being fed into the DiT model, requiring the number of latent frames to be a multiple of the SP size. With this change, sharding now occurs inside the DiT model, specifically after the patchification step. This new approach removes the prior constraint, making the system more flexible as long as the dit_sequence_length is a multiple of the SP size.

Highlights

  • Sequence Parallelism Shifting: The mechanism for sharding tensors for Sequence Parallelism (SP) has been moved from an external preprocessing step to an internal one within the DiT model, occurring after patchification.
  • Enhanced Resolution Flexibility: This refactoring eliminates the previous requirement for latent frames to be a multiple of the SP size, allowing for greater flexibility in handling various input resolutions.
  • Rotary Embedding Integration: Rotary embeddings are now directly applied within the DistributedAttention and LocalAttention layers, ensuring proper positional encoding during distributed attention computations.
  • New SP Sharding Utility: A dedicated sequence_model_parallel_shard utility has been introduced to streamline the sharding of tensors for sequence parallelism.
  • Configurable Rotary Embedding Sharding: The get_rotary_pos_embed function now includes a do_sp_sharding parameter, providing explicit control over whether positional embeddings are sharded for sequence parallelism.
  • Multi-GPU Inference Update: The default number of GPUs for the v1_inference_wan_dmd.sh script has been updated from 1 to 2, indicating readiness for multi-GPU distributed inference.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors sequence parallelism by moving the sharding logic from the pipeline stage into the DiT model. This is a good change that increases modularity and supports more flexible input resolutions. The changes to handle rotary embeddings within the attention layers appear correct. I've identified a few areas for improvement in fastvideo/distributed/communication_op.py and fastvideo/models/dits/wanvideo.py related to code clarity, correctness, and the use of abstractions.


def sequence_model_parallel_shard(input_: torch.Tensor,
dim: int = 1) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for this function is incorrect. It describes an all-gather operation, but the function performs sharding. Please update the docstring to accurately reflect the function's behavior.

Suggested change
"""All-gather the input tensor across model parallel group."""
"""Shard the input tensor across model parallel group."""

"""All-gather the input tensor across model parallel group."""
sp_rank = get_sp_parallel_rank()
sp_world_size = get_sp_world_size()
assert input_.shape[dim] % sp_world_size == 0, "input tensor dim={dim} must be divisible by sp_world_size"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The string in the assert statement is not an f-string, so {dim} will be treated as a literal string instead of being replaced by the value of the dim variable. This can make debugging more difficult.

Suggested change
assert input_.shape[dim] % sp_world_size == 0, "input tensor dim={dim} must be divisible by sp_world_size"
assert input_.shape[dim] % sp_world_size == 0, f"input tensor dim={dim} must be divisible by sp_world_size"

Comment on lines 627 to 631
# hidden_states = sequence_model_parallel_shard(hidden_states, dim=1)
sp_rank = get_sp_parallel_rank()
sp_world_size = get_sp_world_size()
elements_per_rank = hidden_states.shape[1] // sp_world_size
hidden_states = hidden_states[:, sp_rank*elements_per_rank:(sp_rank+1)*elements_per_rank]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The sharding logic is implemented manually here, while there's a sequence_model_parallel_shard function available (and commented out) that encapsulates this logic. Using the existing abstraction would make the code cleaner and more maintainable.

Suggested change
# hidden_states = sequence_model_parallel_shard(hidden_states, dim=1)
sp_rank = get_sp_parallel_rank()
sp_world_size = get_sp_world_size()
elements_per_rank = hidden_states.shape[1] // sp_world_size
hidden_states = hidden_states[:, sp_rank*elements_per_rank:(sp_rank+1)*elements_per_rank]
hidden_states = sequence_model_parallel_shard(hidden_states, dim=1)

Comment on lines 677 to 680
# hidden_states = sequence_model_parallel_all_gather(hidden_states, dim=1)
output_tensor = [torch.empty_like(hidden_states) for _ in range(sp_world_size)]
hidden_states = torch.distributed.all_gather(output_tensor, hidden_states)
hidden_states = torch.cat(output_tensor, dim=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The all_gather operation is implemented manually. For consistency and to leverage the existing abstraction, please use the sequence_model_parallel_all_gather function which is already imported and commented out.

Suggested change
# hidden_states = sequence_model_parallel_all_gather(hidden_states, dim=1)
output_tensor = [torch.empty_like(hidden_states) for _ in range(sp_world_size)]
hidden_states = torch.distributed.all_gather(output_tensor, hidden_states)
hidden_states = torch.cat(output_tensor, dim=1)
hidden_states = sequence_model_parallel_all_gather(hidden_states, dim=1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant