Skip to content

Why is_causal is not used in flash_attention_forward ? #39554

@lucaswychan

Description

@lucaswychan

I want to perform bidirectional attention in the Qwen3 model to train an embedding model, so I passed is_causal=False in the model forward (I manually added is_causal arguments in all forward method such as Qwen3Model and Qwen3Attention inmodeling_qwen3.py):

class Qwen3Attention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""
        ...

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        is_causal: Optional[bool] = True,   # I add is_causal here
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
        ...

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            sliding_window=self.sliding_window,  # diff with Llama
            is_causal=is_causal, # and is_causal from the argument is passed to the attention_interface (e.g. `flash_attention_2`, `sdpa_attention_forward`)
            **kwargs,
        )

I can successfully change the causality of the attention in sdpa_attention_forward. However, I realized that it does not change the causality in the attention in flash_attention_forward. After diving into the implementation of flash_attention_forward, I found the reason in flash_attention_forward located at transformers/integrations/flash_attention.py:

def flash_attention_forward(
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    dropout: float = 0.0,
    scaling: Optional[float] = None,
    sliding_window: Optional[int] = None,
    softcap: Optional[float] = None,
    **kwargs,
) -> tuple[torch.Tensor, None]:
    ...

    # FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
    kwargs.pop("is_causal", None)

    attn_output = _flash_attention_forward(
        query,
        key,
        value,
        attention_mask,
        query_length=seq_len,
        is_causal=module.is_causal,  # here module is `Qwen3Attention`
        dropout=dropout,
        softmax_scale=scaling,
        sliding_window=sliding_window,
        softcap=softcap,
        use_top_left_mask=_use_top_left_mask,
        target_dtype=target_dtype,
        attn_implementation=module.config._attn_implementation,
        **kwargs,
    )

As you can see, the is_causal argument is popped, and the is_causal of Qwen3Attention is used as the argument. Note that Qwen3Attention.is_causal is never changed, and its default value is True, so the is_causal argument passed into _flash_attention_forward will always be True regardless of any change.

After I add a line of code to alter the Qwen3Attention.is_causal, i.e. self.is_causal = is_causal before passing the arguments into attention_interface, I can change the causality of flash_attention_forward. So I would like to know if it is a feature or a bug? Thank you!!

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions