diff --git a/Cargo.lock b/Cargo.lock index de46032f7b3f3..efc30103e72aa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5436,6 +5436,7 @@ dependencies = [ "bytemuck", "bytes", "cbordata", + "cc", "crc32fast", "databend-common-ast", "databend-common-exception", diff --git a/licenserc.toml b/licenserc.toml index f6fc60c4cb145..046577c2f1828 100644 --- a/licenserc.toml +++ b/licenserc.toml @@ -10,6 +10,7 @@ excludes = [ "benchmark", "src/common/compress/tests", "src/meta/compat.py", + "src/query/storages/common/index/cpp/", # licensed under Elastic License 2.0 "src/binaries/query/ee_main.rs", "src/meta/binaries/meta/ee_main.rs", diff --git a/src/query/storages/common/index/Cargo.toml b/src/query/storages/common/index/Cargo.toml index e40720b047164..e3bd05d865cb4 100644 --- a/src/query/storages/common/index/Cargo.toml +++ b/src/query/storages/common/index/Cargo.toml @@ -9,6 +9,9 @@ edition = { workspace = true } [package.metadata.cargo-machete] ignored = ["xorfilter-rs", "match-template"] +[build-dependencies] +cc = "1.0" + [dependencies] databend-common-ast = { workspace = true } databend-common-exception = { workspace = true } diff --git a/src/query/storages/common/index/build.rs b/src/query/storages/common/index/build.rs new file mode 100644 index 0000000000000..5fc081d2dd6bf --- /dev/null +++ b/src/query/storages/common/index/build.rs @@ -0,0 +1,57 @@ +// Copyright Qdrant +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::env; + +fn main() { + println!("cargo:rerun-if-changed=cpp"); + let mut builder = cc::Build::new(); + + let target_arch = env::var("CARGO_CFG_TARGET_ARCH") + .expect("CARGO_CFG_TARGET_ARCH env-var is not defined or is not UTF-8"); + + // TODO: Is `CARGO_CFG_TARGET_FEATURE` *always* defined? + // + // Cargo docs says that "boolean configurations are present if they are set, + // and not present otherwise", so, what about "target features"? + // + // https://doc.rust-lang.org/cargo/reference/environment-variables.html (Ctrl-F CARGO_CFG_) + let target_feature = env::var("CARGO_CFG_TARGET_FEATURE") + .expect("CARGO_CFG_TARGET_FEATURE env-var is not defined or is not UTF-8"); + + if target_arch == "x86_64" { + builder.file("cpp/sse.c"); + builder.file("cpp/avx2.c"); + + if builder.get_compiler().is_like_msvc() { + builder.flag("/arch:AVX"); + builder.flag("/arch:AVX2"); + builder.flag("/arch:SSE"); + builder.flag("/arch:SSE2"); + } else { + builder.flag("-march=haswell"); + } + + // O3 optimization level + builder.flag("-O3"); + // Use popcnt instruction + builder.flag("-mpopcnt"); + } else if target_arch == "aarch64" && target_feature.split(',').any(|feat| feat == "neon") { + builder.file("cpp/neon.c"); + builder.flag("-O3"); + } + + builder.compile("simd_utils"); +} diff --git a/src/query/storages/common/index/cpp/avx2.c b/src/query/storages/common/index/cpp/avx2.c new file mode 100644 index 0000000000000..4899fff76a1e8 --- /dev/null +++ b/src/query/storages/common/index/cpp/avx2.c @@ -0,0 +1,223 @@ +/* + * Copyright Qdrant + * Copyright 2021 Datafuse Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "export_macro.h" + +#define HSUM256_PS(X, R) \ + float R = 0.0f; \ + { \ + __m128 x128 = _mm_add_ps(_mm256_extractf128_ps(X, 1), _mm256_castps256_ps128(X)); \ + __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); \ + __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); \ + R = _mm_cvtss_f32(x32); \ + } + +#define HSUM256_EPI32(X, R) \ + int R = 0; \ + { \ + __m128i x128 = _mm_add_epi32(_mm256_extractf128_si256(X, 1), _mm256_castsi256_si128(X)); \ + __m128i x64 = _mm_add_epi32(x128, _mm_srli_si128(x128, 8)); \ + __m128i x32 = _mm_add_epi32(x64, _mm_srli_si128(x64, 4)); \ + R = _mm_cvtsi128_si32(x32); \ + } + +EXPORT float impl_score_dot_avx( + const uint8_t* query_ptr, + const uint8_t* vector_ptr, + uint32_t dim +) { + const __m256i* v_ptr = (const __m256i*)vector_ptr; + const __m256i* q_ptr = (const __m256i*)query_ptr; + + __m256i mul1 = _mm256_setzero_si256(); + __m256i mask_epu32 = _mm256_set1_epi32(0xFFFF); + for (uint32_t _i = 0; _i < dim / 32; _i++) { + __m256i v = _mm256_loadu_si256(v_ptr); + __m256i q = _mm256_loadu_si256(q_ptr); + v_ptr++; + q_ptr++; + + __m256i s = _mm256_maddubs_epi16(v, q); + __m256i s_low = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(s)); + __m256i s_high = _mm256_cvtepi16_epi32(_mm256_extractf128_si256(s, 1)); + mul1 = _mm256_add_epi32(mul1, s_low); + mul1 = _mm256_add_epi32(mul1, s_high); + } + + // the vector sizes are assumed to be multiples of 16, check if one last 16-element part remaining + if (dim % 32 != 0) { + __m128i v_short = _mm_loadu_si128((const __m128i*)v_ptr); + __m128i q_short = _mm_loadu_si128((const __m128i*)q_ptr); + + __m256i v1 = _mm256_cvtepu8_epi16(v_short); + __m256i q1 = _mm256_cvtepu8_epi16(q_short); + + __m256i s = _mm256_mullo_epi16(v1, q1); + mul1 = _mm256_add_epi32(mul1, _mm256_and_si256(s, mask_epu32)); + mul1 = _mm256_add_epi32(mul1, _mm256_srli_epi32(s, 16)); + } + __m256 mul_ps = _mm256_cvtepi32_ps(mul1); + HSUM256_PS(mul_ps, mul_scalar); + return mul_scalar; +} + +EXPORT float impl_score_l1_avx( + const uint8_t* query_ptr, + const uint8_t* vector_ptr, + uint32_t dim +) { + const __m256i* v_ptr = (const __m256i*)vector_ptr; + const __m256i* q_ptr = (const __m256i*)query_ptr; + + uint32_t m = dim - (dim % 32); + __m256i sum256 = _mm256_setzero_si256(); + + for (uint32_t i = 0; i < m; i += 32) { + __m256i v = _mm256_loadu_si256(v_ptr); + __m256i q = _mm256_loadu_si256(q_ptr); + v_ptr++; + q_ptr++; + + // Compute the difference in both directions and take the maximum for abs + __m256i diff1 = _mm256_subs_epu8(v, q); + __m256i diff2 = _mm256_subs_epu8(q, v); + + __m256i abs_diff = _mm256_max_epu8(diff1, diff2); + + __m256i abs_diff16_lo = _mm256_unpacklo_epi8(abs_diff, _mm256_setzero_si256()); + __m256i abs_diff16_hi = _mm256_unpackhi_epi8(abs_diff, _mm256_setzero_si256()); + + sum256 = _mm256_add_epi16(sum256, abs_diff16_lo); + sum256 = _mm256_add_epi16(sum256, abs_diff16_hi); + } + + // the vector sizes are assumed to be multiples of 16, check if one last 16-element part remaining + if (m < dim) { + __m128i v_short = _mm_loadu_si128((const __m128i * ) v_ptr); + __m128i q_short = _mm_loadu_si128((const __m128i * ) q_ptr); + + __m128i diff1 = _mm_subs_epu8(v_short, q_short); + __m128i diff2 = _mm_subs_epu8(q_short, v_short); + + __m128i abs_diff = _mm_max_epu8(diff1, diff2); + + __m128i abs_diff16_lo_128 = _mm_unpacklo_epi8(abs_diff, _mm_setzero_si128()); + __m128i abs_diff16_hi_128 = _mm_unpackhi_epi8(abs_diff, _mm_setzero_si128()); + + __m256i abs_diff16_lo = _mm256_cvtepu16_epi32(abs_diff16_lo_128); + __m256i abs_diff16_hi = _mm256_cvtepu16_epi32(abs_diff16_hi_128); + + sum256 = _mm256_add_epi16(sum256, abs_diff16_lo); + sum256 = _mm256_add_epi16(sum256, abs_diff16_hi); + } + + __m256i sum_epi32 = _mm256_add_epi32( + _mm256_unpacklo_epi16(sum256, _mm256_setzero_si256()), + _mm256_unpackhi_epi16(sum256, _mm256_setzero_si256())); + + HSUM256_EPI32(sum_epi32, sum); + + return (float) sum; +} + +EXPORT uint32_t impl_xor_popcnt_scalar8_avx_uint128( + const uint8_t* query_ptr, + const uint8_t* vector_ptr, + uint32_t count +) { + const uint64_t* v_ptr = (const uint64_t*)vector_ptr; + const uint64_t* q_ptr = (const uint64_t*)query_ptr; + + __m256i sum1 = _mm256_set1_epi32(0); + __m256i sum2 = _mm256_set1_epi32(0); + for (uint32_t _i = 0; _i < count; _i++) { + uint64_t v_1 = *v_ptr; + uint64_t v_2 = *(v_ptr + 1); + + __m256i popcnt1 = _mm256_set_epi32( + _mm_popcnt_u64(v_1 ^ *(q_ptr + 0)), + _mm_popcnt_u64(v_1 ^ *(q_ptr + 2)), + _mm_popcnt_u64(v_1 ^ *(q_ptr + 4)), + _mm_popcnt_u64(v_1 ^ *(q_ptr + 6)), + _mm_popcnt_u64(v_2 ^ *(q_ptr + 1)), + _mm_popcnt_u64(v_2 ^ *(q_ptr + 3)), + _mm_popcnt_u64(v_2 ^ *(q_ptr + 5)), + _mm_popcnt_u64(v_2 ^ *(q_ptr + 7)) + ); + sum1 = _mm256_add_epi32(sum1, popcnt1); + + __m256i popcnt2 = _mm256_set_epi32( + _mm_popcnt_u64(v_1 ^ *(q_ptr + 8)), + _mm_popcnt_u64(v_1 ^ *(q_ptr + 10)), + _mm_popcnt_u64(v_1 ^ *(q_ptr + 12)), + _mm_popcnt_u64(v_1 ^ *(q_ptr + 14)), + _mm_popcnt_u64(v_2 ^ *(q_ptr + 9)), + _mm_popcnt_u64(v_2 ^ *(q_ptr + 11)), + _mm_popcnt_u64(v_2 ^ *(q_ptr + 13)), + _mm_popcnt_u64(v_2 ^ *(q_ptr + 15)) + ); + sum2 = _mm256_add_epi32(sum2, popcnt2); + + v_ptr += 2; + q_ptr += 16; + } + __m256i factor1 = _mm256_set_epi32(1, 2, 4, 8, 1, 2, 4, 8); + __m256i factor2 = _mm256_set_epi32(16, 32, 64, 128, 16, 32, 64, 128); + __m256 result1_mm256 = _mm256_cvtepi32_ps(_mm256_mullo_epi32(sum1, factor1)); + __m256 result2_mm256 = _mm256_cvtepi32_ps(_mm256_mullo_epi32(sum2, factor2)); + HSUM256_PS(_mm256_add_ps(result1_mm256, result2_mm256), mul_scalar); + return (uint32_t)mul_scalar; +} + +EXPORT uint32_t impl_xor_popcnt_scalar4_avx_uint128( + const uint8_t* query_ptr, + const uint8_t* vector_ptr, + uint32_t count +) { + const uint64_t* v_ptr = (const uint64_t*)vector_ptr; + const uint64_t* q_ptr = (const uint64_t*)query_ptr; + + __m256i sum = _mm256_set1_epi32(0); + for (uint32_t _i = 0; _i < count; _i++) { + uint64_t v_1 = *v_ptr; + uint64_t v_2 = *(v_ptr + 1); + + __m256i popcnt = _mm256_set_epi32( + _mm_popcnt_u64(v_1 ^ *(q_ptr + 0)), + _mm_popcnt_u64(v_1 ^ *(q_ptr + 2)), + _mm_popcnt_u64(v_1 ^ *(q_ptr + 4)), + _mm_popcnt_u64(v_1 ^ *(q_ptr + 6)), + _mm_popcnt_u64(v_2 ^ *(q_ptr + 1)), + _mm_popcnt_u64(v_2 ^ *(q_ptr + 3)), + _mm_popcnt_u64(v_2 ^ *(q_ptr + 5)), + _mm_popcnt_u64(v_2 ^ *(q_ptr + 7)) + ); + sum = _mm256_add_epi32(sum, popcnt); + + v_ptr += 2; + q_ptr += 8; + } + __m256i factor = _mm256_set_epi32(1, 2, 4, 8, 1, 2, 4, 8); + __m256 result_mm256 = _mm256_cvtepi32_ps(_mm256_mullo_epi32(sum, factor)); + HSUM256_PS(result_mm256, mul_scalar); + return (uint32_t)mul_scalar; +} + diff --git a/src/query/storages/common/index/cpp/export_macro.h b/src/query/storages/common/index/cpp/export_macro.h new file mode 100644 index 0000000000000..9432dd0c86bc3 --- /dev/null +++ b/src/query/storages/common/index/cpp/export_macro.h @@ -0,0 +1,23 @@ +/* + * Copyright Qdrant + * Copyright 2021 Datafuse Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if defined(_MSC_VER) +#define EXPORT __declspec(dllexport) +#else +#define EXPORT __attribute__((visibility("default"))) +#endif + diff --git a/src/query/storages/common/index/cpp/neon.c b/src/query/storages/common/index/cpp/neon.c new file mode 100644 index 0000000000000..be04df22bc456 --- /dev/null +++ b/src/query/storages/common/index/cpp/neon.c @@ -0,0 +1,478 @@ +/* + * Copyright Qdrant + * Copyright 2021 Datafuse Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "export_macro.h" + +EXPORT float impl_score_dot_neon( + const uint8_t* query_ptr, + const uint8_t* vector_ptr, + uint32_t dim +) { + uint32x4_t mul1 = vdupq_n_u32(0); + uint32x4_t mul2 = vdupq_n_u32(0); + for (uint32_t _i = 0; _i < dim / 16; _i++) { + uint8x16_t q = vld1q_u8(query_ptr); + uint8x16_t v = vld1q_u8(vector_ptr); + query_ptr += 16; + vector_ptr += 16; + uint16x8_t mul_low = vmull_u8(vget_low_u8(q), vget_low_u8(v)); + uint16x8_t mul_high = vmull_u8(vget_high_u8(q), vget_high_u8(v)); + mul1 = vpadalq_u16(mul1, mul_low); + mul2 = vpadalq_u16(mul2, mul_high); + } + return (float)vaddvq_u32(vaddq_u32(mul1, mul2)); +} + +EXPORT uint32_t impl_xor_popcnt_neon_uint128( + const uint8_t* query_ptr, + const uint8_t* vector_ptr, + uint32_t count +) { + uint32x4_t result = vdupq_n_u32(0); + for (uint32_t _i = 0; _i < count; _i++) { + uint8x16_t v = vld1q_u8(vector_ptr); + uint8x16_t q = vld1q_u8(query_ptr); + + uint8x16_t x = veorq_u8(q, v); + uint8x16_t popcnt = vcntq_u8(x); + uint8x8_t popcnt_low = vget_low_u8(popcnt); + uint8x8_t popcnt_high = vget_high_u8(popcnt); + uint16x8_t sum = vaddl_u8(popcnt_low, popcnt_high); + result = vpadalq_u16(result, sum); + + query_ptr += 16; + vector_ptr += 16; + } + return (uint32_t)vaddvq_u32(result); +} + +EXPORT uint32_t impl_xor_popcnt_neon_uint64( + const uint8_t* query_ptr, + const uint8_t* vector_ptr, + uint32_t count +) { + uint16x4_t result = vdup_n_u16(0); + for (uint32_t _i = 0; _i < count; _i++) { + uint8x8_t v = vld1_u8(vector_ptr); + uint8x8_t q = vld1_u8(query_ptr); + + uint8x8_t x = veor_u8(q, v); + uint8x8_t popcnt = vcnt_u8(x); + result = vpadal_u8(result, popcnt); + + query_ptr += 8; + vector_ptr += 8; + } + return (uint32_t)vaddv_u16(result); +} + +EXPORT uint32_t impl_xor_popcnt_scalar8_neon_uint128( + const uint8_t* q_ptr, + const uint8_t* v_ptr, + uint32_t count +) { + uint16x8_t result1 = vdupq_n_u16(0); + uint16x8_t result2 = vdupq_n_u16(0); + uint16x8_t result3 = vdupq_n_u16(0); + uint16x8_t result4 = vdupq_n_u16(0); + uint16x8_t result5 = vdupq_n_u16(0); + uint16x8_t result6 = vdupq_n_u16(0); + uint16x8_t result7 = vdupq_n_u16(0); + uint16x8_t result8 = vdupq_n_u16(0); + for (uint32_t _i = 0; _i < count; _i++) { + uint8x16_t v = vld1q_u8(v_ptr); + uint8x16_t x1 = veorq_u8(vld1q_u8(q_ptr + 0), v); + uint8x16_t x2 = veorq_u8(vld1q_u8(q_ptr + 16), v); + uint8x16_t x3 = veorq_u8(vld1q_u8(q_ptr + 32), v); + uint8x16_t x4 = veorq_u8(vld1q_u8(q_ptr + 48), v); + uint8x16_t x5 = veorq_u8(vld1q_u8(q_ptr + 64), v); + uint8x16_t x6 = veorq_u8(vld1q_u8(q_ptr + 80), v); + uint8x16_t x7 = veorq_u8(vld1q_u8(q_ptr + 96), v); + uint8x16_t x8 = veorq_u8(vld1q_u8(q_ptr + 112), v); + + result1 = vpadalq_u8(result1, vcntq_u8(x1)); + result2 = vpadalq_u8(result2, vcntq_u8(x2)); + result3 = vpadalq_u8(result3, vcntq_u8(x3)); + result4 = vpadalq_u8(result4, vcntq_u8(x4)); + result5 = vpadalq_u8(result5, vcntq_u8(x5)); + result6 = vpadalq_u8(result6, vcntq_u8(x6)); + result7 = vpadalq_u8(result7, vcntq_u8(x7)); + result8 = vpadalq_u8(result8, vcntq_u8(x8)); + + v_ptr += 16; + q_ptr += 128; + } + + uint32_t r1 = vaddvq_u16(result1); + uint32_t r2 = vaddvq_u16(result2); + uint32_t r3 = vaddvq_u16(result3); + uint32_t r4 = vaddvq_u16(result4); + uint32_t r5 = vaddvq_u16(result5); + uint32_t r6 = vaddvq_u16(result6); + uint32_t r7 = vaddvq_u16(result7); + uint32_t r8 = vaddvq_u16(result8); + + return r1 + (r2 << 1) + (r3 << 2) + (r4 << 3) + (r5 << 4) + (r6 << 5) + (r7 << 6) + (r8 << 7); +} + +EXPORT uint32_t impl_xor_popcnt_scalar4_neon_uint128( + const uint8_t* q_ptr, + const uint8_t* v_ptr, + uint32_t count +) { + uint16x8_t result1 = vdupq_n_u16(0); + uint16x8_t result2 = vdupq_n_u16(0); + uint16x8_t result3 = vdupq_n_u16(0); + uint16x8_t result4 = vdupq_n_u16(0); + for (uint32_t _i = 0; _i < count; _i++) { + uint8x16_t v = vld1q_u8(v_ptr); + uint8x16_t x1 = veorq_u8(vld1q_u8(q_ptr + 0), v); + uint8x16_t x2 = veorq_u8(vld1q_u8(q_ptr + 16), v); + uint8x16_t x3 = veorq_u8(vld1q_u8(q_ptr + 32), v); + uint8x16_t x4 = veorq_u8(vld1q_u8(q_ptr + 48), v); + + result1 = vpadalq_u8(result1, vcntq_u8(x1)); + result2 = vpadalq_u8(result2, vcntq_u8(x2)); + result3 = vpadalq_u8(result3, vcntq_u8(x3)); + result4 = vpadalq_u8(result4, vcntq_u8(x4)); + + v_ptr += 16; + q_ptr += 64; + } + + uint32_t r1 = vaddvq_u16(result1); + uint32_t r2 = vaddvq_u16(result2); + uint32_t r3 = vaddvq_u16(result3); + uint32_t r4 = vaddvq_u16(result4); + return r1 + (r2 << 1) + (r3 << 2) + (r4 << 3); +} + +EXPORT uint32_t impl_xor_popcnt_scalar8_neon_u8( + const uint8_t* q_ptr, + const uint8_t* v_ptr, + uint32_t count +) { + uint16x4_t result1 = vdup_n_u16(0); + uint16x4_t result2 = vdup_n_u16(0); + uint16x4_t result3 = vdup_n_u16(0); + uint16x4_t result4 = vdup_n_u16(0); + uint16x4_t result5 = vdup_n_u16(0); + uint16x4_t result6 = vdup_n_u16(0); + uint16x4_t result7 = vdup_n_u16(0); + uint16x4_t result8 = vdup_n_u16(0); + for (uint32_t _i = 0; _i < count / 8; _i++) { + uint8x8_t v = vld1_u8(v_ptr); + + uint8_t values1[8] = { + *(q_ptr + 0), + *(q_ptr + 8), + *(q_ptr + 16), + *(q_ptr + 24), + *(q_ptr + 32), + *(q_ptr + 40), + *(q_ptr + 48), + *(q_ptr + 56), + }; + uint8x8_t x1 = veor_u8(vld1_u8(values1), v); + + uint8_t values2[8] = { + *(q_ptr + 1), + *(q_ptr + 9), + *(q_ptr + 17), + *(q_ptr + 25), + *(q_ptr + 33), + *(q_ptr + 41), + *(q_ptr + 49), + *(q_ptr + 57), + }; + uint8x8_t x2 = veor_u8(vld1_u8(values2), v); + + uint8_t values3[8] = { + *(q_ptr + 2), + *(q_ptr + 10), + *(q_ptr + 18), + *(q_ptr + 26), + *(q_ptr + 34), + *(q_ptr + 42), + *(q_ptr + 50), + *(q_ptr + 58), + }; + uint8x8_t x3 = veor_u8(vld1_u8(values3), v); + + uint8_t values4[8] = { + *(q_ptr + 3), + *(q_ptr + 11), + *(q_ptr + 19), + *(q_ptr + 27), + *(q_ptr + 35), + *(q_ptr + 43), + *(q_ptr + 51), + *(q_ptr + 59), + }; + uint8x8_t x4 = veor_u8(vld1_u8(values4), v); + + uint8_t values5[8] = { + *(q_ptr + 4), + *(q_ptr + 12), + *(q_ptr + 20), + *(q_ptr + 28), + *(q_ptr + 36), + *(q_ptr + 44), + *(q_ptr + 52), + *(q_ptr + 60), + }; + uint8x8_t x5 = veor_u8(vld1_u8(values5), v); + + uint8_t values6[8] = { + *(q_ptr + 5), + *(q_ptr + 13), + *(q_ptr + 21), + *(q_ptr + 29), + *(q_ptr + 37), + *(q_ptr + 45), + *(q_ptr + 53), + *(q_ptr + 61), + }; + uint8x8_t x6 = veor_u8(vld1_u8(values6), v); + + uint8_t values7[8] = { + *(q_ptr + 6), + *(q_ptr + 14), + *(q_ptr + 22), + *(q_ptr + 30), + *(q_ptr + 38), + *(q_ptr + 46), + *(q_ptr + 54), + *(q_ptr + 62), + }; + uint8x8_t x7 = veor_u8(vld1_u8(values7), v); + + uint8_t values8[8] = { + *(q_ptr + 7), + *(q_ptr + 15), + *(q_ptr + 23), + *(q_ptr + 31), + *(q_ptr + 39), + *(q_ptr + 47), + *(q_ptr + 55), + *(q_ptr + 63), + }; + uint8x8_t x8 = veor_u8(vld1_u8(values8), v); + + result1 = vpadal_u8(result1, vcnt_u8(x1)); + result2 = vpadal_u8(result2, vcnt_u8(x2)); + result3 = vpadal_u8(result3, vcnt_u8(x3)); + result4 = vpadal_u8(result4, vcnt_u8(x4)); + result5 = vpadal_u8(result5, vcnt_u8(x5)); + result6 = vpadal_u8(result6, vcnt_u8(x6)); + result7 = vpadal_u8(result7, vcnt_u8(x7)); + result8 = vpadal_u8(result8, vcnt_u8(x8)); + + v_ptr += 8; + q_ptr += 64; + } + + uint32_t dr1 = 0; + uint32_t dr2 = 0; + uint32_t dr3 = 0; + uint32_t dr4 = 0; + uint32_t dr5 = 0; + uint32_t dr6 = 0; + uint32_t dr7 = 0; + uint32_t dr8 = 0; + for (uint32_t _i = count % 8; _i > 0; _i--) { + uint8_t v = *(v_ptr++); + uint8_t q1 = *(q_ptr++); + uint8_t q2 = *(q_ptr++); + uint8_t q3 = *(q_ptr++); + uint8_t q4 = *(q_ptr++); + uint8_t q5 = *(q_ptr++); + uint8_t q6 = *(q_ptr++); + uint8_t q7 = *(q_ptr++); + uint8_t q8 = *(q_ptr++); + + uint8_t x1 = v ^ q1; + uint8_t x2 = v ^ q2; + uint8_t x3 = v ^ q3; + uint8_t x4 = v ^ q4; + uint8_t x5 = v ^ q5; + uint8_t x6 = v ^ q6; + uint8_t x7 = v ^ q7; + uint8_t x8 = v ^ q8; + + dr1 += __builtin_popcount(x1); + dr2 += __builtin_popcount(x2); + dr3 += __builtin_popcount(x3); + dr4 += __builtin_popcount(x4); + dr5 += __builtin_popcount(x5); + dr6 += __builtin_popcount(x6); + dr7 += __builtin_popcount(x7); + dr8 += __builtin_popcount(x8); + } + + uint32_t r1 = vaddv_u16(result1) + dr1; + uint32_t r2 = vaddv_u16(result2) + dr2; + uint32_t r3 = vaddv_u16(result3) + dr3; + uint32_t r4 = vaddv_u16(result4) + dr4; + uint32_t r5 = vaddv_u16(result5) + dr5; + uint32_t r6 = vaddv_u16(result6) + dr6; + uint32_t r7 = vaddv_u16(result7) + dr7; + uint32_t r8 = vaddv_u16(result8) + dr8; + + return r1 + (r2 << 1) + (r3 << 2) + (r4 << 3) + (r5 << 4) + (r6 << 5) + (r7 << 6) + (r8 << 7); +} + +EXPORT uint32_t impl_xor_popcnt_scalar4_neon_u8( + const uint8_t* q_ptr, + const uint8_t* v_ptr, + uint32_t count +) { + uint16x4_t result1 = vdup_n_u16(0); + uint16x4_t result2 = vdup_n_u16(0); + uint16x4_t result3 = vdup_n_u16(0); + uint16x4_t result4 = vdup_n_u16(0); + for (uint32_t _i = 0; _i < count / 8; _i++) { + uint8x8_t v = vld1_u8(v_ptr); + + uint8_t values1[8] = { + *(q_ptr + 0), + *(q_ptr + 4), + *(q_ptr + 8), + *(q_ptr + 12), + *(q_ptr + 16), + *(q_ptr + 20), + *(q_ptr + 24), + *(q_ptr + 28), + }; + uint8x8_t x1 = veor_u8(vld1_u8(values1), v); + + uint8_t values2[8] = { + *(q_ptr + 1), + *(q_ptr + 5), + *(q_ptr + 9), + *(q_ptr + 13), + *(q_ptr + 17), + *(q_ptr + 21), + *(q_ptr + 25), + *(q_ptr + 29), + }; + uint8x8_t x2 = veor_u8(vld1_u8(values2), v); + + uint8_t values3[8] = { + *(q_ptr + 2), + *(q_ptr + 6), + *(q_ptr + 10), + *(q_ptr + 14), + *(q_ptr + 18), + *(q_ptr + 22), + *(q_ptr + 26), + *(q_ptr + 30), + }; + uint8x8_t x3 = veor_u8(vld1_u8(values3), v); + + uint8_t values4[8] = { + *(q_ptr + 3), + *(q_ptr + 7), + *(q_ptr + 11), + *(q_ptr + 15), + *(q_ptr + 19), + *(q_ptr + 23), + *(q_ptr + 27), + *(q_ptr + 31), + }; + uint8x8_t x4 = veor_u8(vld1_u8(values4), v); + + result1 = vpadal_u8(result1, vcnt_u8(x1)); + result2 = vpadal_u8(result2, vcnt_u8(x2)); + result3 = vpadal_u8(result3, vcnt_u8(x3)); + result4 = vpadal_u8(result4, vcnt_u8(x4)); + + v_ptr += 8; + q_ptr += 32; + } + + uint32_t dr1 = 0; + uint32_t dr2 = 0; + uint32_t dr3 = 0; + uint32_t dr4 = 0; + for (uint32_t _i = count % 8; _i > 0; _i--) { + uint8_t v = *(v_ptr++); + uint8_t q1 = *(q_ptr++); + uint8_t q2 = *(q_ptr++); + uint8_t q3 = *(q_ptr++); + uint8_t q4 = *(q_ptr++); + + uint8_t x1 = v ^ q1; + uint8_t x2 = v ^ q2; + uint8_t x3 = v ^ q3; + uint8_t x4 = v ^ q4; + dr1 += __builtin_popcount(x1); + dr2 += __builtin_popcount(x2); + dr3 += __builtin_popcount(x3); + dr4 += __builtin_popcount(x4); + } + + uint32_t r1 = vaddv_u16(result1) + dr1; + uint32_t r2 = vaddv_u16(result2) + dr2; + uint32_t r3 = vaddv_u16(result3) + dr3; + uint32_t r4 = vaddv_u16(result4) + dr4; + return r1 + (r2 << 1) + (r3 << 2) + (r4 << 3); +} + +EXPORT float impl_score_l1_neon( + const uint8_t * query_ptr, + const uint8_t * vector_ptr, + uint32_t dim +) { + const uint8_t* v_ptr = (const uint8_t*)vector_ptr; + const uint8_t* q_ptr = (const uint8_t*)query_ptr; + + uint32_t m = dim - (dim % 16); + uint16x8_t sum16_low = vdupq_n_u16(0); + uint16x8_t sum16_high = vdupq_n_u16(0); + + // the vector sizes are assumed to be multiples of 16, no remaining part here + for (uint32_t i = 0; i < m; i += 16) { + uint8x16_t vec1 = vld1q_u8(v_ptr); + uint8x16_t vec2 = vld1q_u8(q_ptr); + + uint8x16_t abs_diff = vabdq_u8(vec1, vec2); + uint16x8_t abs_diff16_low = vmovl_u8(vget_low_u8(abs_diff)); + uint16x8_t abs_diff16_high = vmovl_u8(vget_high_u8(abs_diff)); + + sum16_low = vaddq_u16(sum16_low, abs_diff16_low); + sum16_high = vaddq_u16(sum16_high, abs_diff16_high); + + v_ptr += 16; + q_ptr += 16; + } + + // Horizontal sum of 16-bit integers + uint32x4_t sum32_low = vpaddlq_u16(sum16_low); + uint32x4_t sum32_high = vpaddlq_u16(sum16_high); + uint32x4_t sum32 = vaddq_u32(sum32_low, sum32_high); + + uint32x2_t sum64_low = vadd_u32(vget_low_u32(sum32), vget_high_u32(sum32)); + uint32x2_t sum64_high = vpadd_u32(sum64_low, sum64_low); + uint32_t sum = vget_lane_u32(sum64_high, 0); + + return (float) sum; +} + diff --git a/src/query/storages/common/index/cpp/sse.c b/src/query/storages/common/index/cpp/sse.c new file mode 100644 index 0000000000000..daf8c03241222 --- /dev/null +++ b/src/query/storages/common/index/cpp/sse.c @@ -0,0 +1,531 @@ +/* + * Copyright Qdrant + * Copyright 2021 Datafuse Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "export_macro.h" + +#ifdef _MSC_VER +#include +#define __builtin_popcount __popcnt +#endif + +#define HSUM128_PS(X, R) \ + float R = 0.0f; \ + { \ + __m128 x64 = _mm_add_ps(X, _mm_movehl_ps(X, X)); \ + __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); \ + R = _mm_cvtss_f32(x32); \ + } + +#define HSUM128_EPI16(X, R) \ + int R = 0; \ + { \ + __m128i x64 = _mm_add_epi16(X, _mm_srli_si128(X, 8)); \ + __m128i x32 = _mm_add_epi16(x64, _mm_srli_si128(x64, 4)); \ + R = _mm_extract_epi16(x32, 0) + _mm_extract_epi16(x32, 1); \ + } + +EXPORT float impl_score_dot_sse( + const uint8_t* query_ptr, + const uint8_t* vector_ptr, + uint32_t dim +) { + const __m128i* v_ptr = (const __m128i*)vector_ptr; + const __m128i* q_ptr = (const __m128i*)query_ptr; + + __m128i mul = _mm_setzero_si128(); + for (uint32_t _i = 0; _i < dim / 16; _i++) { + __m128i v = _mm_loadu_si128(v_ptr); + __m128i q = _mm_loadu_si128(q_ptr); + v_ptr++; + q_ptr++; + + __m128i s = _mm_maddubs_epi16(v, q); + __m128i s_low = _mm_cvtepi16_epi32(s); + __m128i s_high = _mm_cvtepi16_epi32(_mm_srli_si128(s, 8)); + mul = _mm_add_epi32(mul, s_low); + mul = _mm_add_epi32(mul, s_high); + } + __m128 mul_ps = _mm_cvtepi32_ps(mul); + HSUM128_PS(mul_ps, mul_scalar); + return mul_scalar; +} + +EXPORT uint32_t impl_xor_popcnt_sse_uint128( + const uint8_t* query_ptr, + const uint8_t* vector_ptr, + uint32_t count +) { + int64_t result = 0; + for (uint32_t _i = 0; _i < count; _i++) { + const uint64_t* v_ptr_1 = (const uint64_t*)vector_ptr; + const uint64_t* q_ptr_1 = (const uint64_t*)query_ptr; + uint64_t x_1 = (*v_ptr_1) ^ (*q_ptr_1); + result += _mm_popcnt_u64(x_1); + + const uint64_t* v_ptr_2 = v_ptr_1 + 1; + const uint64_t* q_ptr_2 = q_ptr_1 + 1; + uint64_t x_2 = (*v_ptr_2) ^ (*q_ptr_2); + result += _mm_popcnt_u64(x_2); + + vector_ptr += 16; + query_ptr += 16; + } + return (uint32_t)result; +} + +EXPORT uint32_t impl_xor_popcnt_sse_uint64( + const uint8_t* query_ptr, + const uint8_t* vector_ptr, + uint32_t count +) { + int64_t result = 0; + for (uint32_t _i = 0; _i < count; _i++) { + const uint64_t* v_ptr = (const uint64_t*)vector_ptr; + const uint64_t* q_ptr = (const uint64_t*)query_ptr; + uint64_t x = (*v_ptr) ^ (*q_ptr); + result += _mm_popcnt_u64(x); + + vector_ptr += 8; + query_ptr += 8; + } + return (uint32_t)result; +} + +EXPORT uint32_t impl_xor_popcnt_sse_uint32( + const uint8_t* query_ptr, + const uint8_t* vector_ptr, + uint32_t count +) { + int result = 0; + for (uint32_t _i = 0; _i < count; _i++) { + const uint32_t* v_ptr = (const uint32_t*)vector_ptr; + const uint32_t* q_ptr = (const uint32_t*)query_ptr; + uint32_t x = (*v_ptr) ^ (*q_ptr); + result += _mm_popcnt_u32(x); + + vector_ptr += 4; + query_ptr += 4; + } + return (uint32_t)result; +} + +EXPORT uint32_t impl_xor_popcnt_scalar8_sse_uint128( + const uint8_t* query_ptr, + const uint8_t* vector_ptr, + uint32_t count +) { + const uint64_t* v_ptr = (const uint64_t*)vector_ptr; + const uint64_t* q_ptr = (const uint64_t*)query_ptr; + + __m128i sum1 = _mm_set1_epi32(0); + __m128i sum2 = _mm_set1_epi32(0); + for (uint32_t _i = 0; _i < count; _i++) { + uint64_t v_1 = *v_ptr; + uint64_t v_2 = *(v_ptr + 1); + + __m128i popcnt1 = _mm_set_epi32( + _mm_popcnt_u64(v_1 ^ *(q_ptr + 0)), + _mm_popcnt_u64(v_1 ^ *(q_ptr + 2)), + _mm_popcnt_u64(v_1 ^ *(q_ptr + 4)), + _mm_popcnt_u64(v_1 ^ *(q_ptr + 6)) + ); + sum1 = _mm_add_epi32(sum1, popcnt1); + + __m128i popcnt2 = _mm_set_epi32( + _mm_popcnt_u64(v_1 ^ *(q_ptr + 8)), + _mm_popcnt_u64(v_1 ^ *(q_ptr + 10)), + _mm_popcnt_u64(v_1 ^ *(q_ptr + 12)), + _mm_popcnt_u64(v_1 ^ *(q_ptr + 14)) + ); + sum2 = _mm_add_epi32(sum2, popcnt2); + + __m128i popcnt3 = _mm_set_epi32( + _mm_popcnt_u64(v_2 ^ *(q_ptr + 1)), + _mm_popcnt_u64(v_2 ^ *(q_ptr + 3)), + _mm_popcnt_u64(v_2 ^ *(q_ptr + 5)), + _mm_popcnt_u64(v_2 ^ *(q_ptr + 7)) + ); + sum1 = _mm_add_epi32(sum1, popcnt3); + + __m128i popcnt4 = _mm_set_epi32( + _mm_popcnt_u64(v_2 ^ *(q_ptr + 9)), + _mm_popcnt_u64(v_2 ^ *(q_ptr + 11)), + _mm_popcnt_u64(v_2 ^ *(q_ptr + 13)), + _mm_popcnt_u64(v_2 ^ *(q_ptr + 15)) + ); + sum2 = _mm_add_epi32(sum2, popcnt4); + + v_ptr += 2; + q_ptr += 16; + } + __m128i factor1 = _mm_set_epi32(1, 2, 4, 8); + __m128i factor2 = _mm_set_epi32(16, 32, 64, 128); + __m128 result1_mm128 = _mm_cvtepi32_ps(_mm_mullo_epi32(sum1, factor1)); + __m128 result2_mm128 = _mm_cvtepi32_ps(_mm_mullo_epi32(sum2, factor2)); + HSUM128_PS(_mm_add_ps(result1_mm128, result2_mm128), mul_scalar); + return (uint32_t)mul_scalar; +} + +EXPORT uint32_t impl_xor_popcnt_scalar4_sse_uint128( + const uint8_t* query_ptr, + const uint8_t* vector_ptr, + uint32_t count +) { + const uint64_t* v_ptr = (const uint64_t*)vector_ptr; + const uint64_t* q_ptr = (const uint64_t*)query_ptr; + + __m128i sum = _mm_set1_epi32(0); + for (uint32_t _i = 0; _i < count; _i++) { + uint64_t v_1 = *v_ptr; + uint64_t v_2 = *(v_ptr + 1); + + __m128i popcnt1 = _mm_set_epi32( + _mm_popcnt_u64(v_1 ^ *(q_ptr + 0)), + _mm_popcnt_u64(v_1 ^ *(q_ptr + 2)), + _mm_popcnt_u64(v_1 ^ *(q_ptr + 4)), + _mm_popcnt_u64(v_1 ^ *(q_ptr + 6)) + ); + sum = _mm_add_epi32(sum, popcnt1); + + __m128i popcnt2 = _mm_set_epi32( + _mm_popcnt_u64(v_2 ^ *(q_ptr + 1)), + _mm_popcnt_u64(v_2 ^ *(q_ptr + 3)), + _mm_popcnt_u64(v_2 ^ *(q_ptr + 5)), + _mm_popcnt_u64(v_2 ^ *(q_ptr + 7)) + ); + sum = _mm_add_epi32(sum, popcnt2); + + v_ptr += 2; + q_ptr += 8; + } + __m128i factor = _mm_set_epi32(1, 2, 4, 8); + __m128 result_mm128 = _mm_cvtepi32_ps(_mm_mullo_epi32(sum, factor)); + HSUM128_PS(result_mm128, mul_scalar); + return (uint32_t)mul_scalar; +} + +EXPORT uint32_t impl_xor_popcnt_scalar8_sse_u8( + const uint8_t* q_ptr, + const uint8_t* v_ptr, + uint32_t count +) { + __m128i sum1 = _mm_set1_epi32(0); + __m128i sum2 = _mm_set1_epi32(0); + for (uint32_t _i = 0; _i < count / 8; _i++) { + uint64_t v = *((const uint64_t*)v_ptr); + + uint8_t values1[8] = { + *(q_ptr + 0), + *(q_ptr + 8), + *(q_ptr + 16), + *(q_ptr + 24), + *(q_ptr + 32), + *(q_ptr + 40), + *(q_ptr + 48), + *(q_ptr + 56), + }; + uint8_t values2[8] = { + *(q_ptr + 1), + *(q_ptr + 9), + *(q_ptr + 17), + *(q_ptr + 25), + *(q_ptr + 33), + *(q_ptr + 41), + *(q_ptr + 49), + *(q_ptr + 57), + }; + uint8_t values3[8] = { + *(q_ptr + 2), + *(q_ptr + 10), + *(q_ptr + 18), + *(q_ptr + 26), + *(q_ptr + 34), + *(q_ptr + 42), + *(q_ptr + 50), + *(q_ptr + 58), + }; + uint8_t values4[8] = { + *(q_ptr + 3), + *(q_ptr + 11), + *(q_ptr + 19), + *(q_ptr + 27), + *(q_ptr + 35), + *(q_ptr + 43), + *(q_ptr + 51), + *(q_ptr + 59), + }; + uint8_t values5[8] = { + *(q_ptr + 4), + *(q_ptr + 12), + *(q_ptr + 20), + *(q_ptr + 28), + *(q_ptr + 36), + *(q_ptr + 44), + *(q_ptr + 52), + *(q_ptr + 60), + }; + uint8_t values6[8] = { + *(q_ptr + 5), + *(q_ptr + 13), + *(q_ptr + 21), + *(q_ptr + 29), + *(q_ptr + 37), + *(q_ptr + 45), + *(q_ptr + 53), + *(q_ptr + 61), + }; + uint8_t values7[8] = { + *(q_ptr + 6), + *(q_ptr + 14), + *(q_ptr + 22), + *(q_ptr + 30), + *(q_ptr + 38), + *(q_ptr + 46), + *(q_ptr + 54), + *(q_ptr + 62), + }; + uint8_t values8[8] = { + *(q_ptr + 7), + *(q_ptr + 15), + *(q_ptr + 23), + *(q_ptr + 31), + *(q_ptr + 39), + *(q_ptr + 47), + *(q_ptr + 55), + *(q_ptr + 63), + }; + + uint64_t values1_u64 = *((const uint64_t*)values1); + uint64_t values2_u64 = *((const uint64_t*)values2); + uint64_t values3_u64 = *((const uint64_t*)values3); + uint64_t values4_u64 = *((const uint64_t*)values4); + uint64_t values5_u64 = *((const uint64_t*)values5); + uint64_t values6_u64 = *((const uint64_t*)values6); + uint64_t values7_u64 = *((const uint64_t*)values7); + uint64_t values8_u64 = *((const uint64_t*)values8); + + __m128i popcnt1 = _mm_set_epi32( + _mm_popcnt_u64(v ^ values1_u64), + _mm_popcnt_u64(v ^ values2_u64), + _mm_popcnt_u64(v ^ values3_u64), + _mm_popcnt_u64(v ^ values4_u64) + ); + sum1 = _mm_add_epi32(sum1, popcnt1); + + __m128i popcnt2 = _mm_set_epi32( + _mm_popcnt_u64(v ^ values5_u64), + _mm_popcnt_u64(v ^ values6_u64), + _mm_popcnt_u64(v ^ values7_u64), + _mm_popcnt_u64(v ^ values8_u64) + ); + sum2 = _mm_add_epi32(sum2, popcnt2); + + v_ptr += 8; + q_ptr += 64; + } + + uint32_t dr1 = 0; + uint32_t dr2 = 0; + uint32_t dr3 = 0; + uint32_t dr4 = 0; + uint32_t dr5 = 0; + uint32_t dr6 = 0; + uint32_t dr7 = 0; + uint32_t dr8 = 0; + for (uint32_t _i = count % 8; _i > 0; _i--) { + uint8_t v = *(v_ptr++); + uint8_t q1 = *(q_ptr++); + uint8_t q2 = *(q_ptr++); + uint8_t q3 = *(q_ptr++); + uint8_t q4 = *(q_ptr++); + uint8_t q5 = *(q_ptr++); + uint8_t q6 = *(q_ptr++); + uint8_t q7 = *(q_ptr++); + uint8_t q8 = *(q_ptr++); + + uint8_t x1 = v ^ q1; + uint8_t x2 = v ^ q2; + uint8_t x3 = v ^ q3; + uint8_t x4 = v ^ q4; + uint8_t x5 = v ^ q5; + uint8_t x6 = v ^ q6; + uint8_t x7 = v ^ q7; + uint8_t x8 = v ^ q8; + + dr1 += __builtin_popcount(x1); + dr2 += __builtin_popcount(x2); + dr3 += __builtin_popcount(x3); + dr4 += __builtin_popcount(x4); + dr5 += __builtin_popcount(x5); + dr6 += __builtin_popcount(x6); + dr7 += __builtin_popcount(x7); + dr8 += __builtin_popcount(x8); + } + + __m128i factor1 = _mm_set_epi32(1, 2, 4, 8); + __m128i factor2 = _mm_set_epi32(16, 32, 64, 128); + __m128 result_mm128 = _mm_cvtepi32_ps( + _mm_add_epi32( + _mm_mullo_epi32(sum1, factor1), + _mm_mullo_epi32(sum2, factor2) + ) + ); + HSUM128_PS(result_mm128, mul_scalar); + return (uint32_t)mul_scalar + dr1 + (dr2 << 1) + (dr3 << 2) + (dr4 << 3) + (dr5 << 4) + (dr6 << 5) + (dr7 << 6) + (dr8 << 7); +} + +EXPORT uint32_t impl_xor_popcnt_scalar4_sse_u8( + const uint8_t* q_ptr, + const uint8_t* v_ptr, + uint32_t count +) { + __m128i sum = _mm_set1_epi32(0); + for (uint32_t _i = 0; _i < count / 8; _i++) { + uint64_t v = *((const uint64_t*)v_ptr); + + uint8_t values1[8] = { + *(q_ptr + 0), + *(q_ptr + 4), + *(q_ptr + 8), + *(q_ptr + 12), + *(q_ptr + 16), + *(q_ptr + 20), + *(q_ptr + 24), + *(q_ptr + 28), + }; + uint8_t values2[8] = { + *(q_ptr + 1), + *(q_ptr + 5), + *(q_ptr + 9), + *(q_ptr + 13), + *(q_ptr + 17), + *(q_ptr + 21), + *(q_ptr + 25), + *(q_ptr + 29), + }; + uint8_t values3[8] = { + *(q_ptr + 2), + *(q_ptr + 6), + *(q_ptr + 10), + *(q_ptr + 14), + *(q_ptr + 18), + *(q_ptr + 22), + *(q_ptr + 26), + *(q_ptr + 30), + }; + uint8_t values4[8] = { + *(q_ptr + 3), + *(q_ptr + 7), + *(q_ptr + 11), + *(q_ptr + 15), + *(q_ptr + 19), + *(q_ptr + 23), + *(q_ptr + 27), + *(q_ptr + 31), + }; + + uint64_t values1_u64 = *((const uint64_t*)values1); + uint64_t values2_u64 = *((const uint64_t*)values2); + uint64_t values3_u64 = *((const uint64_t*)values3); + uint64_t values4_u64 = *((const uint64_t*)values4); + + __m128i popcnt = _mm_set_epi32( + _mm_popcnt_u64(v ^ values1_u64), + _mm_popcnt_u64(v ^ values2_u64), + _mm_popcnt_u64(v ^ values3_u64), + _mm_popcnt_u64(v ^ values4_u64) + ); + sum = _mm_add_epi32(sum, popcnt); + + v_ptr += 8; + q_ptr += 32; + } + + uint32_t dr1 = 0; + uint32_t dr2 = 0; + uint32_t dr3 = 0; + uint32_t dr4 = 0; + for (uint32_t _i = count % 8; _i > 0; _i--) { + uint8_t v = *(v_ptr++); + uint8_t q1 = *(q_ptr++); + uint8_t q2 = *(q_ptr++); + uint8_t q3 = *(q_ptr++); + uint8_t q4 = *(q_ptr++); + + uint8_t x1 = v ^ q1; + uint8_t x2 = v ^ q2; + uint8_t x3 = v ^ q3; + uint8_t x4 = v ^ q4; + dr1 += __builtin_popcount(x1); + dr2 += __builtin_popcount(x2); + dr3 += __builtin_popcount(x3); + dr4 += __builtin_popcount(x4); + } + + __m128i factor = _mm_set_epi32(1, 2, 4, 8); + __m128 result_mm128 = _mm_cvtepi32_ps(_mm_mullo_epi32(sum, factor)); + HSUM128_PS(result_mm128, mul_scalar); + return (uint32_t)mul_scalar + dr1 + (dr2 << 1) + (dr3 << 2) + (dr4 << 3); +} + +EXPORT float impl_score_l1_sse( + const uint8_t* query_ptr, + const uint8_t* vector_ptr, + uint32_t dim +) { + const __m128i* v_ptr = (const __m128i*)vector_ptr; + const __m128i* q_ptr = (const __m128i*)query_ptr; + + uint32_t m = dim - (dim % 16); + __m128i sum128 = _mm_setzero_si128(); + + // the vector sizes are assumed to be multiples of 16, no remaining part here + for (uint32_t i = 0; i < m; i += 16) { + __m128i vec2 = _mm_loadu_si128(v_ptr); + __m128i vec1 = _mm_loadu_si128(q_ptr); + v_ptr++; + q_ptr++; + + // Compute the difference in both directions + __m128i diff1 = _mm_subs_epu8(vec1, vec2); + __m128i diff2 = _mm_subs_epu8(vec2, vec1); + + // Take the maximum + __m128i abs_diff = _mm_max_epu8(diff1, diff2); + + __m128i abs_diff16_low = _mm_unpacklo_epi8(abs_diff, _mm_setzero_si128()); + __m128i abs_diff16_high = _mm_unpackhi_epi8(abs_diff, _mm_setzero_si128()); + + sum128 = _mm_add_epi16(sum128, abs_diff16_low); + sum128 = _mm_add_epi16(sum128, abs_diff16_high); + } + + // Convert 16-bit sums to 32-bit and sum them up + __m128i sum_epi32 = _mm_add_epi32( + _mm_unpacklo_epi16(sum128, _mm_setzero_si128()), + _mm_unpackhi_epi16(sum128, _mm_setzero_si128())); + + // Horizontal sum using the macro + HSUM128_EPI16(sum_epi32, sum); + + return (float) sum; +} + diff --git a/src/query/storages/common/index/src/hnsw_index/quantization/encoded_vectors_u8.rs b/src/query/storages/common/index/src/hnsw_index/quantization/encoded_vectors_u8.rs index 652d689571462..29cc2da494c0c 100644 --- a/src/query/storages/common/index/src/hnsw_index/quantization/encoded_vectors_u8.rs +++ b/src/query/storages/common/index/src/hnsw_index/quantization/encoded_vectors_u8.rs @@ -163,6 +163,55 @@ impl EncodedVectorsU8 { pub fn score_point_simple(&self, query: &EncodedQueryU8, i: u32) -> f32 { let (vector_offset, v_ptr) = self.get_vec_ptr(i); + let q_ptr = query.encoded_query.as_ptr(); + #[cfg(target_arch = "x86_64")] + if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") { + unsafe { + let score = match self.metadata.vector_parameters.distance_type { + DistanceType::Dot | DistanceType::L2 => { + impl_score_dot_avx(q_ptr, v_ptr, self.metadata.actual_dim as u32) + } + DistanceType::L1 => { + impl_score_l1_avx(q_ptr, v_ptr, self.metadata.actual_dim as u32) + } + }; + + return self.metadata.multiplier * score as f32 + query.offset + vector_offset; + } + } + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + if is_x86_feature_detected!("sse4.1") { + unsafe { + let score = match self.metadata.vector_parameters.distance_type { + DistanceType::Dot | DistanceType::L2 => { + impl_score_dot_sse(q_ptr, v_ptr, self.metadata.actual_dim as u32) + } + DistanceType::L1 => { + impl_score_l1_sse(q_ptr, v_ptr, self.metadata.actual_dim as u32) + } + }; + + return self.metadata.multiplier * score as f32 + query.offset + vector_offset; + } + } + + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] + if std::arch::is_aarch64_feature_detected!("neon") { + unsafe { + let score = match self.metadata.vector_parameters.distance_type { + DistanceType::Dot | DistanceType::L2 => { + impl_score_dot_neon(q_ptr, v_ptr, self.metadata.actual_dim as u32) + } + DistanceType::L1 => { + impl_score_l1_neon(q_ptr, v_ptr, self.metadata.actual_dim as u32) + } + }; + + return self.metadata.multiplier * score as f32 + query.offset + vector_offset; + } + } + let score = match self.metadata.vector_parameters.distance_type { DistanceType::Dot | DistanceType::L2 => impl_score_dot( query.encoded_query.as_ptr(), @@ -362,3 +411,18 @@ fn impl_score_l1(q_ptr: *const u8, v_ptr: *const u8, actual_dim: usize) -> i32 { score } } + +#[cfg(target_arch = "x86_64")] +unsafe extern "C" { + fn impl_score_dot_avx(query_ptr: *const u8, vector_ptr: *const u8, dim: u32) -> f32; + fn impl_score_l1_avx(query_ptr: *const u8, vector_ptr: *const u8, dim: u32) -> f32; + + fn impl_score_dot_sse(query_ptr: *const u8, vector_ptr: *const u8, dim: u32) -> f32; + fn impl_score_l1_sse(query_ptr: *const u8, vector_ptr: *const u8, dim: u32) -> f32; +} + +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +unsafe extern "C" { + fn impl_score_dot_neon(query_ptr: *const u8, vector_ptr: *const u8, dim: u32) -> f32; + fn impl_score_l1_neon(query_ptr: *const u8, vector_ptr: *const u8, dim: u32) -> f32; +}