Skip to content

Support for relative position embedding in the Encoder-Decoder attention (Context Attention) #2599

@frankang

Description

@frankang

I was reading the code and noticed that the initialization of the MultiheadAttention module in the decoder layer (context attention) doesn't seem to include the relative position embedding arguments, which are present in the self-attention part.

Since the relative position embedding typically replaces the standard sinusoidal positional encoding, I'm concerned that the absence of relative position embeddings in the encoder-decoder attention might cause discrepancies, as the decoder's hidden state would lack necessary positional information for understanding word relationships across sequences.

Could this affect the accuracy of the encoder-decoder attention, since the hidden state from the decoder may struggle to recognize adjacent words by their positions?
Are there any experiments or results exploring the impact of using relative position embeddings in the context attention?
Is there any observed synergistic effect between using standard fixed positional embeddings (like the sinusoidal one) and relative positional embeddings together?
Any clarification or experimental insights on this would be greatly appreciated!

FYI, codes on the context attention and self-attention:

self.context_attn = MultiHeadedAttention(
heads,
d_model,
dropout=attention_dropout,
attn_type="context",
self_attn_type=self.self_attn_type,
add_qkvbias=add_qkvbias,
num_kv=num_kv,
use_ckpting=use_ckpting,
parallel_gpu=parallel_gpu,
)

# encoder-decoder attention
        self.context_attn = MultiHeadedAttention(
            heads,
            d_model,
            dropout=attention_dropout,
            attn_type="context",
            self_attn_type=self.self_attn_type,
            add_qkvbias=add_qkvbias,
            num_kv=num_kv,
            use_ckpting=use_ckpting,
            parallel_gpu=parallel_gpu,
        )
# self attention
        if self_attn_type in ["scaled-dot", "scaled-dot-flash"]:
            self.self_attn = MultiHeadedAttention(
                heads,
                d_model,
                dropout=attention_dropout,
                max_relative_positions=max_relative_positions,
                relative_positions_buckets=relative_positions_buckets,
                rotary_interleave=rotary_interleave,
                rotary_theta=rotary_theta,
                rotary_dim=rotary_dim,
                attn_type="self",
                self_attn_type=self_attn_type,
                add_qkvbias=add_qkvbias,
                num_kv=num_kv,
                use_ckpting=use_ckpting,
                parallel_gpu=parallel_gpu,
            )

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions