Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -108,24 +108,24 @@ public void init() {

@Benchmark
public float binaryCosineScalar() {
return VectorUtil.cosine(bytesA, bytesB);
return VectorUtil.cosineBytes(bytesA, bytesB);
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public float binaryCosineVector() {
return VectorUtil.cosine(bytesA, bytesB);
return VectorUtil.cosineBytes(bytesA, bytesB);
}

@Benchmark
public int binarySquareScalar() {
return VectorUtil.squareDistance(bytesA, bytesB);
return VectorUtil.squareDistanceBytes(bytesA, bytesB);
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public int binarySquareVector() {
return VectorUtil.squareDistance(bytesA, bytesB);
return VectorUtil.squareDistanceBytes(bytesA, bytesB);
}

@Benchmark
Expand Down Expand Up @@ -187,13 +187,13 @@ public int binaryHalfByteSquareBothPackedVector() {

@Benchmark
public int binaryDotProductScalar() {
return VectorUtil.dotProduct(bytesA, bytesB);
return VectorUtil.dotProductBytes(bytesA, bytesB);
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public int binaryDotProductVector() {
return VectorUtil.dotProduct(bytesA, bytesB);
return VectorUtil.dotProductBytes(bytesA, bytesB);
}

@Benchmark
Expand Down Expand Up @@ -277,7 +277,7 @@ public int binaryHalfByteDotProductBothPackedVector() {

@Benchmark
public float floatCosineScalar() {
return VectorUtil.cosine(floatsA, floatsB);
return VectorUtil.cosineFloats(floatsA, floatsB);
}

@Benchmark
Expand All @@ -290,7 +290,7 @@ public float[] l2Normalize() {
value = 15,
jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public float floatCosineVector() {
return VectorUtil.cosine(floatsA, floatsB);
return VectorUtil.cosineFloats(floatsA, floatsB);
}

@Benchmark
Expand All @@ -303,27 +303,27 @@ public float[] l2NormalizeVector() {

@Benchmark
public float floatDotProductScalar() {
return VectorUtil.dotProduct(floatsA, floatsB);
return VectorUtil.dotProductFloats(floatsA, floatsB);
}

@Benchmark
@Fork(
value = 15,
jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public float floatDotProductVector() {
return VectorUtil.dotProduct(floatsA, floatsB);
return VectorUtil.dotProductFloats(floatsA, floatsB);
}

@Benchmark
public float floatSquareScalar() {
return VectorUtil.squareDistance(floatsA, floatsB);
return VectorUtil.squareDistanceFloats(floatsA, floatsB);
}

@Benchmark
@Fork(
value = 15,
jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public float floatSquareVector() {
return VectorUtil.squareDistance(floatsA, floatsB);
return VectorUtil.squareDistanceFloats(floatsA, floatsB);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,6 @@ int discretizedDimensions() {
float getCentroidDP() throws IOException {
// this only gets executed on-merge
float[] centroid = getCentroid();
return VectorUtil.dotProduct(centroid, centroid);
return VectorUtil.dotProductFloats(centroid, centroid);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ private void writeField(
writeBinarizedVectors(fieldData, clusterCenter, quantizer);
long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset;
float centroidDp =
fieldData.getVectors().size() > 0 ? VectorUtil.dotProduct(clusterCenter, clusterCenter) : 0;
fieldData.getVectors().size() > 0
? VectorUtil.dotProductFloats(clusterCenter, clusterCenter)
: 0;

writeMeta(
fieldData.fieldInfo,
Expand Down Expand Up @@ -227,7 +229,7 @@ private void writeSortingField(
writeSortedBinarizedVectors(fieldData, clusterCenter, ordMap, scalarQuantizer);
long quantizedVectorLength = binarizedVectorData.getFilePointer() - vectorDataOffset;

float centroidDp = VectorUtil.dotProduct(clusterCenter, clusterCenter);
float centroidDp = VectorUtil.dotProductFloats(clusterCenter, clusterCenter);
writeMeta(
fieldData.fieldInfo,
maxDoc,
Expand Down Expand Up @@ -336,7 +338,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE
writeBinarizedVectorData(binarizedVectorData, binarizedVectorValues);
long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset;
float centroidDp =
docsWithField.cardinality() > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0;
docsWithField.cardinality() > 0 ? VectorUtil.dotProductFloats(centroid, centroid) : 0;
writeMeta(
fieldInfo,
segmentWriteState.segmentInfo.maxDoc(),
Expand Down Expand Up @@ -426,7 +428,7 @@ public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
// Don't need access to the random vectors, we can just use the merged
rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
centroid = mergedCentroid;
cDotC = vectorCount > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0;
cDotC = vectorCount > 0 ? VectorUtil.dotProductFloats(centroid, centroid) : 0;
if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) {
segmentWriteState.infoStream.message(
BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount);
Expand Down Expand Up @@ -696,7 +698,7 @@ public boolean isFinished() {
public void addValue(int docID, float[] vectorValue) throws IOException {
flatFieldVectorsWriter.addValue(docID, vectorValue);
if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
float dp = VectorUtil.dotProduct(vectorValue, vectorValue);
float dp = VectorUtil.dotProductFloats(vectorValue, vectorValue);
float divisor = (float) Math.sqrt(dp);
magnitudes.add(divisor);
for (int i = 0; i < vectorValue.length; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ private static float quantizedScore(
float qcDist =
switch (scalarEncoding) {
case UNSIGNED_BYTE -> VectorUtil.uint8DotProduct(quantizedQuery, quantizedDoc);
case SEVEN_BIT -> VectorUtil.dotProduct(quantizedQuery, quantizedDoc);
case SEVEN_BIT -> VectorUtil.dotProductBytes(quantizedQuery, quantizedDoc);
case PACKED_NIBBLE -> VectorUtil.int4DotProductSinglePacked(quantizedQuery, quantizedDoc);
case SINGLE_BIT_QUERY_NIBBLE ->
VectorUtil.int4BitDotProduct(quantizedQuery, quantizedDoc);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ private void writeField(
writeVectors(fieldData, clusterCenter, quantizer);
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
float centroidDp =
!fieldData.getVectors().isEmpty() ? VectorUtil.dotProduct(clusterCenter, clusterCenter) : 0;
!fieldData.getVectors().isEmpty()
? VectorUtil.dotProductFloats(clusterCenter, clusterCenter)
: 0;

writeMeta(
fieldData.fieldInfo,
Expand Down Expand Up @@ -234,7 +236,7 @@ private void writeSortingField(
writeSortedVectors(fieldData, clusterCenter, ordMap, scalarQuantizer);
long quantizedVectorLength = vectorData.getFilePointer() - vectorDataOffset;

float centroidDp = VectorUtil.dotProduct(clusterCenter, clusterCenter);
float centroidDp = VectorUtil.dotProductFloats(clusterCenter, clusterCenter);
writeMeta(
fieldData.fieldInfo,
maxDoc,
Expand Down Expand Up @@ -355,7 +357,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE
DocsWithFieldSet docsWithField = writeVectorData(vectorData, quantizedVectorValues);
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
float centroidDp =
docsWithField.cardinality() > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0;
docsWithField.cardinality() > 0 ? VectorUtil.dotProductFloats(centroid, centroid) : 0;
writeMeta(
fieldInfo,
segmentWriteState.segmentInfo.maxDoc(),
Expand Down Expand Up @@ -447,7 +449,7 @@ public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
// Don't need access to the random vectors, we can just use the merged
rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
centroid = mergedCentroid;
cDotC = vectorCount > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0;
cDotC = vectorCount > 0 ? VectorUtil.dotProductFloats(centroid, centroid) : 0;
if (segmentWriteState.infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) {
segmentWriteState.infoStream.message(
QUANTIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount);
Expand Down Expand Up @@ -747,7 +749,7 @@ public boolean isFinished() {
public void addValue(int docID, float[] vectorValue) throws IOException {
flatFieldVectorsWriter.addValue(docID, vectorValue);
if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
float dp = VectorUtil.dotProduct(vectorValue, vectorValue);
float dp = VectorUtil.dotProductFloats(vectorValue, vectorValue);
float divisor = (float) Math.sqrt(dp);
magnitudes.add(divisor);
for (int i = 0; i < vectorValue.length; i++) {
Expand Down Expand Up @@ -802,7 +804,7 @@ static class QuantizedFloatVectorValues extends QuantizedByteVectorValues {
new byte[encoding.getDocPackedLength(quantized.length)];
};
this.centroid = centroid;
this.centroidDP = VectorUtil.dotProduct(centroid, centroid);
this.centroidDP = VectorUtil.dotProductFloats(centroid, centroid);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.internal.vectorization;
package org.apache.lucene.codecs.lucene99;

import java.io.IOException;
import java.lang.foreign.MemorySegment;
Expand All @@ -25,9 +25,11 @@
import org.apache.lucene.store.FilterIndexInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.MemorySegmentAccessInput;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;

abstract sealed class Lucene99MemorySegmentByteVectorScorer
/** A scorer of vectors whose element size is byte. */
public abstract sealed class Lucene99MemorySegmentByteVectorScorer
extends RandomVectorScorer.AbstractRandomVectorScorer {

final int vectorByteSize;
Expand Down Expand Up @@ -98,7 +100,7 @@ static final class CosineScorer extends Lucene99MemorySegmentByteVectorScorer {
@Override
public float score(int node) throws IOException {
checkOrdinal(node);
float raw = PanamaVectorUtilSupport.cosine(query, getSegment(node));
float raw = VectorUtil.cosineBytes(query, getSegment(node));
return (1 + raw) / 2;
}
}
Expand All @@ -112,7 +114,7 @@ static final class DotProductScorer extends Lucene99MemorySegmentByteVectorScore
public float score(int node) throws IOException {
checkOrdinal(node);
// divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len
float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node));
float raw = VectorUtil.dotProductBytes(query, getSegment(node));
return 0.5f + raw / (float) (query.length * (1 << 15));
}
}
Expand All @@ -125,7 +127,7 @@ static final class EuclideanScorer extends Lucene99MemorySegmentByteVectorScorer
@Override
public float score(int node) throws IOException {
checkOrdinal(node);
float raw = PanamaVectorUtilSupport.squareDistance(query, getSegment(node));
float raw = VectorUtil.squareDistanceBytes(query, getSegment(node));
return 1 / (1f + raw);
}
}
Expand All @@ -138,7 +140,7 @@ static final class MaxInnerProductScorer extends Lucene99MemorySegmentByteVector
@Override
public float score(int node) throws IOException {
checkOrdinal(node);
float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node));
float raw = VectorUtil.dotProductBytes(query, getSegment(node));
if (raw < 0) {
return 1 / (1 + -1 * raw);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.internal.vectorization;
package org.apache.lucene.codecs.lucene99;

import java.io.IOException;
import java.lang.foreign.MemorySegment;
Expand All @@ -25,6 +25,7 @@
import org.apache.lucene.store.FilterIndexInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.MemorySegmentAccessInput;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;

Expand All @@ -41,7 +42,7 @@ public abstract sealed class Lucene99MemorySegmentByteVectorScorerSupplier
* Return an optional whose value, if present, is the scorer supplier. Otherwise, an empty
* optional is returned.
*/
static Optional<RandomVectorScorerSupplier> create(
public static Optional<RandomVectorScorerSupplier> create(
VectorSimilarityFunction type, IndexInput input, KnnVectorValues values) {
assert values instanceof ByteVectorValues;
input = FilterIndexInput.unwrapOnlyTest(input);
Expand Down Expand Up @@ -117,8 +118,7 @@ public UpdateableRandomVectorScorer scorer() {
@Override
public float score(int node) throws IOException {
checkOrdinal(node);
float raw =
PanamaVectorUtilSupport.cosine(getFirstSegment(queryOrd), getSecondSegment(node));
float raw = VectorUtil.cosineBytes(getFirstSegment(queryOrd), getSecondSegment(node));
return (1 + raw) / 2;
}

Expand Down Expand Up @@ -151,8 +151,7 @@ public UpdateableRandomVectorScorer scorer() {
public float score(int node) throws IOException {
checkOrdinal(node);
// divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len
float raw =
PanamaVectorUtilSupport.dotProduct(getFirstSegment(queryOrd), getSecondSegment(node));
float raw = VectorUtil.dotProductBytes(getFirstSegment(queryOrd), getSecondSegment(node));
return 0.5f + raw / (float) (values.dimension() * (1 << 15));
}

Expand Down Expand Up @@ -185,8 +184,7 @@ public UpdateableRandomVectorScorer scorer() {
public float score(int node) throws IOException {
checkOrdinal(node);
float raw =
PanamaVectorUtilSupport.squareDistance(
getFirstSegment(queryOrd), getSecondSegment(node));
VectorUtil.squareDistanceBytes(getFirstSegment(queryOrd), getSecondSegment(node));
return 1 / (1f + raw);
}

Expand Down Expand Up @@ -218,8 +216,7 @@ public UpdateableRandomVectorScorer scorer() {
@Override
public float score(int node) throws IOException {
checkOrdinal(node);
float raw =
PanamaVectorUtilSupport.dotProduct(getFirstSegment(queryOrd), getSecondSegment(node));
float raw = VectorUtil.dotProductBytes(getFirstSegment(queryOrd), getSecondSegment(node));
if (raw < 0) {
return 1 / (1 + -1 * raw);
}
Expand Down
Loading
Loading