-
Notifications
You must be signed in to change notification settings - Fork 178
[Feat] Improve SP to support any resolution. #664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @jzhang38, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
I've refactored the sequence parallelism (SP) implementation to enhance support for arbitrary resolutions. Previously, latent frames were sharded before being fed into the DiT model, requiring the number of latent frames to be a multiple of the SP size. With this change, sharding now occurs inside the DiT model, specifically after the patchification step. This new approach removes the prior constraint, making the system more flexible as long as the dit_sequence_length is a multiple of the SP size.
Highlights
- Sequence Parallelism Shifting: The mechanism for sharding tensors for Sequence Parallelism (SP) has been moved from an external preprocessing step to an internal one within the DiT model, occurring after patchification.
- Enhanced Resolution Flexibility: This refactoring eliminates the previous requirement for latent frames to be a multiple of the SP size, allowing for greater flexibility in handling various input resolutions.
- Rotary Embedding Integration: Rotary embeddings are now directly applied within the DistributedAttention and LocalAttention layers, ensuring proper positional encoding during distributed attention computations.
- New SP Sharding Utility: A dedicated sequence_model_parallel_shard utility has been introduced to streamline the sharding of tensors for sequence parallelism.
- Configurable Rotary Embedding Sharding: The get_rotary_pos_embed function now includes a do_sp_sharding parameter, providing explicit control over whether positional embeddings are sharded for sequence parallelism.
- Multi-GPU Inference Update: The default number of GPUs for the v1_inference_wan_dmd.sh script has been updated from 1 to 2, indicating readiness for multi-GPU distributed inference.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors sequence parallelism by moving the sharding logic from the pipeline stage into the DiT model. This is a good change that increases modularity and supports more flexible input resolutions. The changes to handle rotary embeddings within the attention layers appear correct. I've identified a few areas for improvement in fastvideo/distributed/communication_op.py
and fastvideo/models/dits/wanvideo.py
related to code clarity, correctness, and the use of abstractions.
|
||
def sequence_model_parallel_shard(input_: torch.Tensor, | ||
dim: int = 1) -> torch.Tensor: | ||
"""All-gather the input tensor across model parallel group.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring for this function is incorrect. It describes an all-gather operation, but the function performs sharding. Please update the docstring to accurately reflect the function's behavior.
"""All-gather the input tensor across model parallel group.""" | |
"""Shard the input tensor across model parallel group.""" |
"""All-gather the input tensor across model parallel group.""" | ||
sp_rank = get_sp_parallel_rank() | ||
sp_world_size = get_sp_world_size() | ||
assert input_.shape[dim] % sp_world_size == 0, "input tensor dim={dim} must be divisible by sp_world_size" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The string in the assert
statement is not an f-string, so {dim}
will be treated as a literal string instead of being replaced by the value of the dim
variable. This can make debugging more difficult.
assert input_.shape[dim] % sp_world_size == 0, "input tensor dim={dim} must be divisible by sp_world_size" | |
assert input_.shape[dim] % sp_world_size == 0, f"input tensor dim={dim} must be divisible by sp_world_size" |
fastvideo/models/dits/wanvideo.py
Outdated
# hidden_states = sequence_model_parallel_shard(hidden_states, dim=1) | ||
sp_rank = get_sp_parallel_rank() | ||
sp_world_size = get_sp_world_size() | ||
elements_per_rank = hidden_states.shape[1] // sp_world_size | ||
hidden_states = hidden_states[:, sp_rank*elements_per_rank:(sp_rank+1)*elements_per_rank] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sharding logic is implemented manually here, while there's a sequence_model_parallel_shard
function available (and commented out) that encapsulates this logic. Using the existing abstraction would make the code cleaner and more maintainable.
# hidden_states = sequence_model_parallel_shard(hidden_states, dim=1) | |
sp_rank = get_sp_parallel_rank() | |
sp_world_size = get_sp_world_size() | |
elements_per_rank = hidden_states.shape[1] // sp_world_size | |
hidden_states = hidden_states[:, sp_rank*elements_per_rank:(sp_rank+1)*elements_per_rank] | |
hidden_states = sequence_model_parallel_shard(hidden_states, dim=1) |
fastvideo/models/dits/wanvideo.py
Outdated
# hidden_states = sequence_model_parallel_all_gather(hidden_states, dim=1) | ||
output_tensor = [torch.empty_like(hidden_states) for _ in range(sp_world_size)] | ||
hidden_states = torch.distributed.all_gather(output_tensor, hidden_states) | ||
hidden_states = torch.cat(output_tensor, dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The all_gather
operation is implemented manually. For consistency and to leverage the existing abstraction, please use the sequence_model_parallel_all_gather
function which is already imported and commented out.
# hidden_states = sequence_model_parallel_all_gather(hidden_states, dim=1) | |
output_tensor = [torch.empty_like(hidden_states) for _ in range(sp_world_size)] | |
hidden_states = torch.distributed.all_gather(output_tensor, hidden_states) | |
hidden_states = torch.cat(output_tensor, dim=1) | |
hidden_states = sequence_model_parallel_all_gather(hidden_states, dim=1) |
Before this PR, we shard latents in their temporal (number of latent frames) dimension before feed them to DiT. This requires the number of latent frames to be multiple of SP size.
After this PR, we shard tensors inside DiT (after pachification). So long the dit_sequence_length is multiple if SP size. It should work.