Skip to content

Commit 07033dc

Browse files
mccullochtbenwtrent
andcommitted
Add a new codec to implement OSQ for 4 and 8 bit quantized vectors (#15169)
Unlike the existing ScalarQuantizer selects a mode based on an enum to quantize to unsigned bytes or packed nibbles using the same packing scheme as the existing scalar quantized codec. Seven bits is also supported for anyone interested in backward compatibility, but this setting is discouraged. This is separate from Lucene102BinaryQuantizedVectorsFormat as we need a larger value to store the component sum for each vector owing to larger quantized values. This closes #15064 --------- Co-authored-by: Benjamin Trent <ben.w.trent@gmail.com>
1 parent c925e9b commit 07033dc

15 files changed

+2794
-2
lines changed

lucene/CHANGES.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ New Features
2525
* GITHUB#15176: Add `[Float|Byte]VectorValues#rescorer(element[])` interface to allow optimized rescoring of vectors.
2626
(Ben Trent)
2727

28+
* GITHUB#15169: Add codecs for 4 and 8 bit Optimized Scalar Quantization vectors (Trevor McCulloch)
29+
2830
Improvements
2931
---------------------
3032
# GITHUB#15148: Add support uint8 distance and allow 8 bit scalar quantization (Trevor McCulloch)

lucene/core/src/java/module-info.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,9 @@
8787
org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat,
8888
org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat,
8989
org.apache.lucene.codecs.lucene102.Lucene102HnswBinaryQuantizedVectorsFormat,
90-
org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat;
90+
org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat,
91+
org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat,
92+
org.apache.lucene.codecs.lucene104.Lucene104HnswScalarQuantizedVectorsFormat;
9193
provides org.apache.lucene.codecs.PostingsFormat with
9294
org.apache.lucene.codecs.lucene104.Lucene104PostingsFormat;
9395
provides org.apache.lucene.index.SortFieldProvider with
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.lucene.codecs.lucene104;
18+
19+
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
20+
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
21+
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER;
22+
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH;
23+
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN;
24+
25+
import java.io.IOException;
26+
import java.util.concurrent.ExecutorService;
27+
import org.apache.lucene.codecs.KnnVectorsFormat;
28+
import org.apache.lucene.codecs.KnnVectorsReader;
29+
import org.apache.lucene.codecs.KnnVectorsWriter;
30+
import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding;
31+
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
32+
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
33+
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter;
34+
import org.apache.lucene.index.SegmentReadState;
35+
import org.apache.lucene.index.SegmentWriteState;
36+
import org.apache.lucene.search.TaskExecutor;
37+
import org.apache.lucene.util.hnsw.HnswGraph;
38+
39+
/**
40+
* A vectors format that uses HNSW graph to store and search for vectors. But vectors are binary
41+
* quantized using {@link Lucene104ScalarQuantizedVectorsFormat} before being stored in the graph.
42+
*/
43+
public class Lucene104HnswScalarQuantizedVectorsFormat extends KnnVectorsFormat {
44+
45+
public static final String NAME = "Lucene104HnswBinaryQuantizedVectorsFormat";
46+
47+
/**
48+
* Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to
49+
* {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details.
50+
*/
51+
private final int maxConn;
52+
53+
/**
54+
* The number of candidate neighbors to track while searching the graph for each newly inserted
55+
* node. Defaults to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link HnswGraph}
56+
* for details.
57+
*/
58+
private final int beamWidth;
59+
60+
/** The format for storing, reading, merging vectors on disk */
61+
private final Lucene104ScalarQuantizedVectorsFormat flatVectorsFormat;
62+
63+
private final int numMergeWorkers;
64+
private final TaskExecutor mergeExec;
65+
66+
/** Constructs a format using default graph construction parameters */
67+
public Lucene104HnswScalarQuantizedVectorsFormat() {
68+
this(
69+
ScalarEncoding.UNSIGNED_BYTE,
70+
DEFAULT_MAX_CONN,
71+
DEFAULT_BEAM_WIDTH,
72+
DEFAULT_NUM_MERGE_WORKER,
73+
null);
74+
}
75+
76+
/**
77+
* Constructs a format using the given graph construction parameters.
78+
*
79+
* @param maxConn the maximum number of connections to a node in the HNSW graph
80+
* @param beamWidth the size of the queue maintained during graph construction.
81+
*/
82+
public Lucene104HnswScalarQuantizedVectorsFormat(int maxConn, int beamWidth) {
83+
this(ScalarEncoding.UNSIGNED_BYTE, maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null);
84+
}
85+
86+
/**
87+
* Constructs a format using the given graph construction parameters and scalar quantization.
88+
*
89+
* @param maxConn the maximum number of connections to a node in the HNSW graph
90+
* @param beamWidth the size of the queue maintained during graph construction.
91+
* @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
92+
* larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec
93+
* @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
94+
* generated by this format to do the merge
95+
*/
96+
public Lucene104HnswScalarQuantizedVectorsFormat(
97+
ScalarEncoding encoding,
98+
int maxConn,
99+
int beamWidth,
100+
int numMergeWorkers,
101+
ExecutorService mergeExec) {
102+
super(NAME);
103+
flatVectorsFormat = new Lucene104ScalarQuantizedVectorsFormat(encoding);
104+
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
105+
throw new IllegalArgumentException(
106+
"maxConn must be positive and less than or equal to "
107+
+ MAXIMUM_MAX_CONN
108+
+ "; maxConn="
109+
+ maxConn);
110+
}
111+
if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) {
112+
throw new IllegalArgumentException(
113+
"beamWidth must be positive and less than or equal to "
114+
+ MAXIMUM_BEAM_WIDTH
115+
+ "; beamWidth="
116+
+ beamWidth);
117+
}
118+
this.maxConn = maxConn;
119+
this.beamWidth = beamWidth;
120+
if (numMergeWorkers == 1 && mergeExec != null) {
121+
throw new IllegalArgumentException(
122+
"No executor service is needed as we'll use single thread to merge");
123+
}
124+
this.numMergeWorkers = numMergeWorkers;
125+
if (mergeExec != null) {
126+
this.mergeExec = new TaskExecutor(mergeExec);
127+
} else {
128+
this.mergeExec = null;
129+
}
130+
}
131+
132+
@Override
133+
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
134+
return new Lucene99HnswVectorsWriter(
135+
state,
136+
maxConn,
137+
beamWidth,
138+
flatVectorsFormat.fieldsWriter(state),
139+
numMergeWorkers,
140+
mergeExec);
141+
}
142+
143+
@Override
144+
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
145+
return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
146+
}
147+
148+
@Override
149+
public int getMaxDimensions(String fieldName) {
150+
return 1024;
151+
}
152+
153+
@Override
154+
public String toString() {
155+
return "Lucene104HnswScalarQuantizedVectorsFormat(name=Lucene104HnswScalarQuantizedVectorsFormat, maxConn="
156+
+ maxConn
157+
+ ", beamWidth="
158+
+ beamWidth
159+
+ ", flatVectorFormat="
160+
+ flatVectorsFormat
161+
+ ")";
162+
}
163+
}
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.lucene.codecs.lucene104;
18+
19+
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
20+
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
21+
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
22+
23+
import java.io.IOException;
24+
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
25+
import org.apache.lucene.index.KnnVectorValues;
26+
import org.apache.lucene.index.VectorSimilarityFunction;
27+
import org.apache.lucene.util.ArrayUtil;
28+
import org.apache.lucene.util.VectorUtil;
29+
import org.apache.lucene.util.hnsw.RandomVectorScorer;
30+
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
31+
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
32+
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
33+
34+
/** Vector scorer over OptimizedScalarQuantized vectors */
35+
public class Lucene104ScalarQuantizedVectorScorer implements FlatVectorsScorer {
36+
private final FlatVectorsScorer nonQuantizedDelegate;
37+
38+
public Lucene104ScalarQuantizedVectorScorer(FlatVectorsScorer nonQuantizedDelegate) {
39+
this.nonQuantizedDelegate = nonQuantizedDelegate;
40+
}
41+
42+
@Override
43+
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
44+
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues)
45+
throws IOException {
46+
if (vectorValues instanceof QuantizedByteVectorValues qv) {
47+
return new ScalarQuantizedVectorScorerSupplier(qv, similarityFunction);
48+
}
49+
// It is possible to get to this branch during initial indexing and flush
50+
return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues);
51+
}
52+
53+
@Override
54+
public RandomVectorScorer getRandomVectorScorer(
55+
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target)
56+
throws IOException {
57+
if (vectorValues instanceof QuantizedByteVectorValues qv) {
58+
OptimizedScalarQuantizer quantizer = qv.getQuantizer();
59+
byte[] targetQuantized =
60+
new byte
61+
[OptimizedScalarQuantizer.discretize(
62+
target.length, qv.getScalarEncoding().getDimensionsPerByte())];
63+
// We make a copy as the quantization process mutates the input
64+
float[] copy = ArrayUtil.copyOfSubArray(target, 0, target.length);
65+
if (similarityFunction == COSINE) {
66+
VectorUtil.l2normalize(copy);
67+
}
68+
target = copy;
69+
var targetCorrectiveTerms =
70+
quantizer.scalarQuantize(
71+
target, targetQuantized, qv.getScalarEncoding().getBits(), qv.getCentroid());
72+
return new RandomVectorScorer.AbstractRandomVectorScorer(qv) {
73+
@Override
74+
public float score(int node) throws IOException {
75+
return quantizedScore(
76+
targetQuantized, targetCorrectiveTerms, qv, node, similarityFunction);
77+
}
78+
};
79+
}
80+
// It is possible to get to this branch during initial indexing and flush
81+
return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
82+
}
83+
84+
@Override
85+
public RandomVectorScorer getRandomVectorScorer(
86+
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target)
87+
throws IOException {
88+
return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
89+
}
90+
91+
@Override
92+
public String toString() {
93+
return "Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate="
94+
+ nonQuantizedDelegate
95+
+ ")";
96+
}
97+
98+
private static final class ScalarQuantizedVectorScorerSupplier
99+
implements RandomVectorScorerSupplier {
100+
private final QuantizedByteVectorValues targetValues;
101+
private final QuantizedByteVectorValues values;
102+
private final VectorSimilarityFunction similarity;
103+
104+
public ScalarQuantizedVectorScorerSupplier(
105+
QuantizedByteVectorValues values, VectorSimilarityFunction similarity) throws IOException {
106+
this.targetValues = values.copy();
107+
this.values = values;
108+
this.similarity = similarity;
109+
}
110+
111+
@Override
112+
public UpdateableRandomVectorScorer scorer() throws IOException {
113+
return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(values) {
114+
private byte[] targetVector;
115+
private OptimizedScalarQuantizer.QuantizationResult targetCorrectiveTerms;
116+
117+
@Override
118+
public float score(int node) throws IOException {
119+
return quantizedScore(targetVector, targetCorrectiveTerms, values, node, similarity);
120+
}
121+
122+
@Override
123+
public void setScoringOrdinal(int node) throws IOException {
124+
var rawTargetVector = targetValues.vectorValue(node);
125+
switch (values.getScalarEncoding()) {
126+
case UNSIGNED_BYTE -> targetVector = rawTargetVector;
127+
case SEVEN_BIT -> targetVector = rawTargetVector;
128+
case PACKED_NIBBLE -> {
129+
if (targetVector == null) {
130+
targetVector = new byte[OptimizedScalarQuantizer.discretize(values.dimension(), 2)];
131+
}
132+
OffHeapScalarQuantizedVectorValues.unpackNibbles(rawTargetVector, targetVector);
133+
}
134+
}
135+
targetCorrectiveTerms = targetValues.getCorrectiveTerms(node);
136+
}
137+
};
138+
}
139+
140+
@Override
141+
public RandomVectorScorerSupplier copy() throws IOException {
142+
return new ScalarQuantizedVectorScorerSupplier(values.copy(), similarity);
143+
}
144+
}
145+
146+
private static final float[] SCALE_LUT =
147+
new float[] {
148+
1f,
149+
1f / ((1 << 2) - 1),
150+
1f / ((1 << 3) - 1),
151+
1f / ((1 << 4) - 1),
152+
1f / ((1 << 5) - 1),
153+
1f / ((1 << 6) - 1),
154+
1f / ((1 << 7) - 1),
155+
1f / ((1 << 8) - 1),
156+
};
157+
158+
private static float quantizedScore(
159+
byte[] quantizedQuery,
160+
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
161+
QuantizedByteVectorValues targetVectors,
162+
int targetOrd,
163+
VectorSimilarityFunction similarityFunction)
164+
throws IOException {
165+
var scalarEncoding = targetVectors.getScalarEncoding();
166+
byte[] quantizedDoc = targetVectors.vectorValue(targetOrd);
167+
float qcDist =
168+
switch (scalarEncoding) {
169+
case UNSIGNED_BYTE -> VectorUtil.uint8DotProduct(quantizedQuery, quantizedDoc);
170+
case SEVEN_BIT -> VectorUtil.dotProduct(quantizedQuery, quantizedDoc);
171+
case PACKED_NIBBLE -> VectorUtil.int4DotProductSinglePacked(quantizedQuery, quantizedDoc);
172+
};
173+
OptimizedScalarQuantizer.QuantizationResult indexCorrections =
174+
targetVectors.getCorrectiveTerms(targetOrd);
175+
float scale = SCALE_LUT[scalarEncoding.getBits() - 1];
176+
float x1 = indexCorrections.quantizedComponentSum();
177+
float ax = indexCorrections.lowerInterval();
178+
// Here we must scale according to the bits
179+
float lx = (indexCorrections.upperInterval() - ax) * scale;
180+
float ay = queryCorrections.lowerInterval();
181+
float ly = (queryCorrections.upperInterval() - ay) * scale;
182+
float y1 = queryCorrections.quantizedComponentSum();
183+
float score =
184+
ax * ay * targetVectors.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * qcDist;
185+
// For euclidean, we need to invert the score and apply the additional correction, which is
186+
// assumed to be the squared l2norm of the centroid centered vectors.
187+
if (similarityFunction == EUCLIDEAN) {
188+
score =
189+
queryCorrections.additionalCorrection()
190+
+ indexCorrections.additionalCorrection()
191+
- 2 * score;
192+
return Math.max(1 / (1f + score), 0);
193+
} else {
194+
// For cosine and max inner product, we need to apply the additional correction, which is
195+
// assumed to be the non-centered dot-product between the vector and the centroid
196+
score +=
197+
queryCorrections.additionalCorrection()
198+
+ indexCorrections.additionalCorrection()
199+
- targetVectors.getCentroidDP();
200+
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
201+
return VectorUtil.scaleMaxInnerProductScore(score);
202+
}
203+
return Math.max((1f + score) / 2f, 0);
204+
}
205+
}
206+
}

0 commit comments

Comments
 (0)