Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions chromadb/api/rust.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,14 +229,25 @@ def create_collection(
else:
configuration_json_str = None

if schema:
schema_json_str = schema.serialize_to_json()
else:
schema_json_str = None

collection = self.bindings.create_collection(
name, configuration_json_str, metadata, get_or_create, tenant, database
name,
configuration_json_str,
schema_json_str,
metadata,
get_or_create,
tenant,
database,
)
collection_model = CollectionModel(
id=collection.id,
name=collection.name,
configuration_json=collection.configuration,
serialized_schema=None,
serialized_schema=collection.schema,
metadata=collection.metadata,
dimension=collection.dimension,
tenant=collection.tenant,
Expand All @@ -256,7 +267,7 @@ def get_collection(
id=collection.id,
name=collection.name,
configuration_json=collection.configuration,
serialized_schema=None,
serialized_schema=collection.schema,
metadata=collection.metadata,
dimension=collection.dimension,
tenant=collection.tenant,
Expand Down
1 change: 1 addition & 0 deletions chromadb/chromadb_rust_bindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class Bindings:
self,
name: str,
configuration_json_str: Optional[str] = None,
schema_json_str: Optional[str] = None,
metadata: Optional[CollectionMetadata] = None,
get_or_create: bool = False,
tenant: str = DEFAULT_TENANT,
Expand Down
16 changes: 11 additions & 5 deletions rust/python_bindings/src/bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,22 +252,22 @@ impl Bindings {

#[allow(clippy::too_many_arguments)]
#[pyo3(
signature = (name, configuration_json_str, metadata = None, get_or_create = false, tenant = DEFAULT_TENANT.to_string(), database = DEFAULT_DATABASE.to_string())
signature = (name, configuration_json_str = None, schema_json_str = None, metadata = None, get_or_create = false, tenant = DEFAULT_TENANT.to_string(), database = DEFAULT_DATABASE.to_string())
)]
fn create_collection(
&self,
name: String,
configuration_json_str: Option<String>,
schema_json_str: Option<String>,
metadata: Option<Metadata>,
get_or_create: bool,
tenant: String,
database: String,
) -> ChromaPyResult<Collection> {
let configuration_json = match configuration_json_str {
Some(configuration_json_str) => {
let configuration_json =
serde_json::from_str::<CollectionConfiguration>(&configuration_json_str)
.map_err(WrappedSerdeJsonError::SerdeJsonError)?;
let configuration_json = serde_json::from_str(&configuration_json_str)
.map_err(WrappedSerdeJsonError::SerdeJsonError)?;

Some(configuration_json)
}
Expand All @@ -291,13 +291,19 @@ impl Bindings {
)?),
};

let schema = match schema_json_str {
Some(schema_json_string) => serde_json::from_str(&schema_json_string)
.map_err(WrappedSerdeJsonError::SerdeJsonError)?,
None => None,
};

let request = CreateCollectionRequest::try_new(
tenant,
database,
name,
metadata,
configuration,
None,
schema,
get_or_create,
)?;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-- Stores collection configuration dictionaries.
ALTER TABLE collections ADD COLUMN schema_json_str TEXT;
1 change: 1 addition & 0 deletions rust/sqlite/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pub enum Collections {
Dimension,
DatabaseId,
ConfigJsonStr,
SchemaJsonStr,
}

#[derive(Iden)]
Expand Down
68 changes: 37 additions & 31 deletions rust/sysdb/src/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ use chroma_types::{
CreateTenantError, CreateTenantResponse, Database, DatabaseUuid, DeleteCollectionError,
DeleteDatabaseError, DeleteDatabaseResponse, GetCollectionWithSegmentsError,
GetCollectionsError, GetDatabaseError, GetSegmentsError, GetTenantError, GetTenantResponse,
InternalCollectionConfiguration, InternalUpdateCollectionConfiguration, ListDatabasesError,
Metadata, MetadataValue, ResetError, ResetResponse, Segment, SegmentScope, SegmentType,
SegmentUuid, UpdateCollectionError, UpdateTenantError, UpdateTenantResponse,
InternalCollectionConfiguration, InternalSchema, InternalUpdateCollectionConfiguration,
KnnIndex, ListDatabasesError, Metadata, MetadataValue, ResetError, ResetResponse, Segment,
SegmentScope, SegmentType, SegmentUuid, UpdateCollectionError, UpdateTenantError,
UpdateTenantResponse,
};
use futures::TryStreamExt;
use sea_query_binder::SqlxBinder;
Expand Down Expand Up @@ -250,7 +251,7 @@ impl SqliteSysDb {
collection_id: CollectionUuid,
name: String,
segments: Vec<Segment>,
configuration: InternalCollectionConfiguration,
schema: InternalSchema,
metadata: Option<Metadata>,
dimension: Option<i32>,
get_or_create: bool,
Expand Down Expand Up @@ -307,13 +308,13 @@ impl SqliteSysDb {
sqlx::query(
r#"
INSERT INTO collections
(id, name, config_json_str, dimension, database_id)
(id, name, schema_json_str, dimension, database_id)
VALUES ($1, $2, $3, $4, $5)
"#,
)
.bind(collection_id.to_string())
.bind(&name)
.bind(serde_json::to_string(&configuration).map_err(CreateCollectionError::Configuration)?)
.bind(serde_json::to_string(&schema).map_err(CreateCollectionError::Configuration)?)
.bind(dimension)
.bind(database_id)
.execute(&mut *tx)
Expand Down Expand Up @@ -345,9 +346,9 @@ impl SqliteSysDb {
name,
tenant,
database,
config: configuration,
config: InternalCollectionConfiguration::default_hnsw(),
metadata,
schema: None,
schema: Some(schema),
dimension,
log_position: 0,
total_records_post_compaction: 0,
Expand Down Expand Up @@ -683,6 +684,7 @@ impl SqliteSysDb {
.column((table::Collections::Table, table::Collections::Id))
.column((table::Collections::Table, table::Collections::Name))
.column((table::Collections::Table, table::Collections::ConfigJsonStr))
.column((table::Collections::Table, table::Collections::SchemaJsonStr))
.column((table::Collections::Table, table::Collections::Dimension))
.column((table::Collections::Table, table::Collections::DatabaseId))
.inner_join(
Expand Down Expand Up @@ -735,6 +737,7 @@ impl SqliteSysDb {
.column((table::Collections::Table, table::Collections::Id))
.column((table::Collections::Table, table::Collections::Name))
.column((table::Collections::Table, table::Collections::ConfigJsonStr))
.column((table::Collections::Table, table::Collections::SchemaJsonStr))
.column((table::Collections::Table, table::Collections::Dimension))
.column((table::Databases::Table, table::Databases::TenantId))
.column((table::Databases::Table, table::Databases::Name))
Expand Down Expand Up @@ -778,16 +781,19 @@ impl SqliteSysDb {
let first_row = rows.first().unwrap();

let configuration = match first_row.get::<Option<&str>, _>(2) {
Some(json_str) => {
match serde_json::from_str::<InternalCollectionConfiguration>(json_str)
.map_err(GetCollectionsError::Configuration)
{
Ok(configuration) => configuration,
Err(e) => return Some(Err(e)),
}
}
Some(json_str) => match serde_json::from_str(json_str) {
Ok(configuration) => configuration,
Err(e) => return Some(Err(GetCollectionsError::Configuration(e))),
},
None => InternalCollectionConfiguration::default_hnsw(),
};
let schema = match first_row.get::<Option<&str>, _>(3) {
Some(json_str) => match serde_json::from_str(json_str) {
Ok(schema) => schema,
Err(e) => return Some(Err(GetCollectionsError::Configuration(e))),
},
None => InternalSchema::new_default(KnnIndex::Hnsw),
};
let database_id = match DatabaseUuid::from_str(first_row.get(6)) {
Ok(db_id) => db_id,
Err(_) => return Some(Err(GetCollectionsError::DatabaseId)),
Expand All @@ -796,15 +802,15 @@ impl SqliteSysDb {
Some(Ok(Collection {
collection_id,
config: configuration,
schema: None,
schema: Some(schema),
metadata,
total_records_post_compaction: 0,
version: 0,
log_position: 0,
dimension: first_row.get(3),
dimension: first_row.get(4),
name: first_row.get(1),
tenant: first_row.get(4),
database: first_row.get(5),
tenant: first_row.get(5),
database: first_row.get(6),
size_bytes_post_compaction: 0,
last_compaction_time_secs: 0,
version_file_path: None,
Expand Down Expand Up @@ -1112,7 +1118,7 @@ mod tests {
use super::*;
use chroma_sqlite::db::test_utils::get_new_sqlite_db;
use chroma_types::{
InternalUpdateCollectionConfiguration, SegmentScope, SegmentType, SegmentUuid,
InternalUpdateCollectionConfiguration, KnnIndex, SegmentScope, SegmentType, SegmentUuid,
UpdateHnswConfiguration, UpdateMetadata, UpdateMetadataValue,
UpdateVectorIndexConfiguration, VectorIndexConfiguration,
};
Expand Down Expand Up @@ -1294,7 +1300,7 @@ mod tests {
collection_id,
"test_collection".to_string(),
segments.clone(),
InternalCollectionConfiguration::default_hnsw(),
InternalSchema::new_default(KnnIndex::Hnsw),
Some(collection_metadata.clone()),
None,
false,
Expand Down Expand Up @@ -1337,7 +1343,7 @@ mod tests {
collection_id,
"test_collection".to_string(),
segments.clone(),
InternalCollectionConfiguration::default_hnsw(),
InternalSchema::new_default(KnnIndex::Hnsw),
None,
None,
false,
Expand All @@ -1354,7 +1360,7 @@ mod tests {
collection_id,
"test_collection".to_string(),
segments,
InternalCollectionConfiguration::default_hnsw(),
InternalSchema::new_default(KnnIndex::Hnsw),
None,
None,
false,
Expand Down Expand Up @@ -1384,7 +1390,7 @@ mod tests {
collection_id,
"test_collection".to_string(),
segments.clone(),
InternalCollectionConfiguration::default_hnsw(),
InternalSchema::new_default(KnnIndex::Hnsw),
None,
None,
false,
Expand All @@ -1401,7 +1407,7 @@ mod tests {
CollectionUuid::new(),
"test_collection".to_string(),
vec![],
InternalCollectionConfiguration::default_hnsw(),
InternalSchema::new_default(KnnIndex::Hnsw),
None,
None,
true,
Expand All @@ -1424,7 +1430,7 @@ mod tests {
collection_id,
"test_collection".to_string(),
vec![],
InternalCollectionConfiguration::default_hnsw(),
InternalSchema::new_default(KnnIndex::Hnsw),
None,
None,
false,
Expand Down Expand Up @@ -1497,7 +1503,7 @@ mod tests {
collection_id,
"test_collection".to_string(),
vec![],
InternalCollectionConfiguration::default_hnsw(),
InternalSchema::new_default(KnnIndex::Hnsw),
None,
None,
false,
Expand Down Expand Up @@ -1578,7 +1584,7 @@ mod tests {
collection_id,
"test_collection".to_string(),
segments.clone(),
InternalCollectionConfiguration::default_hnsw(),
InternalSchema::new_default(KnnIndex::Hnsw),
Some(collection_metadata.clone()),
None,
false,
Expand Down Expand Up @@ -1628,7 +1634,7 @@ mod tests {
collection_id,
"test_collection".to_string(),
segments.clone(),
InternalCollectionConfiguration::default_hnsw(),
InternalSchema::new_default(KnnIndex::Hnsw),
Some(collection_metadata.clone()),
None,
false,
Expand Down Expand Up @@ -1658,7 +1664,7 @@ mod tests {
collection_id,
"test_collection".to_string(),
vec![],
InternalCollectionConfiguration::default_hnsw(),
InternalSchema::new_default(KnnIndex::Hnsw),
None,
None,
false,
Expand Down
9 changes: 8 additions & 1 deletion rust/sysdb/src/sysdb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,14 +320,21 @@ impl SysDb {
.await
}
SysDb::Sqlite(sqlite) => {
let reconciled_schema = InternalSchema::reconcile_schema_and_config(
schema,
configuration,
)
.map_err(|err| {
CreateCollectionError::Schema(SchemaError::InvalidSchema { reason: err })
})?;
sqlite
.create_collection(
tenant,
database,
collection_id,
name,
segments,
configuration.unwrap_or(InternalCollectionConfiguration::default_hnsw()),
reconciled_schema,
metadata,
dimension,
get_or_create,
Expand Down
Loading