Skip to content

Commit 951c262

Browse files
authored
[ENH] Use Idf if the index is BM25 (#5609)
## Description of changes _Summarize the changes made by this PR._ - Improvements & Bug fixes - N/A - New functionality - Add `bm25` flag in sparse vector config for schema - Scale query by IDF automatically if the key corresponds to BM25 index ## Test plan _How are these changes tested?_ - [ ] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Migration plan _Are there any migrations, or any forwards/backwards compatibility changes needed in order to make sure this change deploys reliably?_ ## Observability plan _What is the plan to instrument and monitor this change?_ ## Documentation Changes _Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the_ [_docs section](https://github.com/chroma-core/chroma/tree/main/docs/docs.trychroma.com)?_
1 parent bd4398d commit 951c262

File tree

6 files changed

+56
-20
lines changed

6 files changed

+56
-20
lines changed

clients/new-js/packages/chromadb/src/api/types.gen.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,10 @@ export type SparseVector = {
380380
};
381381

382382
export type SparseVectorIndexConfig = {
383+
/**
384+
* Whether this embedding is BM25
385+
*/
386+
bm25?: boolean | null;
383387
embedding_function?: null | EmbeddingFunctionConfiguration;
384388
/**
385389
* Key to source the sparse vector from

rust/frontend/src/server.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1989,11 +1989,13 @@ async fn collection_get(
19891989
payload.offset.unwrap_or(0),
19901990
payload.include,
19911991
)?;
1992-
let res = server
1993-
.frontend
1994-
.get(request)
1995-
.meter(metering_context_container)
1996-
.await?;
1992+
let res = Box::pin(
1993+
server
1994+
.frontend
1995+
.get(request)
1996+
.meter(metering_context_container),
1997+
)
1998+
.await?;
19971999
Ok(Json(res))
19982000
}
19992001

rust/frontend/tests/proptest_helpers/frontend_under_test.rs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,8 @@ impl StateMachineTest for FrontendUnderTest {
110110
// Update stats
111111
{
112112
if request.r#where.is_some() {
113-
let filtered_records = state
114-
.frontend
115-
.clone()
116-
.get(
113+
let filtered_records = Box::pin(
114+
state.frontend.clone().get(
117115
GetRequest::try_new(
118116
collection.tenant,
119117
collection.database,
@@ -125,9 +123,10 @@ impl StateMachineTest for FrontendUnderTest {
125123
IncludeList(vec![]),
126124
)
127125
.unwrap(),
128-
)
129-
.await
130-
.unwrap();
126+
),
127+
)
128+
.await
129+
.unwrap();
131130

132131
STATS.with_borrow_mut(|stats| {
133132
stats.num_log_operations += filtered_records.ids.len()
@@ -199,7 +198,7 @@ impl StateMachineTest for FrontendUnderTest {
199198
request.tenant_id = collection.tenant;
200199
request.database_name = collection.database;
201200

202-
state.frontend.get(request.clone()).await.unwrap()
201+
Box::pin(state.frontend.get(request.clone())).await.unwrap()
203202
};
204203

205204
check_get_responses_are_close_to_equal(expected_result, received_result);
@@ -332,8 +331,8 @@ impl StateMachineTest for FrontendUnderTest {
332331
)
333332
.unwrap();
334333

335-
let received_results = frontend_under_test
336-
.get(
334+
let received_results = Box::pin(
335+
frontend_under_test.get(
337336
GetRequest::try_new(
338337
collection_under_test.tenant,
339338
collection_under_test.database,
@@ -345,9 +344,10 @@ impl StateMachineTest for FrontendUnderTest {
345344
IncludeList::default_get(),
346345
)
347346
.unwrap(),
348-
)
349-
.await
350-
.unwrap();
347+
),
348+
)
349+
.await
350+
.unwrap();
351351

352352
check_get_responses_are_close_to_equal(expected_results, received_results);
353353
});

rust/python_bindings/src/bindings.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,7 @@ impl Bindings {
623623
let mut frontend_clone = self.frontend.clone();
624624
let result = py.allow_threads(move || {
625625
self.runtime
626-
.block_on(async { frontend_clone.get(request).await })
626+
.block_on(async { Box::pin(frontend_clone.get(request)).await })
627627
})?;
628628
Ok(result)
629629
}

rust/types/src/collection_schema.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ impl InternalSchema {
311311
config: SparseVectorIndexConfig {
312312
embedding_function: Some(EmbeddingFunctionConfiguration::Legacy),
313313
source_key: None,
314+
bm25: Some(false),
314315
},
315316
}),
316317
}),
@@ -909,6 +910,7 @@ impl InternalSchema {
909910
.clone()
910911
.or(default.embedding_function.clone()),
911912
source_key: user.source_key.clone().or(default.source_key.clone()),
913+
bm25: user.bm25.or(default.bm25),
912914
})
913915
}
914916

@@ -1436,6 +1438,9 @@ pub struct SparseVectorIndexConfig {
14361438
/// Key to source the sparse vector from
14371439
#[serde(skip_serializing_if = "Option::is_none")]
14381440
pub source_key: Option<String>,
1441+
/// Whether this embedding is BM25
1442+
#[serde(skip_serializing_if = "Option::is_none")]
1443+
pub bm25: Option<bool>,
14391444
}
14401445

14411446
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, ToSchema)]
@@ -1791,6 +1796,7 @@ mod tests {
17911796
config: SparseVectorIndexConfig {
17921797
embedding_function: Some(EmbeddingFunctionConfiguration::Legacy),
17931798
source_key: None,
1799+
bm25: None,
17941800
},
17951801
}),
17961802
}),
@@ -2072,11 +2078,13 @@ mod tests {
20722078
let default_config = SparseVectorIndexConfig {
20732079
embedding_function: Some(EmbeddingFunctionConfiguration::Legacy),
20742080
source_key: Some("default_sparse_key".to_string()),
2081+
bm25: None,
20752082
};
20762083

20772084
let user_config = SparseVectorIndexConfig {
20782085
embedding_function: None, // Will use default
20792086
source_key: Some("user_sparse_key".to_string()), // Override
2087+
bm25: None,
20802088
};
20812089

20822090
let result =

rust/worker/src/execution/orchestration/sparse_knn.rs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,29 @@ impl Orchestrator for SparseKnnOrchestrator {
209209
&mut self,
210210
ctx: &ComponentContext<Self>,
211211
) -> Vec<(TaskMessage, Option<Span>)> {
212-
if self.use_bm25 {
212+
let use_bm25 = self.use_bm25
213+
|| self
214+
.collection_and_segments
215+
.collection
216+
.schema
217+
.as_ref()
218+
.is_some_and(|schema| {
219+
if let Some(flag) = schema.key_overrides.get(&self.key).and_then(|uvt| {
220+
uvt.sparse_vector.as_ref().and_then(|vt| {
221+
vt.sparse_vector_index
222+
.as_ref()
223+
.and_then(|it| it.config.bm25)
224+
})
225+
}) {
226+
return flag;
227+
}
228+
schema.defaults.sparse_vector.as_ref().is_some_and(|vt| {
229+
vt.sparse_vector_index
230+
.as_ref()
231+
.is_some_and(|it| it.config.bm25.unwrap_or_default())
232+
})
233+
});
234+
if use_bm25 {
213235
let idf_task = wrap(
214236
Box::new(Idf {
215237
query: self.query.clone(),

0 commit comments

Comments
 (0)