@@ -251,7 +251,7 @@ impl SnippetRetriever {
251
251
} )
252
252
}
253
253
254
- pub ( crate ) async fn initialse_database ( & mut self , db_name : & str ) -> Result < Db > {
254
+ pub ( crate ) async fn initialise_database ( & mut self , db_name : & str ) -> Result < Db > {
255
255
let uri = self . cache_path . join ( db_name) ;
256
256
let mut db = Db :: open ( uri) . await . expect ( "failed to open database" ) ;
257
257
match db
@@ -282,13 +282,15 @@ impl SnippetRetriever {
282
282
debug ! ( "building workspace snippets" ) ;
283
283
let workspace_root = PathBuf :: from ( workspace_root) ;
284
284
if self . db . is_none ( ) {
285
- self . initialse_database (
285
+ self . initialise_database ( & format ! (
286
+ "{}--{}" ,
286
287
workspace_root
287
288
. file_name( )
288
289
. ok_or_else( || Error :: NoFinalPath ( workspace_root. clone( ) ) ) ?
289
290
. to_str( )
290
291
. ok_or( Error :: NonUnicode ) ?,
291
- )
292
+ self . model_config. id. replace( '/' , "--" ) ,
293
+ ) )
292
294
. await ?;
293
295
}
294
296
let mut files = Vec :: new ( ) ;
@@ -377,29 +379,60 @@ impl SnippetRetriever {
377
379
Ok ( ( ) )
378
380
}
379
381
380
- pub ( crate ) async fn search (
382
+ pub ( crate ) async fn build_query (
381
383
& self ,
382
384
snippet : String ,
385
+ strategy : BuildFrom ,
386
+ ) -> Result < Vec < f32 > > {
387
+ match strategy {
388
+ BuildFrom :: Start => {
389
+ let mut encoding = self . tokenizer . encode ( snippet. clone ( ) , true ) ?;
390
+ encoding. truncate (
391
+ self . model_config . max_input_size ,
392
+ 1 ,
393
+ TruncationDirection :: Right ,
394
+ ) ;
395
+ self . generate_embedding ( encoding, self . model . clone ( ) ) . await
396
+ }
397
+ BuildFrom :: Cursor { cursor_position } => {
398
+ let ( before, after) = snippet. split_at ( cursor_position) ;
399
+ let mut before_encoding = self . tokenizer . encode ( before, true ) ?;
400
+ let mut after_encoding = self . tokenizer . encode ( after, true ) ?;
401
+ let share = self . model_config . max_input_size / 2 ;
402
+ before_encoding. truncate ( share, 1 , TruncationDirection :: Left ) ;
403
+ after_encoding. truncate ( share, 1 , TruncationDirection :: Right ) ;
404
+ before_encoding. take_overflowing ( ) ;
405
+ after_encoding. take_overflowing ( ) ;
406
+ before_encoding. merge_with ( after_encoding, false ) ;
407
+ self . generate_embedding ( before_encoding, self . model . clone ( ) )
408
+ . await
409
+ }
410
+ BuildFrom :: End => {
411
+ let mut encoding = self . tokenizer . encode ( snippet. clone ( ) , true ) ?;
412
+ encoding. truncate (
413
+ self . model_config . max_input_size ,
414
+ 1 ,
415
+ TruncationDirection :: Left ,
416
+ ) ;
417
+ self . generate_embedding ( encoding, self . model . clone ( ) ) . await
418
+ }
419
+ }
420
+ }
421
+
422
+ pub ( crate ) async fn search (
423
+ & self ,
424
+ query : & [ f32 ] ,
383
425
filter : Option < FilterBuilder > ,
384
426
) -> Result < Vec < Snippet > > {
385
427
let db = match self . db . as_ref ( ) {
386
428
Some ( db) => db. clone ( ) ,
387
429
None => return Err ( Error :: UninitialisedDatabase ) ,
388
430
} ;
389
431
let col = db. get_collection ( & self . collection_name ) . await ?;
390
- let mut encoding = self . tokenizer . encode ( snippet. clone ( ) , true ) ?;
391
- encoding. truncate (
392
- self . model_config . max_input_size ,
393
- 1 ,
394
- TruncationDirection :: Right ,
395
- ) ;
396
- let query = self
397
- . generate_embedding ( encoding, self . model . clone ( ) )
398
- . await ?;
399
432
let result = col
400
433
. read ( )
401
434
. await
402
- . get ( & query, 5 , filter)
435
+ . get ( query, 5 , filter)
403
436
. await ?
404
437
. iter ( )
405
438
. map ( TryInto :: try_into)
@@ -537,3 +570,12 @@ impl SnippetRetriever {
537
570
Ok ( ( ) )
538
571
}
539
572
}
573
+
574
+ pub ( crate ) enum BuildFrom {
575
+ Cursor {
576
+ cursor_position : usize ,
577
+ } ,
578
+ End ,
579
+ #[ allow( dead_code) ]
580
+ Start ,
581
+ }
0 commit comments