Skip to content

Commit 11d2906

Browse files
committed
feat: add tinyvec-embed
1 parent d0a02eb commit 11d2906

File tree

9 files changed

+1248
-604
lines changed

9 files changed

+1248
-604
lines changed

Cargo.lock

Lines changed: 845 additions & 590 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/llm-ls/Cargo.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ edition = "2021"
77
name = "llm-ls"
88

99
[dependencies]
10-
arrow-array = "49"
11-
arrow-schema = "49"
10+
arrow-array = "50"
11+
arrow-schema = "50"
1212
candle = { version = "0.3", package = "candle-core", default-features = false }
1313
candle-nn = "0.3"
1414
candle-transformers = "0.3"
@@ -50,15 +50,15 @@ tree-sitter-css = "0.20"
5050
tree-sitter-elixir = "0.1"
5151
tree-sitter-erlang = "0.4"
5252
tree-sitter-go = "0.20"
53-
tree-sitter-html = "0.19"
53+
tree-sitter-html = "0.20"
5454
tree-sitter-java = "0.20"
5555
tree-sitter-javascript = "0.20"
5656
tree-sitter-json = "0.20"
5757
tree-sitter-kotlin = "0.3.1"
5858
tree-sitter-lua = "0.0.19"
5959
tree-sitter-md = "0.1"
6060
tree-sitter-objc = "3"
61-
tree-sitter-php = "0.21"
61+
tree-sitter-php = "0.22"
6262
tree-sitter-python = "0.20"
6363
tree-sitter-r = "0.19"
6464
tree-sitter-ruby = "0.20"

crates/llm-ls/src/retrieval.rs

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -221,16 +221,21 @@ async fn initialse_database(cache_path: PathBuf) -> Arc<dyn Table> {
221221
],
222222
)
223223
.expect("failure while defining schema");
224-
db.create_table(
225-
"code-slices",
226-
Box::new(RecordBatchIterator::new(
227-
vec![batch].into_iter().map(Ok),
228-
schema,
229-
)),
230-
None,
231-
)
232-
.await
233-
.expect("failed to create table")
224+
let tbl = db
225+
.create_table(
226+
"code-slices",
227+
Box::new(RecordBatchIterator::new(vec![].into_iter().map(Ok), schema)),
228+
None,
229+
)
230+
.await
231+
.expect("failed to create table");
232+
tbl.create_index(&["vector"])
233+
.ivf_pq()
234+
.num_partitions(256)
235+
.build()
236+
.await
237+
.expect("failed to create index");
238+
tbl
234239
}
235240
Err(err) => panic!("error while opening table: {}", err),
236241
}

crates/tinyvec-embed/Cargo.toml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
[package]
2+
name = "tinyvec-embed"
3+
version = "0.1.0"
4+
edition.workspace = true
5+
license.workspace = true
6+
authors.workspace = true
7+
8+
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
9+
10+
[dependencies]
11+
bincode = "1"
12+
serde = "1"
13+
thiserror = "1"
14+
tokio = { version = "1", features = [
15+
"fs",
16+
"macros",
17+
"rt-multi-thread",
18+
"sync",
19+
] }
20+
tracing = "0.1"
21+
22+
[dependencies.uuid]
23+
version = "1.7.0"
24+
features = ["v4", "fast-rng", "macro-diagnostics"]

crates/tinyvec-embed/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# tinyvec-embed
2+
3+
Tiny embedded vector database.
4+
5+
Inspired by [tinyvector](https://github.com/m1guelpf/tinyvector).

crates/tinyvec-embed/src/db.rs

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
use serde::{Deserialize, Serialize};
2+
use std::{
3+
collections::{BinaryHeap, HashMap},
4+
fs,
5+
path::{Path, PathBuf},
6+
sync::Arc,
7+
};
8+
use tokio::{sync::Semaphore, task::JoinSet};
9+
use tracing::debug;
10+
use uuid::Uuid;
11+
12+
use crate::{
13+
error::{Collection as Error, Result},
14+
similarity::{Distance, ScoreIndex},
15+
};
16+
17+
#[derive(Debug, Serialize, Deserialize)]
18+
pub struct Db {
19+
pub collections: HashMap<String, Collection>,
20+
pub location: PathBuf,
21+
}
22+
23+
impl Db {
24+
pub fn open(path: impl AsRef<Path>) -> Result<Self> {
25+
let path = path.as_ref();
26+
if !path.exists() {
27+
debug!("Creating database store");
28+
fs::create_dir_all(
29+
path.parent()
30+
.ok_or(Error::InvalidPath(path.to_path_buf()))?,
31+
)
32+
.map_err(Into::<Error>::into)?;
33+
34+
return Ok(Self {
35+
collections: HashMap::new(),
36+
location: path.to_path_buf(),
37+
});
38+
}
39+
debug!("Loading database from store");
40+
let db = fs::read(path).map_err(Into::<Error>::into)?;
41+
Ok(bincode::deserialize(&db[..]).map_err(Into::<Error>::into)?)
42+
}
43+
44+
pub fn create_collection(
45+
&mut self,
46+
name: String,
47+
dimension: usize,
48+
distance: Distance,
49+
) -> Result<Collection> {
50+
if self.collections.contains_key(&name) {
51+
return Err(Error::UniqueViolation.into());
52+
}
53+
54+
let collection = Collection {
55+
dimension,
56+
distance,
57+
embeddings: Vec::new(),
58+
};
59+
60+
self.collections.insert(name, collection.clone());
61+
62+
Ok(collection)
63+
}
64+
65+
pub fn delete_collection(&mut self, name: &str) {
66+
self.collections.remove(name);
67+
}
68+
69+
pub fn get_collection(&self, name: &str) -> Result<&Collection> {
70+
self.collections.get(name).ok_or(Error::NotFound.into())
71+
}
72+
73+
fn save_to_store(&self) -> Result<()> {
74+
let db = bincode::serialize(self).map_err(Into::<Error>::into)?;
75+
76+
fs::write(self.location.as_path(), db).map_err(Into::<Error>::into)?;
77+
78+
Ok(())
79+
}
80+
}
81+
82+
impl Drop for Db {
83+
fn drop(&mut self) {
84+
debug!("Saving database to store");
85+
let _ = self.save_to_store();
86+
}
87+
}
88+
89+
#[derive(Debug, Clone, Serialize, Deserialize)]
90+
pub struct SimilarityResult {
91+
score: f32,
92+
embedding: Embedding,
93+
}
94+
95+
#[derive(Debug, Clone, Serialize, Deserialize)]
96+
pub struct Collection {
97+
/// Dimension of the vectors in the collection
98+
pub dimension: usize,
99+
/// Distance metric used for querying
100+
pub distance: Distance,
101+
/// Embeddings in the collection
102+
#[serde(default)]
103+
pub embeddings: Vec<Embedding>,
104+
}
105+
106+
impl Collection {
107+
pub fn filter(&self) -> FilterBuilder {
108+
FilterBuilder::new()
109+
}
110+
111+
pub async fn get(
112+
&self,
113+
query: &[f32],
114+
k: usize,
115+
filter: Option<impl FnMut(&&Embedding) -> bool>,
116+
) -> Result<Vec<SimilarityResult>> {
117+
let embeddings = if let Some(filter) = filter {
118+
self.embeddings.iter().filter(filter).collect::<Vec<_>>()
119+
} else {
120+
self.embeddings.iter().collect::<Vec<_>>()
121+
};
122+
get_similarity(self.distance, &embeddings, query, k).await
123+
}
124+
125+
pub fn insert(&mut self, embedding: Embedding) -> Result<()> {
126+
if embedding.vector.len() != self.dimension {
127+
return Err(Error::DimensionMismatch.into());
128+
}
129+
130+
self.embeddings.push(embedding);
131+
132+
Ok(())
133+
}
134+
}
135+
136+
#[derive(Debug, Clone, Serialize, Deserialize)]
137+
pub struct Embedding {
138+
pub id: Uuid,
139+
pub metadata: Option<HashMap<String, String>>,
140+
pub vector: Vec<f32>,
141+
}
142+
143+
impl Embedding {
144+
pub fn new(vector: Vec<f32>, metadata: Option<HashMap<String, String>>) -> Self {
145+
Self {
146+
id: Uuid::new_v4(),
147+
metadata,
148+
vector,
149+
}
150+
}
151+
}
152+
153+
pub enum Compare {
154+
Eq,
155+
Neq,
156+
Gt,
157+
Lt,
158+
}
159+
160+
#[derive(Clone)]
161+
enum Chain {
162+
And,
163+
Or,
164+
}
165+
166+
pub struct FilterBuilder {
167+
filter: Vec<(String, Compare, String, Option<Chain>)>,
168+
}
169+
170+
impl FilterBuilder {
171+
pub fn new() -> Self {
172+
Self { filter: Vec::new() }
173+
}
174+
175+
pub fn and(mut self) -> Self {
176+
self.filter
177+
.last_mut()
178+
.map(|c| c.3.as_mut().map(|c| *c = Chain::And));
179+
self
180+
}
181+
182+
pub fn or(mut self) -> Self {
183+
self.filter
184+
.last_mut()
185+
.map(|c| c.3.as_mut().map(|c| *c = Chain::Or));
186+
self
187+
}
188+
189+
pub fn condtion(mut self, lhs: String, op: Compare, rhs: String) -> Self {
190+
self.filter.push((lhs, op, rhs, None));
191+
self
192+
}
193+
194+
pub fn build(self) -> impl Fn(&&Embedding) -> bool {
195+
move |e| {
196+
let mut ret = true;
197+
let mut prev = None;
198+
for condition in &self.filter {
199+
let cond_res = match condition.1 {
200+
Compare::Eq => e
201+
.metadata
202+
.as_ref()
203+
.map(|f| f.get(&condition.0) == Some(&condition.2))
204+
.unwrap_or(false),
205+
Compare::Neq => e
206+
.metadata
207+
.as_ref()
208+
.map(|f| f.get(&condition.0) != Some(&condition.2))
209+
.unwrap_or(false),
210+
Compare::Gt => e
211+
.metadata
212+
.as_ref()
213+
.map(|f| f.get(&condition.0) > Some(&condition.2))
214+
.unwrap_or(false),
215+
Compare::Lt => e
216+
.metadata
217+
.as_ref()
218+
.map(|f| f.get(&condition.0) < Some(&condition.2))
219+
.unwrap_or(false),
220+
};
221+
if let Some(prev) = prev {
222+
match prev {
223+
Chain::And => ret = ret && cond_res,
224+
Chain::Or => ret = ret || cond_res,
225+
}
226+
}
227+
prev = condition.3.clone();
228+
}
229+
ret
230+
}
231+
}
232+
}
233+
234+
async fn get_similarity(
235+
distance: Distance,
236+
embeddings: &[&Embedding],
237+
query: &[f32],
238+
k: usize,
239+
) -> Result<Vec<SimilarityResult>> {
240+
let semaphore = Arc::new(Semaphore::new(8));
241+
let mut set = JoinSet::new();
242+
for (index, embedding) in embeddings.into_iter().enumerate() {
243+
let embedding = (*embedding).clone();
244+
let query = query.to_owned();
245+
let permit = semaphore.clone().acquire_owned().await.unwrap();
246+
set.spawn_blocking(move || {
247+
let score = distance.compute(&embedding.vector, &query);
248+
drop(permit);
249+
ScoreIndex { score, index }
250+
});
251+
}
252+
253+
let mut heap = BinaryHeap::new();
254+
while let Some(res) = set.join_next().await {
255+
let score_index = res.map_err(Into::<Error>::into)?;
256+
if heap.len() < k || score_index < *heap.peek().unwrap() {
257+
heap.push(score_index);
258+
259+
if heap.len() > k {
260+
heap.pop();
261+
}
262+
}
263+
}
264+
Ok(heap
265+
.into_sorted_vec()
266+
.into_iter()
267+
.map(|ScoreIndex { score, index }| SimilarityResult {
268+
score,
269+
embedding: embeddings[index].clone(),
270+
})
271+
.collect())
272+
}

crates/tinyvec-embed/src/error.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
use std::path::PathBuf;
2+
3+
#[derive(Debug, thiserror::Error)]
4+
pub enum Collection {
5+
#[error("bincode error: {0}")]
6+
Bincode(#[from] bincode::Error),
7+
#[error("The dimension of the vector doesn't match the dimension of the collection")]
8+
DimensionMismatch,
9+
#[error("io error: {0}")]
10+
Io(#[from] std::io::Error),
11+
#[error("invalid path: {0}")]
12+
InvalidPath(PathBuf),
13+
#[error("join error: {0}")]
14+
Join(#[from] tokio::task::JoinError),
15+
#[error("Collection doesn't exist")]
16+
NotFound,
17+
#[error("error sending message in channel")]
18+
Send,
19+
#[error("Collection already exists")]
20+
UniqueViolation,
21+
}
22+
23+
#[derive(Debug, thiserror::Error)]
24+
pub enum Error {
25+
#[error("collection error: {0}")]
26+
Collection(#[from] Collection),
27+
}
28+
29+
pub type Result<T> = std::result::Result<T, Error>;

crates/tinyvec-embed/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
mod db;
2+
mod error;
3+
mod similarity;

0 commit comments

Comments
 (0)