Skip to content

Commit 58f814d

Browse files
committed
feat: add strategies for building query embedding vector
1 parent baedf85 commit 58f814d

File tree

2 files changed

+69
-13
lines changed

2 files changed

+69
-13
lines changed

crates/llm-ls/src/main.rs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use uuid::Uuid;
3333
use crate::backend::{build_body, build_headers, parse_generations};
3434
use crate::document::Document;
3535
use crate::error::{internal_error, Error, Result};
36+
use crate::retrieval::BuildFrom;
3637

3738
mod backend;
3839
mod config;
@@ -238,11 +239,21 @@ async fn build_prompt(
238239
after_line = after_iter.next();
239240
}
240241
let before = before.into_iter().rev().collect::<Vec<_>>().join("");
242+
let query = snippet_retriever
243+
.read()
244+
.await
245+
.build_query(
246+
format!("{before}{after}"),
247+
BuildFrom::Cursor {
248+
cursor_position: before.len(),
249+
},
250+
)
251+
.await?;
241252
let snippets = snippet_retriever
242253
.read()
243254
.await
244255
.search(
245-
format!("{before}{after}"),
256+
&query,
246257
Some(FilterBuilder::new().comparison(
247258
"file_url".to_owned(),
248259
Compare::Neq,
@@ -281,11 +292,16 @@ async fn build_prompt(
281292
before.push(line);
282293
}
283294
let prompt = before.into_iter().rev().collect::<Vec<_>>().join("");
295+
let query = snippet_retriever
296+
.read()
297+
.await
298+
.build_query(prompt.clone(), BuildFrom::End)
299+
.await?;
284300
let snippets = snippet_retriever
285301
.read()
286302
.await
287303
.search(
288-
prompt.clone(),
304+
&query,
289305
Some(FilterBuilder::new().comparison(
290306
"file_url".to_owned(),
291307
Compare::Neq,

crates/llm-ls/src/retrieval.rs

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -377,29 +377,60 @@ impl SnippetRetriever {
377377
Ok(())
378378
}
379379

380-
pub(crate) async fn search(
380+
pub(crate) async fn build_query(
381381
&self,
382382
snippet: String,
383+
strategy: BuildFrom,
384+
) -> Result<Vec<f32>> {
385+
match strategy {
386+
BuildFrom::Start => {
387+
let mut encoding = self.tokenizer.encode(snippet.clone(), true)?;
388+
encoding.truncate(
389+
self.model_config.max_input_size,
390+
1,
391+
TruncationDirection::Right,
392+
);
393+
self.generate_embedding(encoding, self.model.clone()).await
394+
}
395+
BuildFrom::Cursor { cursor_position } => {
396+
let (before, after) = snippet.split_at(cursor_position);
397+
let mut before_encoding = self.tokenizer.encode(before, true)?;
398+
let mut after_encoding = self.tokenizer.encode(after, true)?;
399+
let share = self.model_config.max_input_size / 2;
400+
before_encoding.truncate(share, 1, TruncationDirection::Left);
401+
after_encoding.truncate(share, 1, TruncationDirection::Right);
402+
before_encoding.take_overflowing();
403+
after_encoding.take_overflowing();
404+
before_encoding.merge_with(after_encoding, false);
405+
self.generate_embedding(before_encoding, self.model.clone())
406+
.await
407+
}
408+
BuildFrom::End => {
409+
let mut encoding = self.tokenizer.encode(snippet.clone(), true)?;
410+
encoding.truncate(
411+
self.model_config.max_input_size,
412+
1,
413+
TruncationDirection::Left,
414+
);
415+
self.generate_embedding(encoding, self.model.clone()).await
416+
}
417+
}
418+
}
419+
420+
pub(crate) async fn search(
421+
&self,
422+
query: &[f32],
383423
filter: Option<FilterBuilder>,
384424
) -> Result<Vec<Snippet>> {
385425
let db = match self.db.as_ref() {
386426
Some(db) => db.clone(),
387427
None => return Err(Error::UninitialisedDatabase),
388428
};
389429
let col = db.get_collection(&self.collection_name).await?;
390-
let mut encoding = self.tokenizer.encode(snippet.clone(), true)?;
391-
encoding.truncate(
392-
self.model_config.max_input_size,
393-
1,
394-
TruncationDirection::Right,
395-
);
396-
let query = self
397-
.generate_embedding(encoding, self.model.clone())
398-
.await?;
399430
let result = col
400431
.read()
401432
.await
402-
.get(&query, 5, filter)
433+
.get(query, 5, filter)
403434
.await?
404435
.iter()
405436
.map(TryInto::try_into)
@@ -537,3 +568,12 @@ impl SnippetRetriever {
537568
Ok(())
538569
}
539570
}
571+
572+
pub(crate) enum BuildFrom {
573+
Cursor {
574+
cursor_position: usize,
575+
},
576+
End,
577+
#[allow(dead_code)]
578+
Start,
579+
}

0 commit comments

Comments
 (0)