@@ -22,17 +22,9 @@ export type NearestEntry = {
2222 index : number ;
2323 dist : number ;
2424} ;
25- /**
26- * Optimal size for the height of the matrix when doing computation on the GPU
27- * using WebGL. This was found experimentally.
28- *
29- * This also guarantees that for computing pair-wise distance for up to 10K
30- * vectors, no more than 40MB will be allocated in the GPU. Without the
31- * allocation limit, we can freeze the graphics of the whole OS.
32- */
33- const OPTIMAL_GPU_BLOCK_SIZE = 256 ;
34- /** Id of message box used for knn gpu progress bar. */
35- const KNN_GPU_MSG_ID = 'knn-gpu' ;
25+
26+ /** Id of message box used for knn. */
27+ const KNN_MSG_ID = 'knn' ;
3628
3729/**
3830 * Returns the K nearest neighbors for each vector where the distance
@@ -52,105 +44,63 @@ export function findKNNGPUCosDistNorm<T>(
5244 const N = dataPoints . length ;
5345 const dim = accessor ( dataPoints [ 0 ] ) . length ;
5446 // The goal is to compute a large matrix multiplication A*A.T where A is of
55- // size NxD and A.T is its transpose. This results in a NxN matrix which
56- // could be too big to store on the GPU memory. To avoid memory overflow, we
57- // compute multiple A*partial_A.T where partial_A is of size BxD (B is much
58- // smaller than N). This results in storing only NxB size matrices on the GPU
59- // at a given time.
47+ // size NxD and A.T is its transpose. This results in a NxN matrix.
6048 // A*A.T will give us NxN matrix holding the cosine distance between every
6149 // pair of points, which we sort using KMin data structure to obtain the
6250 // K nearest neighbors for each point.
6351 const nearest : NearestEntry [ ] [ ] = new Array ( N ) ;
64- let numPieces = Math . ceil ( N / OPTIMAL_GPU_BLOCK_SIZE ) ;
65- const actualPieceSize = Math . floor ( N / numPieces ) ;
66- const modulo = N % actualPieceSize ;
67- numPieces += modulo ? 1 : 0 ;
68- let offset = 0 ;
69- let progress = 0 ;
70- let progressDiff = 1 / ( 2 * numPieces ) ;
71- let piece = 0 ;
72-
73- const typedArray = vector . toTypedArray ( dataPoints , accessor ) ;
74- const bigMatrix = tf . tensor ( typedArray , [ N , dim ] ) ;
75- const bigMatrixTransposed = tf . transpose ( bigMatrix ) ;
76- // 1 - A * A^T.
77- const bigMatrixSquared = tf . matMul ( bigMatrix , bigMatrixTransposed ) ;
78- const cosDistMatrix = tf . sub ( 1 , bigMatrixSquared ) ;
79-
80- let maybePaddedCosDistMatrix = cosDistMatrix ;
81- if ( actualPieceSize * numPieces > N ) {
82- // Expect the input to be rank 2 (though it is not typed that way) so we
83- // want to pad the first dimension so we split very evenly (all splitted
84- // tensor have exactly the same dimesion).
85- const padding : Array < [ number , number ] > = [
86- [ 0 , actualPieceSize * numPieces - N ] ,
87- [ 0 , 0 ] ,
88- ] ;
89- maybePaddedCosDistMatrix = tf . pad ( cosDistMatrix , padding ) ;
90- }
91- const splits = tf . split (
92- maybePaddedCosDistMatrix ,
93- new Array ( numPieces ) . fill ( actualPieceSize ) ,
94- 0
95- ) ;
96-
9752 function step ( resolve : ( result : NearestEntry [ ] [ ] ) => void ) {
98- let progressMsg =
99- 'Finding nearest neighbors: ' + ( progress * 100 ) . toFixed ( ) + '%' ;
10053 util
10154 . runAsyncTask (
102- progressMsg ,
55+ 'Finding nearest neighbors...' ,
10356 async ( ) => {
57+ const cosSimilarityMatrix = tf . tidy ( ( ) => {
58+ const typedArray = vector . toTypedArray ( dataPoints , accessor ) ;
59+ const bigMatrix = tf . tensor ( typedArray , [ N , dim ] ) ;
60+ const bigMatrixTransposed = tf . transpose ( bigMatrix ) ;
61+ // A * A^T.
62+ return tf . matMul ( bigMatrix , bigMatrixTransposed ) ;
63+ } ) ;
10464 // `.data()` returns flattened Float32Array of B * N dimension.
10565 // For matrix of
10666 // [ 1 2 ]
10767 // [ 3 4 ],
10868 // `.data()` returns [1, 2, 3, 4].
109- const partial = await splits [ piece ] . data ( ) ;
110- progress += progressDiff ;
111- for ( let i = 0 ; i < actualPieceSize ; i ++ ) {
69+ let partial ;
70+ try {
71+ partial = await cosSimilarityMatrix . data ( ) ;
72+ } finally {
73+ // Discard all tensors and free up the memory.
74+ cosSimilarityMatrix . dispose ( ) ;
75+ }
76+ for ( let i = 0 ; i < N ; i ++ ) {
11277 let kMin = new KMin < NearestEntry > ( k ) ;
113- let iReal = offset + i ;
114- if ( iReal >= N ) break ;
11578 for ( let j = 0 ; j < N ; j ++ ) {
11679 // Skip diagonal entries.
117- if ( j === iReal ) {
80+ if ( j === i ) {
11881 continue ;
11982 }
12083 // Access i * N's row at `j` column.
12184 // Reach row has N entries and j-th index has cosine distance
122- // between iReal vs. j-th vectors.
123- const cosDist = partial [ i * N + j ] ;
85+ // between i-th vs. j-th vectors.
86+ const cosDist = 1 - partial [ i * N + j ] ;
12487 if ( cosDist >= 0 ) {
12588 kMin . add ( cosDist , { index : j , dist : cosDist } ) ;
12689 }
12790 }
128- nearest [ iReal ] = kMin . getMinKItems ( ) ;
91+ nearest [ i ] = kMin . getMinKItems ( ) ;
12992 }
130- progress += progressDiff ;
131- offset += actualPieceSize ;
132- piece ++ ;
13393 } ,
134- KNN_GPU_MSG_ID
94+ KNN_MSG_ID
13595 )
13696 . then (
13797 ( ) => {
138- if ( piece < numPieces ) {
139- step ( resolve ) ;
140- } else {
141- logging . setModalMessage ( null ! , KNN_GPU_MSG_ID ) ;
142- // Discard all tensors and free up the memory.
143- bigMatrix . dispose ( ) ;
144- bigMatrixTransposed . dispose ( ) ;
145- bigMatrixSquared . dispose ( ) ;
146- cosDistMatrix . dispose ( ) ;
147- splits . forEach ( ( split ) => split . dispose ( ) ) ;
148- resolve ( nearest ) ;
149- }
98+ logging . setModalMessage ( null ! , KNN_MSG_ID ) ;
99+ resolve ( nearest ) ;
150100 } ,
151101 ( error ) => {
152102 // GPU failed. Reverting back to CPU.
153- logging . setModalMessage ( null ! , KNN_GPU_MSG_ID ) ;
103+ logging . setModalMessage ( null ! , KNN_MSG_ID ) ;
154104 let distFunc = ( a , b , limit ) => vector . cosDistNorm ( a , b ) ;
155105 findKNN ( dataPoints , k , accessor , distFunc ) . then ( ( nearest ) => {
156106 resolve ( nearest ) ;
@@ -212,47 +162,12 @@ export function findKNN<T>(
212162 for ( let i = 0 ; i < N ; i ++ ) {
213163 nearest [ i ] = kMin [ i ] . getMinKItems ( ) ;
214164 }
165+ logging . setModalMessage ( null ! , KNN_MSG_ID ) ;
215166 return nearest ;
216- }
167+ } ,
168+ KNN_MSG_ID
217169 ) ;
218170}
219- /** Calculates the minimum distance between a search point and a rectangle. */
220- function minDist (
221- point : [ number , number ] ,
222- x1 : number ,
223- y1 : number ,
224- x2 : number ,
225- y2 : number
226- ) {
227- let x = point [ 0 ] ;
228- let y = point [ 1 ] ;
229- let dx1 = x - x1 ;
230- let dx2 = x - x2 ;
231- let dy1 = y - y1 ;
232- let dy2 = y - y2 ;
233- if ( dx1 * dx2 <= 0 ) {
234- // x is between x1 and x2
235- if ( dy1 * dy2 <= 0 ) {
236- // (x,y) is inside the rectangle
237- return 0 ; // return 0 as point is in rect
238- }
239- return Math . min ( Math . abs ( dy1 ) , Math . abs ( dy2 ) ) ;
240- }
241- if ( dy1 * dy2 <= 0 ) {
242- // y is between y1 and y2
243- // We know it is already inside the rectangle
244- return Math . min ( Math . abs ( dx1 ) , Math . abs ( dx2 ) ) ;
245- }
246- let corner : [ number , number ] ;
247- if ( x > x2 ) {
248- // Upper-right vs lower-right.
249- corner = y > y2 ? [ x2 , y2 ] : [ x2 , y1 ] ;
250- } else {
251- // Upper-left vs lower-left.
252- corner = y > y2 ? [ x1 , y2 ] : [ x1 , y1 ] ;
253- }
254- return Math . sqrt ( vector . dist22D ( [ x , y ] , corner ) ) ;
255- }
256171/**
257172 * Returns the nearest neighbors of a particular point.
258173 *
@@ -281,5 +196,3 @@ export function findKNNofPoint<T>(
281196 }
282197 return kMin . getMinKItems ( ) ;
283198}
284-
285- export const TEST_ONLY = { OPTIMAL_GPU_BLOCK_SIZE } ;
0 commit comments