Skip to content

Commit e72e701

Browse files
committed
chore(query): Accelerate vector index quantization score calculation with SIMD
1 parent ac1fee9 commit e72e701

File tree

9 files changed

+1381
-0
lines changed

9 files changed

+1381
-0
lines changed

โ€ŽCargo.lockโ€Ž

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

โ€Žlicenserc.tomlโ€Ž

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ excludes = [
1010
"benchmark",
1111
"src/common/compress/tests",
1212
"src/meta/compat.py",
13+
"src/query/storages/common/index/cpp/"
1314
# licensed under Elastic License 2.0
1415
"src/binaries/query/ee_main.rs",
1516
"src/meta/binaries/meta/ee_main.rs",

โ€Žsrc/query/storages/common/index/Cargo.tomlโ€Ž

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ edition = { workspace = true }
99
[package.metadata.cargo-machete]
1010
ignored = ["xorfilter-rs", "match-template"]
1111

12+
[build-dependencies]
13+
cc = "1.0"
14+
1215
[dependencies]
1316
databend-common-ast = { workspace = true }
1417
databend-common-exception = { workspace = true }
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// Copyright Qdrant
2+
// Copyright 2021 Datafuse Labs
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
use std::env;
17+
18+
fn main() {
19+
println!("cargo:rerun-if-changed=cpp");
20+
let mut builder = cc::Build::new();
21+
22+
let target_arch = env::var("CARGO_CFG_TARGET_ARCH")
23+
.expect("CARGO_CFG_TARGET_ARCH env-var is not defined or is not UTF-8");
24+
25+
// TODO: Is `CARGO_CFG_TARGET_FEATURE` *always* defined?
26+
//
27+
// Cargo docs says that "boolean configurations are present if they are set,
28+
// and not present otherwise", so, what about "target features"?
29+
//
30+
// https://doc.rust-lang.org/cargo/reference/environment-variables.html (Ctrl-F CARGO_CFG_<cfg>)
31+
let target_feature = env::var("CARGO_CFG_TARGET_FEATURE")
32+
.expect("CARGO_CFG_TARGET_FEATURE env-var is not defined or is not UTF-8");
33+
34+
if target_arch == "x86_64" {
35+
builder.file("cpp/sse.c");
36+
builder.file("cpp/avx2.c");
37+
38+
if builder.get_compiler().is_like_msvc() {
39+
builder.flag("/arch:AVX");
40+
builder.flag("/arch:AVX2");
41+
builder.flag("/arch:SSE");
42+
builder.flag("/arch:SSE2");
43+
} else {
44+
builder.flag("-march=haswell");
45+
}
46+
47+
// O3 optimization level
48+
builder.flag("-O3");
49+
// Use popcnt instruction
50+
builder.flag("-mpopcnt");
51+
} else if target_arch == "aarch64" && target_feature.split(',').any(|feat| feat == "neon") {
52+
builder.file("cpp/neon.c");
53+
builder.flag("-O3");
54+
}
55+
56+
builder.compile("simd_utils");
57+
}
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
/*
2+
* Copyright Qdrant
3+
* Copyright 2021 Datafuse Labs
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
#include <stdlib.h>
19+
#include <stdint.h>
20+
#include <immintrin.h>
21+
22+
#include "export_macro.h"
23+
24+
#define HSUM256_PS(X, R) \
25+
float R = 0.0f; \
26+
{ \
27+
__m128 x128 = _mm_add_ps(_mm256_extractf128_ps(X, 1), _mm256_castps256_ps128(X)); \
28+
__m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); \
29+
__m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); \
30+
R = _mm_cvtss_f32(x32); \
31+
}
32+
33+
#define HSUM256_EPI32(X, R) \
34+
int R = 0; \
35+
{ \
36+
__m128i x128 = _mm_add_epi32(_mm256_extractf128_si256(X, 1), _mm256_castsi256_si128(X)); \
37+
__m128i x64 = _mm_add_epi32(x128, _mm_srli_si128(x128, 8)); \
38+
__m128i x32 = _mm_add_epi32(x64, _mm_srli_si128(x64, 4)); \
39+
R = _mm_cvtsi128_si32(x32); \
40+
}
41+
42+
EXPORT float impl_score_dot_avx(
43+
const uint8_t* query_ptr,
44+
const uint8_t* vector_ptr,
45+
uint32_t dim
46+
) {
47+
const __m256i* v_ptr = (const __m256i*)vector_ptr;
48+
const __m256i* q_ptr = (const __m256i*)query_ptr;
49+
50+
__m256i mul1 = _mm256_setzero_si256();
51+
__m256i mask_epu32 = _mm256_set1_epi32(0xFFFF);
52+
for (uint32_t _i = 0; _i < dim / 32; _i++) {
53+
__m256i v = _mm256_loadu_si256(v_ptr);
54+
__m256i q = _mm256_loadu_si256(q_ptr);
55+
v_ptr++;
56+
q_ptr++;
57+
58+
__m256i s = _mm256_maddubs_epi16(v, q);
59+
__m256i s_low = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(s));
60+
__m256i s_high = _mm256_cvtepi16_epi32(_mm256_extractf128_si256(s, 1));
61+
mul1 = _mm256_add_epi32(mul1, s_low);
62+
mul1 = _mm256_add_epi32(mul1, s_high);
63+
}
64+
65+
// the vector sizes are assumed to be multiples of 16, check if one last 16-element part remaining
66+
if (dim % 32 != 0) {
67+
__m128i v_short = _mm_loadu_si128((const __m128i*)v_ptr);
68+
__m128i q_short = _mm_loadu_si128((const __m128i*)q_ptr);
69+
70+
__m256i v1 = _mm256_cvtepu8_epi16(v_short);
71+
__m256i q1 = _mm256_cvtepu8_epi16(q_short);
72+
73+
__m256i s = _mm256_mullo_epi16(v1, q1);
74+
mul1 = _mm256_add_epi32(mul1, _mm256_and_si256(s, mask_epu32));
75+
mul1 = _mm256_add_epi32(mul1, _mm256_srli_epi32(s, 16));
76+
}
77+
__m256 mul_ps = _mm256_cvtepi32_ps(mul1);
78+
HSUM256_PS(mul_ps, mul_scalar);
79+
return mul_scalar;
80+
}
81+
82+
EXPORT float impl_score_l1_avx(
83+
const uint8_t* query_ptr,
84+
const uint8_t* vector_ptr,
85+
uint32_t dim
86+
) {
87+
const __m256i* v_ptr = (const __m256i*)vector_ptr;
88+
const __m256i* q_ptr = (const __m256i*)query_ptr;
89+
90+
uint32_t m = dim - (dim % 32);
91+
__m256i sum256 = _mm256_setzero_si256();
92+
93+
for (uint32_t i = 0; i < m; i += 32) {
94+
__m256i v = _mm256_loadu_si256(v_ptr);
95+
__m256i q = _mm256_loadu_si256(q_ptr);
96+
v_ptr++;
97+
q_ptr++;
98+
99+
// Compute the difference in both directions and take the maximum for abs
100+
__m256i diff1 = _mm256_subs_epu8(v, q);
101+
__m256i diff2 = _mm256_subs_epu8(q, v);
102+
103+
__m256i abs_diff = _mm256_max_epu8(diff1, diff2);
104+
105+
__m256i abs_diff16_lo = _mm256_unpacklo_epi8(abs_diff, _mm256_setzero_si256());
106+
__m256i abs_diff16_hi = _mm256_unpackhi_epi8(abs_diff, _mm256_setzero_si256());
107+
108+
sum256 = _mm256_add_epi16(sum256, abs_diff16_lo);
109+
sum256 = _mm256_add_epi16(sum256, abs_diff16_hi);
110+
}
111+
112+
// the vector sizes are assumed to be multiples of 16, check if one last 16-element part remaining
113+
if (m < dim) {
114+
__m128i v_short = _mm_loadu_si128((const __m128i * ) v_ptr);
115+
__m128i q_short = _mm_loadu_si128((const __m128i * ) q_ptr);
116+
117+
__m128i diff1 = _mm_subs_epu8(v_short, q_short);
118+
__m128i diff2 = _mm_subs_epu8(q_short, v_short);
119+
120+
__m128i abs_diff = _mm_max_epu8(diff1, diff2);
121+
122+
__m128i abs_diff16_lo_128 = _mm_unpacklo_epi8(abs_diff, _mm_setzero_si128());
123+
__m128i abs_diff16_hi_128 = _mm_unpackhi_epi8(abs_diff, _mm_setzero_si128());
124+
125+
__m256i abs_diff16_lo = _mm256_cvtepu16_epi32(abs_diff16_lo_128);
126+
__m256i abs_diff16_hi = _mm256_cvtepu16_epi32(abs_diff16_hi_128);
127+
128+
sum256 = _mm256_add_epi16(sum256, abs_diff16_lo);
129+
sum256 = _mm256_add_epi16(sum256, abs_diff16_hi);
130+
}
131+
132+
__m256i sum_epi32 = _mm256_add_epi32(
133+
_mm256_unpacklo_epi16(sum256, _mm256_setzero_si256()),
134+
_mm256_unpackhi_epi16(sum256, _mm256_setzero_si256()));
135+
136+
HSUM256_EPI32(sum_epi32, sum);
137+
138+
return (float) sum;
139+
}
140+
141+
EXPORT uint32_t impl_xor_popcnt_scalar8_avx_uint128(
142+
const uint8_t* query_ptr,
143+
const uint8_t* vector_ptr,
144+
uint32_t count
145+
) {
146+
const uint64_t* v_ptr = (const uint64_t*)vector_ptr;
147+
const uint64_t* q_ptr = (const uint64_t*)query_ptr;
148+
149+
__m256i sum1 = _mm256_set1_epi32(0);
150+
__m256i sum2 = _mm256_set1_epi32(0);
151+
for (uint32_t _i = 0; _i < count; _i++) {
152+
uint64_t v_1 = *v_ptr;
153+
uint64_t v_2 = *(v_ptr + 1);
154+
155+
__m256i popcnt1 = _mm256_set_epi32(
156+
_mm_popcnt_u64(v_1 ^ *(q_ptr + 0)),
157+
_mm_popcnt_u64(v_1 ^ *(q_ptr + 2)),
158+
_mm_popcnt_u64(v_1 ^ *(q_ptr + 4)),
159+
_mm_popcnt_u64(v_1 ^ *(q_ptr + 6)),
160+
_mm_popcnt_u64(v_2 ^ *(q_ptr + 1)),
161+
_mm_popcnt_u64(v_2 ^ *(q_ptr + 3)),
162+
_mm_popcnt_u64(v_2 ^ *(q_ptr + 5)),
163+
_mm_popcnt_u64(v_2 ^ *(q_ptr + 7))
164+
);
165+
sum1 = _mm256_add_epi32(sum1, popcnt1);
166+
167+
__m256i popcnt2 = _mm256_set_epi32(
168+
_mm_popcnt_u64(v_1 ^ *(q_ptr + 8)),
169+
_mm_popcnt_u64(v_1 ^ *(q_ptr + 10)),
170+
_mm_popcnt_u64(v_1 ^ *(q_ptr + 12)),
171+
_mm_popcnt_u64(v_1 ^ *(q_ptr + 14)),
172+
_mm_popcnt_u64(v_2 ^ *(q_ptr + 9)),
173+
_mm_popcnt_u64(v_2 ^ *(q_ptr + 11)),
174+
_mm_popcnt_u64(v_2 ^ *(q_ptr + 13)),
175+
_mm_popcnt_u64(v_2 ^ *(q_ptr + 15))
176+
);
177+
sum2 = _mm256_add_epi32(sum2, popcnt2);
178+
179+
v_ptr += 2;
180+
q_ptr += 16;
181+
}
182+
__m256i factor1 = _mm256_set_epi32(1, 2, 4, 8, 1, 2, 4, 8);
183+
__m256i factor2 = _mm256_set_epi32(16, 32, 64, 128, 16, 32, 64, 128);
184+
__m256 result1_mm256 = _mm256_cvtepi32_ps(_mm256_mullo_epi32(sum1, factor1));
185+
__m256 result2_mm256 = _mm256_cvtepi32_ps(_mm256_mullo_epi32(sum2, factor2));
186+
HSUM256_PS(_mm256_add_ps(result1_mm256, result2_mm256), mul_scalar);
187+
return (uint32_t)mul_scalar;
188+
}
189+
190+
EXPORT uint32_t impl_xor_popcnt_scalar4_avx_uint128(
191+
const uint8_t* query_ptr,
192+
const uint8_t* vector_ptr,
193+
uint32_t count
194+
) {
195+
const uint64_t* v_ptr = (const uint64_t*)vector_ptr;
196+
const uint64_t* q_ptr = (const uint64_t*)query_ptr;
197+
198+
__m256i sum = _mm256_set1_epi32(0);
199+
for (uint32_t _i = 0; _i < count; _i++) {
200+
uint64_t v_1 = *v_ptr;
201+
uint64_t v_2 = *(v_ptr + 1);
202+
203+
__m256i popcnt = _mm256_set_epi32(
204+
_mm_popcnt_u64(v_1 ^ *(q_ptr + 0)),
205+
_mm_popcnt_u64(v_1 ^ *(q_ptr + 2)),
206+
_mm_popcnt_u64(v_1 ^ *(q_ptr + 4)),
207+
_mm_popcnt_u64(v_1 ^ *(q_ptr + 6)),
208+
_mm_popcnt_u64(v_2 ^ *(q_ptr + 1)),
209+
_mm_popcnt_u64(v_2 ^ *(q_ptr + 3)),
210+
_mm_popcnt_u64(v_2 ^ *(q_ptr + 5)),
211+
_mm_popcnt_u64(v_2 ^ *(q_ptr + 7))
212+
);
213+
sum = _mm256_add_epi32(sum, popcnt);
214+
215+
v_ptr += 2;
216+
q_ptr += 8;
217+
}
218+
__m256i factor = _mm256_set_epi32(1, 2, 4, 8, 1, 2, 4, 8);
219+
__m256 result_mm256 = _mm256_cvtepi32_ps(_mm256_mullo_epi32(sum, factor));
220+
HSUM256_PS(result_mm256, mul_scalar);
221+
return (uint32_t)mul_scalar;
222+
}
223+
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* Copyright Qdrant
3+
* Copyright 2021 Datafuse Labs
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
#if defined(_MSC_VER)
19+
#define EXPORT __declspec(dllexport)
20+
#else
21+
#define EXPORT __attribute__((visibility("default")))
22+
#endif
23+

0 commit comments

Comments
ย (0)