Skip to content

Commit 6e3d6c0

Browse files
committed
feat: add strategies for building query embedding vector
1 parent baedf85 commit 6e3d6c0

File tree

2 files changed

+74
-16
lines changed

2 files changed

+74
-16
lines changed

crates/llm-ls/src/main.rs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use uuid::Uuid;
3333
use crate::backend::{build_body, build_headers, parse_generations};
3434
use crate::document::Document;
3535
use crate::error::{internal_error, Error, Result};
36+
use crate::retrieval::BuildFrom;
3637

3738
mod backend;
3839
mod config;
@@ -238,11 +239,21 @@ async fn build_prompt(
238239
after_line = after_iter.next();
239240
}
240241
let before = before.into_iter().rev().collect::<Vec<_>>().join("");
242+
let query = snippet_retriever
243+
.read()
244+
.await
245+
.build_query(
246+
format!("{before}{after}"),
247+
BuildFrom::Cursor {
248+
cursor_position: before.len(),
249+
},
250+
)
251+
.await?;
241252
let snippets = snippet_retriever
242253
.read()
243254
.await
244255
.search(
245-
format!("{before}{after}"),
256+
&query,
246257
Some(FilterBuilder::new().comparison(
247258
"file_url".to_owned(),
248259
Compare::Neq,
@@ -281,11 +292,16 @@ async fn build_prompt(
281292
before.push(line);
282293
}
283294
let prompt = before.into_iter().rev().collect::<Vec<_>>().join("");
295+
let query = snippet_retriever
296+
.read()
297+
.await
298+
.build_query(prompt.clone(), BuildFrom::End)
299+
.await?;
284300
let snippets = snippet_retriever
285301
.read()
286302
.await
287303
.search(
288-
prompt.clone(),
304+
&query,
289305
Some(FilterBuilder::new().comparison(
290306
"file_url".to_owned(),
291307
Compare::Neq,

crates/llm-ls/src/retrieval.rs

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ impl SnippetRetriever {
251251
})
252252
}
253253

254-
pub(crate) async fn initialse_database(&mut self, db_name: &str) -> Result<Db> {
254+
pub(crate) async fn initialise_database(&mut self, db_name: &str) -> Result<Db> {
255255
let uri = self.cache_path.join(db_name);
256256
let mut db = Db::open(uri).await.expect("failed to open database");
257257
match db
@@ -282,13 +282,15 @@ impl SnippetRetriever {
282282
debug!("building workspace snippets");
283283
let workspace_root = PathBuf::from(workspace_root);
284284
if self.db.is_none() {
285-
self.initialse_database(
285+
self.initialise_database(&format!(
286+
"{}--{}",
286287
workspace_root
287288
.file_name()
288289
.ok_or_else(|| Error::NoFinalPath(workspace_root.clone()))?
289290
.to_str()
290291
.ok_or(Error::NonUnicode)?,
291-
)
292+
self.model_config.id.replace('/', "--"),
293+
))
292294
.await?;
293295
}
294296
let mut files = Vec::new();
@@ -377,29 +379,60 @@ impl SnippetRetriever {
377379
Ok(())
378380
}
379381

380-
pub(crate) async fn search(
382+
pub(crate) async fn build_query(
381383
&self,
382384
snippet: String,
385+
strategy: BuildFrom,
386+
) -> Result<Vec<f32>> {
387+
match strategy {
388+
BuildFrom::Start => {
389+
let mut encoding = self.tokenizer.encode(snippet.clone(), true)?;
390+
encoding.truncate(
391+
self.model_config.max_input_size,
392+
1,
393+
TruncationDirection::Right,
394+
);
395+
self.generate_embedding(encoding, self.model.clone()).await
396+
}
397+
BuildFrom::Cursor { cursor_position } => {
398+
let (before, after) = snippet.split_at(cursor_position);
399+
let mut before_encoding = self.tokenizer.encode(before, true)?;
400+
let mut after_encoding = self.tokenizer.encode(after, true)?;
401+
let share = self.model_config.max_input_size / 2;
402+
before_encoding.truncate(share, 1, TruncationDirection::Left);
403+
after_encoding.truncate(share, 1, TruncationDirection::Right);
404+
before_encoding.take_overflowing();
405+
after_encoding.take_overflowing();
406+
before_encoding.merge_with(after_encoding, false);
407+
self.generate_embedding(before_encoding, self.model.clone())
408+
.await
409+
}
410+
BuildFrom::End => {
411+
let mut encoding = self.tokenizer.encode(snippet.clone(), true)?;
412+
encoding.truncate(
413+
self.model_config.max_input_size,
414+
1,
415+
TruncationDirection::Left,
416+
);
417+
self.generate_embedding(encoding, self.model.clone()).await
418+
}
419+
}
420+
}
421+
422+
pub(crate) async fn search(
423+
&self,
424+
query: &[f32],
383425
filter: Option<FilterBuilder>,
384426
) -> Result<Vec<Snippet>> {
385427
let db = match self.db.as_ref() {
386428
Some(db) => db.clone(),
387429
None => return Err(Error::UninitialisedDatabase),
388430
};
389431
let col = db.get_collection(&self.collection_name).await?;
390-
let mut encoding = self.tokenizer.encode(snippet.clone(), true)?;
391-
encoding.truncate(
392-
self.model_config.max_input_size,
393-
1,
394-
TruncationDirection::Right,
395-
);
396-
let query = self
397-
.generate_embedding(encoding, self.model.clone())
398-
.await?;
399432
let result = col
400433
.read()
401434
.await
402-
.get(&query, 5, filter)
435+
.get(query, 5, filter)
403436
.await?
404437
.iter()
405438
.map(TryInto::try_into)
@@ -537,3 +570,12 @@ impl SnippetRetriever {
537570
Ok(())
538571
}
539572
}
573+
574+
pub(crate) enum BuildFrom {
575+
Cursor {
576+
cursor_position: usize,
577+
},
578+
End,
579+
#[allow(dead_code)]
580+
Start,
581+
}

0 commit comments

Comments
 (0)