-
Notifications
You must be signed in to change notification settings - Fork 182
[Feat] Improve SP to support any resolution. #664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
9607cde
1c49486
46921af
5bad70a
4ac0603
9658f73
caad406
61c17c8
f3fc2df
bd19f3d
66249ee
f3f94e2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -4,7 +4,10 @@ | |||||
import torch | ||||||
import torch.distributed | ||||||
|
||||||
from fastvideo.distributed.parallel_state import get_sp_group, get_tp_group | ||||||
from fastvideo.distributed.parallel_state import (get_sp_group, | ||||||
get_sp_parallel_rank, | ||||||
get_sp_world_size, | ||||||
get_tp_group) | ||||||
|
||||||
|
||||||
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: | ||||||
|
@@ -30,3 +33,16 @@ def sequence_model_parallel_all_gather(input_: torch.Tensor, | |||||
dim: int = -1) -> torch.Tensor: | ||||||
"""All-gather the input tensor across model parallel group.""" | ||||||
return get_sp_group().all_gather(input_, dim) | ||||||
|
||||||
def sequence_model_parallel_shard(input_: torch.Tensor, | ||||||
dim: int = 1) -> torch.Tensor: | ||||||
"""All-gather the input tensor across model parallel group.""" | ||||||
sp_rank = get_sp_parallel_rank() | ||||||
sp_world_size = get_sp_world_size() | ||||||
assert input_.shape[dim] % sp_world_size == 0, "input tensor dim={dim} must be divisible by sp_world_size" | ||||||
|
assert input_.shape[dim] % sp_world_size == 0, "input tensor dim={dim} must be divisible by sp_world_size" | |
assert input_.shape[dim] % sp_world_size == 0, f"input tensor dim={dim} must be divisible by sp_world_size" |
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -12,7 +12,8 @@ | |||||||||||||
LocalAttention) | ||||||||||||||
from fastvideo.configs.models.dits import WanVideoConfig | ||||||||||||||
from fastvideo.configs.sample.wan import WanTeaCacheParams | ||||||||||||||
from fastvideo.distributed.parallel_state import get_sp_world_size | ||||||||||||||
from fastvideo.distributed.communication_op import sequence_model_parallel_all_gather, sequence_model_parallel_shard | ||||||||||||||
from fastvideo.distributed.parallel_state import get_sp_parallel_rank, get_sp_world_size | ||||||||||||||
from fastvideo.forward_context import get_forward_context | ||||||||||||||
from fastvideo.layers.layernorm import (FP32LayerNorm, LayerNormScaleShift, | ||||||||||||||
RMSNorm, ScaleResidual, | ||||||||||||||
|
@@ -21,8 +22,7 @@ | |||||||||||||
# from torch.nn import RMSNorm | ||||||||||||||
# TODO: RMSNorm .... | ||||||||||||||
from fastvideo.layers.mlp import MLP | ||||||||||||||
from fastvideo.layers.rotary_embedding import (_apply_rotary_emb, | ||||||||||||||
get_rotary_pos_embed) | ||||||||||||||
from fastvideo.layers.rotary_embedding import get_rotary_pos_embed | ||||||||||||||
from fastvideo.layers.visual_embedding import (ModulateProjection, PatchEmbed, | ||||||||||||||
TimestepEmbedder) | ||||||||||||||
from fastvideo.logger import init_logger | ||||||||||||||
|
@@ -328,13 +328,7 @@ def forward( | |||||||||||||
key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) | ||||||||||||||
value = value.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) | ||||||||||||||
|
||||||||||||||
# Apply rotary embeddings | ||||||||||||||
cos, sin = freqs_cis | ||||||||||||||
query, key = _apply_rotary_emb(query, cos, sin, | ||||||||||||||
is_neox_style=False), _apply_rotary_emb( | ||||||||||||||
key, cos, sin, is_neox_style=False) | ||||||||||||||
|
||||||||||||||
attn_output, _ = self.attn1(query, key, value) | ||||||||||||||
attn_output, _ = self.attn1(query, key, value, freqs_cis=freqs_cis) | ||||||||||||||
attn_output = attn_output.flatten(2) | ||||||||||||||
attn_output, _ = self.to_out(attn_output) | ||||||||||||||
attn_output = attn_output.squeeze(1) | ||||||||||||||
|
@@ -476,15 +470,11 @@ def forward( | |||||||||||||
gate_compress = gate_compress.squeeze(1).unflatten( | ||||||||||||||
2, (self.num_attention_heads, -1)) | ||||||||||||||
|
||||||||||||||
# Apply rotary embeddings | ||||||||||||||
cos, sin = freqs_cis | ||||||||||||||
query, key = _apply_rotary_emb(query, cos, sin, | ||||||||||||||
is_neox_style=False), _apply_rotary_emb( | ||||||||||||||
key, cos, sin, is_neox_style=False) | ||||||||||||||
|
||||||||||||||
attn_output, _ = self.attn1(query, | ||||||||||||||
key, | ||||||||||||||
value, | ||||||||||||||
freqs_cis=freqs_cis, | ||||||||||||||
gate_compress=gate_compress) | ||||||||||||||
attn_output = attn_output.flatten(2) | ||||||||||||||
attn_output, _ = self.to_out(attn_output) | ||||||||||||||
|
@@ -622,20 +612,23 @@ def forward(self, | |||||||||||||
d = self.hidden_size // self.num_attention_heads | ||||||||||||||
rope_dim_list = [d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)] | ||||||||||||||
freqs_cos, freqs_sin = get_rotary_pos_embed( | ||||||||||||||
(post_patch_num_frames * get_sp_world_size(), post_patch_height, | ||||||||||||||
(post_patch_num_frames, post_patch_height, | ||||||||||||||
post_patch_width), | ||||||||||||||
self.hidden_size, | ||||||||||||||
self.num_attention_heads, | ||||||||||||||
rope_dim_list, | ||||||||||||||
dtype=torch.float32 if current_platform.is_mps() else torch.float64, | ||||||||||||||
rope_theta=10000) | ||||||||||||||
freqs_cos = freqs_cos.to(hidden_states.device) | ||||||||||||||
freqs_sin = freqs_sin.to(hidden_states.device) | ||||||||||||||
freqs_cis = (freqs_cos.float(), | ||||||||||||||
freqs_sin.float()) if freqs_cos is not None else None | ||||||||||||||
freqs_cis = (freqs_cos.to(hidden_states.device).float(), | ||||||||||||||
freqs_sin.to(hidden_states.device).float()) | ||||||||||||||
|
||||||||||||||
hidden_states = self.patch_embedding(hidden_states) | ||||||||||||||
hidden_states = hidden_states.flatten(2).transpose(1, 2) | ||||||||||||||
# hidden_states = sequence_model_parallel_shard(hidden_states, dim=1) | ||||||||||||||
sp_rank = get_sp_parallel_rank() | ||||||||||||||
sp_world_size = get_sp_world_size() | ||||||||||||||
elements_per_rank = hidden_states.shape[1] // sp_world_size | ||||||||||||||
hidden_states = hidden_states[:, sp_rank*elements_per_rank:(sp_rank+1)*elements_per_rank] | ||||||||||||||
|
# hidden_states = sequence_model_parallel_shard(hidden_states, dim=1) | |
sp_rank = get_sp_parallel_rank() | |
sp_world_size = get_sp_world_size() | |
elements_per_rank = hidden_states.shape[1] // sp_world_size | |
hidden_states = hidden_states[:, sp_rank*elements_per_rank:(sp_rank+1)*elements_per_rank] | |
hidden_states = sequence_model_parallel_shard(hidden_states, dim=1) |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The all_gather
operation is implemented manually. For consistency and to leverage the existing abstraction, please use the sequence_model_parallel_all_gather
function which is already imported and commented out.
# hidden_states = sequence_model_parallel_all_gather(hidden_states, dim=1) | |
output_tensor = [torch.empty_like(hidden_states) for _ in range(sp_world_size)] | |
hidden_states = torch.distributed.all_gather(output_tensor, hidden_states) | |
hidden_states = torch.cat(output_tensor, dim=1) | |
hidden_states = sequence_model_parallel_all_gather(hidden_states, dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring for this function is incorrect. It describes an all-gather operation, but the function performs sharding. Please update the docstring to accurately reflect the function's behavior.