Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions fastvideo/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,9 @@ def forward(self, residual: torch.Tensor, x: torch.Tensor,
frame_seqlen = normalized.shape[1] // num_frames
modulated = (
normalized.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) *
(1.0 + scale) + shift).flatten(1, 2)
(1 + scale) + shift).flatten(1, 2)
else:
modulated = normalized * (1.0 + scale) + shift
modulated = normalized * (1 + scale) + shift
return modulated, residual_output


Expand Down Expand Up @@ -267,13 +267,13 @@ def forward(self, x: torch.Tensor, shift: torch.Tensor,
frame_seqlen = normalized.shape[1] // num_frames
output = (
normalized.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) *
(1.0 + scale) + shift).flatten(1, 2)
(1 + scale) + shift).flatten(1, 2)
else:
# scale.shape: [batch_size, 1, inner_dim]
# shift.shape: [batch_size, 1, inner_dim]
output = normalized * (1.0 + scale) + shift
output = normalized * (1 + scale) + shift

if self.compute_dtype == torch.float32:
output = output.to(x.dtype)

return output
return output
91 changes: 54 additions & 37 deletions fastvideo/models/dits/causal_wanvideo.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def __init__(self,
super().__init__()

# 1. Self-attention
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.norm1 = nn.LayerNorm(dim, eps, elementwise_affine=False)
self.to_q = ReplicatedLinear(dim, dim, bias=True)
self.to_k = ReplicatedLinear(dim, dim, bias=True)
self.to_v = ReplicatedLinear(dim, dim, bias=True)
Expand Down Expand Up @@ -212,8 +212,7 @@ def __init__(self,
norm_type="layer",
eps=eps,
elementwise_affine=True,
dtype=torch.float32,
compute_dtype=torch.float32)
dtype=torch.float32)

# 2. Cross-attention
# Only T2V for now
Expand All @@ -226,8 +225,7 @@ def __init__(self,
norm_type="layer",
eps=eps,
elementwise_affine=False,
dtype=torch.float32,
compute_dtype=torch.float32)
dtype=torch.float32)

# 3. Feed-forward
self.ffn = MLP(dim, ffn_dim, act_type="gelu_pytorch_tanh")
Expand All @@ -252,29 +250,29 @@ def forward(
if hidden_states.dim() == 4:
hidden_states = hidden_states.squeeze(1)
num_frames = temb.shape[1]
frame_seqlen = hidden_states.shape[1] // num_frames
frame_seqlen = hidden_states.shape[1] // num_frames
bs, seq_length, _ = hidden_states.shape
orig_dtype = hidden_states.dtype
# assert orig_dtype != torch.float32
e = self.scale_shift_table + temb.float()
e = self.scale_shift_table + temb
# e.shape: [batch_size, num_frames, 6, inner_dim]
assert e.shape == (bs, num_frames, 6, self.hidden_dim)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = e.chunk(
6, dim=2)
# *_msa.shape: [batch_size, num_frames, 1, inner_dim]
assert shift_msa.dtype == torch.float32
# assert shift_msa.dtype == torch.float32

# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states.float()).unflatten(dim=1, sizes=(num_frames, frame_seqlen)) *
(1 + scale_msa) + shift_msa).flatten(1, 2).to(orig_dtype)
norm_hidden_states = (self.norm1(hidden_states).unflatten(dim=1, sizes=(num_frames, frame_seqlen)) *
(1 + scale_msa) + shift_msa).flatten(1, 2)
query, _ = self.to_q(norm_hidden_states)
key, _ = self.to_k(norm_hidden_states)
value, _ = self.to_v(norm_hidden_states)

if self.norm_q is not None:
query = self.norm_q(query)
query = self.norm_q.forward_native(query)
if self.norm_k is not None:
key = self.norm_k(key)
key = self.norm_k.forward_native(key)

query = query.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
Expand All @@ -288,8 +286,6 @@ def forward(
null_shift = null_scale = torch.tensor([0], device=hidden_states.device)
norm_hidden_states, hidden_states = self.self_attn_residual_norm(
hidden_states, attn_output, gate_msa, null_shift, null_scale)
norm_hidden_states, hidden_states = norm_hidden_states.to(
orig_dtype), hidden_states.to(orig_dtype)

# 2. Cross-attention
attn_output = self.attn2(norm_hidden_states,
Expand All @@ -298,13 +294,10 @@ def forward(
crossattn_cache=crossattn_cache)
norm_hidden_states, hidden_states = self.cross_attn_residual_norm(
hidden_states, attn_output, 1, c_shift_msa, c_scale_msa)
norm_hidden_states, hidden_states = norm_hidden_states.to(
orig_dtype), hidden_states.to(orig_dtype)

# 3. Feed-forward
ff_output = self.ffn(norm_hidden_states)
hidden_states = self.mlp_residual(hidden_states, ff_output, c_gate_msa)
hidden_states = hidden_states.to(orig_dtype)

return hidden_states

Expand Down Expand Up @@ -367,8 +360,7 @@ def __init__(self, config: WanVideoConfig, hf_config: dict[str,
norm_type="layer",
eps=config.eps,
elementwise_affine=False,
dtype=torch.float32,
compute_dtype=torch.float32)
dtype=torch.float32)
self.proj_out = nn.Linear(
inner_dim, config.out_channels * math.prod(config.patch_size))
self.scale_shift_table = nn.Parameter(
Expand Down Expand Up @@ -491,12 +483,16 @@ def _forward_inference(
)
freqs_cos = freqs_cos.to(hidden_states.device)
freqs_sin = freqs_sin.to(hidden_states.device)
freqs_cis = (freqs_cos.float(),
freqs_sin.float()) if freqs_cos is not None else None
freqs_cis = (freqs_cos,
freqs_sin) if freqs_cos is not None else None

hidden_states = self.patch_embedding(hidden_states)
grid_sizes = torch.stack(
[torch.tensor(hidden_states[0].shape[1:], dtype=torch.long)])
hidden_states = hidden_states.flatten(2).transpose(1, 2)

encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states.new_zeros(1, self.text_len - encoder_hidden_states.size(1), encoder_hidden_states.size(2))], dim=1)

temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
timestep.flatten(), encoder_hidden_states, encoder_hidden_states_image)
timestep_proj = timestep_proj.unflatten(1, (6, self.hidden_size)).unflatten(dim=0, sizes=timestep.shape)
Expand Down Expand Up @@ -543,14 +539,9 @@ def _forward_inference(
hidden_states = self.norm_out(hidden_states, shift, scale)
hidden_states = self.proj_out(hidden_states)

hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames,
post_patch_height,
post_patch_width, p_t, p_h, p_w,
-1)
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
output = self.unpatchify(hidden_states, grid_sizes)

return output
return torch.stack(output)

def _forward_train(self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -591,8 +582,8 @@ def _forward_train(self,
)
freqs_cos = freqs_cos.to(hidden_states.device)
freqs_sin = freqs_sin.to(hidden_states.device)
freqs_cis = (freqs_cos.float(),
freqs_sin.float()) if freqs_cos is not None else None
freqs_cis = (freqs_cos,
freqs_sin) if freqs_cos is not None else None

# Construct blockwise causal attn mask
if self.block_mask is None:
Expand All @@ -605,8 +596,12 @@ def _forward_train(self,
)

hidden_states = self.patch_embedding(hidden_states)
grid_sizes = torch.stack(
[torch.tensor(hidden_states[0].shape[1:], dtype=torch.long)])
hidden_states = hidden_states.flatten(2).transpose(1, 2)

encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states.new_zeros(1, self.text_len - encoder_hidden_states.size(1), encoder_hidden_states.size(2))], dim=1)

temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
timestep.flatten(), encoder_hidden_states, encoder_hidden_states_image)
timestep_proj = timestep_proj.unflatten(1, (6, self.hidden_size)).unflatten(dim=0, sizes=timestep.shape)
Expand Down Expand Up @@ -641,14 +636,9 @@ def _forward_train(self,
hidden_states = self.norm_out(hidden_states, shift, scale)
hidden_states = self.proj_out(hidden_states)

hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames,
post_patch_height,
post_patch_width, p_t, p_h, p_w,
-1)
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
output = self.unpatchify(hidden_states, grid_sizes)

return output
return torch.stack(output)

def forward(
self,
Expand All @@ -659,3 +649,30 @@ def forward(
return self._forward_inference(*args, **kwargs)
else:
return self._forward_train(*args, **kwargs)


def unpatchify(self, x, grid_sizes):
r"""


Args:
x (List[Tensor]):
List of patchified features, each with shape [L, C_out * prod(patch_size)]
grid_sizes (Tensor):
Original spatial-temporal grid dimensions before patching,


Returns:
Tensor:
Reconstructed video tensors with shape [B, C_out, F, H / 8, W / 8]
"""

c = self.out_channels
out = []
for u, v in zip(x, grid_sizes.tolist()):
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
u = u.permute(6, 0, 3, 1, 4, 2, 5)
# u = torch.einsum('fhwpqrc->cfphqwr', u.contiguous())
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
out.append(u)
return out
Loading