Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 72 additions & 24 deletions src/core/reference/include/openvino/reference/xattention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <memory>
#include <queue>

#include "openvino/core/type/element_type_traits.hpp"
#include "openvino/reference/divide.hpp"
#include "openvino/reference/matmul.hpp"
#include "openvino/reference/softmax.hpp"
Expand All @@ -28,10 +29,11 @@ template <typename T>
class XAttentionBlockSelector {
public:
/** @param threshold Defines a threshold for introduced block sparsity - XAttention attempts to preserve the
* smallest subset of attention score matrix blocks so that the ratio of the attention score sum to the total sum of
* attention score matrix elements is no less than `threshold`. In other words, `threshold` defines a fraction of
* the attention score mass which is to be preserved by most "important" blocks. Valid range is 0.0-1.0, with 0.0
* corresponding to 0% of the blocks retained, and 1.0 corresponding to 100% of the blocks retained.
* smallest subset of causal non-diagonal attention score matrix blocks so that the ratio of their attention score
* sum to the total sum of causal non-diagonal attention score matrix blocks in the same K-row is no less than
* `threshold`. In other words, `threshold` defines a fraction of the block non-diagonal causal attention score mass
* which is to be preserved by most "important" blocks. Valid range is 0.0-1.0, with 0.0 corresponding to 0% of the
* non-diagonal causal blocks retained, and 1.0 corresponding to 100% of the non-diagonal causal blocks retained.
* @param block_size The size of blocks into which the attention score matrix [num_heads, query_token_dimension,
* key_token_dimension] will be subdivided for purposes of determining the subset of the most important blocks
* according to `threshold`. This subdivision occurs on query and key dimensions of the attention score matrix with
Expand Down Expand Up @@ -76,17 +78,17 @@ class XAttentionBlockSelector {
OPENVINO_ASSERT(input_shape[1] / m_stride == out_shape[1]);
OPENVINO_ASSERT(input_shape[2] * m_stride == out_shape[2]);

size_t num_stride_steps = input_shape[1] / m_stride;
size_t num_elts_in_strided_slice = input_shape[1] / m_stride;
for (size_t head_idx = 0; head_idx < input_shape[0]; head_idx++) {
size_t head_offset = head_idx * input_shape[1] * input_shape[2];
for (size_t slice_idx = 0; slice_idx < m_stride; slice_idx++) {
for (size_t stride_idx = 0; stride_idx < num_stride_steps; stride_idx++) {
for (size_t stride_num = 0; stride_num < m_stride; stride_num++) {
for (size_t intra_slice_step = 0; intra_slice_step < num_elts_in_strided_slice; intra_slice_step++) {
size_t input_offset = head_offset;
size_t output_offset = head_offset + stride_idx * out_shape[2] + slice_idx * input_shape[2];
size_t output_offset = head_offset + intra_slice_step * out_shape[2] + stride_num * input_shape[2];
if (is_antidiagonal) {
input_offset += (input_shape[1] - 1 - slice_idx - stride_idx * m_stride) * input_shape[2];
input_offset += (m_stride - 1 - stride_num + intra_slice_step * m_stride) * input_shape[2];
} else {
input_offset += (slice_idx + stride_idx * m_stride) * input_shape[2];
input_offset += (stride_num + intra_slice_step * m_stride) * input_shape[2];
}
std::memcpy(output_data + output_offset, input_data + input_offset, input_shape[2] * sizeof(T));
}
Expand Down Expand Up @@ -142,6 +144,28 @@ class XAttentionBlockSelector {
}
}

/** Applies the softmax causal mask along the last two dimensions of the rank-3 input tensor in-place.
* @param in_out_data Pointer to the softmax input values (logits).
* @param in_out_shape Shape of the input tensor. Expected shape is [num_heads, num_query_tokens /
* stride, num_key_tokens / stride].
*/
void apply_causal_mask_(T* in_out_data, const Shape& in_out_shape) {
OPENVINO_ASSERT(in_out_shape.size() == 3);
OPENVINO_ASSERT(in_out_shape[1] <= in_out_shape[2]);
size_t query_dim = in_out_shape[1];
size_t key_dim = in_out_shape[2];
for (size_t head_idx = 0; head_idx < in_out_shape[0]; head_idx++) {
size_t head_offset = head_idx * in_out_shape[1] * in_out_shape[2];
for (size_t query_dim_idx = 0; query_dim_idx < in_out_shape[1]; query_dim_idx++) {
size_t query_dim_offset = query_dim_idx * in_out_shape[2];
for (size_t key_dim_idx = key_dim - query_dim + query_dim_idx + 1; key_dim_idx < key_dim;
key_dim_idx++) {
in_out_data[head_offset + query_dim_offset + key_dim_idx] = -INFINITY;
}
}
}
}

/** Performs a softmax operation on the last dimension of the rank-3 input tensor.
* @param reshaped_qk_product_data Pointer to the reshaped query-key product input (attention logits pre-softmax).
* @param reshaped_qk_product_shape Shape of the input tensor. Expected shape is [num_heads, num_query_tokens /
Expand Down Expand Up @@ -203,9 +227,12 @@ class XAttentionBlockSelector {
}
}

/** Selects the elements of the input tensor along the last two dimensions, independently along the first dimension,
* so that the elements constitute a smallest subset constituting a sum portion no less than `threshold` of the
* total element sum.
/** Selects the elements of the input tensor along the last dimension, independently along the first two dimensions,
* so that the selected elements constitute a smallest subset amounting to a sum portion no less than `threshold`
* of the total "causal" element sum. "Causal" is understood in the sense of the last two dimensions being
* treated as the query-block and key-block dimensions in the context of attention matrix scores. The
* first-in-row, the "diagonal" and "non-causal" elements are disregarded when calculating the sum. "Non-causal"
* elements are never preserved, while "diagonal" and first-in-row elements are always preserved.
* @param blocked_scores_data Pointer to the blocked score input.
* @param blocked_attention_scores_shape Shape of the blocked score input tensor. Expected shape is [num_heads,
* num_query_tokens / block_size, num_key_tokens / block_size]
Expand All @@ -217,6 +244,8 @@ class XAttentionBlockSelector {
const Shape& blocked_attention_scores_shape) {
OPENVINO_ASSERT(blocked_attention_scores_shape.size() ==
3); // [num_heads, num_blocks_in_query, num_blocks_in_key]
//
OPENVINO_ASSERT(blocked_attention_scores_shape[1] <= blocked_attention_scores_shape[2]);

auto retval = XAttentionRetainedBlockIndicesForAllHeads(blocked_attention_scores_shape[0]);

Expand All @@ -230,23 +259,40 @@ class XAttentionBlockSelector {

for (size_t head_idx = 0; head_idx < blocked_attention_scores_shape[0]; head_idx++) {
size_t head_offset = head_idx * blocked_attention_scores_shape[1] * blocked_attention_scores_shape[2];
std::priority_queue<IndexAndScore> indices_and_scores_queue;
double total_sum = 0.0;
for (size_t q_block_idx = 0; q_block_idx < blocked_attention_scores_shape[1]; q_block_idx++) {
std::priority_queue<IndexAndScore> indices_and_scores_queue;
double total_sum = 0.0;
double cumsum = 0.0;
for (size_t k_block_idx = 0; k_block_idx < blocked_attention_scores_shape[2]; k_block_idx++) {
if (k_block_idx >
(blocked_attention_scores_shape[2] - blocked_attention_scores_shape[1] + q_block_idx)) {
// Disregard non-causal blocks entirely
continue;
}
size_t target_offset = head_offset + blocked_attention_scores_shape[2] * q_block_idx + k_block_idx;
T current_score = *(blocked_attention_scores_data + target_offset);
indices_and_scores_queue.push({{q_block_idx, k_block_idx}, current_score});
total_sum += current_score;

if ((k_block_idx ==
(blocked_attention_scores_shape[2] - blocked_attention_scores_shape[1] + q_block_idx)) ||
k_block_idx == 0) {
// We preserve first-in-row and diagonal blocks always, and include their score in the
// cumulative sum. The target for the rest of the blocks in row is to fill up the
// rest of the attention mass fraction so that with the diagonal and first blocks they
// comprise the `threshold` portion of the entire causal attention mass in this row
retval[head_idx].insert({q_block_idx, k_block_idx});
cumsum += current_score;
Comment on lines +276 to +284
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@luo-cheng2021 can you confirm that the GPU uses the same logic?

cc @peterchen-intel

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@WeldonWangwang WeldonWangwang Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed, this keep same logic with GPU kernel, thanks.

} else {
indices_and_scores_queue.push({{q_block_idx, k_block_idx}, current_score});
}
}
double required_sum = m_threshold * total_sum;
while (cumsum < required_sum && !indices_and_scores_queue.empty()) {
auto index_and_largest_score = indices_and_scores_queue.top();
indices_and_scores_queue.pop();
cumsum += index_and_largest_score.score;
retval[head_idx].insert(index_and_largest_score.idx);
}
}
double cumsum = 0.0;
double required_sum = m_threshold * total_sum;
while (cumsum < required_sum && !indices_and_scores_queue.empty()) {
auto index_and_largest_score = indices_and_scores_queue.top();
indices_and_scores_queue.pop();
cumsum += index_and_largest_score.score;
retval[head_idx].insert(index_and_largest_score.idx);
}
}
return retval;
Expand Down Expand Up @@ -303,6 +349,8 @@ class XAttentionBlockSelector {
q_buf.reset();
k_buf.reset();

apply_causal_mask_(qk_buf.get(), transpose_matmul_scaled_shape);

Shape attention_scores_shape = transpose_matmul_scaled_shape;
auto attn_score_buf = allocate_buf(attention_scores_shape);
softmax(qk_buf.get(), transpose_matmul_scaled_shape, attn_score_buf.get(), attention_scores_shape);
Expand Down
Loading
Loading