-
Notifications
You must be signed in to change notification settings - Fork 29.8k
Description
I want to perform bidirectional attention in the Qwen3 model to train an embedding model, so I passed is_causal=False
in the model forward
(I manually added is_causal
arguments in all forward
method such as Qwen3Model
and Qwen3Attention
inmodeling_qwen3.py
):
class Qwen3Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
...
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
is_causal: Optional[bool] = True, # I add is_causal here
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
...
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window, # diff with Llama
is_causal=is_causal, # and is_causal from the argument is passed to the attention_interface (e.g. `flash_attention_2`, `sdpa_attention_forward`)
**kwargs,
)
I can successfully change the causality of the attention in sdpa_attention_forward
. However, I realized that it does not change the causality in the attention in flash_attention_forward
. After diving into the implementation of flash_attention_forward
, I found the reason in flash_attention_forward
located at transformers/integrations/flash_attention.py
:
def flash_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
dropout: float = 0.0,
scaling: Optional[float] = None,
sliding_window: Optional[int] = None,
softcap: Optional[float] = None,
**kwargs,
) -> tuple[torch.Tensor, None]:
...
# FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
kwargs.pop("is_causal", None)
attn_output = _flash_attention_forward(
query,
key,
value,
attention_mask,
query_length=seq_len,
is_causal=module.is_causal, # here module is `Qwen3Attention`
dropout=dropout,
softmax_scale=scaling,
sliding_window=sliding_window,
softcap=softcap,
use_top_left_mask=_use_top_left_mask,
target_dtype=target_dtype,
attn_implementation=module.config._attn_implementation,
**kwargs,
)
As you can see, the is_causal
argument is popped, and the is_causal
of Qwen3Attention
is used as the argument. Note that Qwen3Attention.is_causal
is never changed, and its default value is True
, so the is_causal
argument passed into _flash_attention_forward
will always be True
regardless of any change.
After I add a line of code to alter the Qwen3Attention.is_causal
, i.e. self.is_causal = is_causal
before passing the arguments into attention_interface
, I can change the causality of flash_attention_forward
. So I would like to know if it is a feature or a bug? Thank you!!