diff --git a/chromadb/api/rust.py b/chromadb/api/rust.py index 3aae75d030e..11493526b0e 100644 --- a/chromadb/api/rust.py +++ b/chromadb/api/rust.py @@ -229,14 +229,25 @@ def create_collection( else: configuration_json_str = None + if schema: + schema_json_str = schema.serialize_to_json() + else: + schema_json_str = None + collection = self.bindings.create_collection( - name, configuration_json_str, metadata, get_or_create, tenant, database + name, + configuration_json_str, + schema_json_str, + metadata, + get_or_create, + tenant, + database, ) collection_model = CollectionModel( id=collection.id, name=collection.name, configuration_json=collection.configuration, - serialized_schema=None, + serialized_schema=collection.schema, metadata=collection.metadata, dimension=collection.dimension, tenant=collection.tenant, @@ -256,7 +267,7 @@ def get_collection( id=collection.id, name=collection.name, configuration_json=collection.configuration, - serialized_schema=None, + serialized_schema=collection.schema, metadata=collection.metadata, dimension=collection.dimension, tenant=collection.tenant, diff --git a/chromadb/chromadb_rust_bindings.pyi b/chromadb/chromadb_rust_bindings.pyi index a001425ed9c..602e5f2eb16 100644 --- a/chromadb/chromadb_rust_bindings.pyi +++ b/chromadb/chromadb_rust_bindings.pyi @@ -95,6 +95,7 @@ class Bindings: self, name: str, configuration_json_str: Optional[str] = None, + schema_json_str: Optional[str] = None, metadata: Optional[CollectionMetadata] = None, get_or_create: bool = False, tenant: str = DEFAULT_TENANT, diff --git a/rust/python_bindings/src/bindings.rs b/rust/python_bindings/src/bindings.rs index 3a97f2f2426..76bbf6ff1bd 100644 --- a/rust/python_bindings/src/bindings.rs +++ b/rust/python_bindings/src/bindings.rs @@ -252,12 +252,13 @@ impl Bindings { #[allow(clippy::too_many_arguments)] #[pyo3( - signature = (name, configuration_json_str, metadata = None, get_or_create = false, tenant = DEFAULT_TENANT.to_string(), database = DEFAULT_DATABASE.to_string()) + signature = (name, configuration_json_str = None, schema_json_str = None, metadata = None, get_or_create = false, tenant = DEFAULT_TENANT.to_string(), database = DEFAULT_DATABASE.to_string()) )] fn create_collection( &self, name: String, configuration_json_str: Option, + schema_json_str: Option, metadata: Option, get_or_create: bool, tenant: String, @@ -265,9 +266,8 @@ impl Bindings { ) -> ChromaPyResult { let configuration_json = match configuration_json_str { Some(configuration_json_str) => { - let configuration_json = - serde_json::from_str::(&configuration_json_str) - .map_err(WrappedSerdeJsonError::SerdeJsonError)?; + let configuration_json = serde_json::from_str(&configuration_json_str) + .map_err(WrappedSerdeJsonError::SerdeJsonError)?; Some(configuration_json) } @@ -291,13 +291,19 @@ impl Bindings { )?), }; + let schema = match schema_json_str { + Some(schema_json_string) => serde_json::from_str(&schema_json_string) + .map_err(WrappedSerdeJsonError::SerdeJsonError)?, + None => None, + }; + let request = CreateCollectionRequest::try_new( tenant, database, name, metadata, configuration, - None, + schema, get_or_create, )?; diff --git a/rust/sqlite/migrations/sysdb/00010-collection-schema.sqlite.sql b/rust/sqlite/migrations/sysdb/00010-collection-schema.sqlite.sql new file mode 100644 index 00000000000..e9a3568dabb --- /dev/null +++ b/rust/sqlite/migrations/sysdb/00010-collection-schema.sqlite.sql @@ -0,0 +1,2 @@ +-- Stores collection configuration dictionaries. +ALTER TABLE collections ADD COLUMN schema_json_str TEXT; diff --git a/rust/sqlite/src/table.rs b/rust/sqlite/src/table.rs index bb16b6a5862..ffce543124b 100644 --- a/rust/sqlite/src/table.rs +++ b/rust/sqlite/src/table.rs @@ -43,6 +43,7 @@ pub enum Collections { Dimension, DatabaseId, ConfigJsonStr, + SchemaJsonStr, } #[derive(Iden)] diff --git a/rust/sysdb/src/sqlite.rs b/rust/sysdb/src/sqlite.rs index 9b23982ab53..a7b8fa381bf 100644 --- a/rust/sysdb/src/sqlite.rs +++ b/rust/sysdb/src/sqlite.rs @@ -12,9 +12,10 @@ use chroma_types::{ CreateTenantError, CreateTenantResponse, Database, DatabaseUuid, DeleteCollectionError, DeleteDatabaseError, DeleteDatabaseResponse, GetCollectionWithSegmentsError, GetCollectionsError, GetDatabaseError, GetSegmentsError, GetTenantError, GetTenantResponse, - InternalCollectionConfiguration, InternalUpdateCollectionConfiguration, ListDatabasesError, - Metadata, MetadataValue, ResetError, ResetResponse, Segment, SegmentScope, SegmentType, - SegmentUuid, UpdateCollectionError, UpdateTenantError, UpdateTenantResponse, + InternalCollectionConfiguration, InternalSchema, InternalUpdateCollectionConfiguration, + KnnIndex, ListDatabasesError, Metadata, MetadataValue, ResetError, ResetResponse, Segment, + SegmentScope, SegmentType, SegmentUuid, UpdateCollectionError, UpdateTenantError, + UpdateTenantResponse, }; use futures::TryStreamExt; use sea_query_binder::SqlxBinder; @@ -250,7 +251,7 @@ impl SqliteSysDb { collection_id: CollectionUuid, name: String, segments: Vec, - configuration: InternalCollectionConfiguration, + schema: InternalSchema, metadata: Option, dimension: Option, get_or_create: bool, @@ -307,13 +308,13 @@ impl SqliteSysDb { sqlx::query( r#" INSERT INTO collections - (id, name, config_json_str, dimension, database_id) + (id, name, schema_json_str, dimension, database_id) VALUES ($1, $2, $3, $4, $5) "#, ) .bind(collection_id.to_string()) .bind(&name) - .bind(serde_json::to_string(&configuration).map_err(CreateCollectionError::Configuration)?) + .bind(serde_json::to_string(&schema).map_err(CreateCollectionError::Configuration)?) .bind(dimension) .bind(database_id) .execute(&mut *tx) @@ -345,9 +346,9 @@ impl SqliteSysDb { name, tenant, database, - config: configuration, + config: InternalCollectionConfiguration::default_hnsw(), metadata, - schema: None, + schema: Some(schema), dimension, log_position: 0, total_records_post_compaction: 0, @@ -683,6 +684,7 @@ impl SqliteSysDb { .column((table::Collections::Table, table::Collections::Id)) .column((table::Collections::Table, table::Collections::Name)) .column((table::Collections::Table, table::Collections::ConfigJsonStr)) + .column((table::Collections::Table, table::Collections::SchemaJsonStr)) .column((table::Collections::Table, table::Collections::Dimension)) .column((table::Collections::Table, table::Collections::DatabaseId)) .inner_join( @@ -735,6 +737,7 @@ impl SqliteSysDb { .column((table::Collections::Table, table::Collections::Id)) .column((table::Collections::Table, table::Collections::Name)) .column((table::Collections::Table, table::Collections::ConfigJsonStr)) + .column((table::Collections::Table, table::Collections::SchemaJsonStr)) .column((table::Collections::Table, table::Collections::Dimension)) .column((table::Databases::Table, table::Databases::TenantId)) .column((table::Databases::Table, table::Databases::Name)) @@ -778,16 +781,19 @@ impl SqliteSysDb { let first_row = rows.first().unwrap(); let configuration = match first_row.get::, _>(2) { - Some(json_str) => { - match serde_json::from_str::(json_str) - .map_err(GetCollectionsError::Configuration) - { - Ok(configuration) => configuration, - Err(e) => return Some(Err(e)), - } - } + Some(json_str) => match serde_json::from_str(json_str) { + Ok(configuration) => configuration, + Err(e) => return Some(Err(GetCollectionsError::Configuration(e))), + }, None => InternalCollectionConfiguration::default_hnsw(), }; + let schema = match first_row.get::, _>(3) { + Some(json_str) => match serde_json::from_str(json_str) { + Ok(schema) => schema, + Err(e) => return Some(Err(GetCollectionsError::Configuration(e))), + }, + None => InternalSchema::new_default(KnnIndex::Hnsw), + }; let database_id = match DatabaseUuid::from_str(first_row.get(6)) { Ok(db_id) => db_id, Err(_) => return Some(Err(GetCollectionsError::DatabaseId)), @@ -796,15 +802,15 @@ impl SqliteSysDb { Some(Ok(Collection { collection_id, config: configuration, - schema: None, + schema: Some(schema), metadata, total_records_post_compaction: 0, version: 0, log_position: 0, - dimension: first_row.get(3), + dimension: first_row.get(4), name: first_row.get(1), - tenant: first_row.get(4), - database: first_row.get(5), + tenant: first_row.get(5), + database: first_row.get(6), size_bytes_post_compaction: 0, last_compaction_time_secs: 0, version_file_path: None, @@ -1112,7 +1118,7 @@ mod tests { use super::*; use chroma_sqlite::db::test_utils::get_new_sqlite_db; use chroma_types::{ - InternalUpdateCollectionConfiguration, SegmentScope, SegmentType, SegmentUuid, + InternalUpdateCollectionConfiguration, KnnIndex, SegmentScope, SegmentType, SegmentUuid, UpdateHnswConfiguration, UpdateMetadata, UpdateMetadataValue, UpdateVectorIndexConfiguration, VectorIndexConfiguration, }; @@ -1294,7 +1300,7 @@ mod tests { collection_id, "test_collection".to_string(), segments.clone(), - InternalCollectionConfiguration::default_hnsw(), + InternalSchema::new_default(KnnIndex::Hnsw), Some(collection_metadata.clone()), None, false, @@ -1337,7 +1343,7 @@ mod tests { collection_id, "test_collection".to_string(), segments.clone(), - InternalCollectionConfiguration::default_hnsw(), + InternalSchema::new_default(KnnIndex::Hnsw), None, None, false, @@ -1354,7 +1360,7 @@ mod tests { collection_id, "test_collection".to_string(), segments, - InternalCollectionConfiguration::default_hnsw(), + InternalSchema::new_default(KnnIndex::Hnsw), None, None, false, @@ -1384,7 +1390,7 @@ mod tests { collection_id, "test_collection".to_string(), segments.clone(), - InternalCollectionConfiguration::default_hnsw(), + InternalSchema::new_default(KnnIndex::Hnsw), None, None, false, @@ -1401,7 +1407,7 @@ mod tests { CollectionUuid::new(), "test_collection".to_string(), vec![], - InternalCollectionConfiguration::default_hnsw(), + InternalSchema::new_default(KnnIndex::Hnsw), None, None, true, @@ -1424,7 +1430,7 @@ mod tests { collection_id, "test_collection".to_string(), vec![], - InternalCollectionConfiguration::default_hnsw(), + InternalSchema::new_default(KnnIndex::Hnsw), None, None, false, @@ -1497,7 +1503,7 @@ mod tests { collection_id, "test_collection".to_string(), vec![], - InternalCollectionConfiguration::default_hnsw(), + InternalSchema::new_default(KnnIndex::Hnsw), None, None, false, @@ -1578,7 +1584,7 @@ mod tests { collection_id, "test_collection".to_string(), segments.clone(), - InternalCollectionConfiguration::default_hnsw(), + InternalSchema::new_default(KnnIndex::Hnsw), Some(collection_metadata.clone()), None, false, @@ -1628,7 +1634,7 @@ mod tests { collection_id, "test_collection".to_string(), segments.clone(), - InternalCollectionConfiguration::default_hnsw(), + InternalSchema::new_default(KnnIndex::Hnsw), Some(collection_metadata.clone()), None, false, @@ -1658,7 +1664,7 @@ mod tests { collection_id, "test_collection".to_string(), vec![], - InternalCollectionConfiguration::default_hnsw(), + InternalSchema::new_default(KnnIndex::Hnsw), None, None, false, diff --git a/rust/sysdb/src/sysdb.rs b/rust/sysdb/src/sysdb.rs index 527d8db3617..68cbcc3a085 100644 --- a/rust/sysdb/src/sysdb.rs +++ b/rust/sysdb/src/sysdb.rs @@ -320,6 +320,13 @@ impl SysDb { .await } SysDb::Sqlite(sqlite) => { + let reconciled_schema = InternalSchema::reconcile_schema_and_config( + schema, + configuration, + ) + .map_err(|err| { + CreateCollectionError::Schema(SchemaError::InvalidSchema { reason: err }) + })?; sqlite .create_collection( tenant, @@ -327,7 +334,7 @@ impl SysDb { collection_id, name, segments, - configuration.unwrap_or(InternalCollectionConfiguration::default_hnsw()), + reconciled_schema, metadata, dimension, get_or_create,