@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212See the License for the specific language governing permissions and
1313limitations under the License.
1414==============================================================================*/
15- import { findKNN , findKNNGPUCosDistNorm , NearestEntry , TEST_ONLY } from './knn' ;
16- import { cosDistNorm , unit } from './vector' ;
15+ import { findKNN , findKNNGPUCosDist , NearestEntry , TEST_ONLY } from './knn' ;
16+ import { cosDist } from './vector' ;
1717
1818describe ( 'projector knn test' , ( ) => {
1919 function getIndices ( nearest : NearestEntry [ ] [ ] ) : number [ ] [ ] {
@@ -22,22 +22,16 @@ describe('projector knn test', () => {
2222 } ) ;
2323 }
2424
25- function unitVector ( vector : Float32Array ) : Float32Array {
26- // `unit` method replaces the vector in-place.
27- unit ( vector ) ;
28- return vector ;
29- }
30-
31- describe ( '#findKNNGPUCosDistNorm' , ( ) => {
25+ describe ( '#findKNNGPUCosDist' , ( ) => {
3226 it ( 'finds n-nearest neighbor for each item' , async ( ) => {
33- const values = await findKNNGPUCosDistNorm (
27+ const values = await findKNNGPUCosDist (
3428 [
35- { a : unitVector ( new Float32Array ( [ 1 , 2 , 0 ] ) ) } ,
36- { a : unitVector ( new Float32Array ( [ 1 , 1 , 3 ] ) ) } ,
37- { a : unitVector ( new Float32Array ( [ 100 , 30 , 0 ] ) ) } ,
38- { a : unitVector ( new Float32Array ( [ 95 , 23 , 3 ] ) ) } ,
39- { a : unitVector ( new Float32Array ( [ 100 , 10 , 0 ] ) ) } ,
40- { a : unitVector ( new Float32Array ( [ 95 , 23 , 100 ] ) ) } ,
29+ { a : new Float32Array ( [ 1 , 2 , 0 ] ) } ,
30+ { a : new Float32Array ( [ 1 , 1 , 3 ] ) } ,
31+ { a : new Float32Array ( [ 100 , 30 , 0 ] ) } ,
32+ { a : new Float32Array ( [ 95 , 23 , 3 ] ) } ,
33+ { a : new Float32Array ( [ 100 , 10 , 0 ] ) } ,
34+ { a : new Float32Array ( [ 95 , 23 , 100 ] ) } ,
4135 ] ,
4236 4 ,
4337 ( data ) => data . a
@@ -54,11 +48,8 @@ describe('projector knn test', () => {
5448 } ) ;
5549
5650 it ( 'returns less than N when number of item is lower' , async ( ) => {
57- const values = await findKNNGPUCosDistNorm (
58- [
59- unitVector ( new Float32Array ( [ 1 , 2 , 0 ] ) ) ,
60- unitVector ( new Float32Array ( [ 1 , 1 , 3 ] ) ) ,
61- ] ,
51+ const values = await findKNNGPUCosDist (
52+ [ new Float32Array ( [ 1 , 2 , 0 ] ) , new Float32Array ( [ 1 , 1 , 3 ] ) ] ,
6253 4 ,
6354 ( a ) => a
6455 ) ;
@@ -68,10 +59,8 @@ describe('projector knn test', () => {
6859
6960 it ( 'splits a large data into one that would fit into GPU memory' , async ( ) => {
7061 const size = TEST_ONLY . OPTIMAL_GPU_BLOCK_SIZE + 5 ;
71- const data = new Array ( size ) . fill (
72- unitVector ( new Float32Array ( [ 1 , 1 , 1 ] ) )
73- ) ;
74- const values = await findKNNGPUCosDistNorm ( data , 1 , ( a ) => a ) ;
62+ const data = new Array ( size ) . fill ( new Float32Array ( [ 1 , 1 , 1 ] ) ) ;
63+ const values = await findKNNGPUCosDist ( data , 1 , ( a ) => a ) ;
7564
7665 expect ( getIndices ( values ) ) . toEqual ( [
7766 // Since distance to the diagonal entries (distance to self is 0) is
@@ -84,25 +73,25 @@ describe('projector knn test', () => {
8473 } ) ;
8574
8675 describe ( '#findKNN' , ( ) => {
87- // Covered by equality tests below (#findKNNGPUCosDistNorm == #findKNN).
76+ // Covered by equality tests below (#findKNNGPUCosDist == #findKNN).
8877 } ) ;
8978
90- describe ( '#findKNNGPUCosDistNorm and #findKNN' , ( ) => {
79+ describe ( '#findKNNGPUCosDist and #findKNN' , ( ) => {
9180 it ( 'returns same value when dist metrics are cosine' , async ( ) => {
9281 const data = [
93- unitVector ( new Float32Array ( [ 1 , 2 , 0 ] ) ) ,
94- unitVector ( new Float32Array ( [ 1 , 1 , 3 ] ) ) ,
95- unitVector ( new Float32Array ( [ 100 , 30 , 0 ] ) ) ,
96- unitVector ( new Float32Array ( [ 95 , 23 , 3 ] ) ) ,
97- unitVector ( new Float32Array ( [ 100 , 10 , 0 ] ) ) ,
98- unitVector ( new Float32Array ( [ 95 , 23 , 100 ] ) ) ,
82+ new Float32Array ( [ 1 , 2 , 0 ] ) ,
83+ new Float32Array ( [ 1 , 1 , 3 ] ) ,
84+ new Float32Array ( [ 100 , 30 , 0 ] ) ,
85+ new Float32Array ( [ 95 , 23 , 3 ] ) ,
86+ new Float32Array ( [ 100 , 10 , 0 ] ) ,
87+ new Float32Array ( [ 95 , 23 , 100 ] ) ,
9988 ] ;
100- const findKnnGpuCosVal = await findKNNGPUCosDistNorm ( data , 2 , ( a ) => a ) ;
89+ const findKnnGpuCosVal = await findKNNGPUCosDist ( data , 2 , ( a ) => a ) ;
10190 const findKnnVal = await findKNN (
10291 data ,
10392 2 ,
10493 ( a ) => a ,
105- ( a , b , limit ) => cosDistNorm ( a , b )
94+ ( a , b , limit ) => cosDist ( a , b )
10695 ) ;
10796
10897 // Floating point precision makes it hard to test. Just assert indices.
@@ -112,15 +101,15 @@ describe('projector knn test', () => {
112101 it ( 'splits a large data without the result being wrong' , async ( ) => {
113102 const size = TEST_ONLY . OPTIMAL_GPU_BLOCK_SIZE + 5 ;
114103 const data = Array . from ( new Array ( size ) ) . map ( ( _ , index ) => {
115- return unitVector ( new Float32Array ( [ index + 1 , index + 1 ] ) ) ;
104+ return new Float32Array ( [ index + 1 , index + 2 ] ) ;
116105 } ) ;
117106
118- const findKnnGpuCosVal = await findKNNGPUCosDistNorm ( data , 2 , ( a ) => a ) ;
107+ const findKnnGpuCosVal = await findKNNGPUCosDist ( data , 2 , ( a ) => a ) ;
119108 const findKnnVal = await findKNN (
120109 data ,
121110 2 ,
122111 ( a ) => a ,
123- ( a , b , limit ) => cosDistNorm ( a , b )
112+ ( a , b , limit ) => cosDist ( a , b )
124113 ) ;
125114
126115 expect ( getIndices ( findKnnGpuCosVal ) ) . toEqual ( getIndices ( findKnnVal ) ) ;
0 commit comments