@@ -377,29 +377,60 @@ impl SnippetRetriever {
377
377
Ok ( ( ) )
378
378
}
379
379
380
- pub ( crate ) async fn search (
380
+ pub ( crate ) async fn build_query (
381
381
& self ,
382
382
snippet : String ,
383
+ strategy : BuildFrom ,
384
+ ) -> Result < Vec < f32 > > {
385
+ match strategy {
386
+ BuildFrom :: Start => {
387
+ let mut encoding = self . tokenizer . encode ( snippet. clone ( ) , true ) ?;
388
+ encoding. truncate (
389
+ self . model_config . max_input_size ,
390
+ 1 ,
391
+ TruncationDirection :: Right ,
392
+ ) ;
393
+ self . generate_embedding ( encoding, self . model . clone ( ) ) . await
394
+ }
395
+ BuildFrom :: Cursor { cursor_position } => {
396
+ let ( before, after) = snippet. split_at ( cursor_position) ;
397
+ let mut before_encoding = self . tokenizer . encode ( before, true ) ?;
398
+ let mut after_encoding = self . tokenizer . encode ( after, true ) ?;
399
+ let share = self . model_config . max_input_size / 2 ;
400
+ before_encoding. truncate ( share, 1 , TruncationDirection :: Left ) ;
401
+ after_encoding. truncate ( share, 1 , TruncationDirection :: Right ) ;
402
+ before_encoding. take_overflowing ( ) ;
403
+ after_encoding. take_overflowing ( ) ;
404
+ before_encoding. merge_with ( after_encoding, false ) ;
405
+ self . generate_embedding ( before_encoding, self . model . clone ( ) )
406
+ . await
407
+ }
408
+ BuildFrom :: End => {
409
+ let mut encoding = self . tokenizer . encode ( snippet. clone ( ) , true ) ?;
410
+ encoding. truncate (
411
+ self . model_config . max_input_size ,
412
+ 1 ,
413
+ TruncationDirection :: Left ,
414
+ ) ;
415
+ self . generate_embedding ( encoding, self . model . clone ( ) ) . await
416
+ }
417
+ }
418
+ }
419
+
420
+ pub ( crate ) async fn search (
421
+ & self ,
422
+ query : & [ f32 ] ,
383
423
filter : Option < FilterBuilder > ,
384
424
) -> Result < Vec < Snippet > > {
385
425
let db = match self . db . as_ref ( ) {
386
426
Some ( db) => db. clone ( ) ,
387
427
None => return Err ( Error :: UninitialisedDatabase ) ,
388
428
} ;
389
429
let col = db. get_collection ( & self . collection_name ) . await ?;
390
- let mut encoding = self . tokenizer . encode ( snippet. clone ( ) , true ) ?;
391
- encoding. truncate (
392
- self . model_config . max_input_size ,
393
- 1 ,
394
- TruncationDirection :: Right ,
395
- ) ;
396
- let query = self
397
- . generate_embedding ( encoding, self . model . clone ( ) )
398
- . await ?;
399
430
let result = col
400
431
. read ( )
401
432
. await
402
- . get ( & query, 5 , filter)
433
+ . get ( query, 5 , filter)
403
434
. await ?
404
435
. iter ( )
405
436
. map ( TryInto :: try_into)
@@ -537,3 +568,12 @@ impl SnippetRetriever {
537
568
Ok ( ( ) )
538
569
}
539
570
}
571
+
572
+ pub ( crate ) enum BuildFrom {
573
+ Cursor {
574
+ cursor_position : usize ,
575
+ } ,
576
+ End ,
577
+ #[ allow( dead_code) ]
578
+ Start ,
579
+ }
0 commit comments