-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Fix XAttention reference code for better alignment with the original #32451
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix XAttention reference code for better alignment with the original #32451
Conversation
6373407
to
6a73e15
Compare
1b1ebc5
to
586f3f2
Compare
if ((k_block_idx == | ||
(blocked_attention_scores_shape[2] - blocked_attention_scores_shape[1] + q_block_idx)) || | ||
k_block_idx == 0) { | ||
// We preserve first-in-row and diagonal blocks always, and include their score in the | ||
// cumulative sum. The target for the rest of the blocks in row is to fill up the | ||
// rest of the attention mass fraction so that with the diagonal and first blocks they | ||
// comprise the `threshold` portion of the entire causal attention mass in this row | ||
retval[head_idx].insert({q_block_idx, k_block_idx}); | ||
cumsum += current_score; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@luo-cheng2021 can you confirm that the GPU uses the same logic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Confirmed, this keep same logic with GPU kernel, thanks.
The original PR adding XAttention reference code was not completely aligned with the original code.
Fixed were:
threshold
portion of the total K-row block sum and always include diagonal/first-in-row blocks.Added more E2E tests that compare against the original code results, which are fixed in the test code as reference. Note that the original non-Triton code contained bugs, which had to be fixed in order to obtain correct references (see vshampor/x-attention@fdc5c34).