diff --git a/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.java b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.java index 4c8253fdab9f..db35164ee9f5 100644 --- a/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.java +++ b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.java @@ -108,24 +108,24 @@ public void init() { @Benchmark public float binaryCosineScalar() { - return VectorUtil.cosine(bytesA, bytesB); + return VectorUtil.cosineBytes(bytesA, bytesB); } @Benchmark @Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"}) public float binaryCosineVector() { - return VectorUtil.cosine(bytesA, bytesB); + return VectorUtil.cosineBytes(bytesA, bytesB); } @Benchmark public int binarySquareScalar() { - return VectorUtil.squareDistance(bytesA, bytesB); + return VectorUtil.squareDistanceBytes(bytesA, bytesB); } @Benchmark @Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"}) public int binarySquareVector() { - return VectorUtil.squareDistance(bytesA, bytesB); + return VectorUtil.squareDistanceBytes(bytesA, bytesB); } @Benchmark @@ -187,13 +187,13 @@ public int binaryHalfByteSquareBothPackedVector() { @Benchmark public int binaryDotProductScalar() { - return VectorUtil.dotProduct(bytesA, bytesB); + return VectorUtil.dotProductBytes(bytesA, bytesB); } @Benchmark @Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"}) public int binaryDotProductVector() { - return VectorUtil.dotProduct(bytesA, bytesB); + return VectorUtil.dotProductBytes(bytesA, bytesB); } @Benchmark @@ -277,7 +277,7 @@ public int binaryHalfByteDotProductBothPackedVector() { @Benchmark public float floatCosineScalar() { - return VectorUtil.cosine(floatsA, floatsB); + return VectorUtil.cosineFloats(floatsA, floatsB); } @Benchmark @@ -290,7 +290,7 @@ public float[] l2Normalize() { value = 15, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"}) public float floatCosineVector() { - return VectorUtil.cosine(floatsA, floatsB); + return VectorUtil.cosineFloats(floatsA, floatsB); } @Benchmark @@ -303,7 +303,7 @@ public float[] l2NormalizeVector() { @Benchmark public float floatDotProductScalar() { - return VectorUtil.dotProduct(floatsA, floatsB); + return VectorUtil.dotProductFloats(floatsA, floatsB); } @Benchmark @@ -311,12 +311,12 @@ public float floatDotProductScalar() { value = 15, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"}) public float floatDotProductVector() { - return VectorUtil.dotProduct(floatsA, floatsB); + return VectorUtil.dotProductFloats(floatsA, floatsB); } @Benchmark public float floatSquareScalar() { - return VectorUtil.squareDistance(floatsA, floatsB); + return VectorUtil.squareDistanceFloats(floatsA, floatsB); } @Benchmark @@ -324,6 +324,6 @@ public float floatSquareScalar() { value = 15, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"}) public float floatSquareVector() { - return VectorUtil.squareDistance(floatsA, floatsB); + return VectorUtil.squareDistanceFloats(floatsA, floatsB); } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/BinarizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/BinarizedByteVectorValues.java index 62eff4d72ac4..bf43d1d43950 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/BinarizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/BinarizedByteVectorValues.java @@ -79,6 +79,6 @@ int discretizedDimensions() { float getCentroidDP() throws IOException { // this only gets executed on-merge float[] centroid = getCentroid(); - return VectorUtil.dotProduct(centroid, centroid); + return VectorUtil.dotProductFloats(centroid, centroid); } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsWriter.java index cb0a2e41da17..f22bc44830d2 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsWriter.java @@ -176,7 +176,9 @@ private void writeField( writeBinarizedVectors(fieldData, clusterCenter, quantizer); long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset; float centroidDp = - fieldData.getVectors().size() > 0 ? VectorUtil.dotProduct(clusterCenter, clusterCenter) : 0; + fieldData.getVectors().size() > 0 + ? VectorUtil.dotProductFloats(clusterCenter, clusterCenter) + : 0; writeMeta( fieldData.fieldInfo, @@ -227,7 +229,7 @@ private void writeSortingField( writeSortedBinarizedVectors(fieldData, clusterCenter, ordMap, scalarQuantizer); long quantizedVectorLength = binarizedVectorData.getFilePointer() - vectorDataOffset; - float centroidDp = VectorUtil.dotProduct(clusterCenter, clusterCenter); + float centroidDp = VectorUtil.dotProductFloats(clusterCenter, clusterCenter); writeMeta( fieldData.fieldInfo, maxDoc, @@ -336,7 +338,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE writeBinarizedVectorData(binarizedVectorData, binarizedVectorValues); long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset; float centroidDp = - docsWithField.cardinality() > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0; + docsWithField.cardinality() > 0 ? VectorUtil.dotProductFloats(centroid, centroid) : 0; writeMeta( fieldInfo, segmentWriteState.segmentInfo.maxDoc(), @@ -426,7 +428,7 @@ public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( // Don't need access to the random vectors, we can just use the merged rawVectorDelegate.mergeOneField(fieldInfo, mergeState); centroid = mergedCentroid; - cDotC = vectorCount > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0; + cDotC = vectorCount > 0 ? VectorUtil.dotProductFloats(centroid, centroid) : 0; if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) { segmentWriteState.infoStream.message( BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); @@ -696,7 +698,7 @@ public boolean isFinished() { public void addValue(int docID, float[] vectorValue) throws IOException { flatFieldVectorsWriter.addValue(docID, vectorValue); if (fieldInfo.getVectorSimilarityFunction() == COSINE) { - float dp = VectorUtil.dotProduct(vectorValue, vectorValue); + float dp = VectorUtil.dotProductFloats(vectorValue, vectorValue); float divisor = (float) Math.sqrt(dp); magnitudes.add(divisor); for (int i = 0; i < vectorValue.length; i++) { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java index db2021a2a0e5..b6ea5c39dc85 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java @@ -244,7 +244,7 @@ private static float quantizedScore( float qcDist = switch (scalarEncoding) { case UNSIGNED_BYTE -> VectorUtil.uint8DotProduct(quantizedQuery, quantizedDoc); - case SEVEN_BIT -> VectorUtil.dotProduct(quantizedQuery, quantizedDoc); + case SEVEN_BIT -> VectorUtil.dotProductBytes(quantizedQuery, quantizedDoc); case PACKED_NIBBLE -> VectorUtil.int4DotProductSinglePacked(quantizedQuery, quantizedDoc); case SINGLE_BIT_QUERY_NIBBLE -> VectorUtil.int4BitDotProduct(quantizedQuery, quantizedDoc); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java index 69d23bc95df3..d0ab6c77864e 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java @@ -176,7 +176,9 @@ private void writeField( writeVectors(fieldData, clusterCenter, quantizer); long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; float centroidDp = - !fieldData.getVectors().isEmpty() ? VectorUtil.dotProduct(clusterCenter, clusterCenter) : 0; + !fieldData.getVectors().isEmpty() + ? VectorUtil.dotProductFloats(clusterCenter, clusterCenter) + : 0; writeMeta( fieldData.fieldInfo, @@ -234,7 +236,7 @@ private void writeSortingField( writeSortedVectors(fieldData, clusterCenter, ordMap, scalarQuantizer); long quantizedVectorLength = vectorData.getFilePointer() - vectorDataOffset; - float centroidDp = VectorUtil.dotProduct(clusterCenter, clusterCenter); + float centroidDp = VectorUtil.dotProductFloats(clusterCenter, clusterCenter); writeMeta( fieldData.fieldInfo, maxDoc, @@ -355,7 +357,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE DocsWithFieldSet docsWithField = writeVectorData(vectorData, quantizedVectorValues); long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; float centroidDp = - docsWithField.cardinality() > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0; + docsWithField.cardinality() > 0 ? VectorUtil.dotProductFloats(centroid, centroid) : 0; writeMeta( fieldInfo, segmentWriteState.segmentInfo.maxDoc(), @@ -447,7 +449,7 @@ public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( // Don't need access to the random vectors, we can just use the merged rawVectorDelegate.mergeOneField(fieldInfo, mergeState); centroid = mergedCentroid; - cDotC = vectorCount > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0; + cDotC = vectorCount > 0 ? VectorUtil.dotProductFloats(centroid, centroid) : 0; if (segmentWriteState.infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) { segmentWriteState.infoStream.message( QUANTIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); @@ -747,7 +749,7 @@ public boolean isFinished() { public void addValue(int docID, float[] vectorValue) throws IOException { flatFieldVectorsWriter.addValue(docID, vectorValue); if (fieldInfo.getVectorSimilarityFunction() == COSINE) { - float dp = VectorUtil.dotProduct(vectorValue, vectorValue); + float dp = VectorUtil.dotProductFloats(vectorValue, vectorValue); float divisor = (float) Math.sqrt(dp); magnitudes.add(divisor); for (int i = 0; i < vectorValue.length; i++) { @@ -802,7 +804,7 @@ static class QuantizedFloatVectorValues extends QuantizedByteVectorValues { new byte[encoding.getDocPackedLength(quantized.length)]; }; this.centroid = centroid; - this.centroidDP = VectorUtil.dotProduct(centroid, centroid); + this.centroidDP = VectorUtil.dotProductFloats(centroid, centroid); } @Override diff --git a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99MemorySegmentByteVectorScorer.java similarity index 91% rename from lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java rename to lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99MemorySegmentByteVectorScorer.java index a8799c25a30a..e2fab89a1b0b 100644 --- a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99MemorySegmentByteVectorScorer.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.lucene.internal.vectorization; +package org.apache.lucene.codecs.lucene99; import java.io.IOException; import java.lang.foreign.MemorySegment; @@ -25,9 +25,11 @@ import org.apache.lucene.store.FilterIndexInput; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.MemorySegmentAccessInput; +import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.RandomVectorScorer; -abstract sealed class Lucene99MemorySegmentByteVectorScorer +/** A scorer of vectors whose element size is byte. */ +public abstract sealed class Lucene99MemorySegmentByteVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { final int vectorByteSize; @@ -98,7 +100,7 @@ static final class CosineScorer extends Lucene99MemorySegmentByteVectorScorer { @Override public float score(int node) throws IOException { checkOrdinal(node); - float raw = PanamaVectorUtilSupport.cosine(query, getSegment(node)); + float raw = VectorUtil.cosineBytes(query, getSegment(node)); return (1 + raw) / 2; } } @@ -112,7 +114,7 @@ static final class DotProductScorer extends Lucene99MemorySegmentByteVectorScore public float score(int node) throws IOException { checkOrdinal(node); // divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len - float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node)); + float raw = VectorUtil.dotProductBytes(query, getSegment(node)); return 0.5f + raw / (float) (query.length * (1 << 15)); } } @@ -125,7 +127,7 @@ static final class EuclideanScorer extends Lucene99MemorySegmentByteVectorScorer @Override public float score(int node) throws IOException { checkOrdinal(node); - float raw = PanamaVectorUtilSupport.squareDistance(query, getSegment(node)); + float raw = VectorUtil.squareDistanceBytes(query, getSegment(node)); return 1 / (1f + raw); } } @@ -138,7 +140,7 @@ static final class MaxInnerProductScorer extends Lucene99MemorySegmentByteVector @Override public float score(int node) throws IOException { checkOrdinal(node); - float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node)); + float raw = VectorUtil.dotProductBytes(query, getSegment(node)); if (raw < 0) { return 1 / (1 + -1 * raw); } diff --git a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorerSupplier.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99MemorySegmentByteVectorScorerSupplier.java similarity index 94% rename from lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorerSupplier.java rename to lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99MemorySegmentByteVectorScorerSupplier.java index 74b31e1bf0c8..139dc4ac6163 100644 --- a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorerSupplier.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99MemorySegmentByteVectorScorerSupplier.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.lucene.internal.vectorization; +package org.apache.lucene.codecs.lucene99; import java.io.IOException; import java.lang.foreign.MemorySegment; @@ -25,6 +25,7 @@ import org.apache.lucene.store.FilterIndexInput; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.MemorySegmentAccessInput; +import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; @@ -41,7 +42,7 @@ public abstract sealed class Lucene99MemorySegmentByteVectorScorerSupplier * Return an optional whose value, if present, is the scorer supplier. Otherwise, an empty * optional is returned. */ - static Optional create( + public static Optional create( VectorSimilarityFunction type, IndexInput input, KnnVectorValues values) { assert values instanceof ByteVectorValues; input = FilterIndexInput.unwrapOnlyTest(input); @@ -117,8 +118,7 @@ public UpdateableRandomVectorScorer scorer() { @Override public float score(int node) throws IOException { checkOrdinal(node); - float raw = - PanamaVectorUtilSupport.cosine(getFirstSegment(queryOrd), getSecondSegment(node)); + float raw = VectorUtil.cosineBytes(getFirstSegment(queryOrd), getSecondSegment(node)); return (1 + raw) / 2; } @@ -151,8 +151,7 @@ public UpdateableRandomVectorScorer scorer() { public float score(int node) throws IOException { checkOrdinal(node); // divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len - float raw = - PanamaVectorUtilSupport.dotProduct(getFirstSegment(queryOrd), getSecondSegment(node)); + float raw = VectorUtil.dotProductBytes(getFirstSegment(queryOrd), getSecondSegment(node)); return 0.5f + raw / (float) (values.dimension() * (1 << 15)); } @@ -185,8 +184,7 @@ public UpdateableRandomVectorScorer scorer() { public float score(int node) throws IOException { checkOrdinal(node); float raw = - PanamaVectorUtilSupport.squareDistance( - getFirstSegment(queryOrd), getSecondSegment(node)); + VectorUtil.squareDistanceBytes(getFirstSegment(queryOrd), getSecondSegment(node)); return 1 / (1f + raw); } @@ -218,8 +216,7 @@ public UpdateableRandomVectorScorer scorer() { @Override public float score(int node) throws IOException { checkOrdinal(node); - float raw = - PanamaVectorUtilSupport.dotProduct(getFirstSegment(queryOrd), getSecondSegment(node)); + float raw = VectorUtil.dotProductBytes(getFirstSegment(queryOrd), getSecondSegment(node)); if (raw < 0) { return 1 / (1 + -1 * raw); } diff --git a/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java b/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java index a692d917e606..9fd372c0a8a7 100644 --- a/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java +++ b/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java @@ -16,13 +16,16 @@ */ package org.apache.lucene.index; -import static org.apache.lucene.util.VectorUtil.cosine; -import static org.apache.lucene.util.VectorUtil.dotProduct; +import static org.apache.lucene.util.VectorUtil.cosineBytes; +import static org.apache.lucene.util.VectorUtil.cosineFloats; +import static org.apache.lucene.util.VectorUtil.dotProductBytes; +import static org.apache.lucene.util.VectorUtil.dotProductFloats; import static org.apache.lucene.util.VectorUtil.dotProductScore; import static org.apache.lucene.util.VectorUtil.normalizeDistanceToUnitInterval; import static org.apache.lucene.util.VectorUtil.normalizeToUnitInterval; import static org.apache.lucene.util.VectorUtil.scaleMaxInnerProductScore; -import static org.apache.lucene.util.VectorUtil.squareDistance; +import static org.apache.lucene.util.VectorUtil.squareDistanceBytes; +import static org.apache.lucene.util.VectorUtil.squareDistanceFloats; /** * Vector similarity function; used in search to return top K most similar vectors to a target @@ -35,12 +38,12 @@ public enum VectorSimilarityFunction { EUCLIDEAN { @Override public float compare(float[] v1, float[] v2) { - return normalizeDistanceToUnitInterval(squareDistance(v1, v2)); + return normalizeDistanceToUnitInterval(squareDistanceFloats(v1, v2)); } @Override public float compare(byte[] v1, byte[] v2) { - return 1 / (1f + squareDistance(v1, v2)); + return 1 / (1f + squareDistanceBytes(v1, v2)); } }, @@ -54,7 +57,7 @@ public float compare(byte[] v1, byte[] v2) { DOT_PRODUCT { @Override public float compare(float[] v1, float[] v2) { - return normalizeToUnitInterval(dotProduct(v1, v2)); + return normalizeToUnitInterval(dotProductFloats(v1, v2)); } @Override @@ -72,12 +75,12 @@ public float compare(byte[] v1, byte[] v2) { COSINE { @Override public float compare(float[] v1, float[] v2) { - return normalizeToUnitInterval(cosine(v1, v2)); + return normalizeToUnitInterval(cosineFloats(v1, v2)); } @Override public float compare(byte[] v1, byte[] v2) { - return (1 + cosine(v1, v2)) / 2; + return (1 + cosineBytes(v1, v2)) / 2; } }, @@ -89,12 +92,12 @@ public float compare(byte[] v1, byte[] v2) { MAXIMUM_INNER_PRODUCT { @Override public float compare(float[] v1, float[] v2) { - return scaleMaxInnerProductScore(dotProduct(v1, v2)); + return scaleMaxInnerProductScore(dotProductFloats(v1, v2)); } @Override public float compare(byte[] v1, byte[] v2) { - return scaleMaxInnerProductScore(dotProduct(v1, v2)); + return scaleMaxInnerProductScore(dotProductBytes(v1, v2)); } }; diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java index 3ec646288cdd..8e63579b35ad 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java @@ -17,13 +17,121 @@ package org.apache.lucene.internal.vectorization; +import static java.lang.foreign.ValueLayout.JAVA_BYTE; +import static java.lang.foreign.ValueLayout.JAVA_FLOAT_UNALIGNED; +import static java.lang.foreign.ValueLayout.JAVA_INT_UNALIGNED; +import static java.nio.ByteOrder.LITTLE_ENDIAN; import static org.apache.lucene.util.VectorUtil.EPSILON; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; import org.apache.lucene.util.BitUtil; import org.apache.lucene.util.Constants; import org.apache.lucene.util.SuppressForbidden; -final class DefaultVectorUtilSupport implements VectorUtilSupport { +final class DefaultVectorUtilSupport + implements VectorUtilSupport< + DefaultVectorUtilSupport.IByteVector, DefaultVectorUtilSupport.IFloatVector> { + interface IByteVector { + int length(); + + byte get(int index); + + int getInt(int index); + } + + static class ArrayByteVector implements IByteVector { + protected final byte[] array; + + ArrayByteVector(byte[] array) { + this.array = array; + } + + @Override + public final int length() { + return array.length; + } + + @Override + public final byte get(int index) { + return array[index]; + } + + @Override + public final int getInt(int index) { + return (int) BitUtil.VH_NATIVE_INT.get(array, index); + } + } + + static class MemorySegmentByteVector implements IByteVector { + private static final ValueLayout.OfInt INT_LE = JAVA_INT_UNALIGNED.withOrder(LITTLE_ENDIAN); + + protected final MemorySegment segment; + + MemorySegmentByteVector(MemorySegment segment) { + this.segment = segment; + } + + @Override + public final int length() { + return Math.toIntExact(segment.byteSize()); + } + + @Override + public final byte get(int index) { + return segment.getAtIndex(JAVA_BYTE, index); + } + + @Override + public final int getInt(int index) { + return segment.get(INT_LE, index); + } + } + + interface IFloatVector { + int length(); + + float get(int index); + } + + static class ArrayFloatVector implements IFloatVector { + protected final float[] array; + + ArrayFloatVector(float[] array) { + this.array = array; + } + + @Override + public final int length() { + return array.length; + } + + @Override + public final float get(int index) { + return array[index]; + } + } + + static class MemorySegmentFloatVector implements IFloatVector { + private static final ValueLayout.OfFloat FLOAT_LE = + JAVA_FLOAT_UNALIGNED.withOrder(LITTLE_ENDIAN); + + protected final MemorySegment segment; + + MemorySegmentFloatVector(MemorySegment segment) { + this.segment = segment; + } + + @Override + public final int length() { + return Math.toIntExact(segment.byteSize()); + } + + @Override + public final float get(int index) { + return segment.getAtIndex(FLOAT_LE, index); + } + } DefaultVectorUtilSupport() {} @@ -38,41 +146,63 @@ private static float fma(float a, float b, float c) { } @Override - public float dotProduct(float[] a, float[] b) { + public IByteVector bytesFromArray(byte[] array) { + return new ArrayByteVector(array); + } + + @Override + public IByteVector bytesFromMemorySegment(MemorySegment segment) { + return new MemorySegmentByteVector(segment); + } + + @Override + public IFloatVector floatsFromArray(float[] array) { + return new ArrayFloatVector(array); + } + + @Override + public IFloatVector floatsFromMemorySegment(MemorySegment segment) { + return new MemorySegmentFloatVector(segment); + } + + @Override + public float dotProductFloats(IFloatVector a, IFloatVector b) { float res = 0f; int i = 0; + int length = a.length(); // if the array is big, unroll it - if (a.length > 32) { + if (length > 32) { float acc1 = 0; float acc2 = 0; float acc3 = 0; float acc4 = 0; - int upperBound = a.length & ~(4 - 1); + int upperBound = length & ~(4 - 1); for (; i < upperBound; i += 4) { - acc1 = fma(a[i], b[i], acc1); - acc2 = fma(a[i + 1], b[i + 1], acc2); - acc3 = fma(a[i + 2], b[i + 2], acc3); - acc4 = fma(a[i + 3], b[i + 3], acc4); + acc1 = fma(a.get(i), b.get(i), acc1); + acc2 = fma(a.get(i + 1), b.get(i + 1), acc2); + acc3 = fma(a.get(i + 2), b.get(i + 2), acc3); + acc4 = fma(a.get(i + 3), b.get(i + 3), acc4); } res += acc1 + acc2 + acc3 + acc4; } - for (; i < a.length; i++) { - res = fma(a[i], b[i], res); + for (; i < length; i++) { + res = fma(a.get(i), b.get(i), res); } return res; } @Override - public float cosine(float[] a, float[] b) { + public float cosineFloats(IFloatVector a, IFloatVector b) { float sum = 0.0f; float norm1 = 0.0f; float norm2 = 0.0f; int i = 0; + int length = a.length(); // if the array is big, unroll it - if (a.length > 32) { + if (length > 32) { float sum1 = 0; float sum2 = 0; float norm1_1 = 0; @@ -80,101 +210,102 @@ public float cosine(float[] a, float[] b) { float norm2_1 = 0; float norm2_2 = 0; - int upperBound = a.length & ~(2 - 1); + int upperBound = length & ~(2 - 1); for (; i < upperBound; i += 2) { // one - sum1 = fma(a[i], b[i], sum1); - norm1_1 = fma(a[i], a[i], norm1_1); - norm2_1 = fma(b[i], b[i], norm2_1); + sum1 = fma(a.get(i), b.get(i), sum1); + norm1_1 = fma(a.get(i), a.get(i), norm1_1); + norm2_1 = fma(b.get(i), b.get(i), norm2_1); // two - sum2 = fma(a[i + 1], b[i + 1], sum2); - norm1_2 = fma(a[i + 1], a[i + 1], norm1_2); - norm2_2 = fma(b[i + 1], b[i + 1], norm2_2); + sum2 = fma(a.get(i + 1), b.get(i + 1), sum2); + norm1_2 = fma(a.get(i + 1), a.get(i + 1), norm1_2); + norm2_2 = fma(b.get(i + 1), b.get(i + 1), norm2_2); } sum += sum1 + sum2; norm1 += norm1_1 + norm1_2; norm2 += norm2_1 + norm2_2; } - for (; i < a.length; i++) { - sum = fma(a[i], b[i], sum); - norm1 = fma(a[i], a[i], norm1); - norm2 = fma(b[i], b[i], norm2); + for (; i < length; i++) { + sum = fma(a.get(i), b.get(i), sum); + norm1 = fma(a.get(i), a.get(i), norm1); + norm2 = fma(b.get(i), b.get(i), norm2); } return (float) (sum / Math.sqrt((double) norm1 * (double) norm2)); } @Override - public float squareDistance(float[] a, float[] b) { + public float squareDistanceFloats(IFloatVector a, IFloatVector b) { float res = 0; int i = 0; + int length = a.length(); // if the array is big, unroll it - if (a.length > 32) { + if (length > 32) { float acc1 = 0; float acc2 = 0; float acc3 = 0; float acc4 = 0; - int upperBound = a.length & ~(4 - 1); + int upperBound = length & ~(4 - 1); for (; i < upperBound; i += 4) { // one - float diff1 = a[i] - b[i]; + float diff1 = a.get(i) - b.get(i); acc1 = fma(diff1, diff1, acc1); // two - float diff2 = a[i + 1] - b[i + 1]; + float diff2 = a.get(i + 1) - b.get(i + 1); acc2 = fma(diff2, diff2, acc2); // three - float diff3 = a[i + 2] - b[i + 2]; + float diff3 = a.get(i + 2) - b.get(i + 2); acc3 = fma(diff3, diff3, acc3); // four - float diff4 = a[i + 3] - b[i + 3]; + float diff4 = a.get(i + 3) - b.get(i + 3); acc4 = fma(diff4, diff4, acc4); } res += acc1 + acc2 + acc3 + acc4; } - for (; i < a.length; i++) { - float diff = a[i] - b[i]; + for (; i < length; i++) { + float diff = a.get(i) - b.get(i); res = fma(diff, diff, res); } return res; } @Override - public int dotProduct(byte[] a, byte[] b) { + public int dotProductBytes(IByteVector a, IByteVector b) { int total = 0; - for (int i = 0; i < a.length; i++) { - total += a[i] * b[i]; + for (int i = 0, length = a.length(); i < length; i++) { + total += a.get(i) * b.get(i); } return total; } @Override - public int uint8DotProduct(byte[] a, byte[] b) { + public int uint8DotProduct(IByteVector a, IByteVector b) { int total = 0; - for (int i = 0; i < a.length; i++) { - total += Byte.toUnsignedInt(a[i]) * Byte.toUnsignedInt(b[i]); + for (int i = 0, length = a.length(); i < length; i++) { + total += Byte.toUnsignedInt(a.get(i)) * Byte.toUnsignedInt(b.get(i)); } return total; } @Override - public int int4DotProduct(byte[] a, byte[] b) { - return dotProduct(a, b); + public int int4DotProduct(IByteVector a, IByteVector b) { + return dotProductBytes(a, b); } @Override - public int int4DotProductSinglePacked(byte[] unpacked, byte[] packed) { + public int int4DotProductSinglePacked(IByteVector unpacked, IByteVector packed) { int total = 0; - for (int i = 0; i < packed.length; i++) { - byte packedByte = packed[i]; - byte unpacked1 = unpacked[i]; - byte unpacked2 = unpacked[i + packed.length]; + for (int i = 0, length = packed.length(); i < length; i++) { + byte packedByte = packed.get(i); + byte unpacked1 = unpacked.get(i); + byte unpacked2 = unpacked.get(i + length); total += (packedByte & 0x0F) * unpacked2; total += ((packedByte & 0xFF) >> 4) * unpacked1; } @@ -182,11 +313,11 @@ public int int4DotProductSinglePacked(byte[] unpacked, byte[] packed) { } @Override - public int int4DotProductBothPacked(byte[] a, byte[] b) { + public int int4DotProductBothPacked(IByteVector a, IByteVector b) { int total = 0; - for (int i = 0; i < a.length; i++) { - byte aByte = a[i]; - byte bByte = b[i]; + for (int i = 0, length = a.length(); i < length; i++) { + byte aByte = a.get(i); + byte bByte = b.get(i); total += (aByte & 0x0F) * (bByte & 0x0F); total += ((aByte & 0xFF) >> 4) * ((bByte & 0xFF) >> 4); } @@ -194,15 +325,15 @@ public int int4DotProductBothPacked(byte[] a, byte[] b) { } @Override - public float cosine(byte[] a, byte[] b) { + public float cosineBytes(IByteVector a, IByteVector b) { // Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14. int sum = 0; int norm1 = 0; int norm2 = 0; - for (int i = 0; i < a.length; i++) { - byte elem1 = a[i]; - byte elem2 = b[i]; + for (int i = 0, length = a.length(); i < length; i++) { + byte elem1 = a.get(i); + byte elem2 = b.get(i); sum += elem1 * elem2; norm1 += elem1 * elem1; norm2 += elem2 * elem2; @@ -211,28 +342,28 @@ public float cosine(byte[] a, byte[] b) { } @Override - public int squareDistance(byte[] a, byte[] b) { + public int squareDistanceBytes(IByteVector a, IByteVector b) { // Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14. int squareSum = 0; - for (int i = 0; i < a.length; i++) { - int diff = a[i] - b[i]; + for (int i = 0, length = a.length(); i < length; i++) { + int diff = a.get(i) - b.get(i); squareSum += diff * diff; } return squareSum; } @Override - public int int4SquareDistance(byte[] a, byte[] b) { - return squareDistance(a, b); + public int int4SquareDistance(IByteVector a, IByteVector b) { + return squareDistanceBytes(a, b); } @Override - public int int4SquareDistanceSinglePacked(byte[] unpacked, byte[] packed) { + public int int4SquareDistanceSinglePacked(IByteVector unpacked, IByteVector packed) { int total = 0; - for (int i = 0; i < packed.length; i++) { - byte packedByte = packed[i]; - byte unpacked1 = unpacked[i]; - byte unpacked2 = unpacked[i + packed.length]; + for (int i = 0, length = packed.length(); i < length; i++) { + byte packedByte = packed.get(i); + byte unpacked1 = unpacked.get(i); + byte unpacked2 = unpacked.get(i + length); int diff1 = (packedByte & 0x0F) - unpacked2; int diff2 = ((packedByte & 0xFF) >> 4) - unpacked1; @@ -243,11 +374,11 @@ public int int4SquareDistanceSinglePacked(byte[] unpacked, byte[] packed) { } @Override - public int int4SquareDistanceBothPacked(byte[] a, byte[] b) { + public int int4SquareDistanceBothPacked(IByteVector a, IByteVector b) { int total = 0; - for (int i = 0; i < a.length; i++) { - byte aByte = a[i]; - byte bByte = b[i]; + for (int i = 0, length = a.length(); i < length; i++) { + byte aByte = a.get(i); + byte bByte = b.get(i); int diff1 = (aByte & 0x0F) - (bByte & 0x0F); int diff2 = ((aByte & 0xFF) >> 4) - ((bByte & 0xFF) >> 4); @@ -258,11 +389,11 @@ public int int4SquareDistanceBothPacked(byte[] a, byte[] b) { } @Override - public int uint8SquareDistance(byte[] a, byte[] b) { + public int uint8SquareDistance(IByteVector a, IByteVector b) { // Note: this will not overflow if dim < 2^16, since max(ubyte * ubyte) = 2^16. int squareSum = 0; - for (int i = 0; i < a.length; i++) { - int diff = Byte.toUnsignedInt(a[i]) - Byte.toUnsignedInt(b[i]); + for (int i = 0, length = a.length(); i < length; i++) { + int diff = Byte.toUnsignedInt(a.get(i)) - Byte.toUnsignedInt(b.get(i)); squareSum += diff * diff; } return squareSum; @@ -279,25 +410,22 @@ public int findNextGEQ(int[] buffer, int target, int from, int to) { } @Override - public long int4BitDotProduct(byte[] int4Quantized, byte[] binaryQuantized) { + public long int4BitDotProduct(IByteVector int4Quantized, IByteVector binaryQuantized) { return int4BitDotProductImpl(int4Quantized, binaryQuantized); } - public static long int4BitDotProductImpl(byte[] q, byte[] d) { - assert q.length == d.length * 4; + public static long int4BitDotProductImpl(IByteVector q, IByteVector d) { + assert q.length() == d.length() * 4; long ret = 0; - int size = d.length; + int size = d.length(); for (int i = 0; i < 4; i++) { int r = 0; long subRet = 0; - for (final int upperBound = d.length & -Integer.BYTES; r < upperBound; r += Integer.BYTES) { - subRet += - Integer.bitCount( - (int) BitUtil.VH_NATIVE_INT.get(q, i * size + r) - & (int) BitUtil.VH_NATIVE_INT.get(d, r)); + for (final int upperBound = size & -Integer.BYTES; r < upperBound; r += Integer.BYTES) { + subRet += Integer.bitCount(q.getInt(i * size + r) & d.getInt(r)); } - for (; r < d.length; r++) { - subRet += Integer.bitCount((q[i * size + r] & d[r]) & 0xFF); + for (; r < size; r++) { + subRet += Integer.bitCount((q.get(i * size + r) & d.get(r)) & 0xFF); } ret += subRet << i; } @@ -397,7 +525,8 @@ public int filterByScore( @Override public float[] l2normalize(float[] v, boolean throwOnZero) { - double l1norm = this.dotProduct(v, v); + IFloatVector vector = new ArrayFloatVector(v); + double l1norm = this.dotProductFloats(vector, vector); if (l1norm == 0) { if (throwOnZero) { throw new IllegalArgumentException("Cannot normalize a zero-length vector"); diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java index 21977fa3dc77..1aa33b09ef2c 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java @@ -25,14 +25,14 @@ /** Default provider returning scalar implementations. */ final class DefaultVectorizationProvider extends VectorizationProvider { - private final VectorUtilSupport vectorUtilSupport; + private final VectorUtilSupport vectorUtilSupport; DefaultVectorizationProvider() { vectorUtilSupport = new DefaultVectorUtilSupport(); } @Override - public VectorUtilSupport getVectorUtilSupport() { + public VectorUtilSupport getVectorUtilSupport() { return vectorUtilSupport; } diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java index f92a0b653caa..209ba1ba685f 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java @@ -17,61 +17,214 @@ package org.apache.lucene.internal.vectorization; +import java.lang.foreign.MemorySegment; + /** * Interface for implementations of VectorUtil support. * * @lucene.internal */ -public interface VectorUtilSupport { +public interface VectorUtilSupport { + IByteVector bytesFromArray(byte[] array); + + IByteVector bytesFromMemorySegment(MemorySegment segment); + + IFloatVector floatsFromArray(float[] array); + + IFloatVector floatsFromMemorySegment(MemorySegment segment); + + default float dotProductFloats(float[] a, float[] b) { + return dotProductFloats(floatsFromArray(a), floatsFromArray(b)); + } + + default float dotProductFloats(float[] a, MemorySegment b) { + return dotProductFloats(floatsFromArray(a), floatsFromMemorySegment(b)); + } + + default float dotProductFloats(MemorySegment a, MemorySegment b) { + return dotProductFloats(floatsFromMemorySegment(a), floatsFromMemorySegment(b)); + } /** Calculates the dot product of the given float arrays. */ - float dotProduct(float[] a, float[] b); + float dotProductFloats(IFloatVector a, IFloatVector b); + + default float cosineFloats(float[] a, float[] b) { + return cosineFloats(floatsFromArray(a), floatsFromArray(b)); + } /** Returns the cosine similarity between the two vectors. */ - float cosine(float[] v1, float[] v2); + float cosineFloats(IFloatVector v1, IFloatVector v2); + + default float squareDistanceFloats(float[] a, float[] b) { + return squareDistanceFloats(floatsFromArray(a), floatsFromArray(b)); + } + + default float squareDistanceFloats(float[] a, MemorySegment b) { + return squareDistanceFloats(floatsFromArray(a), floatsFromMemorySegment(b)); + } + + default float squareDistanceFloats(MemorySegment a, MemorySegment b) { + return squareDistanceFloats(floatsFromMemorySegment(a), floatsFromMemorySegment(b)); + } /** Returns the sum of squared differences of the two vectors. */ - float squareDistance(float[] a, float[] b); + float squareDistanceFloats(IFloatVector a, IFloatVector b); + + default int dotProductBytes(byte[] a, byte[] b) { + return dotProductBytes(bytesFromArray(a), bytesFromArray(b)); + } + + default int dotProductBytes(byte[] a, MemorySegment b) { + return dotProductBytes(bytesFromArray(a), bytesFromMemorySegment(b)); + } + + default int dotProductBytes(MemorySegment a, MemorySegment b) { + return dotProductBytes(bytesFromMemorySegment(a), bytesFromMemorySegment(b)); + } /** Returns the dot product computed over signed bytes. */ - int dotProduct(byte[] a, byte[] b); + int dotProductBytes(IByteVector a, IByteVector b); + + default int int4DotProduct(byte[] a, byte[] b) { + return int4DotProduct(bytesFromArray(a), bytesFromArray(b)); + } + + default int int4DotProduct(byte[] a, MemorySegment b) { + return int4DotProduct(bytesFromArray(a), bytesFromMemorySegment(b)); + } + + default int int4DotProduct(MemorySegment a, MemorySegment b) { + return int4DotProduct(bytesFromMemorySegment(a), bytesFromMemorySegment(b)); + } /** Returns the dot product computed over unsigned half-bytes, both uncompressed. */ - int int4DotProduct(byte[] a, byte[] b); + int int4DotProduct(IByteVector a, IByteVector b); + + default int int4DotProductSinglePacked(byte[] a, byte[] b) { + return int4DotProductSinglePacked(bytesFromArray(a), bytesFromArray(b)); + } + + default int int4DotProductSinglePacked(byte[] a, MemorySegment b) { + return int4DotProductSinglePacked(bytesFromArray(a), bytesFromMemorySegment(b)); + } /** Returns the dot product computed over unsigned half-bytes, one compressed. */ - int int4DotProductSinglePacked(byte[] unpacked, byte[] packed); + int int4DotProductSinglePacked(IByteVector unpacked, IByteVector packed); + + default int int4DotProductBothPacked(byte[] a, byte[] b) { + return int4DotProductBothPacked(bytesFromArray(a), bytesFromArray(b)); + } + + default int int4DotProductBothPacked(MemorySegment a, MemorySegment b) { + return int4DotProductBothPacked(bytesFromMemorySegment(a), bytesFromMemorySegment(b)); + } /** Returns the dot product computed over unsigned half-bytes, both compressed. */ - int int4DotProductBothPacked(byte[] a, byte[] b); + int int4DotProductBothPacked(IByteVector a, IByteVector b); + + default int uint8DotProduct(byte[] a, byte[] b) { + return uint8DotProduct(bytesFromArray(a), bytesFromArray(b)); + } + + default int uint8DotProduct(byte[] a, MemorySegment b) { + return uint8DotProduct(bytesFromArray(a), bytesFromMemorySegment(b)); + } + + default int uint8DotProduct(MemorySegment a, MemorySegment b) { + return uint8DotProduct(bytesFromMemorySegment(a), bytesFromMemorySegment(b)); + } /** Returns the dot product computed as though the bytes were unsigned. */ - int uint8DotProduct(byte[] a, byte[] b); + int uint8DotProduct(IByteVector a, IByteVector b); + + default float cosineBytes(byte[] a, byte[] b) { + return cosineBytes(bytesFromArray(a), bytesFromArray(b)); + } + + default float cosineBytes(byte[] a, MemorySegment b) { + return cosineBytes(bytesFromArray(a), bytesFromMemorySegment(b)); + } + + default float cosineBytes(MemorySegment a, MemorySegment b) { + return cosineBytes(bytesFromMemorySegment(a), bytesFromMemorySegment(b)); + } /** Returns the cosine similarity between the two byte vectors. */ - float cosine(byte[] a, byte[] b); + float cosineBytes(IByteVector a, IByteVector b); + + default int squareDistanceBytes(byte[] a, byte[] b) { + return squareDistanceBytes(bytesFromArray(a), bytesFromArray(b)); + } + + default int squareDistanceBytes(byte[] a, MemorySegment b) { + return squareDistanceBytes(bytesFromArray(a), bytesFromMemorySegment(b)); + } + + default int squareDistanceBytes(MemorySegment a, MemorySegment b) { + return squareDistanceBytes(bytesFromMemorySegment(a), bytesFromMemorySegment(b)); + } /** Returns the sum of squared differences of the two byte vectors. */ - int squareDistance(byte[] a, byte[] b); + int squareDistanceBytes(IByteVector a, IByteVector b); + + default int int4SquareDistance(byte[] a, byte[] b) { + return int4SquareDistance(bytesFromArray(a), bytesFromArray(b)); + } + + default int int4SquareDistance(byte[] a, MemorySegment b) { + return int4SquareDistance(bytesFromArray(a), bytesFromMemorySegment(b)); + } + + default int int4SquareDistance(MemorySegment a, MemorySegment b) { + return int4SquareDistance(bytesFromMemorySegment(a), bytesFromMemorySegment(b)); + } /** * Returns the sum of squared differences between two unsigned half-byte vectors, both * uncompressed. */ - int int4SquareDistance(byte[] a, byte[] b); + int int4SquareDistance(IByteVector a, IByteVector b); + + default int int4SquareDistanceSinglePacked(byte[] a, byte[] b) { + return int4SquareDistanceSinglePacked(bytesFromArray(a), bytesFromArray(b)); + } + + default int int4SquareDistanceSinglePacked(byte[] a, MemorySegment b) { + return int4SquareDistanceSinglePacked(bytesFromArray(a), bytesFromMemorySegment(b)); + } /** * Returns the sum of squared differences between two unsigned half-byte vectors, one compressed. */ - int int4SquareDistanceSinglePacked(byte[] unpacked, byte[] packed); + int int4SquareDistanceSinglePacked(IByteVector unpacked, IByteVector packed); + + default int int4SquareDistanceBothPacked(byte[] a, byte[] b) { + return int4SquareDistanceBothPacked(bytesFromArray(a), bytesFromArray(b)); + } + + default int int4SquareDistanceBothPacked(MemorySegment a, MemorySegment b) { + return int4SquareDistanceBothPacked(bytesFromMemorySegment(a), bytesFromMemorySegment(b)); + } /** * Returns the sum of squared differences between two unsigned half-byte vectors, both compressed. */ - int int4SquareDistanceBothPacked(byte[] a, byte[] b); + int int4SquareDistanceBothPacked(IByteVector a, IByteVector b); + + default int uint8SquareDistance(byte[] a, byte[] b) { + return uint8SquareDistance(bytesFromArray(a), bytesFromArray(b)); + } + + default int uint8SquareDistance(byte[] a, MemorySegment b) { + return uint8SquareDistance(bytesFromArray(a), bytesFromMemorySegment(b)); + } + + default int uint8SquareDistance(MemorySegment a, MemorySegment b) { + return uint8SquareDistance(bytesFromMemorySegment(a), bytesFromMemorySegment(b)); + } /** Returns the sum of squared differences of the two unsigned byte vectors. */ - int uint8SquareDistance(byte[] a, byte[] b); + int uint8SquareDistance(IByteVector a, IByteVector b); /** * Given an array {@code buffer} that is sorted between indexes {@code 0} inclusive and {@code to} @@ -81,6 +234,10 @@ public interface VectorUtilSupport { */ int findNextGEQ(int[] buffer, int target, int from, int to); + default long int4BitDotProduct(byte[] a, byte[] b) { + return int4BitDotProduct(bytesFromArray(a), bytesFromArray(b)); + } + /** * Compute the dot product between a quantized int4 vector and a binary quantized vector. It is * assumed that the int4 quantized bits are packed in the byte array in the same way as the {@link @@ -92,7 +249,7 @@ public interface VectorUtilSupport { * @param binaryQuantized byte packed binary quantized vector * @return the dot product */ - long int4BitDotProduct(byte[] int4Quantized, byte[] binaryQuantized); + long int4BitDotProduct(IByteVector int4Quantized, IByteVector binaryQuantized); /** * Quantizes {@code vector}, putting the result into {@code dest}. diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java index cf9c56c59774..43d8bbc371ba 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java @@ -104,7 +104,7 @@ public static VectorizationProvider getInstance() { * Returns a singleton (stateless) {@link VectorUtilSupport} to support SIMD usage in {@link * VectorUtil}. */ - public abstract VectorUtilSupport getVectorUtilSupport(); + public abstract VectorUtilSupport getVectorUtilSupport(); /** Returns a FlatVectorsScorer that supports the Lucene99 format. */ public abstract FlatVectorsScorer getLucene99FlatVectorsScorer(); diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index a410d2707530..465083c4faaf 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -17,6 +17,7 @@ package org.apache.lucene.util; +import java.lang.foreign.MemorySegment; import java.util.stream.IntStream; import org.apache.lucene.internal.vectorization.VectorUtilSupport; import org.apache.lucene.internal.vectorization.VectorizationProvider; @@ -50,7 +51,7 @@ public final class VectorUtil { public static final float EPSILON = 1e-4f; - private static final VectorUtilSupport IMPL = + private static final VectorUtilSupport IMPL = VectorizationProvider.getInstance().getVectorUtilSupport(); private VectorUtil() {} @@ -60,11 +61,11 @@ private VectorUtil() {} * * @throws IllegalArgumentException if the vectors' dimensions differ. */ - public static float dotProduct(float[] a, float[] b) { + public static float dotProductFloats(float[] a, float[] b) { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - float r = IMPL.dotProduct(a, b); + float r = IMPL.dotProductFloats(a, b); assert Float.isFinite(r) : "not finite: " + r @@ -76,51 +77,82 @@ public static float dotProduct(float[] a, float[] b) { return r; } - /** - * Returns the cosine similarity between the two vectors. - * - * @throws IllegalArgumentException if the vectors' dimensions differ. - */ - public static float cosine(float[] a, float[] b) { + /** Returns the cosine similarity between two vectors, both on-heap. */ + public static float cosineFloats(float[] a, float[] b) { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - float r = IMPL.cosine(a, b); + float r = IMPL.cosineFloats(a, b); assert Float.isFinite(r); return r; } - /** Returns the cosine similarity between the two vectors. */ - public static float cosine(byte[] a, byte[] b) { + /** Returns the cosine similarity between two vectors, both on-heap. */ + public static float cosineBytes(byte[] a, byte[] b) { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return IMPL.cosine(a, b); + return IMPL.cosineBytes(a, b); } - /** - * Returns the sum of squared differences of the two vectors. - * - * @throws IllegalArgumentException if the vectors' dimensions differ. - */ - public static float squareDistance(float[] a, float[] b) { + /** Returns the cosine similarity between two vectors, one on-heap and one off-heap. */ + public static float cosineBytes(byte[] a, MemorySegment b) { + if (a.length != b.byteSize()) { + throw new IllegalArgumentException( + "vector dimensions differ: " + a.length + "!=" + b.byteSize()); + } + return IMPL.cosineBytes(a, b); + } + + /** Returns the cosine similarity between two vectors, both off-heap. */ + public static float cosineBytes(MemorySegment a, MemorySegment b) { + if (a.byteSize() != b.byteSize()) { + throw new IllegalArgumentException( + "vector dimensions differ: " + a.byteSize() + "!=" + b.byteSize()); + } + return IMPL.cosineBytes(a, b); + } + + /** Returns the sum of squared differences of the two vectors, both on-heap. */ + public static float squareDistanceFloats(float[] a, float[] b) { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - float r = IMPL.squareDistance(a, b); + float r = IMPL.squareDistanceFloats(a, b); assert Float.isFinite(r); return r; } - /** Returns the sum of squared differences of the two vectors. */ - public static int squareDistance(byte[] a, byte[] b) { + /** Returns the sum of squared differences of the two vectors, both on-heap. */ + public static int squareDistanceBytes(byte[] a, byte[] b) { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return IMPL.squareDistance(a, b); + return IMPL.squareDistanceBytes(a, b); + } + + /** Returns the sum of squared differences of the two vectors, one on-heap and one off-heap. */ + public static int squareDistanceBytes(byte[] a, MemorySegment b) { + if (a.length != b.byteSize()) { + throw new IllegalArgumentException( + "vector dimensions differ: " + a.length + "!=" + b.byteSize()); + } + return IMPL.squareDistanceBytes(a, b); } - /** Returns the sum of squared differences between two uint4 (values between [0,15]) vectors. */ + /** Returns the sum of squared differences of the two vectors, both off-heap. */ + public static int squareDistanceBytes(MemorySegment a, MemorySegment b) { + if (a.byteSize() != b.byteSize()) { + throw new IllegalArgumentException( + "vector dimensions differ: " + a.byteSize() + "!=" + b.byteSize()); + } + return IMPL.squareDistanceBytes(a, b); + } + + /** + * Returns the sum of squared differences between two uint4 (values between [0,15]) vectors, both + * on-heap. + */ public static int int4SquareDistance(byte[] a, byte[] b) { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); @@ -129,8 +161,32 @@ public static int int4SquareDistance(byte[] a, byte[] b) { } /** - * Returns the sum of squared differences between two uint4 (values between [0,15]) vectors. The - * second vector is considered "packed" (i.e. every byte representing two values). + * Returns the sum of squared differences between two uint4 (values between [0,15]) vectors, one + * on-heap and one off-heap. + */ + public static int int4SquareDistance(byte[] a, MemorySegment b) { + if (a.length != b.byteSize()) { + throw new IllegalArgumentException( + "vector dimensions differ: " + a.length + "!=" + b.byteSize()); + } + return IMPL.int4SquareDistance(a, b); + } + + /** + * Returns the sum of squared differences between two uint4 (values between [0,15]) vectors, both + * off-heap. + */ + public static int int4SquareDistance(MemorySegment a, MemorySegment b) { + if (a.byteSize() != b.byteSize()) { + throw new IllegalArgumentException( + "vector dimensions differ: " + a.byteSize() + "!=" + b.byteSize()); + } + return IMPL.int4SquareDistance(a, b); + } + + /** + * Returns the sum of squared differences between two uint4 (values between [0,15]) vectors, both + * on-heap. The second vector is considered "packed" (i.e. every byte representing two values). */ public static int int4SquareDistanceSinglePacked(byte[] unpacked, byte[] packed) { if (packed.length != ((unpacked.length + 1) >> 1)) { @@ -141,9 +197,18 @@ public static int int4SquareDistanceSinglePacked(byte[] unpacked, byte[] packed) } /** - * Returns the sum of squared differences between two uint4 (values between [0,15]) vectors. Both - * vectors are considered "packed" (i.e. every byte representing two values). + * Returns the sum of squared differences between two uint4 (values between [0,15]) vectors, one + * on-heap and one off-heap. The second vector is considered "packed" (i.e. every byte + * representing two values). */ + public static int int4SquareDistanceSinglePacked(byte[] unpacked, MemorySegment packed) { + if (packed.byteSize() != ((unpacked.length + 1) >> 1)) { + throw new IllegalArgumentException( + "vector dimensions differ: " + unpacked.length + "!= 2 * " + packed.byteSize()); + } + return IMPL.int4SquareDistanceSinglePacked(unpacked, packed); + } + public static int int4SquareDistanceBothPacked(byte[] a, byte[] b) { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); @@ -151,7 +216,22 @@ public static int int4SquareDistanceBothPacked(byte[] a, byte[] b) { return IMPL.int4SquareDistanceBothPacked(a, b); } - /** Returns the sum of squared differences of the two vectors where each byte is unsigned */ + /** + * Returns the sum of squared differences between two uint4 (values between [0,15]) vectors, both + * off-heap. Both vectors are considered "packed" (i.e. every byte representing two values). + */ + public static int int4SquareDistanceBothPacked(MemorySegment a, MemorySegment b) { + if (a.byteSize() != b.byteSize()) { + throw new IllegalArgumentException( + "vector dimensions differ: " + a.byteSize() + "!=" + b.byteSize()); + } + return IMPL.int4SquareDistanceBothPacked(a, b); + } + + /** + * Returns the sum of squared differences of the two vectors where each byte is unsigned, both + * on-heap. + */ public static int uint8SquareDistance(byte[] a, byte[] b) { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); @@ -159,6 +239,30 @@ public static int uint8SquareDistance(byte[] a, byte[] b) { return IMPL.uint8SquareDistance(a, b); } + /** + * Returns the sum of squared differences of the two vectors where each byte is unsigned, one + * on-heap and one off-heap. + */ + public static int uint8SquareDistance(byte[] a, MemorySegment b) { + if (a.length != b.byteSize()) { + throw new IllegalArgumentException( + "vector dimensions differ: " + a.length + "!=" + b.byteSize()); + } + return IMPL.uint8SquareDistance(a, b); + } + + /** + * Returns the sum of squared differences of the two vectors where each byte is unsigned, both + * off-heap. + */ + public static int uint8SquareDistance(MemorySegment a, MemorySegment b) { + if (a.byteSize() != b.byteSize()) { + throw new IllegalArgumentException( + "vector dimensions differ: " + a.byteSize() + "!=" + b.byteSize()); + } + return IMPL.uint8SquareDistance(a, b); + } + /** * Modifies the argument to be unit length, dividing by its l2-norm. IllegalArgumentException is * thrown for zero vectors. @@ -171,7 +275,7 @@ public static float[] l2normalize(float[] v) { } public static boolean isUnitVector(float[] v) { - double l1norm = IMPL.dotProduct(v, v); + double l1norm = IMPL.dotProductFloats(v, v); return Math.abs(l1norm - 1.0d) <= EPSILON; } @@ -199,27 +303,33 @@ public static void add(float[] u, float[] v) { } } - /** - * Dot product computed over signed bytes. - * - * @param a bytes containing a vector - * @param b bytes containing another vector, of the same dimension - * @return the value of the dot product of the two vectors - */ - public static int dotProduct(byte[] a, byte[] b) { + /** Dot product computed over signed bytes, both on-heap. */ + public static int dotProductBytes(byte[] a, byte[] b) { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return IMPL.dotProduct(a, b); + return IMPL.dotProductBytes(a, b); } - /** - * Dot product over bytes assuming that the values are actually unsigned. - * - * @param a uint8 byte vector - * @param b another uint8 byte vector of the same dimension - * @return the value of the dot product of the two vectors - */ + /** Dot product computed over signed bytes, one on-heap and one off-heap. */ + public static int dotProductBytes(byte[] a, MemorySegment b) { + if (a.length != b.byteSize()) { + throw new IllegalArgumentException( + "vector dimensions differ: " + a.length + "!=" + b.byteSize()); + } + return IMPL.dotProductBytes(a, b); + } + + /** Dot product computed over signed bytes, both off-heap. */ + public static int dotProductBytes(MemorySegment a, MemorySegment b) { + if (a.byteSize() != b.byteSize()) { + throw new IllegalArgumentException( + "vector dimensions differ: " + a.byteSize() + "!=" + b.byteSize()); + } + return IMPL.dotProductBytes(a, b); + } + + /** Dot product over bytes assuming that the values are actually unsigned, both on-heap. */ public static int uint8DotProduct(byte[] a, byte[] b) { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); @@ -228,12 +338,27 @@ public static int uint8DotProduct(byte[] a, byte[] b) { } /** - * Dot product computed over uint4 (values between [0,15]) bytes. - * - * @param a bytes containing a vector - * @param b bytes containing another vector, of the same dimension - * @return the value of the dot product of the two vectors + * Dot product over bytes assuming that the values are actually unsigned, one on-heap and one + * off-heap. */ + public static int uint8DotProduct(byte[] a, MemorySegment b) { + if (a.length != b.byteSize()) { + throw new IllegalArgumentException( + "vector dimensions differ: " + a.length + "!=" + b.byteSize()); + } + return IMPL.uint8DotProduct(a, b); + } + + /** Dot product over bytes assuming that the values are actually unsigned, both off-heap. */ + public static int uint8DotProduct(MemorySegment a, MemorySegment b) { + if (a.byteSize() != b.byteSize()) { + throw new IllegalArgumentException( + "vector dimensions differ: " + a.byteSize() + "!=" + b.byteSize()); + } + return IMPL.uint8DotProduct(a, b); + } + + /** Dot product computed over uint4 (values between [0,15]) bytes, both on-heap. */ public static int int4DotProduct(byte[] a, byte[] b) { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); @@ -242,8 +367,29 @@ public static int int4DotProduct(byte[] a, byte[] b) { } /** - * Dot product computed over uint4 (values between [0,15]) bytes. The second vector is considered - * "packed" (i.e. every byte representing two values). The following packing is assumed: + * Dot product computed over uint4 (values between [0,15]) bytes, one on-heap and one off-heap. + */ + public static int int4DotProduct(byte[] a, MemorySegment b) { + if (a.length != b.byteSize()) { + throw new IllegalArgumentException( + "vector dimensions differ: " + a.length + "!=" + b.byteSize()); + } + return IMPL.int4DotProduct(a, b); + } + + /** Dot product computed over uint4 (values between [0,15]) bytes, both off-heap. */ + public static int int4DotProduct(MemorySegment a, MemorySegment b) { + if (a.byteSize() != b.byteSize()) { + throw new IllegalArgumentException( + "vector dimensions differ: " + a.byteSize() + "!=" + b.byteSize()); + } + return IMPL.int4DotProduct(a, b); + } + + /** + * Dot product computed over uint4 (values between [0,15]) bytes, both on-heap. The second vector + * is considered "packed" (i.e. every byte representing two values). The following packing is + * assumed: * *
    *   packed[0] = (raw[0] * 16) | raw[packed.length];
@@ -265,13 +411,33 @@ public static int int4DotProductSinglePacked(byte[] unpacked, byte[] packed) {
   }
 
   /**
-   * Dot product computed over uint4 (values between [0,15]) bytes. Both vectors are considered
-   * "packed" (i.e. every byte representing two values).
+   * Dot product computed over uint4 (values between [0,15]) bytes, one on-heap and one off-heap.
+   * The second vector is considered "packed" (i.e. every byte representing two values). The
+   * following packing is assumed:
+   *
+   * 
+   *   packed[0] = (raw[0] * 16) | raw[packed.length];
+   *   packed[1] = (raw[1] * 16) | raw[packed.length + 1];
+   *   ...
+   *   packed[packed.length - 1] = (raw[packed.length - 1] * 16) | raw[2 * packed.length - 1];
+   * 
* - * @param a bytes containing a packed vector - * @param b bytes containing another packed vector, of the same dimension + * @param unpacked the unpacked vector, of even length + * @param packed the packed vector, of length {@code (unpacked.length + 1) / 2} * @return the value of the dot product of the two vectors */ + public static int int4DotProductSinglePacked(byte[] unpacked, MemorySegment packed) { + if (packed.byteSize() != ((unpacked.length + 1) >> 1)) { + throw new IllegalArgumentException( + "vector dimensions differ: " + unpacked.length + " != 2 * " + packed.byteSize()); + } + return IMPL.int4DotProductSinglePacked(unpacked, packed); + } + + /** + * Dot product computed over uint4 (values between [0,15]) bytes, both on-heap. Both vectors are + * considered "packed" (i.e. every byte representing two values). + */ public static int int4DotProductBothPacked(byte[] a, byte[] b) { if (a.length != b.length) { throw new IllegalArgumentException( @@ -280,6 +446,18 @@ public static int int4DotProductBothPacked(byte[] a, byte[] b) { return IMPL.int4DotProductBothPacked(a, b); } + /** + * Dot product computed over uint4 (values between [0,15]) bytes, both off-heap. Both vectors are + * considered "packed" (i.e. every byte representing two values). + */ + public static int int4DotProductBothPacked(MemorySegment a, MemorySegment b) { + if (a.byteSize() != b.byteSize()) { + throw new IllegalArgumentException( + "vector dimensions differ: " + a.byteSize() + " != " + b.byteSize()); + } + return IMPL.int4DotProductBothPacked(a, b); + } + /** * Dot product computed over int4 (values between [0,15]) bytes and a binary vector. * @@ -361,7 +539,7 @@ static int xorBitCountLong(byte[] a, byte[] b) { public static float dotProductScore(byte[] a, byte[] b) { // divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len float denom = (float) (a.length * (1 << 15)); - return 0.5f + dotProduct(a, b) / denom; + return 0.5f + dotProductBytes(a, b) / denom; } /** diff --git a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java b/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java index 8f60de6cda13..98dc3bcc7dcd 100644 --- a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java +++ b/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java @@ -20,6 +20,8 @@ import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.lucene95.HasIndexSlice; +import org.apache.lucene.codecs.lucene99.Lucene99MemorySegmentByteVectorScorer; +import org.apache.lucene.codecs.lucene99.Lucene99MemorySegmentByteVectorScorerSupplier; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.KnnVectorValues; diff --git a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java b/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java index 12b95f6c2ff2..1b05929ebbee 100644 --- a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java @@ -219,32 +219,32 @@ public float score(int node) throws IOException { @Override int euclidean(MemorySegment doc) { - return PanamaVectorUtilSupport.uint8SquareDistance(targetBytes, doc); + return VectorUtil.uint8SquareDistance(targetBytes, doc); } @Override int int4Euclidean(MemorySegment doc) { - return PanamaVectorUtilSupport.int4SquareDistance(targetBytes, doc); + return VectorUtil.int4SquareDistance(targetBytes, doc); } @Override int compressedInt4Euclidean(MemorySegment doc) { - return PanamaVectorUtilSupport.int4SquareDistanceSinglePacked(targetBytes, doc); + return VectorUtil.int4SquareDistanceSinglePacked(targetBytes, doc); } @Override int dotProduct(MemorySegment doc) { - return PanamaVectorUtilSupport.uint8DotProduct(targetBytes, doc); + return VectorUtil.uint8DotProduct(targetBytes, doc); } @Override int int4DotProduct(MemorySegment doc) { - return PanamaVectorUtilSupport.int4DotProduct(targetBytes, doc); + return VectorUtil.int4DotProduct(targetBytes, doc); } @Override int compressedInt4DotProduct(MemorySegment doc) { - return PanamaVectorUtilSupport.int4DotProductSinglePacked(targetBytes, doc); + return VectorUtil.int4DotProductSinglePacked(targetBytes, doc); } } @@ -292,32 +292,32 @@ public float score(int node) throws IOException { @Override int euclidean(MemorySegment doc) { - return PanamaVectorUtilSupport.uint8SquareDistance(query, doc); + return VectorUtil.uint8SquareDistance(query, doc); } @Override int int4Euclidean(MemorySegment doc) { - return PanamaVectorUtilSupport.int4SquareDistance(query, doc); + return VectorUtil.int4SquareDistance(query, doc); } @Override int compressedInt4Euclidean(MemorySegment doc) { - return PanamaVectorUtilSupport.int4SquareDistanceBothPacked(query, doc); + return VectorUtil.int4SquareDistanceBothPacked(query, doc); } @Override int dotProduct(MemorySegment doc) { - return PanamaVectorUtilSupport.uint8DotProduct(query, doc); + return VectorUtil.uint8DotProduct(query, doc); } @Override int int4DotProduct(MemorySegment doc) { - return PanamaVectorUtilSupport.int4DotProduct(query, doc); + return VectorUtil.int4DotProduct(query, doc); } @Override int compressedInt4DotProduct(MemorySegment doc) { - return PanamaVectorUtilSupport.int4DotProductBothPacked(query, doc); + return VectorUtil.int4DotProductBothPacked(query, doc); } } } diff --git a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java b/lucene/core/src/java25/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java index 390dc97e8f85..2758d75f47bc 100644 --- a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java +++ b/lucene/core/src/java25/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java @@ -16,7 +16,6 @@ */ package org.apache.lucene.internal.vectorization; -import static java.lang.foreign.ValueLayout.JAVA_BYTE; import static java.nio.ByteOrder.LITTLE_ENDIAN; import static jdk.incubator.vector.VectorOperators.ADD; import static jdk.incubator.vector.VectorOperators.B2I; @@ -54,7 +53,9 @@ * * Setting these properties will make this code run EXTREMELY slow! */ -final class PanamaVectorUtilSupport implements VectorUtilSupport { +final class PanamaVectorUtilSupport + implements VectorUtilSupport< + PanamaVectorUtilSupport.IByteVector, PanamaVectorUtilSupport.IFloatVector> { // preferred vector sizes, which can be altered for testing private static final VectorSpecies FLOAT_SPECIES; @@ -86,6 +87,62 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { } } + interface IByteVector extends DefaultVectorUtilSupport.IByteVector { + ByteVector get(VectorSpecies species, int index); // additional vectorized impl + } + + static class ArrayByteVector extends DefaultVectorUtilSupport.ArrayByteVector + implements IByteVector { + ArrayByteVector(byte[] array) { + super(array); + } + + @Override + public final ByteVector get(VectorSpecies species, int index) { + return ByteVector.fromArray(species, array, index); + } + } + + static class MemorySegmentByteVector extends DefaultVectorUtilSupport.MemorySegmentByteVector + implements IByteVector { + MemorySegmentByteVector(MemorySegment segment) { + super(segment); + } + + @Override + public final ByteVector get(VectorSpecies species, int index) { + return ByteVector.fromMemorySegment(species, segment, index, LITTLE_ENDIAN); + } + } + + interface IFloatVector extends DefaultVectorUtilSupport.IFloatVector { + FloatVector get(VectorSpecies species, int index); // additional vectorized impl + } + + static class ArrayFloatVector extends DefaultVectorUtilSupport.ArrayFloatVector + implements IFloatVector { + ArrayFloatVector(float[] array) { + super(array); + } + + @Override + public final FloatVector get(VectorSpecies species, int index) { + return FloatVector.fromArray(species, array, index); + } + } + + static class MemorySegmentFloatVector extends DefaultVectorUtilSupport.MemorySegmentFloatVector + implements IFloatVector { + MemorySegmentFloatVector(MemorySegment segment) { + super(segment); + } + + @Override + public final FloatVector get(VectorSpecies species, int index) { + return FloatVector.fromMemorySegment(species, segment, index, LITTLE_ENDIAN); + } + } + // the way FMA should work! if available use it, otherwise fall back to mul/add @SuppressForbidden(reason = "Uses FMA only where fast and carefully contained") static FloatVector fma(FloatVector a, FloatVector b, FloatVector c) { @@ -106,25 +163,46 @@ static float fma(float a, float b, float c) { } @Override - public float dotProduct(float[] a, float[] b) { + public IByteVector bytesFromArray(byte[] array) { + return new ArrayByteVector(array); + } + + @Override + public IByteVector bytesFromMemorySegment(MemorySegment segment) { + return new MemorySegmentByteVector(segment); + } + + @Override + public IFloatVector floatsFromArray(float[] array) { + return new ArrayFloatVector(array); + } + + @Override + public IFloatVector floatsFromMemorySegment(MemorySegment segment) { + return new MemorySegmentFloatVector(segment); + } + + @Override + public float dotProductFloats(IFloatVector a, IFloatVector b) { int i = 0; float res = 0; + int length = a.length(); // if the array size is large (> 2x platform vector size), it's worth the overhead to vectorize - if (a.length > 2 * FLOAT_SPECIES.length()) { - i += FLOAT_SPECIES.loopBound(a.length); + if (length > 2 * FLOAT_SPECIES.length()) { + i += FLOAT_SPECIES.loopBound(length); res += dotProductBody(a, b, i); } // scalar tail - for (; i < a.length; i++) { - res = fma(a[i], b[i], res); + for (; i < length; i++) { + res = fma(a.get(i), b.get(i), res); } return res; } /** vectorized float dot product body */ - private float dotProductBody(float[] a, float[] b, int limit) { + private float dotProductBody(IFloatVector a, IFloatVector b, int limit) { int i = 0; // vector loop is unrolled 4x (4 accumulators in parallel) // we don't know how many the cpu can do at once, some can do 2, some 4 @@ -135,29 +213,29 @@ private float dotProductBody(float[] a, float[] b, int limit) { int unrolledLimit = limit - 3 * FLOAT_SPECIES.length(); for (; i < unrolledLimit; i += 4 * FLOAT_SPECIES.length()) { // one - FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i); - FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i); + FloatVector va = a.get(FLOAT_SPECIES, i); + FloatVector vb = b.get(FLOAT_SPECIES, i); acc1 = fma(va, vb, acc1); // two - FloatVector vc = FloatVector.fromArray(FLOAT_SPECIES, a, i + FLOAT_SPECIES.length()); - FloatVector vd = FloatVector.fromArray(FLOAT_SPECIES, b, i + FLOAT_SPECIES.length()); + FloatVector vc = a.get(FLOAT_SPECIES, i + FLOAT_SPECIES.length()); + FloatVector vd = b.get(FLOAT_SPECIES, i + FLOAT_SPECIES.length()); acc2 = fma(vc, vd, acc2); // three - FloatVector ve = FloatVector.fromArray(FLOAT_SPECIES, a, i + 2 * FLOAT_SPECIES.length()); - FloatVector vf = FloatVector.fromArray(FLOAT_SPECIES, b, i + 2 * FLOAT_SPECIES.length()); + FloatVector ve = a.get(FLOAT_SPECIES, i + 2 * FLOAT_SPECIES.length()); + FloatVector vf = b.get(FLOAT_SPECIES, i + 2 * FLOAT_SPECIES.length()); acc3 = fma(ve, vf, acc3); // four - FloatVector vg = FloatVector.fromArray(FLOAT_SPECIES, a, i + 3 * FLOAT_SPECIES.length()); - FloatVector vh = FloatVector.fromArray(FLOAT_SPECIES, b, i + 3 * FLOAT_SPECIES.length()); + FloatVector vg = a.get(FLOAT_SPECIES, i + 3 * FLOAT_SPECIES.length()); + FloatVector vh = b.get(FLOAT_SPECIES, i + 3 * FLOAT_SPECIES.length()); acc4 = fma(vg, vh, acc4); } // vector tail: less scalar computations for unaligned sizes, esp with big vector sizes for (; i < limit; i += FLOAT_SPECIES.length()) { - FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i); - FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i); + FloatVector va = a.get(FLOAT_SPECIES, i); + FloatVector vb = b.get(FLOAT_SPECIES, i); acc1 = fma(va, vb, acc1); } // reduce @@ -167,15 +245,16 @@ private float dotProductBody(float[] a, float[] b, int limit) { } @Override - public float cosine(float[] a, float[] b) { + public float cosineFloats(IFloatVector a, IFloatVector b) { int i = 0; float sum = 0; float norm1 = 0; float norm2 = 0; + int length = a.length(); // if the array size is large (> 2x platform vector size), it's worth the overhead to vectorize - if (a.length > 2 * FLOAT_SPECIES.length()) { - i += FLOAT_SPECIES.loopBound(a.length); + if (length > 2 * FLOAT_SPECIES.length()) { + i += FLOAT_SPECIES.loopBound(length); float[] ret = cosineBody(a, b, i); sum += ret[0]; norm1 += ret[1]; @@ -183,16 +262,16 @@ public float cosine(float[] a, float[] b) { } // scalar tail - for (; i < a.length; i++) { - sum = fma(a[i], b[i], sum); - norm1 = fma(a[i], a[i], norm1); - norm2 = fma(b[i], b[i], norm2); + for (; i < length; i++) { + sum = fma(a.get(i), b.get(i), sum); + norm1 = fma(a.get(i), a.get(i), norm1); + norm2 = fma(b.get(i), b.get(i), norm2); } return (float) (sum / Math.sqrt((double) norm1 * (double) norm2)); } /** vectorized cosine body */ - private float[] cosineBody(float[] a, float[] b, int limit) { + private float[] cosineBody(IFloatVector a, IFloatVector b, int limit) { int i = 0; // vector loop is unrolled 2x (2 accumulators in parallel) // each iteration has 3 FMAs, so its a lot already, no need to unroll more @@ -205,23 +284,23 @@ private float[] cosineBody(float[] a, float[] b, int limit) { int unrolledLimit = limit - FLOAT_SPECIES.length(); for (; i < unrolledLimit; i += 2 * FLOAT_SPECIES.length()) { // one - FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i); - FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i); + FloatVector va = a.get(FLOAT_SPECIES, i); + FloatVector vb = b.get(FLOAT_SPECIES, i); sum1 = fma(va, vb, sum1); norm1_1 = fma(va, va, norm1_1); norm2_1 = fma(vb, vb, norm2_1); // two - FloatVector vc = FloatVector.fromArray(FLOAT_SPECIES, a, i + FLOAT_SPECIES.length()); - FloatVector vd = FloatVector.fromArray(FLOAT_SPECIES, b, i + FLOAT_SPECIES.length()); + FloatVector vc = a.get(FLOAT_SPECIES, i + FLOAT_SPECIES.length()); + FloatVector vd = b.get(FLOAT_SPECIES, i + FLOAT_SPECIES.length()); sum2 = fma(vc, vd, sum2); norm1_2 = fma(vc, vc, norm1_2); norm2_2 = fma(vd, vd, norm2_2); } // vector tail: less scalar computations for unaligned sizes, esp with big vector sizes for (; i < limit; i += FLOAT_SPECIES.length()) { - FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i); - FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i); + FloatVector va = a.get(FLOAT_SPECIES, i); + FloatVector vb = b.get(FLOAT_SPECIES, i); sum1 = fma(va, vb, sum1); norm1_1 = fma(va, va, norm1_1); norm2_1 = fma(vb, vb, norm2_1); @@ -234,26 +313,27 @@ private float[] cosineBody(float[] a, float[] b, int limit) { } @Override - public float squareDistance(float[] a, float[] b) { + public float squareDistanceFloats(IFloatVector a, IFloatVector b) { int i = 0; float res = 0; + int length = a.length(); // if the array size is large (> 2x platform vector size), it's worth the overhead to vectorize - if (a.length > 2 * FLOAT_SPECIES.length()) { - i += FLOAT_SPECIES.loopBound(a.length); + if (length > 2 * FLOAT_SPECIES.length()) { + i += FLOAT_SPECIES.loopBound(length); res += squareDistanceBody(a, b, i); } // scalar tail - for (; i < a.length; i++) { - float diff = a[i] - b[i]; + for (; i < length; i++) { + float diff = a.get(i) - b.get(i); res = fma(diff, diff, res); } return res; } /** vectorized square distance body */ - private float squareDistanceBody(float[] a, float[] b, int limit) { + private float squareDistanceBody(IFloatVector a, IFloatVector b, int limit) { int i = 0; // vector loop is unrolled 4x (4 accumulators in parallel) // we don't know how many the cpu can do at once, some can do 2, some 4 @@ -264,33 +344,33 @@ private float squareDistanceBody(float[] a, float[] b, int limit) { int unrolledLimit = limit - 3 * FLOAT_SPECIES.length(); for (; i < unrolledLimit; i += 4 * FLOAT_SPECIES.length()) { // one - FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i); - FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i); + FloatVector va = a.get(FLOAT_SPECIES, i); + FloatVector vb = b.get(FLOAT_SPECIES, i); FloatVector diff1 = va.sub(vb); acc1 = fma(diff1, diff1, acc1); // two - FloatVector vc = FloatVector.fromArray(FLOAT_SPECIES, a, i + FLOAT_SPECIES.length()); - FloatVector vd = FloatVector.fromArray(FLOAT_SPECIES, b, i + FLOAT_SPECIES.length()); + FloatVector vc = a.get(FLOAT_SPECIES, i + FLOAT_SPECIES.length()); + FloatVector vd = b.get(FLOAT_SPECIES, i + FLOAT_SPECIES.length()); FloatVector diff2 = vc.sub(vd); acc2 = fma(diff2, diff2, acc2); // three - FloatVector ve = FloatVector.fromArray(FLOAT_SPECIES, a, i + 2 * FLOAT_SPECIES.length()); - FloatVector vf = FloatVector.fromArray(FLOAT_SPECIES, b, i + 2 * FLOAT_SPECIES.length()); + FloatVector ve = a.get(FLOAT_SPECIES, i + 2 * FLOAT_SPECIES.length()); + FloatVector vf = b.get(FLOAT_SPECIES, i + 2 * FLOAT_SPECIES.length()); FloatVector diff3 = ve.sub(vf); acc3 = fma(diff3, diff3, acc3); // four - FloatVector vg = FloatVector.fromArray(FLOAT_SPECIES, a, i + 3 * FLOAT_SPECIES.length()); - FloatVector vh = FloatVector.fromArray(FLOAT_SPECIES, b, i + 3 * FLOAT_SPECIES.length()); + FloatVector vg = a.get(FLOAT_SPECIES, i + 3 * FLOAT_SPECIES.length()); + FloatVector vh = b.get(FLOAT_SPECIES, i + 3 * FLOAT_SPECIES.length()); FloatVector diff4 = vg.sub(vh); acc4 = fma(diff4, diff4, acc4); } // vector tail: less scalar computations for unaligned sizes, esp with big vector sizes for (; i < limit; i += FLOAT_SPECIES.length()) { - FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i); - FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i); + FloatVector va = a.get(FLOAT_SPECIES, i); + FloatVector vb = b.get(FLOAT_SPECIES, i); FloatVector diff = va.sub(vb); acc1 = fma(diff, diff, acc1); } @@ -312,121 +392,59 @@ private float squareDistanceBody(float[] a, float[] b, int limit) { // We also support 128 bit vectors, going 32 bits at a time. // This is slower but still faster than not vectorizing at all. - private interface ByteVectorLoader { - int length(); - - ByteVector load(VectorSpecies species, int index); - - byte tail(int index); - } - - private record ArrayLoader(byte[] arr) implements ByteVectorLoader { - @Override - public int length() { - return arr.length; - } - - @Override - public ByteVector load(VectorSpecies species, int index) { - assert index + species.length() <= length(); - return ByteVector.fromArray(species, arr, index); - } - - @Override - public byte tail(int index) { - assert index <= length(); - return arr[index]; - } - } - - private record MemorySegmentLoader(MemorySegment segment) implements ByteVectorLoader { - @Override - public int length() { - return Math.toIntExact(segment.byteSize()); - } - - @Override - public ByteVector load(VectorSpecies species, int index) { - assert index + species.length() <= length(); - return ByteVector.fromMemorySegment(species, segment, index, LITTLE_ENDIAN); - } - - @Override - public byte tail(int index) { - assert index <= length(); - return segment.get(JAVA_BYTE, index); - } - } - @Override - public int dotProduct(byte[] a, byte[] b) { - return dotProductBody(new ArrayLoader(a), new ArrayLoader(b), true); + public int dotProductBytes(IByteVector a, IByteVector b) { + return dotProductBody(a, b, true); } @Override - public int uint8DotProduct(byte[] a, byte[] b) { - return dotProductBody(new ArrayLoader(a), new ArrayLoader(b), false); - } - - public static int dotProduct(byte[] a, MemorySegment b) { - return dotProductBody(new ArrayLoader(a), new MemorySegmentLoader(b), true); - } - - public static int dotProduct(MemorySegment a, MemorySegment b) { - return dotProductBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b), true); - } - - public static int uint8DotProduct(byte[] a, MemorySegment b) { - return dotProductBody(new ArrayLoader(a), new MemorySegmentLoader(b), false); + public int uint8DotProduct(IByteVector a, IByteVector b) { + return dotProductBody(a, b, false); } - public static int uint8DotProduct(MemorySegment a, MemorySegment b) { - return dotProductBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b), false); - } - - private static int dotProductBody(ByteVectorLoader a, ByteVectorLoader b, boolean signed) { + private static int dotProductBody(IByteVector a, IByteVector b, boolean signed) { assert a.length() == b.length(); int i = 0; int res = 0; + int length = a.length(); // only vectorize if we'll at least enter the loop a single time - if (a.length() >= 16) { + if (length >= 16) { // compute vectorized dot product consistent with VPDPBUSD instruction if (VECTOR_BITSIZE >= 512) { - i += BYTE_SPECIES.loopBound(a.length()); + i += BYTE_SPECIES.loopBound(length); res += dotProductBody512(a, b, i, signed); } else if (VECTOR_BITSIZE == 256) { - i += BYTE_SPECIES.loopBound(a.length()); + i += BYTE_SPECIES.loopBound(length); res += dotProductBody256(a, b, i, signed); } else { // tricky: we don't have SPECIES_32, so we workaround with "overlapping read" - i += ByteVector.SPECIES_64.loopBound(a.length() - ByteVector.SPECIES_64.length()); + i += ByteVector.SPECIES_64.loopBound(length - ByteVector.SPECIES_64.length()); res += dotProductBody128(a, b, i, signed); } } // scalar tail if (signed) { - for (; i < a.length(); i++) { - res += a.tail(i) * b.tail(i); + for (; i < length; i++) { + res += a.get(i) * b.get(i); } } else { - for (; i < a.length(); i++) { - res += Byte.toUnsignedInt(a.tail(i)) * Byte.toUnsignedInt(b.tail(i)); + for (; i < length; i++) { + res += Byte.toUnsignedInt(a.get(i)) * Byte.toUnsignedInt(b.get(i)); } } return res; } /** vectorized dot product body (512 bit vectors) */ - private static int dotProductBody512( - ByteVectorLoader a, ByteVectorLoader b, int limit, boolean signed) { + private static int dotProductBody512(IByteVector a, IByteVector b, int limit, boolean signed) { IntVector acc = IntVector.zero(INT_SPECIES); var conversion_short = signed ? B2S : ZERO_EXTEND_B2S; var conversion_int = signed ? S2I : ZERO_EXTEND_S2I; for (int i = 0; i < limit; i += BYTE_SPECIES.length()) { - ByteVector va8 = a.load(BYTE_SPECIES, i); - ByteVector vb8 = b.load(BYTE_SPECIES, i); + ByteVector va8 = a.get(BYTE_SPECIES, i); + ByteVector vb8 = b.get(BYTE_SPECIES, i); // 16-bit multiply: avoid AVX-512 heavy multiply on zmm Vector va16 = va8.convertShape(conversion_short, SHORT_SPECIES, 0); @@ -442,13 +460,12 @@ private static int dotProductBody512( } /** vectorized dot product body (256 bit vectors) */ - private static int dotProductBody256( - ByteVectorLoader a, ByteVectorLoader b, int limit, boolean signed) { + private static int dotProductBody256(IByteVector a, IByteVector b, int limit, boolean signed) { IntVector acc = IntVector.zero(IntVector.SPECIES_256); var conversion = signed ? B2I : ZERO_EXTEND_B2I; for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) { - ByteVector va8 = a.load(ByteVector.SPECIES_64, i); - ByteVector vb8 = b.load(ByteVector.SPECIES_64, i); + ByteVector va8 = a.get(ByteVector.SPECIES_64, i); + ByteVector vb8 = b.get(ByteVector.SPECIES_64, i); // 32-bit multiply and add into accumulator Vector va32 = va8.convertShape(conversion, IntVector.SPECIES_256, 0); @@ -460,16 +477,15 @@ private static int dotProductBody256( } /** vectorized dot product body (128 bit vectors) */ - private static int dotProductBody128( - ByteVectorLoader a, ByteVectorLoader b, int limit, boolean signed) { + private static int dotProductBody128(IByteVector a, IByteVector b, int limit, boolean signed) { IntVector acc = IntVector.zero(IntVector.SPECIES_128); var conversion_short = signed ? B2S : ZERO_EXTEND_B2S; var conversion_int = signed ? S2I : ZERO_EXTEND_S2I; // 4 bytes at a time (re-loading half the vector each time!) for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) { // load 8 bytes - ByteVector va8 = a.load(ByteVector.SPECIES_64, i); - ByteVector vb8 = b.load(ByteVector.SPECIES_64, i); + ByteVector va8 = a.get(ByteVector.SPECIES_64, i); + ByteVector vb8 = b.get(ByteVector.SPECIES_64, i); // process first "half" only: 16-bit multiply Vector va16 = va8.convert(conversion_short, 0); @@ -506,33 +522,27 @@ private static class Int4Constants { } @Override - public int int4DotProduct(byte[] a, byte[] b) { - return int4DotProductBody(new ArrayLoader(a), new ArrayLoader(b)); - } - - public static int int4DotProduct(byte[] a, MemorySegment b) { - return int4DotProductBody(new ArrayLoader(a), new MemorySegmentLoader(b)); + public int int4DotProduct(IByteVector a, IByteVector b) { + return int4DotProductBody(a, b); } - public static int int4DotProduct(MemorySegment a, MemorySegment b) { - return int4DotProductBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b)); - } - - private static int int4DotProductBody(ByteVectorLoader a, ByteVectorLoader b) { + private static int int4DotProductBody(IByteVector a, IByteVector b) { int i = 0; int res = 0; - if (a.length() >= 32) { - i += Int4Constants.BYTE_SPECIES.loopBound(a.length()); + int length = a.length(); + + if (length >= 32) { + i += Int4Constants.BYTE_SPECIES.loopBound(length); res += int4DotProductBody(a, b, i); } // scalar tail - for (; i < a.length(); i++) { - res += a.tail(i) * b.tail(i); + for (; i < length; i++) { + res += a.get(i) * b.get(i); } return res; } - private static int int4DotProductBody(ByteVectorLoader a, ByteVectorLoader b, int limit) { + private static int int4DotProductBody(IByteVector a, IByteVector b, int limit) { int sum = 0; // iterate in chunks to ensure we don't overflow the short accumulator for (int i = 0; i < limit; i += Int4Constants.CHUNK) { @@ -540,11 +550,11 @@ private static int int4DotProductBody(ByteVectorLoader a, ByteVectorLoader b, in int innerLimit = Math.min(limit - i, Int4Constants.CHUNK); for (int j = 0; j < innerLimit; j += Int4Constants.BYTE_SPECIES.length()) { // unpacked - ByteVector vb8 = b.load(Int4Constants.BYTE_SPECIES, i + j); + ByteVector vb8 = b.get(Int4Constants.BYTE_SPECIES, i + j); Vector vb16 = vb8.convertShape(B2S, Int4Constants.SHORT_SPECIES, 0); // unpacked - ByteVector va8 = a.load(Int4Constants.BYTE_SPECIES, i + j); + ByteVector va8 = a.get(Int4Constants.BYTE_SPECIES, i + j); Vector va16 = va8.convertShape(B2S, Int4Constants.SHORT_SPECIES, 0); acc = acc.add(vb16.mul(va16)); @@ -557,28 +567,24 @@ private static int int4DotProductBody(ByteVectorLoader a, ByteVectorLoader b, in } @Override - public int int4DotProductSinglePacked(byte[] unpacked, byte[] packed) { - return int4DotProductSinglePackedBody(new ArrayLoader(unpacked), new ArrayLoader(packed)); + public int int4DotProductSinglePacked(IByteVector unpacked, IByteVector packed) { + return int4DotProductSinglePackedBody(unpacked, packed); } - public static int int4DotProductSinglePacked(byte[] unpacked, MemorySegment packed) { - return int4DotProductSinglePackedBody( - new ArrayLoader(unpacked), new MemorySegmentLoader(packed)); - } - - private static int int4DotProductSinglePackedBody( - ByteVectorLoader unpacked, ByteVectorLoader packed) { + private static int int4DotProductSinglePackedBody(IByteVector unpacked, IByteVector packed) { int i = 0; int res = 0; - if (packed.length() >= 32) { - i += Int4Constants.BYTE_SPECIES.loopBound(packed.length()); + int length = packed.length(); + + if (length >= 32) { + i += Int4Constants.BYTE_SPECIES.loopBound(length); res += int4DotProductSinglePackedBody(unpacked, packed, i); } // scalar tail - for (; i < packed.length(); i++) { - byte packedByte = packed.tail(i); - byte unpacked1 = unpacked.tail(i); - byte unpacked2 = unpacked.tail(i + packed.length()); + for (; i < length; i++) { + byte packedByte = packed.get(i); + byte unpacked1 = unpacked.get(i); + byte unpacked2 = unpacked.get(i + length); res += (packedByte & 0x0F) * unpacked2; res += ((packedByte & 0xFF) >> 4) * unpacked1; } @@ -586,7 +592,7 @@ private static int int4DotProductSinglePackedBody( } private static int int4DotProductSinglePackedBody( - ByteVectorLoader unpacked, ByteVectorLoader packed, int limit) { + IByteVector unpacked, IByteVector packed, int limit) { int sum = 0; // iterate in chunks to ensure we don't overflow the short accumulator for (int i = 0; i < limit; i += Int4Constants.CHUNK) { @@ -595,16 +601,16 @@ private static int int4DotProductSinglePackedBody( int innerLimit = Math.min(limit - i, Int4Constants.CHUNK); for (int j = 0; j < innerLimit; j += Int4Constants.BYTE_SPECIES.length()) { // packed - ByteVector vb8 = packed.load(Int4Constants.BYTE_SPECIES, i + j); + ByteVector vb8 = packed.get(Int4Constants.BYTE_SPECIES, i + j); // upper - ByteVector va8 = unpacked.load(Int4Constants.BYTE_SPECIES, i + j + packed.length()); + ByteVector va8 = unpacked.get(Int4Constants.BYTE_SPECIES, i + j + packed.length()); ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); Vector prod16 = prod8.convertShape(ZERO_EXTEND_B2S, Int4Constants.SHORT_SPECIES, 0); acc0 = acc0.add(prod16); // lower - ByteVector vc8 = unpacked.load(Int4Constants.BYTE_SPECIES, i + j); + ByteVector vc8 = unpacked.get(Int4Constants.BYTE_SPECIES, i + j); ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8); Vector prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, Int4Constants.SHORT_SPECIES, 0); @@ -620,33 +626,30 @@ private static int int4DotProductSinglePackedBody( } @Override - public int int4DotProductBothPacked(byte[] a, byte[] b) { - return int4DotProductBothPackedBody(new ArrayLoader(a), new ArrayLoader(b)); + public int int4DotProductBothPacked(IByteVector a, IByteVector b) { + return int4DotProductBothPackedBody(a, b); } - public static int int4DotProductBothPacked(MemorySegment a, MemorySegment b) { - return int4DotProductBothPackedBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b)); - } - - private static int int4DotProductBothPackedBody(ByteVectorLoader a, ByteVectorLoader b) { + private static int int4DotProductBothPackedBody(IByteVector a, IByteVector b) { int i = 0; int res = 0; - if (a.length() >= 32) { - i += Int4Constants.BYTE_SPECIES.loopBound(a.length()); + int length = a.length(); + + if (length >= 32) { + i += Int4Constants.BYTE_SPECIES.loopBound(length); res += int4DotProductBothPackedBody(a, b, i); } // scalar tail - for (; i < a.length(); i++) { - byte aByte = a.tail(i); - byte bByte = b.tail(i); + for (; i < length; i++) { + byte aByte = a.get(i); + byte bByte = b.get(i); res += (aByte & 0x0F) * (bByte & 0x0F); res += ((aByte & 0xFF) >> 4) * ((bByte & 0xFF) >> 4); } return res; } - private static int int4DotProductBothPackedBody( - ByteVectorLoader a, ByteVectorLoader b, int limit) { + private static int int4DotProductBothPackedBody(IByteVector a, IByteVector b, int limit) { int sum = 0; // iterate in chunks to ensure we don't overflow the short accumulator for (int i = 0; i < limit; i += Int4Constants.CHUNK) { @@ -655,9 +658,9 @@ private static int int4DotProductBothPackedBody( int innerLimit = Math.min(limit - i, Int4Constants.CHUNK); for (int j = 0; j < innerLimit; j += Int4Constants.BYTE_SPECIES.length()) { // packed - var vb8 = b.load(Int4Constants.BYTE_SPECIES, i + j); + var vb8 = b.get(Int4Constants.BYTE_SPECIES, i + j); // packed - var va8 = a.load(Int4Constants.BYTE_SPECIES, i + j); + var va8 = a.get(Int4Constants.BYTE_SPECIES, i + j); // upper ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8.and((byte) 0x0F)); @@ -680,36 +683,29 @@ private static int int4DotProductBothPackedBody( } @Override - public float cosine(byte[] a, byte[] b) { - return cosineBody(new ArrayLoader(a), new ArrayLoader(b)); - } - - public static float cosine(MemorySegment a, MemorySegment b) { - return cosineBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b)); - } - - public static float cosine(byte[] a, MemorySegment b) { - return cosineBody(new ArrayLoader(a), new MemorySegmentLoader(b)); + public float cosineBytes(IByteVector a, IByteVector b) { + return cosineBody(a, b); } - private static float cosineBody(ByteVectorLoader a, ByteVectorLoader b) { + private static float cosineBody(IByteVector a, IByteVector b) { int i = 0; int sum = 0; int norm1 = 0; int norm2 = 0; + int length = a.length(); // only vectorize if we'll at least enter the loop a single time - if (a.length() >= 16) { + if (length >= 16) { final int[] ret; if (VECTOR_BITSIZE >= 512) { - i += BYTE_SPECIES.loopBound(a.length()); + i += BYTE_SPECIES.loopBound(length); ret = cosineBody512(a, b, i); } else if (VECTOR_BITSIZE == 256) { - i += BYTE_SPECIES.loopBound(a.length()); + i += BYTE_SPECIES.loopBound(length); ret = cosineBody256(a, b, i); } else { // tricky: we don't have SPECIES_32, so we workaround with "overlapping read" - i += ByteVector.SPECIES_64.loopBound(a.length() - ByteVector.SPECIES_64.length()); + i += ByteVector.SPECIES_64.loopBound(length - ByteVector.SPECIES_64.length()); ret = cosineBody128(a, b, i); } sum += ret[0]; @@ -718,9 +714,9 @@ private static float cosineBody(ByteVectorLoader a, ByteVectorLoader b) { } // scalar tail - for (; i < a.length(); i++) { - byte elem1 = a.tail(i); - byte elem2 = b.tail(i); + for (; i < length; i++) { + byte elem1 = a.get(i); + byte elem2 = b.get(i); sum += elem1 * elem2; norm1 += elem1 * elem1; norm2 += elem2 * elem2; @@ -729,13 +725,13 @@ private static float cosineBody(ByteVectorLoader a, ByteVectorLoader b) { } /** vectorized cosine body (512 bit vectors) */ - private static int[] cosineBody512(ByteVectorLoader a, ByteVectorLoader b, int limit) { + private static int[] cosineBody512(IByteVector a, IByteVector b, int limit) { IntVector accSum = IntVector.zero(INT_SPECIES); IntVector accNorm1 = IntVector.zero(INT_SPECIES); IntVector accNorm2 = IntVector.zero(INT_SPECIES); for (int i = 0; i < limit; i += BYTE_SPECIES.length()) { - ByteVector va8 = a.load(BYTE_SPECIES, i); - ByteVector vb8 = b.load(BYTE_SPECIES, i); + ByteVector va8 = a.get(BYTE_SPECIES, i); + ByteVector vb8 = b.get(BYTE_SPECIES, i); // 16-bit multiply: avoid AVX-512 heavy multiply on zmm Vector va16 = va8.convertShape(B2S, SHORT_SPECIES, 0); @@ -759,13 +755,13 @@ private static int[] cosineBody512(ByteVectorLoader a, ByteVectorLoader b, int l } /** vectorized cosine body (256 bit vectors) */ - private static int[] cosineBody256(ByteVectorLoader a, ByteVectorLoader b, int limit) { + private static int[] cosineBody256(IByteVector a, IByteVector b, int limit) { IntVector accSum = IntVector.zero(IntVector.SPECIES_256); IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_256); IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_256); for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) { - ByteVector va8 = a.load(ByteVector.SPECIES_64, i); - ByteVector vb8 = b.load(ByteVector.SPECIES_64, i); + ByteVector va8 = a.get(ByteVector.SPECIES_64, i); + ByteVector vb8 = b.get(ByteVector.SPECIES_64, i); // 16-bit multiply, and add into accumulators Vector va32 = va8.convertShape(B2I, IntVector.SPECIES_256, 0); @@ -784,13 +780,13 @@ private static int[] cosineBody256(ByteVectorLoader a, ByteVectorLoader b, int l } /** vectorized cosine body (128 bit vectors) */ - private static int[] cosineBody128(ByteVectorLoader a, ByteVectorLoader b, int limit) { + private static int[] cosineBody128(IByteVector a, IByteVector b, int limit) { IntVector accSum = IntVector.zero(IntVector.SPECIES_128); IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_128); IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_128); for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) { - ByteVector va8 = a.load(ByteVector.SPECIES_64, i); - ByteVector vb8 = b.load(ByteVector.SPECIES_64, i); + ByteVector va8 = a.get(ByteVector.SPECIES_64, i); + ByteVector vb8 = b.get(ByteVector.SPECIES_64, i); // process first half only: 16-bit multiply Vector va16 = va8.convert(B2S, 0); @@ -811,56 +807,41 @@ private static int[] cosineBody128(ByteVectorLoader a, ByteVectorLoader b, int l } @Override - public int squareDistance(byte[] a, byte[] b) { - return squareDistanceBody(new ArrayLoader(a), new ArrayLoader(b), true); + public int squareDistanceBytes(IByteVector a, IByteVector b) { + return squareDistanceBody(a, b, true); } @Override - public int uint8SquareDistance(byte[] a, byte[] b) { - return squareDistanceBody(new ArrayLoader(a), new ArrayLoader(b), false); + public int uint8SquareDistance(IByteVector a, IByteVector b) { + return squareDistanceBody(a, b, false); } - public static int squareDistance(MemorySegment a, MemorySegment b) { - return squareDistanceBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b), true); - } - - public static int squareDistance(byte[] a, MemorySegment b) { - return squareDistanceBody(new ArrayLoader(a), new MemorySegmentLoader(b), true); - } - - public static int uint8SquareDistance(MemorySegment a, MemorySegment b) { - return squareDistanceBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b), false); - } - - public static int uint8SquareDistance(byte[] a, MemorySegment b) { - return squareDistanceBody(new ArrayLoader(a), new MemorySegmentLoader(b), false); - } - - private static int squareDistanceBody(ByteVectorLoader a, ByteVectorLoader b, boolean signed) { + private static int squareDistanceBody(IByteVector a, IByteVector b, boolean signed) { assert a.length() == b.length(); int i = 0; int res = 0; + int length = a.length(); // only vectorize if we'll at least enter the loop a single time - if (a.length() >= 16) { + if (length >= 16) { if (VECTOR_BITSIZE >= 256) { - i += BYTE_SPECIES.loopBound(a.length()); + i += BYTE_SPECIES.loopBound(length); res += squareDistanceBody256(a, b, i, signed); } else { - i += ByteVector.SPECIES_64.loopBound(a.length()); + i += ByteVector.SPECIES_64.loopBound(length); res += squareDistanceBody128(a, b, i, signed); } } // scalar tail if (signed) { - for (; i < a.length(); i++) { - int diff = a.tail(i) - b.tail(i); + for (; i < length; i++) { + int diff = a.get(i) - b.get(i); res += diff * diff; } } else { - for (; i < a.length(); i++) { - int diff = Byte.toUnsignedInt(a.tail(i)) - Byte.toUnsignedInt(b.tail(i)); + for (; i < length; i++) { + int diff = Byte.toUnsignedInt(a.get(i)) - Byte.toUnsignedInt(b.get(i)); res += diff * diff; } } @@ -869,12 +850,12 @@ private static int squareDistanceBody(ByteVectorLoader a, ByteVectorLoader b, bo /** vectorized square distance body (256+ bit vectors) */ private static int squareDistanceBody256( - ByteVectorLoader a, ByteVectorLoader b, int limit, boolean signed) { + IByteVector a, IByteVector b, int limit, boolean signed) { IntVector acc = IntVector.zero(INT_SPECIES); var conversion = signed ? B2I : ZERO_EXTEND_B2I; for (int i = 0; i < limit; i += BYTE_SPECIES.length()) { - ByteVector va8 = a.load(BYTE_SPECIES, i); - ByteVector vb8 = b.load(BYTE_SPECIES, i); + ByteVector va8 = a.get(BYTE_SPECIES, i); + ByteVector vb8 = b.get(BYTE_SPECIES, i); // 32-bit sub, multiply, and add into accumulators // TODO: uses AVX-512 heavy multiply on zmm, should we just use 256-bit vectors on AVX-512? @@ -889,15 +870,15 @@ private static int squareDistanceBody256( /** vectorized square distance body (128 bit vectors) */ private static int squareDistanceBody128( - ByteVectorLoader a, ByteVectorLoader b, int limit, boolean signed) { + IByteVector a, IByteVector b, int limit, boolean signed) { // 128-bit implementation, which must "split up" vectors due to widening conversions // it doesn't help to do the overlapping read trick, due to 32-bit multiply in the formula IntVector acc1 = IntVector.zero(IntVector.SPECIES_128); IntVector acc2 = IntVector.zero(IntVector.SPECIES_128); var conversion_short = signed ? B2S : ZERO_EXTEND_B2S; for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) { - ByteVector va8 = a.load(ByteVector.SPECIES_64, i); - ByteVector vb8 = b.load(ByteVector.SPECIES_64, i); + ByteVector va8 = a.get(ByteVector.SPECIES_64, i); + ByteVector vb8 = b.get(ByteVector.SPECIES_64, i); // 16-bit sub Vector va16 = va8.convertShape(conversion_short, ShortVector.SPECIES_128, 0); @@ -915,34 +896,28 @@ private static int squareDistanceBody128( } @Override - public int int4SquareDistance(byte[] a, byte[] b) { - return int4SquareDistanceBody(new ArrayLoader(a), new ArrayLoader(b)); - } - - public static int int4SquareDistance(byte[] a, MemorySegment b) { - return int4SquareDistanceBody(new ArrayLoader(a), new MemorySegmentLoader(b)); + public int int4SquareDistance(IByteVector a, IByteVector b) { + return int4SquareDistanceBody(a, b); } - public static int int4SquareDistance(MemorySegment a, MemorySegment b) { - return int4SquareDistanceBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b)); - } - - private static int int4SquareDistanceBody(ByteVectorLoader a, ByteVectorLoader b) { + private static int int4SquareDistanceBody(IByteVector a, IByteVector b) { int i = 0; int res = 0; - if (a.length() >= 32) { - i += Int4Constants.BYTE_SPECIES.loopBound(a.length()); + int length = a.length(); + + if (length >= 32) { + i += Int4Constants.BYTE_SPECIES.loopBound(length); res += int4SquareDistanceBody(a, b, i); } // scalar tail - for (; i < a.length(); i++) { - int diff = a.tail(i) - b.tail(i); + for (; i < length; i++) { + int diff = a.get(i) - b.get(i); res += diff * diff; } return res; } - private static int int4SquareDistanceBody(ByteVectorLoader a, ByteVectorLoader b, int limit) { + private static int int4SquareDistanceBody(IByteVector a, IByteVector b, int limit) { int sum = 0; // iterate in chunks to ensure we don't overflow the short accumulator for (int i = 0; i < limit; i += Int4Constants.CHUNK) { @@ -950,9 +925,9 @@ private static int int4SquareDistanceBody(ByteVectorLoader a, ByteVectorLoader b int innerLimit = Math.min(limit - i, Int4Constants.CHUNK); for (int j = 0; j < innerLimit; j += Int4Constants.BYTE_SPECIES.length()) { // unpacked - var vb8 = b.load(Int4Constants.BYTE_SPECIES, i + j); + var vb8 = b.get(Int4Constants.BYTE_SPECIES, i + j); // unpacked - var va8 = a.load(Int4Constants.BYTE_SPECIES, i + j); + var va8 = a.get(Int4Constants.BYTE_SPECIES, i + j); ByteVector diff8 = vb8.sub(va8); Vector diff16 = diff8.convertShape(B2S, Int4Constants.SHORT_SPECIES, 0); @@ -966,27 +941,24 @@ private static int int4SquareDistanceBody(ByteVectorLoader a, ByteVectorLoader b } @Override - public int int4SquareDistanceSinglePacked(byte[] a, byte[] b) { - return int4SquareDistanceSinglePackedBody(new ArrayLoader(a), new ArrayLoader(b)); + public int int4SquareDistanceSinglePacked(IByteVector a, IByteVector b) { + return int4SquareDistanceSinglePackedBody(a, b); } - public static int int4SquareDistanceSinglePacked(byte[] a, MemorySegment b) { - return int4SquareDistanceSinglePackedBody(new ArrayLoader(a), new MemorySegmentLoader(b)); - } - - private static int int4SquareDistanceSinglePackedBody( - ByteVectorLoader unpacked, ByteVectorLoader packed) { + private static int int4SquareDistanceSinglePackedBody(IByteVector unpacked, IByteVector packed) { int i = 0; int res = 0; - if (packed.length() >= 32) { - i += Int4Constants.BYTE_SPECIES.loopBound(packed.length()); + int length = packed.length(); + + if (length >= 32) { + i += Int4Constants.BYTE_SPECIES.loopBound(length); res += int4SquareDistanceSinglePackedBody(unpacked, packed, i); } // scalar tail - for (; i < packed.length(); i++) { - byte packedByte = packed.tail(i); - byte unpacked1 = unpacked.tail(i); - byte unpacked2 = unpacked.tail(i + packed.length()); + for (; i < length; i++) { + byte packedByte = packed.get(i); + byte unpacked1 = unpacked.get(i); + byte unpacked2 = unpacked.get(i + length); int diff1 = (packedByte & 0x0F) - unpacked2; int diff2 = ((packedByte & 0xFF) >> 4) - unpacked1; @@ -997,7 +969,7 @@ private static int int4SquareDistanceSinglePackedBody( } private static int int4SquareDistanceSinglePackedBody( - ByteVectorLoader unpacked, ByteVectorLoader packed, int limit) { + IByteVector unpacked, IByteVector packed, int limit) { int sum = 0; // iterate in chunks to ensure we don't overflow the short accumulator for (int i = 0; i < limit; i += Int4Constants.CHUNK) { @@ -1006,16 +978,16 @@ private static int int4SquareDistanceSinglePackedBody( int innerLimit = Math.min(limit - i, Int4Constants.CHUNK); for (int j = 0; j < innerLimit; j += Int4Constants.BYTE_SPECIES.length()) { // packed - ByteVector vb8 = packed.load(Int4Constants.BYTE_SPECIES, i + j); + ByteVector vb8 = packed.get(Int4Constants.BYTE_SPECIES, i + j); // upper - ByteVector va8 = unpacked.load(Int4Constants.BYTE_SPECIES, i + j + packed.length()); + ByteVector va8 = unpacked.get(Int4Constants.BYTE_SPECIES, i + j + packed.length()); ByteVector diff8 = vb8.and((byte) 0x0F).sub(va8); Vector diff16 = diff8.convertShape(B2S, Int4Constants.SHORT_SPECIES, 0); acc0 = acc0.add(diff16.mul(diff16)); // lower - ByteVector vc8 = unpacked.load(Int4Constants.BYTE_SPECIES, i + j); + ByteVector vc8 = unpacked.get(Int4Constants.BYTE_SPECIES, i + j); ByteVector diff8a = vb8.lanewise(LSHR, 4).sub(vc8); Vector diff16a = diff8a.convertShape(B2S, Int4Constants.SHORT_SPECIES, 0); acc1 = acc1.add(diff16a.mul(diff16a)); @@ -1030,25 +1002,23 @@ private static int int4SquareDistanceSinglePackedBody( } @Override - public int int4SquareDistanceBothPacked(byte[] a, byte[] b) { - return int4SquareDistanceBothPackedBody(new ArrayLoader(a), new ArrayLoader(b)); + public int int4SquareDistanceBothPacked(IByteVector a, IByteVector b) { + return int4SquareDistanceBothPackedBody(a, b); } - public static int int4SquareDistanceBothPacked(MemorySegment a, MemorySegment b) { - return int4SquareDistanceBothPackedBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b)); - } - - private static int int4SquareDistanceBothPackedBody(ByteVectorLoader a, ByteVectorLoader b) { + private static int int4SquareDistanceBothPackedBody(IByteVector a, IByteVector b) { int i = 0; int res = 0; - if (a.length() >= 32) { - i += Int4Constants.BYTE_SPECIES.loopBound(a.length()); + int length = a.length(); + + if (length >= 32) { + i += Int4Constants.BYTE_SPECIES.loopBound(length); res += int4SquareDistanceBothPackedBody(a, b, i); } // scalar tail - for (; i < a.length(); i++) { - byte aByte = a.tail(i); - byte bByte = b.tail(i); + for (; i < length; i++) { + byte aByte = a.get(i); + byte bByte = b.get(i); int diff1 = (aByte & 0x0F) - (bByte & 0x0F); int diff2 = ((aByte & 0xFF) >> 4) - ((bByte & 0xFF) >> 4); @@ -1058,8 +1028,7 @@ private static int int4SquareDistanceBothPackedBody(ByteVectorLoader a, ByteVect return res; } - private static int int4SquareDistanceBothPackedBody( - ByteVectorLoader a, ByteVectorLoader b, int limit) { + private static int int4SquareDistanceBothPackedBody(IByteVector a, IByteVector b, int limit) { int sum = 0; // iterate in chunks to ensure we don't overflow the short accumulator for (int i = 0; i < limit; i += Int4Constants.CHUNK) { @@ -1068,9 +1037,9 @@ private static int int4SquareDistanceBothPackedBody( int innerLimit = Math.min(limit - i, Int4Constants.CHUNK); for (int j = 0; j < innerLimit; j += Int4Constants.BYTE_SPECIES.length()) { // packed - var vb8 = b.load(Int4Constants.BYTE_SPECIES, i + j); + var vb8 = b.get(Int4Constants.BYTE_SPECIES, i + j); // packed - var va8 = a.load(Int4Constants.BYTE_SPECIES, i + j); + var va8 = a.get(Int4Constants.BYTE_SPECIES, i + j); // upper ByteVector diff8 = vb8.and((byte) 0x0F).sub(va8.and((byte) 0x0F)); @@ -1119,10 +1088,10 @@ public int findNextGEQ(int[] buffer, int target, int from, int to) { } @Override - public long int4BitDotProduct(byte[] q, byte[] d) { - assert q.length == d.length * 4; + public long int4BitDotProduct(IByteVector q, IByteVector d) { + assert q.length() == d.length() * 4; // 128 / 8 == 16 - if (d.length >= 16) { + if (d.length() >= 16) { if (VECTOR_BITSIZE >= 256) { return int4BitDotProduct256(q, d); } else if (VECTOR_BITSIZE == 128) { @@ -1132,25 +1101,26 @@ public long int4BitDotProduct(byte[] q, byte[] d) { return DefaultVectorUtilSupport.int4BitDotProductImpl(q, d); } - static long int4BitDotProduct256(byte[] q, byte[] d) { + static long int4BitDotProduct256(IByteVector q, IByteVector d) { long subRet0 = 0; long subRet1 = 0; long subRet2 = 0; long subRet3 = 0; int i = 0; + int length = d.length(); - if (d.length >= ByteVector.SPECIES_256.vectorByteSize() * 2) { - int limit = ByteVector.SPECIES_256.loopBound(d.length); + if (length >= ByteVector.SPECIES_256.vectorByteSize() * 2) { + int limit = ByteVector.SPECIES_256.loopBound(length); var sum0 = LongVector.zero(LongVector.SPECIES_256); var sum1 = LongVector.zero(LongVector.SPECIES_256); var sum2 = LongVector.zero(LongVector.SPECIES_256); var sum3 = LongVector.zero(LongVector.SPECIES_256); for (; i < limit; i += ByteVector.SPECIES_256.length()) { - var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs(); - var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + d.length).reinterpretAsLongs(); - var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + d.length * 2).reinterpretAsLongs(); - var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + d.length * 3).reinterpretAsLongs(); - var vd = ByteVector.fromArray(BYTE_SPECIES_256, d, i).reinterpretAsLongs(); + var vq0 = q.get(BYTE_SPECIES_256, i).reinterpretAsLongs(); + var vq1 = q.get(BYTE_SPECIES_256, i + length).reinterpretAsLongs(); + var vq2 = q.get(BYTE_SPECIES_256, i + length * 2).reinterpretAsLongs(); + var vq3 = q.get(BYTE_SPECIES_256, i + length * 3).reinterpretAsLongs(); + var vd = d.get(BYTE_SPECIES_256, i).reinterpretAsLongs(); sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); @@ -1162,18 +1132,18 @@ static long int4BitDotProduct256(byte[] q, byte[] d) { subRet3 += sum3.reduceLanes(VectorOperators.ADD); } - if (d.length - i >= ByteVector.SPECIES_128.vectorByteSize()) { + if (length - i >= ByteVector.SPECIES_128.vectorByteSize()) { var sum0 = LongVector.zero(LongVector.SPECIES_128); var sum1 = LongVector.zero(LongVector.SPECIES_128); var sum2 = LongVector.zero(LongVector.SPECIES_128); var sum3 = LongVector.zero(LongVector.SPECIES_128); - int limit = ByteVector.SPECIES_128.loopBound(d.length); + int limit = ByteVector.SPECIES_128.loopBound(length); for (; i < limit; i += ByteVector.SPECIES_128.length()) { - var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs(); - var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length).reinterpretAsLongs(); - var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length * 2).reinterpretAsLongs(); - var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length * 3).reinterpretAsLongs(); - var vd = ByteVector.fromArray(BYTE_SPECIES_128, d, i).reinterpretAsLongs(); + var vq0 = q.get(BYTE_SPECIES_128, i).reinterpretAsLongs(); + var vq1 = q.get(BYTE_SPECIES_128, i + length).reinterpretAsLongs(); + var vq2 = q.get(BYTE_SPECIES_128, i + length * 2).reinterpretAsLongs(); + var vq3 = q.get(BYTE_SPECIES_128, i + length * 3).reinterpretAsLongs(); + var vd = d.get(BYTE_SPECIES_128, i).reinterpretAsLongs(); sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); @@ -1185,33 +1155,34 @@ static long int4BitDotProduct256(byte[] q, byte[] d) { subRet3 += sum3.reduceLanes(VectorOperators.ADD); } // tail as bytes - for (; i < d.length; i++) { - subRet0 += Integer.bitCount((q[i] & d[i]) & 0xFF); - subRet1 += Integer.bitCount((q[i + d.length] & d[i]) & 0xFF); - subRet2 += Integer.bitCount((q[i + 2 * d.length] & d[i]) & 0xFF); - subRet3 += Integer.bitCount((q[i + 3 * d.length] & d[i]) & 0xFF); + for (; i < length; i++) { + subRet0 += Integer.bitCount((q.get(i) & d.get(i)) & 0xFF); + subRet1 += Integer.bitCount((q.get(i + length) & d.get(i)) & 0xFF); + subRet2 += Integer.bitCount((q.get(i + 2 * length) & d.get(i)) & 0xFF); + subRet3 += Integer.bitCount((q.get(i + 3 * length) & d.get(i)) & 0xFF); } return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); } - public static long int4BitDotProduct128(byte[] q, byte[] d) { + public static long int4BitDotProduct128(IByteVector q, IByteVector d) { long subRet0 = 0; long subRet1 = 0; long subRet2 = 0; long subRet3 = 0; int i = 0; + int length = d.length(); var sum0 = IntVector.zero(IntVector.SPECIES_128); var sum1 = IntVector.zero(IntVector.SPECIES_128); var sum2 = IntVector.zero(IntVector.SPECIES_128); var sum3 = IntVector.zero(IntVector.SPECIES_128); - int limit = ByteVector.SPECIES_128.loopBound(d.length); + int limit = ByteVector.SPECIES_128.loopBound(length); for (; i < limit; i += ByteVector.SPECIES_128.length()) { - var vd = ByteVector.fromArray(BYTE_SPECIES_128, d, i).reinterpretAsInts(); - var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts(); - var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length).reinterpretAsInts(); - var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length * 2).reinterpretAsInts(); - var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length * 3).reinterpretAsInts(); + var vd = d.get(BYTE_SPECIES_128, i).reinterpretAsInts(); + var vq0 = q.get(BYTE_SPECIES_128, i).reinterpretAsInts(); + var vq1 = q.get(BYTE_SPECIES_128, i + length).reinterpretAsInts(); + var vq2 = q.get(BYTE_SPECIES_128, i + length * 2).reinterpretAsInts(); + var vq3 = q.get(BYTE_SPECIES_128, i + length * 3).reinterpretAsInts(); sum0 = sum0.add(vd.and(vq0).lanewise(VectorOperators.BIT_COUNT)); sum1 = sum1.add(vd.and(vq1).lanewise(VectorOperators.BIT_COUNT)); sum2 = sum2.add(vd.and(vq2).lanewise(VectorOperators.BIT_COUNT)); @@ -1222,12 +1193,12 @@ public static long int4BitDotProduct128(byte[] q, byte[] d) { subRet2 += sum2.reduceLanes(VectorOperators.ADD); subRet3 += sum3.reduceLanes(VectorOperators.ADD); // tail as bytes - for (; i < d.length; i++) { - int dValue = d[i]; - subRet0 += Integer.bitCount((dValue & q[i]) & 0xFF); - subRet1 += Integer.bitCount((dValue & q[i + d.length]) & 0xFF); - subRet2 += Integer.bitCount((dValue & q[i + 2 * d.length]) & 0xFF); - subRet3 += Integer.bitCount((dValue & q[i + 3 * d.length]) & 0xFF); + for (; i < length; i++) { + int dValue = d.get(i); + subRet0 += Integer.bitCount((dValue & q.get(i)) & 0xFF); + subRet1 += Integer.bitCount((dValue & q.get(i + length)) & 0xFF); + subRet2 += Integer.bitCount((dValue & q.get(i + 2 * length)) & 0xFF); + subRet3 += Integer.bitCount((dValue & q.get(i + 3 * length)) & 0xFF); } return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); } @@ -1360,7 +1331,7 @@ public int filterByScore( @Override public float[] l2normalize(float[] v, boolean throwOnZero) { - double l1norm = this.dotProduct(v, v); + double l1norm = this.dotProductFloats(v, v); if (l1norm == 0) { if (throwOnZero) { throw new IllegalArgumentException("Cannot normalize a zero-length vector"); diff --git a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java b/lucene/core/src/java25/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java index cf3ab94f417c..dbc17a688cc4 100644 --- a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java +++ b/lucene/core/src/java25/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java @@ -34,7 +34,7 @@ final class PanamaVectorizationProvider extends VectorizationProvider { // would get called before we have a chance to perform sanity checks around the vector API in the // constructor of this class. Put them in PanamaVectorConstants instead. - private final VectorUtilSupport vectorUtilSupport; + private final VectorUtilSupport vectorUtilSupport; PanamaVectorizationProvider() { // hack to work around for JDK-8309727: @@ -69,7 +69,7 @@ private void logIncubatorSetup() { } @Override - public VectorUtilSupport getVectorUtilSupport() { + public VectorUtilSupport getVectorUtilSupport() { return vectorUtilSupport; } diff --git a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java index 78280e7e4c36..7faf5690673e 100644 --- a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java +++ b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java @@ -54,9 +54,9 @@ public void testFloatVectors() { a[i] = random().nextFloat(); b[i] = random().nextFloat(); } - assertFloatReturningProviders(p -> p.dotProduct(a, b)); - assertFloatReturningProviders(p -> p.squareDistance(a, b)); - assertFloatReturningProviders(p -> p.cosine(a, b)); + assertFloatReturningProviders(p -> p.dotProductFloats(a, b)); + assertFloatReturningProviders(p -> p.squareDistanceFloats(a, b)); + assertFloatReturningProviders(p -> p.cosineFloats(a, b)); } public void testBinaryVectors() { @@ -64,9 +64,9 @@ public void testBinaryVectors() { var b = new byte[size]; random().nextBytes(a); random().nextBytes(b); - assertIntReturningProviders(p -> p.dotProduct(a, b)); - assertIntReturningProviders(p -> p.squareDistance(a, b)); - assertFloatReturningProviders(p -> p.cosine(a, b)); + assertIntReturningProviders(p -> p.dotProductBytes(a, b)); + assertIntReturningProviders(p -> p.squareDistanceBytes(a, b)); + assertFloatReturningProviders(p -> p.cosineBytes(a, b)); } public void testBinaryVectorsBoundaries() { @@ -75,27 +75,27 @@ public void testBinaryVectorsBoundaries() { Arrays.fill(a, Byte.MIN_VALUE); Arrays.fill(b, Byte.MIN_VALUE); - assertIntReturningProviders(p -> p.dotProduct(a, b)); - assertIntReturningProviders(p -> p.squareDistance(a, b)); - assertFloatReturningProviders(p -> p.cosine(a, b)); + assertIntReturningProviders(p -> p.dotProductBytes(a, b)); + assertIntReturningProviders(p -> p.squareDistanceBytes(a, b)); + assertFloatReturningProviders(p -> p.cosineBytes(a, b)); Arrays.fill(a, Byte.MAX_VALUE); Arrays.fill(b, Byte.MAX_VALUE); - assertIntReturningProviders(p -> p.dotProduct(a, b)); - assertIntReturningProviders(p -> p.squareDistance(a, b)); - assertFloatReturningProviders(p -> p.cosine(a, b)); + assertIntReturningProviders(p -> p.dotProductBytes(a, b)); + assertIntReturningProviders(p -> p.squareDistanceBytes(a, b)); + assertFloatReturningProviders(p -> p.cosineBytes(a, b)); Arrays.fill(a, Byte.MIN_VALUE); Arrays.fill(b, Byte.MAX_VALUE); - assertIntReturningProviders(p -> p.dotProduct(a, b)); - assertIntReturningProviders(p -> p.squareDistance(a, b)); - assertFloatReturningProviders(p -> p.cosine(a, b)); + assertIntReturningProviders(p -> p.dotProductBytes(a, b)); + assertIntReturningProviders(p -> p.squareDistanceBytes(a, b)); + assertFloatReturningProviders(p -> p.cosineBytes(a, b)); Arrays.fill(a, Byte.MAX_VALUE); Arrays.fill(b, Byte.MIN_VALUE); - assertIntReturningProviders(p -> p.dotProduct(a, b)); - assertIntReturningProviders(p -> p.squareDistance(a, b)); - assertFloatReturningProviders(p -> p.cosine(a, b)); + assertIntReturningProviders(p -> p.dotProductBytes(a, b)); + assertIntReturningProviders(p -> p.squareDistanceBytes(a, b)); + assertFloatReturningProviders(p -> p.cosineBytes(a, b)); } public void testInt4DotProduct() { @@ -113,16 +113,16 @@ public void testInt4DotProduct() { assertIntReturningProviders(p -> p.int4DotProductBothPacked(pack(a), pack(b))); assertEquals( - LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + LUCENE_PROVIDER.getVectorUtilSupport().dotProductBytes(a, b), PANAMA_PROVIDER.getVectorUtilSupport().int4DotProduct(a, b)); assertEquals( - LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + LUCENE_PROVIDER.getVectorUtilSupport().dotProductBytes(a, b), PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductSinglePacked(a, pack(b))); assertEquals( - LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + LUCENE_PROVIDER.getVectorUtilSupport().dotProductBytes(a, b), PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductSinglePacked(b, pack(a))); assertEquals( - LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + LUCENE_PROVIDER.getVectorUtilSupport().dotProductBytes(a, b), PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductBothPacked(pack(a), pack(b))); } @@ -141,16 +141,16 @@ public void testInt4DotProductBoundaries() { assertIntReturningProviders(p -> p.int4DotProductBothPacked(pack(a), pack(b))); assertEquals( - LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + LUCENE_PROVIDER.getVectorUtilSupport().dotProductBytes(a, b), PANAMA_PROVIDER.getVectorUtilSupport().int4DotProduct(a, b)); assertEquals( - LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + LUCENE_PROVIDER.getVectorUtilSupport().dotProductBytes(a, b), PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductSinglePacked(a, pack(b))); assertEquals( - LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + LUCENE_PROVIDER.getVectorUtilSupport().dotProductBytes(a, b), PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductSinglePacked(b, pack(a))); assertEquals( - LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + LUCENE_PROVIDER.getVectorUtilSupport().dotProductBytes(a, b), PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductBothPacked(pack(a), pack(b))); byte MIN_VALUE = 0; @@ -163,16 +163,16 @@ public void testInt4DotProductBoundaries() { assertIntReturningProviders(p -> p.int4DotProductBothPacked(pack(a), pack(b))); assertEquals( - LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + LUCENE_PROVIDER.getVectorUtilSupport().dotProductBytes(a, b), PANAMA_PROVIDER.getVectorUtilSupport().int4DotProduct(a, b)); assertEquals( - LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + LUCENE_PROVIDER.getVectorUtilSupport().dotProductBytes(a, b), PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductSinglePacked(a, pack(b))); assertEquals( - LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + LUCENE_PROVIDER.getVectorUtilSupport().dotProductBytes(a, b), PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductSinglePacked(b, pack(a))); assertEquals( - LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + LUCENE_PROVIDER.getVectorUtilSupport().dotProductBytes(a, b), PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductBothPacked(pack(a), pack(b))); } @@ -191,16 +191,16 @@ public void testInt4SquareDistance() { assertIntReturningProviders(p -> p.int4SquareDistanceBothPacked(pack(a), pack(b))); assertEquals( - LUCENE_PROVIDER.getVectorUtilSupport().squareDistance(a, b), + LUCENE_PROVIDER.getVectorUtilSupport().squareDistanceBytes(a, b), PANAMA_PROVIDER.getVectorUtilSupport().int4SquareDistance(a, b)); assertEquals( - LUCENE_PROVIDER.getVectorUtilSupport().squareDistance(a, b), + LUCENE_PROVIDER.getVectorUtilSupport().squareDistanceBytes(a, b), PANAMA_PROVIDER.getVectorUtilSupport().int4SquareDistanceSinglePacked(a, pack(b))); assertEquals( - LUCENE_PROVIDER.getVectorUtilSupport().squareDistance(a, b), + LUCENE_PROVIDER.getVectorUtilSupport().squareDistanceBytes(a, b), PANAMA_PROVIDER.getVectorUtilSupport().int4SquareDistanceSinglePacked(b, pack(a))); assertEquals( - LUCENE_PROVIDER.getVectorUtilSupport().squareDistance(a, b), + LUCENE_PROVIDER.getVectorUtilSupport().squareDistanceBytes(a, b), PANAMA_PROVIDER.getVectorUtilSupport().int4SquareDistanceBothPacked(pack(a), pack(b))); } @@ -223,16 +223,16 @@ public void testInt4SquareDistanceBoundaries() { assertIntReturningProviders(p -> p.int4DotProductBothPacked(pack(a), pack(b))); assertEquals( - LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + LUCENE_PROVIDER.getVectorUtilSupport().dotProductBytes(a, b), PANAMA_PROVIDER.getVectorUtilSupport().int4DotProduct(a, b)); assertEquals( - LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + LUCENE_PROVIDER.getVectorUtilSupport().dotProductBytes(a, b), PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductSinglePacked(a, pack(b))); assertEquals( - LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + LUCENE_PROVIDER.getVectorUtilSupport().dotProductBytes(a, b), PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductSinglePacked(b, pack(a))); assertEquals( - LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + LUCENE_PROVIDER.getVectorUtilSupport().dotProductBytes(a, b), PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductBothPacked(pack(a), pack(b))); } @@ -305,20 +305,20 @@ public void testMinMaxScalarQuantize() { outputs.getFirst(), alpha, min, newScale, newAlpha, newMin, newMax)); } - private void assertFloatReturningProviders(ToDoubleFunction func) { + private void assertFloatReturningProviders(ToDoubleFunction> func) { assertEquals( func.applyAsDouble(LUCENE_PROVIDER.getVectorUtilSupport()), func.applyAsDouble(PANAMA_PROVIDER.getVectorUtilSupport()), delta); } - private void assertIntReturningProviders(ToIntFunction func) { + private void assertIntReturningProviders(ToIntFunction> func) { assertEquals( func.applyAsInt(LUCENE_PROVIDER.getVectorUtilSupport()), func.applyAsInt(PANAMA_PROVIDER.getVectorUtilSupport())); } - private void assertLongReturningProviders(ToLongFunction func) { + private void assertLongReturningProviders(ToLongFunction> func) { assertEquals( func.applyAsLong(LUCENE_PROVIDER.getVectorUtilSupport()), func.applyAsLong(PANAMA_PROVIDER.getVectorUtilSupport())); diff --git a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java index 1953d6dff758..4147138edf83 100644 --- a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java +++ b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java @@ -32,13 +32,13 @@ public class TestVectorUtil extends LuceneTestCase { public static final double DELTA = 1e-4; public void testBasicDotProduct() { - assertEquals(5, VectorUtil.dotProduct(new float[] {1, 2, 3}, new float[] {-10, 0, 5}), 0); + assertEquals(5, VectorUtil.dotProductFloats(new float[] {1, 2, 3}, new float[] {-10, 0, 5}), 0); } public void testSelfDotProduct() { // the dot product of a vector with itself is equal to the sum of the squares of its components float[] v = randomVector(); - assertEquals(l2(v), VectorUtil.dotProduct(v, v), DELTA); + assertEquals(l2(v), VectorUtil.dotProductFloats(v, v), DELTA); } public void testOrthogonalDotProduct() { @@ -49,27 +49,28 @@ public void testOrthogonalDotProduct() { float[] u = new float[2]; u[0] = v[1]; u[1] = -v[0]; - assertEquals(0, VectorUtil.dotProduct(u, v), DELTA); + assertEquals(0, VectorUtil.dotProductFloats(u, v), DELTA); } public void testDotProductThrowsForDimensionMismatch() { float[] v = {1, 0, 0}, u = {0, 1}; - expectThrows(IllegalArgumentException.class, () -> VectorUtil.dotProduct(u, v)); + expectThrows(IllegalArgumentException.class, () -> VectorUtil.dotProductFloats(u, v)); } public void testSelfSquareDistance() { // the l2 distance of a vector with itself is zero float[] v = randomVector(); - assertEquals(0, VectorUtil.squareDistance(v, v), DELTA); + assertEquals(0, VectorUtil.squareDistanceFloats(v, v), DELTA); } public void testBasicSquareDistance() { - assertEquals(12, VectorUtil.squareDistance(new float[] {1, 2, 3}, new float[] {-1, 0, 5}), 0); + assertEquals( + 12, VectorUtil.squareDistanceFloats(new float[] {1, 2, 3}, new float[] {-1, 0, 5}), 0); } public void testSquareDistanceThrowsForDimensionMismatch() { float[] v = {1, 0, 0}, u = {0, 1}; - expectThrows(IllegalArgumentException.class, () -> VectorUtil.squareDistance(u, v)); + expectThrows(IllegalArgumentException.class, () -> VectorUtil.squareDistanceFloats(u, v)); } public void testRandomSquareDistance() { @@ -77,12 +78,12 @@ public void testRandomSquareDistance() { // its components float[] v = randomVector(); float[] u = negative(v); - assertEquals(4 * l2(v), VectorUtil.squareDistance(u, v), DELTA); + assertEquals(4 * l2(v), VectorUtil.squareDistanceFloats(u, v), DELTA); } public void testBasicCosine() { assertEquals( - 0.11952f, VectorUtil.cosine(new float[] {1, 2, 3}, new float[] {-10, 0, 5}), DELTA); + 0.11952f, VectorUtil.cosineFloats(new float[] {1, 2, 3}, new float[] {-10, 0, 5}), DELTA); } public void testSelfCosine() { @@ -90,7 +91,7 @@ public void testSelfCosine() { float[] v = randomVector(); // ensure the vector is non-zero so that cosine is defined v[0] = random().nextFloat() + 0.01f; - assertEquals(1.0f, VectorUtil.cosine(v, v), DELTA); + assertEquals(1.0f, VectorUtil.cosineFloats(v, v), DELTA); } public void testOrthogonalCosine() { @@ -102,12 +103,12 @@ public void testOrthogonalCosine() { float[] u = new float[2]; u[0] = v[1]; u[1] = -v[0]; - assertEquals(0, VectorUtil.cosine(u, v), DELTA); + assertEquals(0, VectorUtil.cosineFloats(u, v), DELTA); } public void testCosineThrowsForDimensionMismatch() { float[] v = {1, 0, 0}, u = {0, 1}; - expectThrows(IllegalArgumentException.class, () -> VectorUtil.cosine(u, v)); + expectThrows(IllegalArgumentException.class, () -> VectorUtil.cosineFloats(u, v)); } public void testNormalize() { @@ -226,7 +227,7 @@ public static byte[] randomVectorBytes(int dim) { public void testBasicDotProductBytes() { byte[] a = new byte[] {1, 2, 3}; byte[] b = new byte[] {-10, 0, 5}; - assertEquals(5, VectorUtil.dotProduct(a, b), 0); + assertEquals(5, VectorUtil.dotProductBytes(a, b), 0); float denom = a.length * (1 << 15); assertEquals(0.5 + 5 / denom, VectorUtil.dotProductScore(a, b), DELTA); @@ -246,7 +247,7 @@ public void testBasicDotProductBytes() { public void testSelfDotProductBytes() { // the dot product of a vector with itself is equal to the sum of the squares of its components byte[] v = randomVectorBytes(); - assertEquals(l2(v), VectorUtil.dotProduct(v, v), DELTA); + assertEquals(l2(v), VectorUtil.dotProductBytes(v, v), DELTA); } public void testOrthogonalDotProductBytes() { @@ -257,17 +258,18 @@ public void testOrthogonalDotProductBytes() { byte[] b = new byte[2]; b[0] = a[1]; b[1] = (byte) -a[0]; - assertEquals(0, VectorUtil.dotProduct(a, b), DELTA); + assertEquals(0, VectorUtil.dotProductBytes(a, b), DELTA); } public void testSelfSquareDistanceBytes() { // the l2 distance of a vector with itself is zero byte[] v = randomVectorBytes(); - assertEquals(0, VectorUtil.squareDistance(v, v), DELTA); + assertEquals(0, VectorUtil.squareDistanceBytes(v, v), DELTA); } public void testBasicSquareDistanceBytes() { - assertEquals(12, VectorUtil.squareDistance(new byte[] {1, 2, 3}, new byte[] {-1, 0, 5}), 0); + assertEquals( + 12, VectorUtil.squareDistanceBytes(new byte[] {1, 2, 3}, new byte[] {-1, 0, 5}), 0); } public void testRandomSquareDistanceBytes() { @@ -275,7 +277,7 @@ public void testRandomSquareDistanceBytes() { // its components byte[] v = randomVectorBytes(); byte[] u = negative(v); - assertEquals(4 * l2(v), VectorUtil.squareDistance(u, v), DELTA); + assertEquals(4 * l2(v), VectorUtil.squareDistanceBytes(u, v), DELTA); } public void testBasicDotProductUint8() { @@ -306,7 +308,8 @@ public void testBasicSquareDistanceUint8() { } public void testBasicCosineBytes() { - assertEquals(0.11952f, VectorUtil.cosine(new byte[] {1, 2, 3}, new byte[] {-10, 0, 5}), DELTA); + assertEquals( + 0.11952f, VectorUtil.cosineBytes(new byte[] {1, 2, 3}, new byte[] {-10, 0, 5}), DELTA); } public void testSelfCosineBytes() { @@ -314,7 +317,7 @@ public void testSelfCosineBytes() { byte[] v = randomVectorBytes(); // ensure the vector is non-zero so that cosine is defined v[0] = (byte) (random().nextInt(126) + 1); - assertEquals(1.0f, VectorUtil.cosine(v, v), DELTA); + assertEquals(1.0f, VectorUtil.cosineBytes(v, v), DELTA); } public void testOrthogonalCosineBytes() { @@ -326,7 +329,7 @@ public void testOrthogonalCosineBytes() { float[] u = new float[2]; u[0] = v[1]; u[1] = -v[0]; - assertEquals(0, VectorUtil.cosine(u, v), DELTA); + assertEquals(0, VectorUtil.cosineFloats(u, v), DELTA); } interface ToIntBiFunction { diff --git a/lucene/misc/src/java/org/apache/lucene/misc/index/BpVectorReorderer.java b/lucene/misc/src/java/org/apache/lucene/misc/index/BpVectorReorderer.java index 246109ede04c..9d0280eb91be 100644 --- a/lucene/misc/src/java/org/apache/lucene/misc/index/BpVectorReorderer.java +++ b/lucene/misc/src/java/org/apache/lucene/misc/index/BpVectorReorderer.java @@ -278,7 +278,7 @@ static void computeCentroid( case EUCLIDEAN, MAXIMUM_INNER_PRODUCT -> vectorScalarMul(1 / (float) ids.length, centroid); case DOT_PRODUCT, COSINE -> vectorScalarMul( - 1 / (float) Math.sqrt(VectorUtil.dotProduct(centroid, centroid)), centroid); + 1 / (float) Math.sqrt(VectorUtil.dotProductFloats(centroid, centroid)), centroid); } } @@ -312,7 +312,7 @@ private int shuffle( vectorScore) .compute(); vectorSubtract(leftCentroid, rightCentroid, scratch); - float scale = (float) Math.sqrt(VectorUtil.dotProduct(scratch, scratch)); + float scale = (float) Math.sqrt(VectorUtil.dotProductFloats(scratch, scratch)); float maxLeftBias = Float.NEGATIVE_INFINITY; for (int i = ids.offset; i < midPoint; ++i) { maxLeftBias = Math.max(maxLeftBias, biases[i]); @@ -564,11 +564,11 @@ protected void compute() { private float computeBias(float[] vector, float[] leftCentroid, float[] rightCentroid) { return switch (vectorScore) { case EUCLIDEAN -> - VectorUtil.squareDistance(vector, leftCentroid) - - VectorUtil.squareDistance(vector, rightCentroid); + VectorUtil.squareDistanceFloats(vector, leftCentroid) + - VectorUtil.squareDistanceFloats(vector, rightCentroid); case MAXIMUM_INNER_PRODUCT, COSINE, DOT_PRODUCT -> - VectorUtil.dotProduct(vector, rightCentroid) - - VectorUtil.dotProduct(vector, leftCentroid); + VectorUtil.dotProductFloats(vector, rightCentroid) + - VectorUtil.dotProductFloats(vector, leftCentroid); default -> throw new IllegalStateException("unsupported vector score: " + vectorScore); }; } diff --git a/lucene/misc/src/test/org/apache/lucene/misc/index/TestBpVectorReorderer.java b/lucene/misc/src/test/org/apache/lucene/misc/index/TestBpVectorReorderer.java index 653c519729a9..0f62de121b6c 100644 --- a/lucene/misc/src/test/org/apache/lucene/misc/index/TestBpVectorReorderer.java +++ b/lucene/misc/src/test/org/apache/lucene/misc/index/TestBpVectorReorderer.java @@ -110,7 +110,7 @@ private static double sumClosestDistances(List points) { if (j == i) { continue; } - double distance = VectorUtil.squareDistance(points.get(i), points.get(j)); + double distance = VectorUtil.squareDistanceFloats(points.get(i), points.get(j)); if (distance < closeness) { closest = j; closeness = distance; diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/KMeans.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/KMeans.java index 036f4396e3a2..a463f094011e 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/KMeans.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/KMeans.java @@ -236,7 +236,8 @@ private float[][] initializePlusPlus() throws IOException { double totalSum = 0; for (int j = 0; j < numVectors; j++) { // TODO: replace with RandomVectorScorer::score possible on quantized vectors - float dist = VectorUtil.squareDistance(vectors.vectorValue(j), initialCentroids[i - 1]); + float dist = + VectorUtil.squareDistanceFloats(vectors.vectorValue(j), initialCentroids[i - 1]); if (dist < minDistances[j]) { minDistances[j] = dist; } @@ -297,7 +298,7 @@ private static double runKMeansStep( float minSquaredDist = Float.MAX_VALUE; for (short c = 0; c < numCentroids; c++) { // TODO: replace with RandomVectorScorer::score possible on quantized vectors - float squareDist = VectorUtil.squareDistance(centroids[c], vector); + float squareDist = VectorUtil.squareDistanceFloats(centroids[c], vector); if (squareDist < minSquaredDist) { bestCentroid = c; minSquaredDist = squareDist; @@ -359,7 +360,8 @@ static void assignCentroids( for (int i = 0; i < vectors.size(); i++) { float[] vector = vectors.vectorValue(i); for (short j = 0; j < assignedCentroidsIdxs.length; j++) { - float squareDist = VectorUtil.squareDistance(centroids[assignedCentroidsIdxs[j]], vector); + float squareDist = + VectorUtil.squareDistanceFloats(centroids[assignedCentroidsIdxs[j]], vector); queue.insertWithOverflow(i, squareDist); } }