@@ -147,8 +147,8 @@ def forward(self,
147
147
# Assign new keys/values directly up to current_end
148
148
local_end_index = kv_cache ["local_end_index" ].item () + current_end - kv_cache ["global_end_index" ].item ()
149
149
local_start_index = local_end_index - num_new_tokens
150
- # kv_cache["k"] = kv_cache["k"].detach()
151
- # kv_cache["v"] = kv_cache["v"].detach()
150
+ kv_cache ["k" ] = kv_cache ["k" ].detach ()
151
+ kv_cache ["v" ] = kv_cache ["v" ].detach ()
152
152
# logger.info("kv_cache['k'] is in comp graph: %s", kv_cache["k"].requires_grad or kv_cache["k"].grad_fn is not None)
153
153
kv_cache ["k" ][:, local_start_index :local_end_index ] = roped_key
154
154
kv_cache ["v" ][:, local_start_index :local_end_index ] = v
@@ -262,14 +262,9 @@ def forward(
262
262
# *_msa.shape: [batch_size, num_frames, 1, inner_dim]
263
263
# assert shift_msa.dtype == torch.float32
264
264
265
- # logger.info("temb sum: %s, dtype: %s", temb.float().sum().item(), temb.dtype)
266
- # logger.info("scale_msa sum: %s, dtype: %s", scale_msa.float().sum().item(), scale_msa.dtype)
267
- # logger.info("shift_msa sum: %s, dtype: %s", shift_msa.float().sum().item(), shift_msa.dtype)
268
-
269
265
# 1. Self-attention
270
266
norm_hidden_states = (self .norm1 (hidden_states ).unflatten (dim = 1 , sizes = (num_frames , frame_seqlen )) *
271
267
(1 + scale_msa ) + shift_msa ).flatten (1 , 2 )
272
- # logger.info("norm_hidden_states sum: %s, shape: %s", norm_hidden_states.float().sum().item(), norm_hidden_states.shape)
273
268
query , _ = self .to_q (norm_hidden_states )
274
269
key , _ = self .to_k (norm_hidden_states )
275
270
value , _ = self .to_v (norm_hidden_states )
0 commit comments