diff --git a/src/core/reference/include/openvino/reference/xattention.hpp b/src/core/reference/include/openvino/reference/xattention.hpp index 6e847d57053ffb..9ff31a21a2e91d 100644 --- a/src/core/reference/include/openvino/reference/xattention.hpp +++ b/src/core/reference/include/openvino/reference/xattention.hpp @@ -9,6 +9,7 @@ #include #include +#include "openvino/core/type/element_type_traits.hpp" #include "openvino/reference/divide.hpp" #include "openvino/reference/matmul.hpp" #include "openvino/reference/softmax.hpp" @@ -28,10 +29,11 @@ template class XAttentionBlockSelector { public: /** @param threshold Defines a threshold for introduced block sparsity - XAttention attempts to preserve the - * smallest subset of attention score matrix blocks so that the ratio of the attention score sum to the total sum of - * attention score matrix elements is no less than `threshold`. In other words, `threshold` defines a fraction of - * the attention score mass which is to be preserved by most "important" blocks. Valid range is 0.0-1.0, with 0.0 - * corresponding to 0% of the blocks retained, and 1.0 corresponding to 100% of the blocks retained. + * smallest subset of causal non-diagonal attention score matrix blocks so that the ratio of their attention score + * sum to the total sum of causal non-diagonal attention score matrix blocks in the same K-row is no less than + * `threshold`. In other words, `threshold` defines a fraction of the block non-diagonal causal attention score mass + * which is to be preserved by most "important" blocks. Valid range is 0.0-1.0, with 0.0 corresponding to 0% of the + * non-diagonal causal blocks retained, and 1.0 corresponding to 100% of the non-diagonal causal blocks retained. * @param block_size The size of blocks into which the attention score matrix [num_heads, query_token_dimension, * key_token_dimension] will be subdivided for purposes of determining the subset of the most important blocks * according to `threshold`. This subdivision occurs on query and key dimensions of the attention score matrix with @@ -76,17 +78,17 @@ class XAttentionBlockSelector { OPENVINO_ASSERT(input_shape[1] / m_stride == out_shape[1]); OPENVINO_ASSERT(input_shape[2] * m_stride == out_shape[2]); - size_t num_stride_steps = input_shape[1] / m_stride; + size_t num_elts_in_strided_slice = input_shape[1] / m_stride; for (size_t head_idx = 0; head_idx < input_shape[0]; head_idx++) { size_t head_offset = head_idx * input_shape[1] * input_shape[2]; - for (size_t slice_idx = 0; slice_idx < m_stride; slice_idx++) { - for (size_t stride_idx = 0; stride_idx < num_stride_steps; stride_idx++) { + for (size_t stride_num = 0; stride_num < m_stride; stride_num++) { + for (size_t intra_slice_step = 0; intra_slice_step < num_elts_in_strided_slice; intra_slice_step++) { size_t input_offset = head_offset; - size_t output_offset = head_offset + stride_idx * out_shape[2] + slice_idx * input_shape[2]; + size_t output_offset = head_offset + intra_slice_step * out_shape[2] + stride_num * input_shape[2]; if (is_antidiagonal) { - input_offset += (input_shape[1] - 1 - slice_idx - stride_idx * m_stride) * input_shape[2]; + input_offset += (m_stride - 1 - stride_num + intra_slice_step * m_stride) * input_shape[2]; } else { - input_offset += (slice_idx + stride_idx * m_stride) * input_shape[2]; + input_offset += (stride_num + intra_slice_step * m_stride) * input_shape[2]; } std::memcpy(output_data + output_offset, input_data + input_offset, input_shape[2] * sizeof(T)); } @@ -142,6 +144,28 @@ class XAttentionBlockSelector { } } + /** Applies the softmax causal mask along the last two dimensions of the rank-3 input tensor in-place. + * @param in_out_data Pointer to the softmax input values (logits). + * @param in_out_shape Shape of the input tensor. Expected shape is [num_heads, num_query_tokens / + * stride, num_key_tokens / stride]. + */ + void apply_causal_mask_(T* in_out_data, const Shape& in_out_shape) { + OPENVINO_ASSERT(in_out_shape.size() == 3); + OPENVINO_ASSERT(in_out_shape[1] <= in_out_shape[2]); + size_t query_dim = in_out_shape[1]; + size_t key_dim = in_out_shape[2]; + for (size_t head_idx = 0; head_idx < in_out_shape[0]; head_idx++) { + size_t head_offset = head_idx * in_out_shape[1] * in_out_shape[2]; + for (size_t query_dim_idx = 0; query_dim_idx < in_out_shape[1]; query_dim_idx++) { + size_t query_dim_offset = query_dim_idx * in_out_shape[2]; + for (size_t key_dim_idx = key_dim - query_dim + query_dim_idx + 1; key_dim_idx < key_dim; + key_dim_idx++) { + in_out_data[head_offset + query_dim_offset + key_dim_idx] = -INFINITY; + } + } + } + } + /** Performs a softmax operation on the last dimension of the rank-3 input tensor. * @param reshaped_qk_product_data Pointer to the reshaped query-key product input (attention logits pre-softmax). * @param reshaped_qk_product_shape Shape of the input tensor. Expected shape is [num_heads, num_query_tokens / @@ -203,9 +227,12 @@ class XAttentionBlockSelector { } } - /** Selects the elements of the input tensor along the last two dimensions, independently along the first dimension, - * so that the elements constitute a smallest subset constituting a sum portion no less than `threshold` of the - * total element sum. + /** Selects the elements of the input tensor along the last dimension, independently along the first two dimensions, + * so that the selected elements constitute a smallest subset amounting to a sum portion no less than `threshold` + * of the total "causal" element sum. "Causal" is understood in the sense of the last two dimensions being + * treated as the query-block and key-block dimensions in the context of attention matrix scores. The + * first-in-row, the "diagonal" and "non-causal" elements are disregarded when calculating the sum. "Non-causal" + * elements are never preserved, while "diagonal" and first-in-row elements are always preserved. * @param blocked_scores_data Pointer to the blocked score input. * @param blocked_attention_scores_shape Shape of the blocked score input tensor. Expected shape is [num_heads, * num_query_tokens / block_size, num_key_tokens / block_size] @@ -217,6 +244,8 @@ class XAttentionBlockSelector { const Shape& blocked_attention_scores_shape) { OPENVINO_ASSERT(blocked_attention_scores_shape.size() == 3); // [num_heads, num_blocks_in_query, num_blocks_in_key] + // + OPENVINO_ASSERT(blocked_attention_scores_shape[1] <= blocked_attention_scores_shape[2]); auto retval = XAttentionRetainedBlockIndicesForAllHeads(blocked_attention_scores_shape[0]); @@ -230,23 +259,40 @@ class XAttentionBlockSelector { for (size_t head_idx = 0; head_idx < blocked_attention_scores_shape[0]; head_idx++) { size_t head_offset = head_idx * blocked_attention_scores_shape[1] * blocked_attention_scores_shape[2]; - std::priority_queue indices_and_scores_queue; - double total_sum = 0.0; for (size_t q_block_idx = 0; q_block_idx < blocked_attention_scores_shape[1]; q_block_idx++) { + std::priority_queue indices_and_scores_queue; + double total_sum = 0.0; + double cumsum = 0.0; for (size_t k_block_idx = 0; k_block_idx < blocked_attention_scores_shape[2]; k_block_idx++) { + if (k_block_idx > + (blocked_attention_scores_shape[2] - blocked_attention_scores_shape[1] + q_block_idx)) { + // Disregard non-causal blocks entirely + continue; + } size_t target_offset = head_offset + blocked_attention_scores_shape[2] * q_block_idx + k_block_idx; T current_score = *(blocked_attention_scores_data + target_offset); - indices_and_scores_queue.push({{q_block_idx, k_block_idx}, current_score}); total_sum += current_score; + + if ((k_block_idx == + (blocked_attention_scores_shape[2] - blocked_attention_scores_shape[1] + q_block_idx)) || + k_block_idx == 0) { + // We preserve first-in-row and diagonal blocks always, and include their score in the + // cumulative sum. The target for the rest of the blocks in row is to fill up the + // rest of the attention mass fraction so that with the diagonal and first blocks they + // comprise the `threshold` portion of the entire causal attention mass in this row + retval[head_idx].insert({q_block_idx, k_block_idx}); + cumsum += current_score; + } else { + indices_and_scores_queue.push({{q_block_idx, k_block_idx}, current_score}); + } + } + double required_sum = m_threshold * total_sum; + while (cumsum < required_sum && !indices_and_scores_queue.empty()) { + auto index_and_largest_score = indices_and_scores_queue.top(); + indices_and_scores_queue.pop(); + cumsum += index_and_largest_score.score; + retval[head_idx].insert(index_and_largest_score.idx); } - } - double cumsum = 0.0; - double required_sum = m_threshold * total_sum; - while (cumsum < required_sum && !indices_and_scores_queue.empty()) { - auto index_and_largest_score = indices_and_scores_queue.top(); - indices_and_scores_queue.pop(); - cumsum += index_and_largest_score.score; - retval[head_idx].insert(index_and_largest_score.idx); } } return retval; @@ -303,6 +349,8 @@ class XAttentionBlockSelector { q_buf.reset(); k_buf.reset(); + apply_causal_mask_(qk_buf.get(), transpose_matmul_scaled_shape); + Shape attention_scores_shape = transpose_matmul_scaled_shape; auto attn_score_buf = allocate_buf(attention_scores_shape); softmax(qk_buf.get(), transpose_matmul_scaled_shape, attn_score_buf.get(), attention_scores_shape); diff --git a/src/core/tests/reference/xattention.cpp b/src/core/tests/reference/xattention.cpp index 281a4b781c58b6..e5598d45bd7020 100644 --- a/src/core/tests/reference/xattention.cpp +++ b/src/core/tests/reference/xattention.cpp @@ -5,73 +5,23 @@ #include #include +#include #include double DEFAULT_THRESHOLD = 0.8; size_t DEFAULT_BLOCK_SIZE = 32; size_t DEFAULT_STRIDE = 8; -struct E2EBlockSelectTestData { - ov::Shape q_shape; - std::vector q_data; - ov::Shape k_shape; - std::vector k_data; - double threshold; - size_t block_size; - size_t stride; -}; - -using XAttentionE2EBlockSelectTest = ::testing::TestWithParam; - -std::vector E2E_BLOCK_SELECT_TEST_CASES = {{ - {2, 4, 4}, - // clang-format off - { - 3.144, 8.512, 8.518, -8.386, - 7.889, -5.721, 5.507, 4.295, - -6.624, -8.463, 7.474, 9.879, - 4.534, -5.908, -9.388, 2.356, +TEST(XAttentionBasicTest, SelectsBlocksWithoutThrowing) { + ov::reference::XAttentionBlockSelector selector(DEFAULT_THRESHOLD, DEFAULT_BLOCK_SIZE, DEFAULT_STRIDE); - 7.497, 8.186, -8.658, -4.796, - -8.248, -9.797, -7.907, -4.513, - 3.469, 7.633, 7.244, -6.844, - -7.173, 4.450, 6.705, -7.035 - }, - // clang-format on - {2, 4, 4}, - // clang-format off - { - 3.144, 8.512, 8.518, -8.386, - 7.889, -5.721, 5.507, 4.295, - -6.624, -8.463, 7.474, 9.879, - 4.534, -5.908, -9.388, 2.356, - - 7.497, 8.186, -8.658, -4.796, - -8.248, -9.797, -7.907, -4.513, - 3.469, 7.633, 7.244, -6.844, - -7.173, 4.450, 6.705, -7.035 - }, - // clang-format on - - /* threshold = */ 0.8, - /* block_size = */ 2, - /* stride = */ 2, -}}; - -TEST_P(XAttentionE2EBlockSelectTest, SelectsBlocksWithoutThrowing) { - auto test_struct = GetParam(); - ov::reference::XAttentionBlockSelector selector(test_struct.threshold, - test_struct.block_size, - test_struct.stride); - - EXPECT_NO_THROW(selector.select_blocks(test_struct.q_data.data(), - test_struct.q_shape, - test_struct.k_data.data(), - test_struct.k_shape)); + ov::Shape q_shape = {2, 64, 32}; + ov::Shape k_shape = {2, 128, 32}; + std::vector q_data(ov::shape_size(q_shape), 1.0); + std::vector k_data(ov::shape_size(k_shape), 1.0); + EXPECT_NO_THROW(selector.select_blocks(q_data.data(), q_shape, k_data.data(), k_shape)); }; -INSTANTIATE_TEST_SUITE_P(VariousInputs, XAttentionE2EBlockSelectTest, ::testing::ValuesIn(E2E_BLOCK_SELECT_TEST_CASES)); - struct DiagonalReshapeTestData { ov::Shape in_shape; std::vector in_data; @@ -108,11 +58,11 @@ std::vector DIAGONAL_RESHAPE_TEST_CASES = { // clang-format off { - 4.534, -5.908, -9.388, 2.356, -6.624, -8.463, 7.474, 9.879, 7.889, -5.721, 5.507, 4.295, 3.144, 8.512, 8.518, -8.386, + 4.534, -5.908, -9.388, 2.356, -6.624, -8.463, 7.474, 9.879, - -7.173, 4.450, 6.705, -7.035, 3.469, 7.633, 7.244, -6.844, -8.248, -9.797, -7.907, -4.513, 7.497, 8.186, -8.658, -4.796, + -7.173, 4.450, 6.705, -7.035, 3.469, 7.633, 7.244, -6.844, }, // clang-format on }, @@ -180,13 +130,13 @@ std::vector DIAGONAL_RESHAPE_TEST_CASES = { // clang-format off { - -8.410, 6.247, 0.264, 7.095, 1.354, -7.748, - -7.413, 5.855, -4.142, 2.837, 3.930, -2.122, 3.664, -2.459, 3.530, -1.083, 1.110, -4.244, + -7.413, 5.855, -4.142, 2.837, 3.930, -2.122, + -8.410, 6.247, 0.264, 7.095, 1.354, -7.748, - -9.869, -7.636, -5.892, 7.820, 9.438, -2.421, - 3.568, 8.530, -0.841, 1.935, 1.767, 5.950, -5.429, 7.854, -7.414, -3.682, -7.832, 9.163, + 3.568, 8.530, -0.841, 1.935, 1.767, 5.950, + -9.869, -7.636, -5.892, 7.820, 9.438, -2.421, }, // clang-format on }, @@ -388,6 +338,114 @@ INSTANTIATE_TEST_SUITE_P(VariousInputs, XAttentionSoftmaxTest, ::testing::ValuesIn(SOFTMAX_TEST_CASES)); +struct CausalMaskTestData { + ov::Shape in_shape; + std::vector in_data; + std::vector ref_out_data; +}; + +using XAttentionCausalMaskTest = ::testing::TestWithParam; + +std::vector CAUSAL_MASK_TEST_CASES = { + { + {2, 4, 4}, + // clang-format off + { + 4.534, -5.908, -9.388, 2.356, + 7.889, -5.721, 5.507, 4.295, + 4.534, -5.908, -9.388, 2.356, + 7.889, -5.721, 5.507, 4.295, + + -7.173, 4.450, 6.705, -7.035, + -8.248, -9.797, -7.907, -4.513, + -7.173, 4.450, 6.705, -7.035, + -8.248, -9.797, -7.907, -4.513 + }, + // clang-format on + + // clang-format off + { + 4.534, -INFINITY, -INFINITY, -INFINITY, + 7.889, -5.721, -INFINITY, -INFINITY, + 4.534, -5.908, -9.388, -INFINITY, + 7.889, -5.721, 5.507, 4.295, + + -7.173, -INFINITY, -INFINITY, -INFINITY, + -8.248, -9.797, -INFINITY, -INFINITY, + -7.173, 4.450, 6.705, -INFINITY, + -8.248, -9.797, -7.907, -4.513 + }, + }, + { + {2, 2, 4}, + // clang-format off + { + 4.534, -5.908, -9.388, 2.356, + 7.889, -5.721, 5.507, 4.295, + + -7.173, 4.450, 6.705, -7.035, + -8.248, -9.797, -7.907, -4.513 + }, + // clang-format on + + // clang-format off + { + 4.534, -5.908, -9.388, -INFINITY, + 7.889, -5.721, 5.507, 4.295, + + -7.173, 4.450, 6.705, -INFINITY, + -8.248, -9.797, -7.907, -4.513 + }, + }, + { + {2, 4, 6}, + // clang-format off + { + 4.534, -5.908, -9.388, 2.356, -5.908, -9.388, + 7.889, -5.721, 5.507, 4.295, -5.721, 5.507, + 4.534, -5.908, -9.388, 2.356, -5.908, -9.388, + 7.889, -5.721, 5.507, 4.295, -5.721, 5.507, + + -7.173, 4.450, 6.705, -7.035, 4.450, 6.705, + -8.248, -9.797, -7.907, -4.513, -9.797, -7.907, + -7.173, 4.450, 6.705, -7.035, 4.450, 6.705, + -8.248, -9.797, -7.907, -4.513, -9.797, -7.907, + }, + // clang-format on + + // clang-format off + { + 4.534, -5.908, -9.388, -INFINITY, -INFINITY, -INFINITY, + 7.889, -5.721, 5.507, 4.295, -INFINITY, -INFINITY, + 4.534, -5.908, -9.388, 2.356, -5.908, -INFINITY, + 7.889, -5.721, 5.507, 4.295, -5.721, 5.507, + + -7.173, 4.450, 6.705, -INFINITY, -INFINITY, -INFINITY, + -8.248, -9.797, -7.907, -4.513, -INFINITY, -INFINITY, + -7.173, 4.450, 6.705, -7.035, 4.450, -INFINITY, + -8.248, -9.797, -7.907, -4.513, -9.797, -7.907, + }, + }, +}; + +TEST_P(XAttentionCausalMaskTest, CausalMaskIsCorrect) { + auto test_struct = GetParam(); + ASSERT_EQ(test_struct.in_data.size(), ov::shape_size(test_struct.in_shape)); + ASSERT_EQ(test_struct.ref_out_data.size(), ov::shape_size(test_struct.in_shape)); + + ov::reference::XAttentionBlockSelector selector(DEFAULT_THRESHOLD, + DEFAULT_BLOCK_SIZE, + DEFAULT_STRIDE); + std::vector test_out_data = test_struct.in_data; + selector.apply_causal_mask_(test_out_data.data(), test_struct.in_shape); + + EXPECT_THAT(test_out_data, ::testing::Pointwise(::testing::DoubleNear(1e-5), test_struct.ref_out_data)); +} + +INSTANTIATE_TEST_SUITE_P(VariousInputs, + XAttentionCausalMaskTest, + ::testing::ValuesIn(CAUSAL_MASK_TEST_CASES)); + struct BlockSumTestData { ov::Shape in_shape; std::vector in_data; @@ -459,81 +517,82 @@ using XAttentionBlockSelectTest = ::testing::TestWithParam; std::vector BLOCK_SELECT_TEST_CASES = { { - {2, 2, 4}, + {2, 2, 5}, // clang-format off { - 0.5151, 0.4323, 0.5014, 0.5513, - 0.4557, 0.4876, 0.5870, 0.4697, + 0.0000, 0.5151, 0.4323, 0.5014, 0.5513, + 100.0, 0.4557, 0.4876, 0.5870, 0.4697, - 0.5118, 0.4507, 0.5180, 0.5194, - 0.5315, 0.4446, 0.4929, 0.5310 + 1.7491, 0.3118, 0.4507, 0.5180, 0.5194, + 0.3123, 0.5315, 0.4446, 0.4929, 0.5310 }, // clang-format on /* threshold = */ 0.25, { - {{1, 2}, {0, 3}}, - {{1, 0}, {1, 3}}, + {{0, 0}, {1, 0}, {0, 3}, {1, 4}}, + {{0, 0}, {1, 0}, {0, 3}, {1, 4}}, }}, - {{2, 2, 4}, + {{2, 2, 5}, // clang-format off { - 0.5151, 0.4323, 0.5014, 0.5513, - 0.4557, 0.4876, 0.5870, 0.4697, + // larger values in non-causal area should have no impact + 0.4729, 0.5151, 0.4323, 0.5014, 1337.0, + 0.5267, 0.4557, 0.4876, 0.5870, 0.4697, - 0.5118, 0.4507, 0.5180, 0.5194, - 0.5315, 0.4446, 0.4929, 0.5310 + 0.4647, 0.5118, 0.4507, 0.5180, 42.0, + 0.0000, 0.5315, 0.4446, 0.4929, 0.5310 }, // clang-format on - /* threshold = */ 0.35, + /* threshold = */ 0.45, { - {{1, 2}, {0, 3}, {0, 0}}, - {{1, 0}, {1, 3}, {0, 3}}, + {{0, 0}, {1, 0}, {0, 3}, {1, 4}, {1, 3}}, + {{0, 0}, {1, 0}, {0, 3}, {1, 4}, {1, 1}}, }}, - {{2, 2, 4}, + {{2, 2, 5}, // clang-format off { - 0.5151, 0.4323, 0.5014, 0.5513, - 0.4557, 0.4876, 0.5870, 0.4697, + 0.4729, 0.5151, 0.4323, 0.5014, 0.5513, + 0.5267, 0.4557, 0.4876, 0.5870, 0.4697, - 0.5118, 0.4507, 0.5180, 0.5194, - 0.5315, 0.4446, 0.4929, 0.5310 + 0.4647, 0.5118, 0.4507, 0.5180, 0.5194, + 0.0000, 0.4446, 0.1234, 0.4929, 0.5310 }, // clang-format on - /* threshold = */ 0.1, + /* threshold = */ 0.8, { - {{1, 2}}, - {{1, 0}}, + {{0, 0}, {1, 0}, {0, 3}, {1, 4}, {0, 1}, {0, 2}, {1, 3}, {1, 2}}, + {{0, 0}, {1, 0}, {0, 3}, {1, 4}, {0, 1}, {0, 2}, {1, 3}, {1, 1}}, }}, - {{2, 2, 4}, + {{2, 2, 5}, // clang-format off { - 0.5151, 0.4323, 0.5014, 0.5513, - 0.4557, 0.4876, 0.5870, 0.4697, + 0.4729, 0.5151, 0.4323, 0.5014, 0.5513, + 0.5267, 0.4557, 0.4876, 0.5870, 0.4697, - 0.5118, 0.4507, 0.5180, 0.5194, - 0.5315, 0.4446, 0.4929, 0.5310 + 0.4647, 0.5118, 0.4507, 0.5180, 0.5194, + 0.0000, 0.5315, 0.4446, 0.4929, 0.5310 }, // clang-format on /* threshold = */ 0.0, { - {}, - {}, + {{0, 0}, {1, 0}, {0, 3}, {1, 4}}, + {{0, 0}, {1, 0}, {0, 3}, {1, 4}}, }}, - {{2, 2, 4}, + {{2, 2, 5}, // clang-format off { - 0.5151, 0.4323, 0.5014, 0.5513, - 0.4557, 0.4876, 0.5870, 0.4697, + 0.4729, 0.5151, 0.4323, 0.5014, 0.5513, + 0.5267, 0.4557, 0.4876, 0.5870, 0.4697, - 0.5118, 0.4507, 0.5180, 0.5194, - 0.5315, 0.4446, 0.4929, 0.5310 + 0.4647, 0.5118, 0.4507, 0.5180, 0.5194, + 0.0000, 0.5315, 0.4446, 0.4929, 0.5310 }, // clang-format on /* threshold = */ 1.0, { - {{1, 2}, {0, 3}, {0, 0}, {0, 2}, {1, 1}, {1, 3}, {1, 0}, {0, 1}}, - {{1, 0}, {1, 3}, {0, 3}, {0, 2}, {0, 0}, {1, 2}, {0, 1}, {1, 1}}, + {{0, 0}, {1, 0}, {1, 3}, {0, 1}, {0, 3}, {1, 2}, {1, 4}, {1, 1}, {0, 2}}, + {{0, 0}, {1, 0}, {1, 1}, {1, 4}, {0, 3}, {0, 1}, {1, 3}, {0, 2}, {1, 2}}, }}, }; @@ -548,3 +607,265 @@ TEST_P(XAttentionBlockSelectTest, BlockSelectionIsCorrect) { } INSTANTIATE_TEST_SUITE_P(VariousInputs, XAttentionBlockSelectTest, ::testing::ValuesIn(BLOCK_SELECT_TEST_CASES)); + +struct E2EBlockSelectTestData { + ov::Shape q_shape; + std::vector q_data; + ov::Shape k_shape; + std::vector k_data; + double threshold; + size_t block_size; + size_t stride; + ov::reference::XAttentionRetainedBlockIndicesForAllHeads ref_retained_block_indices; +}; + +using XAttentionE2EBlockSelectTest = ::testing::TestWithParam; + +ov::Shape E2E_Q_SHAPE_8 = {2, 8, 2}; +std::vector E2E_Q_DATA_8 = { + // clang-format off + -1.2870, -1.2179, 0.0316, 0.0080, -0.6171, 1.0622, 0.3085, -0.7751, + -1.3612, 0.9485, -0.0803, 0.5752, 0.1925, -0.1113, 1.4693, 0.0673, + 0.7422, 0.7149, -1.7684, -0.0651, -0.1925, -1.4169, 1.0030, -0.8091, + -0.7934, 0.5160, -0.2543, 0.1729, -0.0687, -1.4245, 0.0758, 1.1613 + // clang-format on +}; + +ov::Shape E2E_K_SHAPE_8 = {2, 8, 2}; +std::vector E2E_K_DATA_8 = { + // clang-format off + 0.2980, 0.4959, -0.0834, 0.7015, 1.2516, 0.6656, -2.7873, 1.9731, + -0.4817, 1.1117, -0.8096, -0.5397, -1.0528, 0.2869, -1.1274, 1.4849, + -0.2468, -1.0449, -1.0085, -0.3389, 0.6750, 0.9095, 0.4674, 2.2321, + 1.3183, -0.3513, -0.3717, 0.0176, -0.2545, -0.6729, -1.1547, 0.0279 + // clang-format on +}; + +ov::Shape E2E_K_SHAPE_16 = {2, 16, 2}; +std::vector E2E_K_DATA_16 = { + // clang-format off + -0.9049, -1.9274, -0.3687, -1.1156, 0.1343, 1.1119, 0.7139, 1.0958, + 0.7644, 1.9416, 0.9911, 0.8628, 0.4935, -0.3232, -1.1748, 0.0462, + 0.0488, -0.4271, 1.6657, 0.4596, 1.3253, -1.3023, 0.4961, 1.3707, + -0.1723, -1.1623, -0.6218, -0.5510, 0.1900, 0.2679, -1.0627, 0.6976, + 0.0737, 0.7033, 1.5972, -0.7547, 0.2586, -0.7601, -0.3851, -0.7056, + -1.2970, -0.2983, 0.9817, 0.0878, 1.1081, -0.9637, 0.4593, -0.2039, + -0.3805, 0.1023, -0.2613, -0.5791, 0.2056, -1.1121, -0.0553, -2.4382, + 0.0129, -0.6673, -1.2580, -0.5264, 1.0097, -0.7766, 0.9379, 0.7274 + // clang-format on +}; + +std::vector E2E_BLOCK_SELECT_TEST_CASES = { + { + {2, 4, 4}, + // clang-format off + { + 3.144, 8.512, 8.518, -8.386, + 7.889, -5.721, 5.507, 4.295, + -6.624, -8.463, 7.474, 9.879, + 4.534, -5.908, -9.388, 2.356, + + 7.497, 8.186, -8.658, -4.796, + -8.248, -9.797, -7.907, -4.513, + 3.469, 7.633, 7.244, -6.844, + -7.173, 4.450, 6.705, -7.035 + }, + // clang-format on + {2, 4, 4}, + // clang-format off + { + 3.144, 8.512, 8.518, -8.386, + 7.889, -5.721, 5.507, 4.295, + -6.624, -8.463, 7.474, 9.879, + 4.534, -5.908, -9.388, 2.356, + + 7.497, 8.186, -8.658, -4.796, + -8.248, -9.797, -7.907, -4.513, + 3.469, 7.633, 7.244, -6.844, + -7.173, 4.450, 6.705, -7.035 + }, + // clang-format on + + /* threshold = */ 0.8, + /* block_size = */ 2, + /* stride = */ 2, + + // clang-format off + { + {{0, 0}, {1, 1}, {1, 0}}, + {{0, 0}, {1, 1}, {1, 0}} + } + // clang-format on + }, + + { + E2E_Q_SHAPE_8, + E2E_Q_DATA_8, + E2E_K_SHAPE_16, + E2E_K_DATA_16, + /* threshold = */ 0.0, + /* block_size = */ 2, + /* stride = */ 2, + {{{0, 0}, {0, 4}, {1, 0}, {1, 5}, {2, 0}, {2, 6}, {3, 0}, {3, 7}}, + {{0, 0}, {0, 4}, {1, 0}, {1, 5}, {2, 0}, {2, 6}, {3, 0}, {3, 7}}}, + }, + { + E2E_Q_SHAPE_8, + E2E_Q_DATA_8, + E2E_K_SHAPE_16, + E2E_K_DATA_16, + /* threshold = */ 1.0, + /* block_size = */ 2, + /* stride = */ 2, + + // clang-format off + { + {{0, 0}, {0, 1}, {0, 2}, {0, 3}, {0, 4}, {1, 0}, {1, 1}, {1, 2}, {1, 3}, {1, 4}, {1, 5}, {2, 0}, {2, 1}, {2, 2}, {2, 3}, {2, 4}, {2, 5}, {2, 6}, {3, 0}, {3, 1}, {3, 2}, {3, 3}, {3, 4}, {3, 5}, {3, 6}, {3, 7}}, + {{0, 0}, {0, 1}, {0, 2}, {0, 3}, {0, 4}, {1, 0}, {1, 1}, {1, 2}, {1, 3}, {1, 4}, {1, 5}, {2, 0}, {2, 1}, {2, 2}, {2, 3}, {2, 4}, {2, 5}, {2, 6}, {3, 0}, {3, 1}, {3, 2}, {3, 3}, {3, 4}, {3, 5}, {3, 6}, {3, 7}} + } + // clang-format on + }, + { + E2E_Q_SHAPE_8, + E2E_Q_DATA_8, + E2E_K_SHAPE_16, + E2E_K_DATA_16, + /* threshold = */ 0.8, + /* block_size = */ 2, + /* stride = */ 2, + + // clang-format off + { + + {{0, 0}, {0, 3}, {0, 4}, {1, 0}, {1, 1}, {1, 3}, {1, 4}, {1, 5}, {2, 0}, {2, 1}, {2, 2}, {2, 3}, {2, 5}, {2, 6}, {3, 0}, {3, 1}, {3, 2}, {3, 3}, {3, 4}, {3, 5}, {3, 7}}, + {{0, 0}, {0, 2}, {0, 4}, {1, 0}, {1, 1}, {1, 3}, {1, 5}, {2, 0}, {2, 1}, {2, 2}, {2, 3}, {2, 4}, {2, 6}, {3, 0}, {3, 1}, {3, 4}, {3, 5}, {3, 6}, {3, 7}} + } + // clang-format on + }, + { + E2E_Q_SHAPE_8, + E2E_Q_DATA_8, + E2E_K_SHAPE_16, + E2E_K_DATA_16, + /* threshold = */ 0.45, + /* block_size = */ 2, + /* stride = */ 2, + + // clang-format off + { + {{0, 0}, {0, 4}, {1, 0}, {1, 5}, {2, 0}, {2, 1}, {2, 3}, {2, 6}, {3, 0}, {3, 2}, {3, 5}, {3, 7}}, + {{0, 0}, {0, 2}, {0, 4}, {1, 0}, {1, 5}, {2, 0}, {2, 4}, {2, 6}, {3, 0}, {3, 5}, {3, 7}} + } + // clang-format on + }, + { + E2E_Q_SHAPE_8, + E2E_Q_DATA_8, + E2E_K_SHAPE_16, + E2E_K_DATA_16, + /* threshold = */ 0.45, + /* block_size = */ 4, + /* stride = */ 2, + + // clang-format off + { + {{0, 0}, {0, 2}, {1, 0}, {1, 1}, {1, 3}}, + {{0, 0}, {0, 2}, {1, 0}, {1, 3}} + } + // clang-format on + }, + { + E2E_Q_SHAPE_8, + E2E_Q_DATA_8, + E2E_K_SHAPE_16, + E2E_K_DATA_16, + /* threshold = */ 0.45, + /* block_size = */ 4, + /* stride = */ 4, + + // clang-format off + { + {{0, 0}, {0, 2}, {1, 0}, {1, 3}}, + {{0, 0}, {0, 2}, {1, 0}, {1, 3}} + } + // clang-format on + }, + { + E2E_Q_SHAPE_8, + E2E_Q_DATA_8, + E2E_K_SHAPE_8, + E2E_K_DATA_8, + /* threshold = */ 0.5, + /* block_size = */ 2, + /* stride = */ 2, + + // clang-format off + { + {{0, 0}, {1, 0}, {1, 1}, {2, 0}, {2, 1}, {2, 2}, {3, 0}, {3, 1}, {3, 3}}, + {{0, 0}, {1, 0}, {1, 1}, {2, 0}, {2, 2}, {3, 0}, {3, 3}} + } + // clang-format on + }, + { + E2E_Q_SHAPE_8, + E2E_Q_DATA_8, + E2E_K_SHAPE_8, + E2E_K_DATA_8, + /* threshold = */ 0.2, + /* block_size = */ 2, + /* stride = */ 2, + + // clang-format off + { + {{0, 0}, {1, 0}, {1, 1}, {2, 0}, {2, 2}, {3, 0}, {3, 3}}, + {{0, 0}, {1, 0}, {1, 1}, {2, 0}, {2, 2}, {3, 0}, {3, 3}} + } + // clang-format on + }}; + +TEST_P(XAttentionE2EBlockSelectTest, SelectsBlocksCorrectlyFromQKData) { + auto test_struct = GetParam(); + ov::reference::XAttentionBlockSelector selector(test_struct.threshold, + test_struct.block_size, + test_struct.stride); + + auto test_result = selector.select_blocks(test_struct.q_data.data(), + test_struct.q_shape, + test_struct.k_data.data(), + test_struct.k_shape); + + ASSERT_EQ(test_result.size(), test_struct.ref_retained_block_indices.size()); + EXPECT_EQ(test_result, test_struct.ref_retained_block_indices); + for (size_t head_idx = 0; head_idx < test_result.size(); head_idx++) { + if (test_result != test_struct.ref_retained_block_indices) { + std::cout << "Head " << head_idx << std::endl; + const auto& ref_set = test_struct.ref_retained_block_indices[head_idx]; + const auto& test_set = test_result[head_idx]; + std::cout << "ref has " << ref_set.size() << " elements, test has " << test_set.size() << std::endl; + std::vector> intersection; + std::set_intersection(ref_set.begin(), + ref_set.end(), + test_set.begin(), + test_set.end(), + std::back_inserter(intersection)); + + std::cout << "only ref has "; + for (const auto& idx : ref_set) { + if (test_set.find(idx) == test_set.end()) { + std::cout << "(" << idx.first << ", " << idx.second << ")" << std::endl; + } + } + std::cout << std::endl; + std::cout << "only test has "; + for (const auto& idx : test_set) { + if (ref_set.find(idx) == ref_set.end()) { + std::cout << "(" << idx.first << ", " << idx.second << ")" << std::endl; + } + } + std::cout << std::endl; + std::cout << std::endl; + } + } +} + +INSTANTIATE_TEST_SUITE_P(VariousInputs, XAttentionE2EBlockSelectTest, ::testing::ValuesIn(E2E_BLOCK_SELECT_TEST_CASES));