diff --git a/modeling/bagel/qwen2_navit.py b/modeling/bagel/qwen2_navit.py index f818565..4d873f6 100644 --- a/modeling/bagel/qwen2_navit.py +++ b/modeling/bagel/qwen2_navit.py @@ -23,10 +23,10 @@ from flash_attn import flash_attn_varlen_func from modeling.qwen2.modeling_qwen2 import ( - Qwen2Attention, - Qwen2MLP, - Qwen2PreTrainedModel, - Qwen2RMSNorm, + Qwen2Attention, + Qwen2MLP, + Qwen2PreTrainedModel, + Qwen2RMSNorm, Qwen2RotaryEmbedding, apply_rotary_pos_emb, ) @@ -253,35 +253,67 @@ def forward_train( attention_mask: List[torch.Tensor], packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor], ): - packed_query_states = self.q_proj(packed_sequence).view(-1, self.num_heads, self.head_dim) - packed_key_states = self.k_proj(packed_sequence).view(-1, self.num_key_value_heads, self.head_dim) - packed_value_states = self.v_proj(packed_sequence).view(-1, self.num_key_value_heads, self.head_dim) + packed_query_states = self.q_proj(packed_sequence).view( + -1, self.num_heads, self.head_dim + ) + packed_key_states = self.k_proj(packed_sequence).view( + -1, self.num_key_value_heads, self.head_dim + ) + packed_value_states = self.v_proj(packed_sequence).view( + -1, self.num_key_value_heads, self.head_dim + ) packed_query_states = self.q_norm(packed_query_states) packed_key_states = self.k_norm(packed_key_states) packed_cos, packed_sin = packed_position_embeddings packed_query_states, packed_key_states = apply_rotary_pos_emb( - packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1 + packed_query_states, + packed_key_states, + packed_cos, + packed_sin, + unsqueeze_dim=1, ) if isinstance(attention_mask, List): - packed_key_states = packed_key_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1) - packed_key_states = packed_key_states.reshape(-1, self.num_heads, self.head_dim) - packed_value_states = packed_value_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1) - packed_value_states = packed_value_states.reshape(-1, self.num_heads, self.head_dim) - - unpacked_query_states = packed_query_states.transpose(0, 1).split(sample_lens, dim=1) - unpacked_key_states = packed_key_states.transpose(0, 1).split(sample_lens, dim=1) - unpacked_value_states = packed_value_states.transpose(0, 1).split(sample_lens, dim=1) + packed_key_states = packed_key_states[:, :, None, :].repeat( + 1, 1, self.num_key_value_groups, 1 + ) + packed_key_states = packed_key_states.reshape( + -1, self.num_heads, self.head_dim + ) + packed_value_states = packed_value_states[:, :, None, :].repeat( + 1, 1, self.num_key_value_groups, 1 + ) + packed_value_states = packed_value_states.reshape( + -1, self.num_heads, self.head_dim + ) + + unpacked_query_states = packed_query_states.transpose(0, 1).split( + sample_lens, dim=1 + ) + unpacked_key_states = packed_key_states.transpose(0, 1).split( + sample_lens, dim=1 + ) + unpacked_value_states = packed_value_states.transpose(0, 1).split( + sample_lens, dim=1 + ) upacked_attn_output = [] - for query_states, key_states, value_states, attention_mask_per_sample in zip( - unpacked_query_states, unpacked_key_states, unpacked_value_states, attention_mask + for ( + query_states, + key_states, + value_states, + attention_mask_per_sample, + ) in zip( + unpacked_query_states, + unpacked_key_states, + unpacked_value_states, + attention_mask, ): with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): attn_output = scaled_dot_product_attention( - query_states.to(torch.bfloat16).unsqueeze(0), - key_states.to(torch.bfloat16).unsqueeze(0), + query_states.to(torch.bfloat16).unsqueeze(0), + key_states.to(torch.bfloat16).unsqueeze(0), value_states.to(torch.bfloat16).unsqueeze(0), attention_mask_per_sample.to(torch.bfloat16).unsqueeze(0), ) @@ -289,20 +321,28 @@ def forward_train( packed_attn_output = torch.cat(upacked_attn_output, dim=1) else: pad_size = sum(sample_lens) - packed_query_states.shape[0] - packed_query_states = pad_sequence(packed_query_states.permute(1, 0, 2), pad_size) - packed_key_states = pad_sequence(packed_key_states.permute(1, 0, 2), pad_size) - packed_value_states = pad_sequence(packed_value_states.permute(1, 0, 2), pad_size) + packed_query_states = pad_sequence( + packed_query_states.permute(1, 0, 2), pad_size + ) + packed_key_states = pad_sequence( + packed_key_states.permute(1, 0, 2), pad_size + ) + packed_value_states = pad_sequence( + packed_value_states.permute(1, 0, 2), pad_size + ) packed_attn_output = flex_attention( - packed_query_states.unsqueeze(0), - packed_key_states.unsqueeze(0), - packed_value_states.unsqueeze(0), + packed_query_states.unsqueeze(0), + packed_key_states.unsqueeze(0), + packed_value_states.unsqueeze(0), enable_gqa=True, block_mask=attention_mask, ) end_index = packed_attn_output.shape[2] - pad_size packed_attn_output = packed_attn_output[0, :, :end_index, :] - packed_attn_output = packed_attn_output.transpose(0, 1).reshape(-1, self.hidden_size) + packed_attn_output = packed_attn_output.transpose(0, 1).reshape( + -1, self.hidden_size + ) packed_attn_output = self.o_proj(packed_attn_output) return packed_attn_output @@ -319,29 +359,46 @@ def forward_inference( update_past_key_values=True, is_causal=True, ): - packed_query_states = self.q_proj(packed_query_sequence).view(-1, self.num_heads, self.head_dim) - packed_key_states = self.k_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim) - packed_value_states = self.v_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim) + packed_query_states = self.q_proj(packed_query_sequence).view( + -1, self.num_heads, self.head_dim + ) + packed_key_states = self.k_proj(packed_query_sequence).view( + -1, self.num_key_value_heads, self.head_dim + ) + packed_value_states = self.v_proj(packed_query_sequence).view( + -1, self.num_key_value_heads, self.head_dim + ) packed_query_states = self.q_norm(packed_query_states) packed_key_states = self.k_norm(packed_key_states) packed_cos, packed_sin = packed_query_position_embeddings packed_query_states, packed_key_states = apply_rotary_pos_emb( - packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1 + packed_query_states, + packed_key_states, + packed_cos, + packed_sin, + unsqueeze_dim=1, ) packed_query_states = packed_query_states.to(torch.bfloat16) packed_key_states = packed_key_states.to(torch.bfloat16) packed_value_states = packed_value_states.to(torch.bfloat16) - if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None: + if ( + past_key_values is not None + and past_key_values.key_cache[self.layer_idx] is not None + ): past_key_states = past_key_values.key_cache[self.layer_idx] past_value_states = past_key_values.value_cache[self.layer_idx] seqlens = sum(query_lens) + sum(key_values_lens) - merged_key_states = past_key_states.new_zeros((seqlens, self.num_key_value_heads, self.head_dim)) - merged_value_states = past_key_states.new_zeros((seqlens, self.num_key_value_heads, self.head_dim)) + merged_key_states = past_key_states.new_zeros( + (seqlens, self.num_key_value_heads, self.head_dim) + ) + merged_value_states = past_key_states.new_zeros( + (seqlens, self.num_key_value_heads, self.head_dim) + ) merged_key_states[packed_query_indexes] = packed_key_states merged_key_states[packed_key_value_indexes] = past_key_states merged_value_states[packed_query_indexes] = packed_value_states @@ -353,7 +410,9 @@ def forward_inference( key_values_lens = query_lens cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0)) - cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0)) + cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(key_values_lens, dim=0), (1, 0) + ) packed_attn_output = flash_attn_varlen_func( q=packed_query_states, @@ -389,10 +448,18 @@ def __init__(self, config, layer_idx: Optional[int] = None): self.q_norm_moe_gen = nn.Identity() self.k_norm_moe_gen = nn.Identity() - self.q_proj_moe_gen = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) - self.k_proj_moe_gen = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.v_proj_moe_gen = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.o_proj_moe_gen = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.q_proj_moe_gen = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=True + ) + self.k_proj_moe_gen = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True + ) + self.v_proj_moe_gen = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True + ) + self.o_proj_moe_gen = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=False + ) def forward(self, *args, **kwargs): if self.training: @@ -409,63 +476,121 @@ def forward_train( packed_und_token_indexes: torch.LongTensor, packed_gen_token_indexes: torch.LongTensor, ): - packed_query_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_heads * self.head_dim)) - packed_key_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_key_value_heads * self.head_dim)) - packed_value_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_key_value_heads * self.head_dim)) + packed_query_states = packed_sequence.new_zeros( + (packed_sequence.shape[0], self.num_heads * self.head_dim) + ) + packed_key_states = packed_sequence.new_zeros( + (packed_sequence.shape[0], self.num_key_value_heads * self.head_dim) + ) + packed_value_states = packed_sequence.new_zeros( + (packed_sequence.shape[0], self.num_key_value_heads * self.head_dim) + ) packed_sequence_und = packed_sequence[packed_und_token_indexes] packed_sequence_gen = packed_sequence[packed_gen_token_indexes] packed_query_states[packed_und_token_indexes] = self.q_proj(packed_sequence_und) - packed_query_states[packed_gen_token_indexes] = self.q_proj_moe_gen(packed_sequence_gen) + packed_query_states[packed_gen_token_indexes] = self.q_proj_moe_gen( + packed_sequence_gen + ) packed_key_states[packed_und_token_indexes] = self.k_proj(packed_sequence_und) - packed_key_states[packed_gen_token_indexes] = self.k_proj_moe_gen(packed_sequence_gen) + packed_key_states[packed_gen_token_indexes] = self.k_proj_moe_gen( + packed_sequence_gen + ) packed_value_states[packed_und_token_indexes] = self.v_proj(packed_sequence_und) - packed_value_states[packed_gen_token_indexes] = self.v_proj_moe_gen(packed_sequence_gen) + packed_value_states[packed_gen_token_indexes] = self.v_proj_moe_gen( + packed_sequence_gen + ) - packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim) - packed_key_states = packed_key_states.view(-1, self.num_key_value_heads, self.head_dim) - packed_value_states = packed_value_states.view(-1, self.num_key_value_heads, self.head_dim) + packed_query_states = packed_query_states.view( + -1, self.num_heads, self.head_dim + ) + packed_key_states = packed_key_states.view( + -1, self.num_key_value_heads, self.head_dim + ) + packed_value_states = packed_value_states.view( + -1, self.num_key_value_heads, self.head_dim + ) if self.config.freeze_und: - packed_value_states[packed_und_token_indexes] = packed_value_states[packed_und_token_indexes].detach() + packed_value_states[packed_und_token_indexes] = packed_value_states[ + packed_und_token_indexes + ].detach() packed_query_states_ = packed_query_states.new_zeros(packed_query_states.shape) packed_key_states_ = packed_key_states.new_zeros(packed_key_states.shape) - packed_query_states_[packed_und_token_indexes] = self.q_norm(packed_query_states[packed_und_token_indexes]) + packed_query_states_[packed_und_token_indexes] = self.q_norm( + packed_query_states[packed_und_token_indexes] + ) if self.config.freeze_und: - packed_query_states_[packed_und_token_indexes] = packed_query_states_[packed_und_token_indexes].detach() - packed_query_states_[packed_gen_token_indexes] = self.q_norm_moe_gen(packed_query_states[packed_gen_token_indexes]) + packed_query_states_[packed_und_token_indexes] = packed_query_states_[ + packed_und_token_indexes + ].detach() + packed_query_states_[packed_gen_token_indexes] = self.q_norm_moe_gen( + packed_query_states[packed_gen_token_indexes] + ) - packed_key_states_[packed_und_token_indexes] = self.k_norm(packed_key_states[packed_und_token_indexes]) + packed_key_states_[packed_und_token_indexes] = self.k_norm( + packed_key_states[packed_und_token_indexes] + ) if self.config.freeze_und: - packed_key_states_[packed_und_token_indexes] = packed_key_states_[packed_und_token_indexes].detach() - packed_key_states_[packed_gen_token_indexes] = self.k_norm_moe_gen(packed_key_states[packed_gen_token_indexes]) + packed_key_states_[packed_und_token_indexes] = packed_key_states_[ + packed_und_token_indexes + ].detach() + packed_key_states_[packed_gen_token_indexes] = self.k_norm_moe_gen( + packed_key_states[packed_gen_token_indexes] + ) packed_cos, packed_sin = packed_position_embeddings packed_query_states_, packed_key_states_ = apply_rotary_pos_emb( - packed_query_states_, packed_key_states_, packed_cos, packed_sin, unsqueeze_dim=1 + packed_query_states_, + packed_key_states_, + packed_cos, + packed_sin, + unsqueeze_dim=1, ) if isinstance(attention_mask, List): - packed_key_states_ = packed_key_states_[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1) - packed_key_states_ = packed_key_states_.reshape(-1, self.num_heads, self.head_dim) - packed_value_states = packed_value_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1) - packed_value_states = packed_value_states.reshape(-1, self.num_heads, self.head_dim) - - unpacked_query_states = packed_query_states_.transpose(0, 1).split(sample_lens, dim=1) - unpacked_key_states = packed_key_states_.transpose(0, 1).split(sample_lens, dim=1) - unpacked_value_states = packed_value_states.transpose(0, 1).split(sample_lens, dim=1) + packed_key_states_ = packed_key_states_[:, :, None, :].repeat( + 1, 1, self.num_key_value_groups, 1 + ) + packed_key_states_ = packed_key_states_.reshape( + -1, self.num_heads, self.head_dim + ) + packed_value_states = packed_value_states[:, :, None, :].repeat( + 1, 1, self.num_key_value_groups, 1 + ) + packed_value_states = packed_value_states.reshape( + -1, self.num_heads, self.head_dim + ) + + unpacked_query_states = packed_query_states_.transpose(0, 1).split( + sample_lens, dim=1 + ) + unpacked_key_states = packed_key_states_.transpose(0, 1).split( + sample_lens, dim=1 + ) + unpacked_value_states = packed_value_states.transpose(0, 1).split( + sample_lens, dim=1 + ) upacked_attn_output = [] - for query_states, key_states, value_states, attention_mask_per_sample in zip( - unpacked_query_states, unpacked_key_states, unpacked_value_states, attention_mask + for ( + query_states, + key_states, + value_states, + attention_mask_per_sample, + ) in zip( + unpacked_query_states, + unpacked_key_states, + unpacked_value_states, + attention_mask, ): with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): attn_output = scaled_dot_product_attention( - query_states.to(torch.bfloat16).unsqueeze(0), - key_states.to(torch.bfloat16).unsqueeze(0), + query_states.to(torch.bfloat16).unsqueeze(0), + key_states.to(torch.bfloat16).unsqueeze(0), value_states.to(torch.bfloat16).unsqueeze(0), attention_mask_per_sample.to(torch.bfloat16).unsqueeze(0), ) @@ -473,23 +598,35 @@ def forward_train( packed_attn_output = torch.cat(upacked_attn_output, dim=1) else: pad_size = sum(sample_lens) - packed_query_states.shape[0] - packed_query_states_ = pad_sequence(packed_query_states_.permute(1, 0, 2), pad_size) - packed_key_states_ = pad_sequence(packed_key_states_.permute(1, 0, 2), pad_size) - packed_value_states = pad_sequence(packed_value_states.permute(1, 0, 2), pad_size) + packed_query_states_ = pad_sequence( + packed_query_states_.permute(1, 0, 2), pad_size + ) + packed_key_states_ = pad_sequence( + packed_key_states_.permute(1, 0, 2), pad_size + ) + packed_value_states = pad_sequence( + packed_value_states.permute(1, 0, 2), pad_size + ) packed_attn_output = flex_attention( - packed_query_states_.unsqueeze(0), # 1, num_head, L, head_dim - packed_key_states_.unsqueeze(0), - packed_value_states.unsqueeze(0), + packed_query_states_.unsqueeze(0), # 1, num_head, L, head_dim + packed_key_states_.unsqueeze(0), + packed_value_states.unsqueeze(0), enable_gqa=True, block_mask=attention_mask, ) end_index = packed_attn_output.shape[2] - pad_size packed_attn_output = packed_attn_output[0, :, :end_index, :] - packed_attn_output = packed_attn_output.transpose(0, 1).reshape(-1, self.num_heads * self.head_dim) + packed_attn_output = packed_attn_output.transpose(0, 1).reshape( + -1, self.num_heads * self.head_dim + ) packed_attn_output_ = packed_attn_output.new_zeros(packed_attn_output.shape) - packed_attn_output_[packed_und_token_indexes] = self.o_proj(packed_attn_output[packed_und_token_indexes]) - packed_attn_output_[packed_gen_token_indexes] = self.o_proj_moe_gen(packed_attn_output[packed_gen_token_indexes]) + packed_attn_output_[packed_und_token_indexes] = self.o_proj( + packed_attn_output[packed_und_token_indexes] + ) + packed_attn_output_[packed_gen_token_indexes] = self.o_proj_moe_gen( + packed_attn_output[packed_gen_token_indexes] + ) return packed_attn_output_ @@ -508,58 +645,113 @@ def forward_inference( packed_vae_token_indexes=None, packed_text_indexes=None, ): - if mode == 'und': - packed_query_states = self.q_proj(packed_query_sequence).view(-1, self.num_heads, self.head_dim) - packed_key_states = self.k_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim) - packed_value_states = self.v_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim) + if mode == "und": + packed_query_states = self.q_proj(packed_query_sequence).view( + -1, self.num_heads, self.head_dim + ) + packed_key_states = self.k_proj(packed_query_sequence).view( + -1, self.num_key_value_heads, self.head_dim + ) + packed_value_states = self.v_proj(packed_query_sequence).view( + -1, self.num_key_value_heads, self.head_dim + ) packed_query_states = self.q_norm(packed_query_states) packed_key_states = self.k_norm(packed_key_states) - elif mode == 'gen': + elif mode == "gen": packed_query_sequence = packed_query_sequence.to(torch.bfloat16) - packed_query_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_heads * self.head_dim)) - packed_key_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim)) - packed_value_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim)) + packed_query_states = packed_query_sequence.new_zeros( + (packed_query_sequence.shape[0], self.num_heads * self.head_dim) + ) + packed_key_states = packed_query_sequence.new_zeros( + ( + packed_query_sequence.shape[0], + self.num_key_value_heads * self.head_dim, + ) + ) + packed_value_states = packed_query_sequence.new_zeros( + ( + packed_query_sequence.shape[0], + self.num_key_value_heads * self.head_dim, + ) + ) packed_text_query_sequence = packed_query_sequence[packed_text_indexes] packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes] - packed_query_states[packed_text_indexes] = self.q_proj(packed_text_query_sequence) - packed_query_states[packed_vae_token_indexes] = self.q_proj_moe_gen(packed_vae_query_sequence) + packed_query_states[packed_text_indexes] = self.q_proj( + packed_text_query_sequence + ) + packed_query_states[packed_vae_token_indexes] = self.q_proj_moe_gen( + packed_vae_query_sequence + ) - packed_key_states[packed_text_indexes] = self.k_proj(packed_text_query_sequence) - packed_key_states[packed_vae_token_indexes] = self.k_proj_moe_gen(packed_vae_query_sequence) + packed_key_states[packed_text_indexes] = self.k_proj( + packed_text_query_sequence + ) + packed_key_states[packed_vae_token_indexes] = self.k_proj_moe_gen( + packed_vae_query_sequence + ) - packed_value_states[packed_text_indexes] = self.v_proj(packed_text_query_sequence) - packed_value_states[packed_vae_token_indexes] = self.v_proj_moe_gen(packed_vae_query_sequence) + packed_value_states[packed_text_indexes] = self.v_proj( + packed_text_query_sequence + ) + packed_value_states[packed_vae_token_indexes] = self.v_proj_moe_gen( + packed_vae_query_sequence + ) - packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim) - packed_key_states = packed_key_states.view(-1, self.num_key_value_heads, self.head_dim) - packed_value_states = packed_value_states.view(-1, self.num_key_value_heads, self.head_dim) + packed_query_states = packed_query_states.view( + -1, self.num_heads, self.head_dim + ) + packed_key_states = packed_key_states.view( + -1, self.num_key_value_heads, self.head_dim + ) + packed_value_states = packed_value_states.view( + -1, self.num_key_value_heads, self.head_dim + ) packed_query_states = packed_query_states.to(torch.float32) - packed_query_states[packed_text_indexes] = self.q_norm(packed_query_states[packed_text_indexes]) - packed_query_states[packed_vae_token_indexes] = self.q_norm_moe_gen(packed_query_states[packed_vae_token_indexes]) + packed_query_states[packed_text_indexes] = self.q_norm( + packed_query_states[packed_text_indexes] + ) + packed_query_states[packed_vae_token_indexes] = self.q_norm_moe_gen( + packed_query_states[packed_vae_token_indexes] + ) packed_key_states = packed_key_states.to(torch.float32) - packed_key_states[packed_text_indexes] = self.k_norm(packed_key_states[packed_text_indexes]) - packed_key_states[packed_vae_token_indexes] = self.k_norm_moe_gen(packed_key_states[packed_vae_token_indexes]) + packed_key_states[packed_text_indexes] = self.k_norm( + packed_key_states[packed_text_indexes] + ) + packed_key_states[packed_vae_token_indexes] = self.k_norm_moe_gen( + packed_key_states[packed_vae_token_indexes] + ) packed_cos, packed_sin = packed_query_position_embeddings packed_query_states, packed_key_states = apply_rotary_pos_emb( - packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1 + packed_query_states, + packed_key_states, + packed_cos, + packed_sin, + unsqueeze_dim=1, ) packed_query_states = packed_query_states.to(torch.bfloat16) packed_key_states = packed_key_states.to(torch.bfloat16) packed_value_states = packed_value_states.to(torch.bfloat16) - if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None: + if ( + past_key_values is not None + and past_key_values.key_cache[self.layer_idx] is not None + ): past_key_states = past_key_values.key_cache[self.layer_idx] past_value_states = past_key_values.value_cache[self.layer_idx] seqlens = sum(query_lens) + sum(key_values_lens) - merged_key_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim]) - merged_value_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim]) + merged_key_states = past_key_states.new_zeros( + size=[seqlens, self.num_key_value_heads, self.head_dim] + ) + merged_value_states = past_key_states.new_zeros( + size=[seqlens, self.num_key_value_heads, self.head_dim] + ) merged_key_states[packed_query_indexes] = packed_key_states merged_key_states[packed_key_value_indexes] = past_key_states merged_value_states[packed_query_indexes] = packed_value_states @@ -571,7 +763,9 @@ def forward_inference( key_values_lens = query_lens cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0)) - cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0)) + cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(key_values_lens, dim=0), (1, 0) + ) packed_attn_output = flash_attn_varlen_func( q=packed_query_states, @@ -584,11 +778,15 @@ def forward_inference( causal=is_causal, ) packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size) - if mode == 'und': + if mode == "und": packed_attn_output = self.o_proj(packed_attn_output) - elif mode == 'gen': - packed_attn_output[packed_text_indexes] = self.o_proj(packed_attn_output[packed_text_indexes]) - packed_attn_output[packed_vae_token_indexes] = self.o_proj_moe_gen(packed_attn_output[packed_vae_token_indexes]) + elif mode == "gen": + packed_attn_output[packed_text_indexes] = self.o_proj( + packed_attn_output[packed_text_indexes] + ) + packed_attn_output[packed_vae_token_indexes] = self.o_proj_moe_gen( + packed_attn_output[packed_vae_token_indexes] + ) if update_past_key_values: past_key_values.key_cache[self.layer_idx] = merged_key_states @@ -606,7 +804,9 @@ def __init__(self, config, layer_idx: Optional[int] = None): self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward(self, *args, **kwargs): if self.training: @@ -683,9 +883,9 @@ def forward_inference( class Qwen2MoTDecoderLayer(nn.Module): def __init__( - self, - config, - layer_idx: Optional[int] = None, + self, + config, + layer_idx: Optional[int] = None, attn_module: Optional[Qwen2Attention] = PackedAttentionMoT, ): super().__init__() @@ -697,9 +897,15 @@ def __init__( self.mlp = Qwen2MLP(config) self.mlp_moe_gen = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.input_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm_moe_gen = Qwen2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = Qwen2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm_moe_gen = Qwen2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward(self, *args, **kwargs): if self.training: @@ -719,8 +925,12 @@ def forward_train( residual = packed_sequence packed_sequence_ = packed_sequence.new_zeros(packed_sequence.shape) - packed_sequence_[packed_und_token_indexes] = self.input_layernorm(packed_sequence[packed_und_token_indexes]) - packed_sequence_[packed_gen_token_indexes] = self.input_layernorm_moe_gen(packed_sequence[packed_gen_token_indexes]) + packed_sequence_[packed_und_token_indexes] = self.input_layernorm( + packed_sequence[packed_und_token_indexes] + ) + packed_sequence_[packed_gen_token_indexes] = self.input_layernorm_moe_gen( + packed_sequence[packed_gen_token_indexes] + ) # Self Attention packed_sequence_ = self.self_attn( @@ -732,7 +942,9 @@ def forward_train( packed_gen_token_indexes=packed_gen_token_indexes, ) if self.freeze_und: - packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach() + packed_sequence_[packed_und_token_indexes] = packed_sequence_[ + packed_und_token_indexes + ].detach() packed_sequence = residual + packed_sequence_ # Fully Connected @@ -742,10 +954,14 @@ def forward_train( self.post_attention_layernorm(packed_sequence[packed_und_token_indexes]) ) if self.freeze_und: - packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach() - + packed_sequence_[packed_und_token_indexes] = packed_sequence_[ + packed_und_token_indexes + ].detach() + packed_sequence_[packed_gen_token_indexes] = self.mlp_moe_gen( - self.post_attention_layernorm_moe_gen(packed_sequence[packed_gen_token_indexes]) + self.post_attention_layernorm_moe_gen( + packed_sequence[packed_gen_token_indexes] + ) ) packed_sequence = residual + packed_sequence_ @@ -772,8 +988,14 @@ def forward_inference( packed_query_sequence = self.input_layernorm(packed_query_sequence) elif mode == "gen": packed_query_sequence_ = torch.zeros_like(packed_query_sequence) - packed_query_sequence_[packed_text_indexes] = self.input_layernorm(packed_query_sequence[packed_text_indexes]) - packed_query_sequence_[packed_vae_token_indexes] = self.input_layernorm_moe_gen(packed_query_sequence[packed_vae_token_indexes]) + packed_query_sequence_[packed_text_indexes] = self.input_layernorm( + packed_query_sequence[packed_text_indexes] + ) + packed_query_sequence_[packed_vae_token_indexes] = ( + self.input_layernorm_moe_gen( + packed_query_sequence[packed_vae_token_indexes] + ) + ) packed_query_sequence = packed_query_sequence_ # Self Attention @@ -801,12 +1023,22 @@ def forward_inference( elif mode == "gen": packed_text_query_sequence = packed_query_sequence[packed_text_indexes] packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes] - packed_text_query_sequence = self.post_attention_layernorm(packed_text_query_sequence).to(torch.bfloat16) - packed_vae_query_sequence = self.post_attention_layernorm_moe_gen(packed_vae_query_sequence).to(torch.bfloat16) - - packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to(torch.bfloat16) - packed_query_sequence_[packed_text_indexes] = self.mlp(packed_text_query_sequence) - packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen(packed_vae_query_sequence) + packed_text_query_sequence = self.post_attention_layernorm( + packed_text_query_sequence + ).to(torch.bfloat16) + packed_vae_query_sequence = self.post_attention_layernorm_moe_gen( + packed_vae_query_sequence + ).to(torch.bfloat16) + + packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to( + torch.bfloat16 + ) + packed_query_sequence_[packed_text_indexes] = self.mlp( + packed_text_query_sequence + ) + packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen( + packed_vae_query_sequence + ) packed_query_sequence = packed_query_sequence_ packed_query_sequence = residual + packed_query_sequence @@ -823,7 +1055,9 @@ def __init__(self, config, layer_idx: Optional[int] = None): self.mlp = Qwen2MLP(config) self.mlp_moe_gen = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward(self, *args, **kwargs): if self.training: @@ -859,7 +1093,9 @@ def forward_train( packed_sequence_new = packed_sequence.new_zeros(packed_sequence.shape) packed_sequence_und = self.mlp(packed_sequence[packed_und_token_indexes]) - packed_sequence_gen = self.mlp_moe_gen(packed_sequence[packed_gen_token_indexes]) + packed_sequence_gen = self.mlp_moe_gen( + packed_sequence[packed_gen_token_indexes] + ) packed_sequence_new[packed_und_token_indexes] = packed_sequence_und packed_sequence_new[packed_gen_token_indexes] = packed_sequence_gen @@ -906,9 +1142,15 @@ def forward_inference( if mode == "und": packed_query_sequence = self.mlp(packed_query_sequence) elif mode == "gen": - packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to(torch.bfloat16) - packed_query_sequence_[packed_text_indexes] = self.mlp(packed_query_sequence[packed_text_indexes]) - packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen(packed_query_sequence[packed_vae_token_indexes]) + packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to( + torch.bfloat16 + ) + packed_query_sequence_[packed_text_indexes] = self.mlp( + packed_query_sequence[packed_text_indexes] + ) + packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen( + packed_query_sequence[packed_vae_token_indexes] + ) packed_query_sequence = packed_query_sequence_ packed_query_sequence = residual + packed_query_sequence @@ -918,7 +1160,9 @@ def forward_inference( Decoder_layer_dict = { "Qwen2DecoderLayer": Qwen2DecoderLayer, "Qwen2MoEDecoderLayer": Qwen2MoEDecoderLayer, - "Qwen2MoTDecoderLayer": partial(Qwen2MoTDecoderLayer, attn_module=PackedAttentionMoT), + "Qwen2MoTDecoderLayer": partial( + Qwen2MoTDecoderLayer, attn_module=PackedAttentionMoT + ), } @@ -927,17 +1171,24 @@ def __init__(self, config): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.use_moe = 'Mo' in config.layer_module + self.use_moe = "Mo" in config.layer_module - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) layer_module = Decoder_layer_dict[config.layer_module] self.layers = nn.ModuleList( - [layer_module(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + [ + layer_module(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] ) self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if self.use_moe: - self.norm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm_moe_gen = Qwen2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) self.rotary_emb = Qwen2RotaryEmbedding(config=config) # Initialize weights and apply final processing @@ -960,7 +1211,9 @@ def forward_train( ) -> torch.Tensor: if self.config.freeze_und: - packed_sequence[packed_und_token_indexes] = packed_sequence[packed_und_token_indexes].detach() + packed_sequence[packed_und_token_indexes] = packed_sequence[ + packed_und_token_indexes + ].detach() # create position embeddings to be shared across the decoder layers cos, sin = self.rotary_emb(packed_sequence, packed_position_ids.unsqueeze(0)) @@ -984,15 +1237,21 @@ def forward_train( sample_lens=sample_lens, attention_mask=attention_mask, packed_position_embeddings=packed_position_embeddings, - **extra_inputs + **extra_inputs, ) if self.use_moe: packed_sequence_ = torch.zeros_like(packed_sequence) - packed_sequence_[packed_und_token_indexes] = self.norm(packed_sequence[packed_und_token_indexes]) + packed_sequence_[packed_und_token_indexes] = self.norm( + packed_sequence[packed_und_token_indexes] + ) if self.config.freeze_und: - packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach() - packed_sequence_[packed_gen_token_indexes] = self.norm_moe_gen(packed_sequence[packed_gen_token_indexes]) + packed_sequence_[packed_und_token_indexes] = packed_sequence_[ + packed_und_token_indexes + ].detach() + packed_sequence_[packed_gen_token_indexes] = self.norm_moe_gen( + packed_sequence[packed_gen_token_indexes] + ) return packed_sequence_ else: return self.norm(packed_sequence) @@ -1014,7 +1273,9 @@ def forward_inference( ) -> BaseNavitOutputWithPast: # create position embeddings to be shared across the decoder layers - cos, sin = self.rotary_emb(packed_query_sequence, packed_query_position_ids.unsqueeze(0)) + cos, sin = self.rotary_emb( + packed_query_sequence, packed_query_position_ids.unsqueeze(0) + ) cos = cos.squeeze(0) sin = sin.squeeze(0) packed_query_position_embeddings = (cos, sin) @@ -1022,7 +1283,7 @@ def forward_inference( extra_inputs = {} if self.use_moe: extra_inputs.update(mode=mode) - if mode == 'gen': + if mode == "gen": assert packed_vae_token_indexes is not None assert packed_text_indexes is not None extra_inputs.update( @@ -1049,8 +1310,12 @@ def forward_inference( packed_query_sequence = self.norm(packed_query_sequence) elif mode == "gen": packed_query_sequence_ = torch.zeros_like(packed_query_sequence) - packed_query_sequence_[packed_text_indexes] = self.norm(packed_query_sequence[packed_text_indexes]) - packed_query_sequence_[packed_vae_token_indexes] = self.norm_moe_gen(packed_query_sequence[packed_vae_token_indexes]) + packed_query_sequence_[packed_text_indexes] = self.norm( + packed_query_sequence[packed_text_indexes] + ) + packed_query_sequence_[packed_vae_token_indexes] = self.norm_moe_gen( + packed_query_sequence[packed_vae_token_indexes] + ) packed_query_sequence = packed_query_sequence_ else: packed_query_sequence = self.norm(packed_query_sequence) diff --git a/node.py b/node.py index f5e6ee2..81d2ec7 100644 --- a/node.py +++ b/node.py @@ -7,6 +7,11 @@ from typing import Dict, Tuple, Optional, Any, Union from PIL import Image from folder_paths import folder_names_and_paths +from accelerate import ( + infer_auto_device_map, + load_checkpoint_and_dispatch, + init_empty_weights, +) # Add current directory to path current_dir = os.path.dirname(os.path.abspath(__file__)) @@ -14,11 +19,6 @@ # Import BAGEL related modules try: - from accelerate import ( - infer_auto_device_map, - load_checkpoint_and_dispatch, - init_empty_weights, - ) from data.data_utils import add_special_tokens, pil_img2rgb from data.transforms import ImageTransform from inferencer import InterleaveInferencer @@ -35,7 +35,9 @@ from modeling.qwen2 import Qwen2Tokenizer except ImportError as e: print(f"Error importing BAGEL modules: {e}") - print("Please ensure BAGEL model files are properly installed.") + print( + "Please ensure BAGEL model files are properly installed and accessible in the Python path." + ) # Register the BAGEL model folder models_dir = os.path.join(os.getcwd(), "models") @@ -256,9 +258,11 @@ def load_model(self, model_path: str) -> Tuple[Dict[str, Any]]: ) # Load configuration files - llm_config = Qwen2Config.from_json_file( - os.path.join(local_model_dir, "llm_config.json") - ) + try: + llm_config = Qwen2Config.from_json_file("path/to/llm_config.json") + except Exception as e: + print(f"Error loading Qwen2Config: {e}") + raise llm_config.qk_norm = True llm_config.tie_word_embeddings = False llm_config.layer_module = "Qwen2MoTDecoderLayer"