Skip to content

Commit 65d7f4b

Browse files
committed
Remove extra logger statement
1 parent fd5cbab commit 65d7f4b

File tree

2 files changed

+10
-14
lines changed

2 files changed

+10
-14
lines changed

fastvideo/models/dits/causal_wanvideo.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ def forward(self,
147147
# Assign new keys/values directly up to current_end
148148
local_end_index = kv_cache["local_end_index"].item() + current_end - kv_cache["global_end_index"].item()
149149
local_start_index = local_end_index - num_new_tokens
150-
# kv_cache["k"] = kv_cache["k"].detach()
151-
# kv_cache["v"] = kv_cache["v"].detach()
150+
kv_cache["k"] = kv_cache["k"].detach()
151+
kv_cache["v"] = kv_cache["v"].detach()
152152
# logger.info("kv_cache['k'] is in comp graph: %s", kv_cache["k"].requires_grad or kv_cache["k"].grad_fn is not None)
153153
kv_cache["k"][:, local_start_index:local_end_index] = roped_key
154154
kv_cache["v"][:, local_start_index:local_end_index] = v
@@ -262,14 +262,9 @@ def forward(
262262
# *_msa.shape: [batch_size, num_frames, 1, inner_dim]
263263
# assert shift_msa.dtype == torch.float32
264264

265-
# logger.info("temb sum: %s, dtype: %s", temb.float().sum().item(), temb.dtype)
266-
# logger.info("scale_msa sum: %s, dtype: %s", scale_msa.float().sum().item(), scale_msa.dtype)
267-
# logger.info("shift_msa sum: %s, dtype: %s", shift_msa.float().sum().item(), shift_msa.dtype)
268-
269265
# 1. Self-attention
270266
norm_hidden_states = (self.norm1(hidden_states).unflatten(dim=1, sizes=(num_frames, frame_seqlen)) *
271267
(1 + scale_msa) + shift_msa).flatten(1, 2)
272-
# logger.info("norm_hidden_states sum: %s, shape: %s", norm_hidden_states.float().sum().item(), norm_hidden_states.shape)
273268
query, _ = self.to_q(norm_hidden_states)
274269
key, _ = self.to_k(norm_hidden_states)
275270
value, _ = self.to_v(norm_hidden_states)

fastvideo/models/dits/wanvideo.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -283,16 +283,16 @@ def __init__(self,
283283
# 2. Cross-attention
284284
if added_kv_proj_dim is not None:
285285
# I2V
286-
self.attn2 = WanI2VCrossAttention(dim,
287-
num_heads,
288-
qk_norm=qk_norm,
286+
self.attn2 = WanI2VCrossAttention(dim,
287+
num_heads,
288+
qk_norm=qk_norm,
289289
eps=eps)
290290

291291
else:
292292
# T2V
293-
self.attn2 = WanT2VCrossAttention(dim,
294-
num_heads,
295-
qk_norm=qk_norm,
293+
self.attn2 = WanT2VCrossAttention(dim,
294+
num_heads,
295+
qk_norm=qk_norm,
296296
eps=eps)
297297

298298
self.cross_attn_residual_norm = ScaleResidualLayerNormScaleShift(
@@ -835,4 +835,5 @@ def retrieve_cached_states(self,
835835
if self.is_even:
836836
return hidden_states + self.previous_residual_even
837837
else:
838-
return hidden_states + self.previous_residual_odd
838+
return hidden_states + self.previous_residual_odd
839+

0 commit comments

Comments
 (0)