diff --git a/chromadb/api/types.py b/chromadb/api/types.py index 08db0286c46..2887429be68 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -39,7 +39,7 @@ ) if TYPE_CHECKING: - pass + from chromadb.execution.expression.operator import Key try: from chromadb.is_thin_client import is_thin_client @@ -1530,10 +1530,34 @@ class VectorIndexConfig(BaseModel): model_config = {"arbitrary_types_allowed": True} space: Optional[Space] = None embedding_function: Optional[Any] = DefaultEmbeddingFunction() - source_key: Optional[str] = None # key to source the vector from + source_key: Optional[str] = None # key to source the vector from (accepts str or Key) hnsw: Optional[HnswIndexConfig] = None spann: Optional[SpannIndexConfig] = None + @field_validator("source_key", mode="before") + @classmethod + def validate_source_key_field(cls, v: Any) -> Optional[str]: + """Convert Key objects to strings automatically. Accepts both str and Key types.""" + if v is None: + return None + # Import Key at runtime to avoid circular import + from chromadb.execution.expression.operator import Key as KeyType + if isinstance(v, KeyType): + v = v.name # Extract string from Key + elif isinstance(v, str): + pass # Already a string + else: + raise ValueError(f"source_key must be str or Key, got {type(v).__name__}") + + # Validate: only #document is allowed if key starts with # + if v.startswith("#") and v != "#document": + raise ValueError( + "source_key cannot begin with '#'. " + "The only valid key starting with '#' is Key.DOCUMENT or '#document'." + ) + + return v # type: ignore[no-any-return] + @field_validator("embedding_function", mode="before") @classmethod def validate_embedding_function_field(cls, v: Any) -> Any: @@ -1553,9 +1577,33 @@ class SparseVectorIndexConfig(BaseModel): model_config = {"arbitrary_types_allowed": True} # TODO(Sanket): Change this to the appropriate sparse ef and use a default here. embedding_function: Optional[Any] = None - source_key: Optional[str] = None # key to source the sparse vector from + source_key: Optional[str] = None # key to source the sparse vector from (accepts str or Key) bm25: Optional[bool] = None + @field_validator("source_key", mode="before") + @classmethod + def validate_source_key_field(cls, v: Any) -> Optional[str]: + """Convert Key objects to strings automatically. Accepts both str and Key types.""" + if v is None: + return None + # Import Key at runtime to avoid circular import + from chromadb.execution.expression.operator import Key as KeyType + if isinstance(v, KeyType): + v = v.name # Extract string from Key + elif isinstance(v, str): + pass # Already a string + else: + raise ValueError(f"source_key must be str or Key, got {type(v).__name__}") + + # Validate: only #document is allowed if key starts with # + if v.startswith("#") and v != "#document": + raise ValueError( + "source_key cannot begin with '#'. " + "The only valid key starting with '#' is Key.DOCUMENT or '#document'." + ) + + return v # type: ignore[no-any-return] + @field_validator("embedding_function", mode="before") @classmethod def validate_embedding_function_field(cls, v: Any) -> Any: @@ -1739,9 +1787,14 @@ def __init__(self) -> None: self._initialize_keys() def create_index( - self, config: Optional[IndexConfig] = None, key: Optional[str] = None + self, config: Optional[IndexConfig] = None, key: Optional[Union[str, "Key"]] = None ) -> "Schema": """Create an index configuration.""" + # Convert Key to string if provided + from chromadb.execution.expression.operator import Key as KeyType + if key is not None and isinstance(key, KeyType): + key = key.name + # Disallow config=None and key=None - too dangerous if config is None and key is None: raise ValueError( @@ -1754,6 +1807,13 @@ def create_index( f"Cannot create index on special key '{key}'. These keys are managed automatically by the system. Invoke create_index(VectorIndexConfig(...)) without specifying a key to configure the vector index globally." ) + # Disallow any key starting with # + if key is not None and key.startswith("#"): + raise ValueError( + "key cannot begin with '#'. " + "Keys starting with '#' are reserved for system use." + ) + # Special handling for vector index if isinstance(config, VectorIndexConfig): if key is None: @@ -1809,9 +1869,14 @@ def create_index( return self def delete_index( - self, config: Optional[IndexConfig] = None, key: Optional[str] = None + self, config: Optional[IndexConfig] = None, key: Optional[Union[str, "Key"]] = None ) -> "Schema": """Disable an index configuration (set enabled=False).""" + # Convert Key to string if provided + from chromadb.execution.expression.operator import Key as KeyType + if key is not None and isinstance(key, KeyType): + key = key.name + # Case 1: Both config and key are None - fail the request if config is None and key is None: raise ValueError( @@ -1824,6 +1889,13 @@ def delete_index( f"Cannot delete index on special key '{key}'. These keys are managed automatically by the system." ) + # Disallow any key starting with # + if key is not None and key.startswith("#"): + raise ValueError( + "key cannot begin with '#'. " + "Keys starting with '#' are reserved for system use." + ) + # TODO: Consider removing these checks in the future to allow disabling vector, FTS, and sparse vector indexes # Temporarily disallow deleting vector index (both globally and per-key) if isinstance(config, VectorIndexConfig): diff --git a/chromadb/test/api/test_schema.py b/chromadb/test/api/test_schema.py index 92b832f62d2..29ac4aef97f 100644 --- a/chromadb/test/api/test_schema.py +++ b/chromadb/test/api/test_schema.py @@ -14,6 +14,7 @@ EmbeddingFunction, Embeddings, ) +from chromadb.execution.expression.operator import Key from typing import List, Dict, Any import pytest @@ -2058,3 +2059,281 @@ def test_sparse_vector_cannot_be_deleted() -> None: # Try to delete it - should fail with pytest.raises(ValueError, match="Deleting sparse vector index is not currently supported"): schema.delete_index(config=sparse_config, key="my_key") + + +def test_create_index_accepts_key_type() -> None: + """Test that create_index accepts both str and Key types for the key parameter.""" + schema = Schema() + + # Test with string key + string_config = StringInvertedIndexConfig() + schema.create_index(config=string_config, key="test_field_str") + + # Verify the index was created with string key + assert "test_field_str" in schema.keys + assert schema.keys["test_field_str"].string is not None + assert schema.keys["test_field_str"].string.string_inverted_index is not None + assert schema.keys["test_field_str"].string.string_inverted_index.enabled is True + + # Test with Key type + int_config = IntInvertedIndexConfig() + schema.create_index(config=int_config, key=Key("test_field_key")) + + # Verify the index was created with Key type (should be stored as string internally) + assert "test_field_key" in schema.keys + assert schema.keys["test_field_key"].int_value is not None + assert schema.keys["test_field_key"].int_value.int_inverted_index is not None + assert schema.keys["test_field_key"].int_value.int_inverted_index.enabled is True + + # Test that both approaches produce equivalent results + schema2 = Schema() + schema2.create_index(config=string_config, key="same_field") + + schema3 = Schema() + schema3.create_index(config=string_config, key=Key("same_field")) + + # Both should have the same configuration + assert schema2.keys["same_field"].string is not None + assert schema2.keys["same_field"].string.string_inverted_index is not None + assert schema3.keys["same_field"].string is not None + assert schema3.keys["same_field"].string.string_inverted_index is not None + assert schema2.keys["same_field"].string.string_inverted_index.enabled == \ + schema3.keys["same_field"].string.string_inverted_index.enabled + + +def test_delete_index_accepts_key_type() -> None: + """Test that delete_index accepts both str and Key types for the key parameter.""" + schema = Schema() + + # First, create some indexes to delete + string_config = StringInvertedIndexConfig() + int_config = IntInvertedIndexConfig() + + # Test delete with string key + schema.delete_index(config=string_config, key="test_field_str") + + # Verify the index was disabled with string key + assert "test_field_str" in schema.keys + assert schema.keys["test_field_str"].string is not None + assert schema.keys["test_field_str"].string.string_inverted_index is not None + assert schema.keys["test_field_str"].string.string_inverted_index.enabled is False + + # Test delete with Key type + schema.delete_index(config=int_config, key=Key("test_field_key")) + + # Verify the index was disabled with Key type (should be stored as string internally) + assert "test_field_key" in schema.keys + assert schema.keys["test_field_key"].int_value is not None + assert schema.keys["test_field_key"].int_value.int_inverted_index is not None + assert schema.keys["test_field_key"].int_value.int_inverted_index.enabled is False + + # Test that both approaches produce equivalent results + schema2 = Schema() + schema2.delete_index(config=string_config, key="same_field") + + schema3 = Schema() + schema3.delete_index(config=string_config, key=Key("same_field")) + + # Both should have the same configuration + assert schema2.keys["same_field"].string is not None + assert schema2.keys["same_field"].string.string_inverted_index is not None + assert schema3.keys["same_field"].string is not None + assert schema3.keys["same_field"].string.string_inverted_index is not None + assert schema2.keys["same_field"].string.string_inverted_index.enabled == \ + schema3.keys["same_field"].string.string_inverted_index.enabled + + +def test_create_index_rejects_special_keys() -> None: + """Test that create_index rejects special keys like Key.DOCUMENT and Key.EMBEDDING.""" + schema = Schema() + string_config = StringInvertedIndexConfig() + + # Test that Key.DOCUMENT is rejected (first check catches it) + with pytest.raises(ValueError, match="Cannot create index on special key '#document'"): + schema.create_index(config=string_config, key=Key.DOCUMENT) + + # Test that Key.EMBEDDING is rejected (first check catches it) + with pytest.raises(ValueError, match="Cannot create index on special key '#embedding'"): + schema.create_index(config=string_config, key=Key.EMBEDDING) + + # Test that string "#document" is also rejected (for consistency) + with pytest.raises(ValueError, match="Cannot create index on special key '#document'"): + schema.create_index(config=string_config, key="#document") + + # Test that any other key starting with # is rejected (second check) + with pytest.raises(ValueError, match="key cannot begin with '#'"): + schema.create_index(config=string_config, key="#custom_key") + + # Test with Key object for custom special key + with pytest.raises(ValueError, match="key cannot begin with '#'"): + schema.create_index(config=string_config, key=Key("#custom")) + + +def test_delete_index_rejects_special_keys() -> None: + """Test that delete_index rejects special keys like Key.DOCUMENT and Key.EMBEDDING.""" + schema = Schema() + string_config = StringInvertedIndexConfig() + + # Test that Key.DOCUMENT is rejected (first check catches it) + with pytest.raises(ValueError, match="Cannot delete index on special key '#document'"): + schema.delete_index(config=string_config, key=Key.DOCUMENT) + + # Test that Key.EMBEDDING is rejected (first check catches it) + with pytest.raises(ValueError, match="Cannot delete index on special key '#embedding'"): + schema.delete_index(config=string_config, key=Key.EMBEDDING) + + # Test that string "#embedding" is also rejected (for consistency) + with pytest.raises(ValueError, match="Cannot delete index on special key '#embedding'"): + schema.delete_index(config=string_config, key="#embedding") + + # Test that any other key starting with # is rejected (second check) + with pytest.raises(ValueError, match="key cannot begin with '#'"): + schema.delete_index(config=string_config, key="#custom_key") + + # Test with Key object for custom special key + with pytest.raises(ValueError, match="key cannot begin with '#'"): + schema.delete_index(config=string_config, key=Key("#custom")) + + +def test_vector_index_config_source_key_accepts_key_type() -> None: + """Test that VectorIndexConfig.source_key accepts both str and Key types.""" + # Test with string + config1 = VectorIndexConfig(source_key="my_field") + assert config1.source_key == "my_field" + assert isinstance(config1.source_key, str) + + # Test with Key object + config2 = VectorIndexConfig(source_key=Key("my_field")) # type: ignore[arg-type] + assert config2.source_key == "my_field" + assert isinstance(config2.source_key, str) + + # Test with Key.DOCUMENT + config3 = VectorIndexConfig(source_key=Key.DOCUMENT) # type: ignore[arg-type] + assert config3.source_key == "#document" + assert isinstance(config3.source_key, str) + + # Test that both approaches produce the same result + config4 = VectorIndexConfig(source_key="test") + config5 = VectorIndexConfig(source_key=Key("test")) # type: ignore[arg-type] + assert config4.source_key == config5.source_key + + # Test with None + config6 = VectorIndexConfig(source_key=None) + assert config6.source_key is None + + # Test serialization works correctly + config7 = VectorIndexConfig(source_key=Key("serialize_test")) # type: ignore[arg-type] + config_dict = config7.model_dump() + assert config_dict["source_key"] == "serialize_test" + assert isinstance(config_dict["source_key"], str) + + +def test_sparse_vector_index_config_source_key_accepts_key_type() -> None: + """Test that SparseVectorIndexConfig.source_key accepts both str and Key types.""" + # Test with string + config1 = SparseVectorIndexConfig(source_key="my_field") + assert config1.source_key == "my_field" + assert isinstance(config1.source_key, str) + + # Test with Key object + config2 = SparseVectorIndexConfig(source_key=Key("my_field")) # type: ignore[arg-type] + assert config2.source_key == "my_field" + assert isinstance(config2.source_key, str) + + # Test with Key.DOCUMENT + config3 = SparseVectorIndexConfig(source_key=Key.DOCUMENT) # type: ignore[arg-type] + assert config3.source_key == "#document" + assert isinstance(config3.source_key, str) + + # Test that both approaches produce the same result + config4 = SparseVectorIndexConfig(source_key="test") + config5 = SparseVectorIndexConfig(source_key=Key("test")) # type: ignore[arg-type] + assert config4.source_key == config5.source_key + + # Test with None + config6 = SparseVectorIndexConfig(source_key=None) + assert config6.source_key is None + + # Test serialization works correctly + config7 = SparseVectorIndexConfig(source_key=Key("serialize_test")) # type: ignore[arg-type] + config_dict = config7.model_dump() + assert config_dict["source_key"] == "serialize_test" + assert isinstance(config_dict["source_key"], str) + + +def test_config_source_key_rejects_invalid_types() -> None: + """Test that config validators reject invalid types for source_key.""" + # Test VectorIndexConfig rejects invalid types + with pytest.raises(ValueError, match="source_key must be str or Key"): + VectorIndexConfig(source_key=123) # type: ignore[arg-type] + + with pytest.raises(ValueError, match="source_key must be str or Key"): + VectorIndexConfig(source_key=["not", "valid"]) # type: ignore[arg-type] + + # Test SparseVectorIndexConfig rejects invalid types + with pytest.raises(ValueError, match="source_key must be str or Key"): + SparseVectorIndexConfig(source_key=123) # type: ignore[arg-type] + + with pytest.raises(ValueError, match="source_key must be str or Key"): + SparseVectorIndexConfig(source_key={"not": "valid"}) # type: ignore[arg-type] + + +def test_config_source_key_validates_special_keys() -> None: + """Test that source_key only allows #document, rejects other special keys.""" + # Test VectorIndexConfig + # #document is allowed (string) + config1 = VectorIndexConfig(source_key="#document") + assert config1.source_key == "#document" + + # #document is allowed (Key) + config2 = VectorIndexConfig(source_key=Key.DOCUMENT) # type: ignore[arg-type] + assert config2.source_key == "#document" + + # #embedding is rejected (string) + with pytest.raises(ValueError, match="source_key cannot begin with '#'"): + VectorIndexConfig(source_key="#embedding") + + # #embedding is rejected (Key) + with pytest.raises(ValueError, match="source_key cannot begin with '#'"): + VectorIndexConfig(source_key=Key.EMBEDDING) # type: ignore[arg-type] + + # #metadata is rejected + with pytest.raises(ValueError, match="source_key cannot begin with '#'"): + VectorIndexConfig(source_key="#metadata") + + # #score is rejected + with pytest.raises(ValueError, match="source_key cannot begin with '#'"): + VectorIndexConfig(source_key="#score") + + # Any other key starting with # is rejected + with pytest.raises(ValueError, match="source_key cannot begin with '#'"): + VectorIndexConfig(source_key="#custom") + + # Regular keys (no #) are allowed + config3 = VectorIndexConfig(source_key="my_field") + assert config3.source_key == "my_field" + + # Test SparseVectorIndexConfig + # #document is allowed (string) + config4 = SparseVectorIndexConfig(source_key="#document") + assert config4.source_key == "#document" + + # #document is allowed (Key) + config5 = SparseVectorIndexConfig(source_key=Key.DOCUMENT) # type: ignore[arg-type] + assert config5.source_key == "#document" + + # #embedding is rejected (string) + with pytest.raises(ValueError, match="source_key cannot begin with '#'"): + SparseVectorIndexConfig(source_key="#embedding") + + # #embedding is rejected (Key) + with pytest.raises(ValueError, match="source_key cannot begin with '#'"): + SparseVectorIndexConfig(source_key=Key.EMBEDDING) # type: ignore[arg-type] + + # #metadata is rejected + with pytest.raises(ValueError, match="source_key cannot begin with '#'"): + SparseVectorIndexConfig(source_key="#metadata") + + # Regular keys (no #) are allowed + config6 = SparseVectorIndexConfig(source_key="my_field") + assert config6.source_key == "my_field" diff --git a/chromadb/test/api/test_schema_e2e.py b/chromadb/test/api/test_schema_e2e.py index 5fc02e17027..95903626605 100644 --- a/chromadb/test/api/test_schema_e2e.py +++ b/chromadb/test/api/test_schema_e2e.py @@ -13,6 +13,7 @@ EmbeddingFunction, Embeddings, ) +from chromadb.execution.expression.operator import Key from chromadb.test.conftest import ( ClientFactories, is_spann_disabled_mode, @@ -1662,3 +1663,312 @@ def test_conflicting_embedding_functions_in_schema_and_config_fails( # Verify the error message indicates the conflict error_msg = str(exc_info.value) assert "schema" in error_msg.lower() or "conflict" in error_msg.lower() + + +@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled) +def test_create_index_with_key_type( + client_factories: "ClientFactories", +) -> None: + """Test that create_index accepts both str and Key types.""" + # Create schema using both str and Key types + schema = Schema() + schema.create_index(config=StringInvertedIndexConfig(), key="field_str") + schema.create_index(config=IntInvertedIndexConfig(), key=Key("field_key")) + + collection, _ = _create_isolated_collection(client_factories, schema=schema) + + # Verify both fields are in the schema + assert collection.schema is not None + assert "field_str" in collection.schema.keys + assert "field_key" in collection.schema.keys + + # Verify both indexes work + collection.add( + ids=["key-test-1"], + documents=["test doc"], + metadatas=[{"field_str": "value1", "field_key": 42}], + ) + + # Test string field filter + str_result = collection.get(where={"field_str": "value1"}) + assert set(str_result["ids"]) == {"key-test-1"} + + # Test int field filter + int_result = collection.get(where={"field_key": 42}) + assert set(int_result["ids"]) == {"key-test-1"} + + +@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled) +def test_delete_index_with_key_type( + client_factories: "ClientFactories", +) -> None: + """Test that delete_index accepts both str and Key types.""" + # Disable indexes using both str and Key types + schema = Schema() + schema.delete_index(config=StringInvertedIndexConfig(), key="disabled_str") + schema.delete_index(config=IntInvertedIndexConfig(), key=Key("disabled_key")) + + collection, _ = _create_isolated_collection(client_factories, schema=schema) + + # Verify both fields have disabled indexes + assert collection.schema is not None + assert "disabled_str" in collection.schema.keys + assert "disabled_key" in collection.schema.keys + + # Add data + collection.add( + ids=["disable-test-1"], + documents=["test doc"], + metadatas=[{"disabled_str": "value", "disabled_key": 100}], + ) + + # Verify filtering on disabled fields raises errors + with pytest.raises(Exception): + collection.get(where={"disabled_str": "value"}) + + with pytest.raises(Exception): + collection.get(where={"disabled_key": 100}) + + +@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled) +def test_vector_index_config_with_key_document_source( + client_factories: "ClientFactories", +) -> None: + """Test that VectorIndexConfig source_key accepts Key.DOCUMENT.""" + schema = Schema() + schema.create_index( + config=VectorIndexConfig( + source_key=Key.DOCUMENT, # type: ignore[arg-type] + embedding_function=SimpleEmbeddingFunction(dim=5), + ) + ) + + collection, _ = _create_isolated_collection( + client_factories, + schema=schema, + embedding_function=SimpleEmbeddingFunction(dim=5), + ) + + # Verify source_key was properly converted to "#document" + assert collection.schema is not None + embedding_config = collection.schema.keys["#embedding"].float_list + assert embedding_config is not None + assert embedding_config.vector_index is not None + assert embedding_config.vector_index.config.source_key == "#document" + + # Add test data + collection.add( + ids=["vec-1", "vec-2", "vec-3"], + documents=["apple fruit", "banana fruit", "car vehicle"], + ) + + # Verify embeddings were generated + result = collection.get(ids=["vec-1"], include=["embeddings"]) + assert result["embeddings"] is not None + assert len(result["embeddings"][0]) == 5 # dim=5 + + # Perform vector search + search_result = collection.search( + Search().rank(Knn(query=[0.0, 1.0, 2.0, 3.0, 4.0], limit=2)) + ) + assert len(search_result["ids"]) > 0 + assert len(search_result["ids"][0]) <= 2 # limit=2 + + +@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled) +def test_sparse_vector_index_config_with_key_types( + client_factories: "ClientFactories", +) -> None: + """Test that SparseVectorIndexConfig source_key accepts both str and Key types.""" + sparse_ef = DeterministicSparseEmbeddingFunction(label="key_types") + + # Test with Key.DOCUMENT + schema1 = Schema().create_index( + key="sparse1", + config=SparseVectorIndexConfig( + source_key=Key.DOCUMENT, # type: ignore[arg-type] + embedding_function=sparse_ef, + ), + ) + + collection1, _ = _create_isolated_collection(client_factories, schema=schema1) + assert collection1.schema is not None + sparse1_config = collection1.schema.keys["sparse1"].sparse_vector + assert sparse1_config is not None + assert sparse1_config.sparse_vector_index is not None + assert sparse1_config.sparse_vector_index.config.source_key == "#document" + + # Add data and verify sparse embeddings were generated from documents + collection1.add( + ids=["s1", "s2", "s3"], + documents=["apple", "banana", "orange"], + ) + + # Verify sparse embeddings in metadata + result1 = collection1.get(ids=["s1"], include=["metadatas"]) + assert result1["metadatas"] is not None + assert "sparse1" in result1["metadatas"][0] + sparse_vec = cast(SparseVector, result1["metadatas"][0]["sparse1"]) + + # Perform sparse vector search + search_result1 = collection1.search( + Search().rank(Knn(key="sparse1", query=cast(Any, sparse_vec), limit=2)) + ) + assert len(search_result1["ids"]) > 0 + assert "s1" in search_result1["ids"][0] # Should find itself + + # Test with Key("field_name") + schema2 = Schema().create_index( + key="sparse2", + config=SparseVectorIndexConfig( + source_key=Key("text_field"), # type: ignore[arg-type] + embedding_function=sparse_ef, + ), + ) + + collection2, _ = _create_isolated_collection(client_factories, schema=schema2) + assert collection2.schema is not None + sparse2_config = collection2.schema.keys["sparse2"].sparse_vector + assert sparse2_config is not None + assert sparse2_config.sparse_vector_index is not None + assert sparse2_config.sparse_vector_index.config.source_key == "text_field" + + # Add data with metadata source field + collection2.add( + ids=["sparse-key-1", "sparse-key-2"], + documents=["doc1", "doc2"], + metadatas=[{"text_field": "content one"}, {"text_field": "content two"}], + ) + + # Verify sparse embeddings were generated from text_field + result2 = collection2.get(ids=["sparse-key-1", "sparse-key-2"], include=["metadatas"]) + assert result2["metadatas"] is not None + assert "sparse2" in result2["metadatas"][0] + assert "sparse2" in result2["metadatas"][1] + + # Get the sparse vector for search + sparse_query = cast(SparseVector, result2["metadatas"][0]["sparse2"]) + + # Perform sparse vector search + search_result2 = collection2.search( + Search().rank(Knn(key="sparse2", query=cast(Any, sparse_query), limit=1)) + ) + assert len(search_result2["ids"]) > 0 + assert "sparse-key-1" in search_result2["ids"][0] # Should find itself + + +def test_schema_rejects_special_key_in_create_index() -> None: + """Test that create_index rejects keys starting with # (except system keys).""" + # Test with string starting with # + schema = Schema() + with pytest.raises(ValueError, match="key cannot begin with '#'"): + schema.create_index(config=StringInvertedIndexConfig(), key="#custom_field") + + # Test with Key object starting with # + with pytest.raises(ValueError, match="Cannot create index on special key '#embedding'"): + schema.create_index(config=StringInvertedIndexConfig(), key=Key.EMBEDDING) + + +def test_schema_rejects_special_key_in_delete_index() -> None: + """Test that delete_index rejects keys starting with # (except system keys).""" + # Test with string starting with # + schema = Schema() + with pytest.raises(ValueError, match="key cannot begin with '#'"): + schema.delete_index(config=StringInvertedIndexConfig(), key="#custom_field") + + # Test with Key object starting with # + with pytest.raises(ValueError, match="Cannot delete index on special key '#document'"): + schema.delete_index(config=StringInvertedIndexConfig(), key=Key.DOCUMENT) + + +def test_schema_rejects_invalid_source_key_in_configs() -> None: + """Test that config validators reject invalid source_key values.""" + # Test VectorIndexConfig rejects non-#document special keys + with pytest.raises(ValueError, match="source_key cannot begin with '#'"): + VectorIndexConfig(source_key="#embedding") + + with pytest.raises(ValueError, match="source_key cannot begin with '#'"): + VectorIndexConfig(source_key=Key.EMBEDDING) # type: ignore[arg-type] + + # Test SparseVectorIndexConfig rejects non-#document special keys + with pytest.raises(ValueError, match="source_key cannot begin with '#'"): + SparseVectorIndexConfig(source_key="#embedding") + + with pytest.raises(ValueError, match="source_key cannot begin with '#'"): + SparseVectorIndexConfig(source_key=Key.EMBEDDING) # type: ignore[arg-type] + + with pytest.raises(ValueError, match="source_key cannot begin with '#'"): + SparseVectorIndexConfig(source_key="#metadata") + + +@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled) +def test_server_validates_schema_with_special_keys( + client_factories: "ClientFactories", +) -> None: + """Test that server-side validation rejects schemas with invalid special keys.""" + client = client_factories.create_client_from_system() + client.reset() + + collection_name = f"server_validate_{uuid4().hex}" + + # Try to create collection with invalid key in schema + # This should be caught server-side by validate_schema() + schema = Schema() + # Bypass client-side validation by directly manipulating schema.keys + from chromadb.api.types import ValueTypes, StringValueType, StringInvertedIndexType, StringInvertedIndexConfig + schema.keys["#invalid_key"] = ValueTypes( + string=StringValueType( + string_inverted_index=StringInvertedIndexType( + enabled=True, + config=StringInvertedIndexConfig(), + ) + ) + ) + + # Server should reject this + with pytest.raises(Exception) as exc_info: + client.create_collection(name=collection_name, schema=schema) + + # Verify server caught the invalid key + error_msg = str(exc_info.value) + assert "#" in error_msg or "key" in error_msg.lower() or "invalid" in error_msg.lower() + + +@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled) +def test_server_validates_invalid_source_key_in_sparse_vector_config( + client_factories: "ClientFactories", +) -> None: + """Test that server-side validation rejects invalid source_key in SparseVectorIndexConfig.""" + client = client_factories.create_client_from_system() + client.reset() + + collection_name = f"server_source_key_{uuid4().hex}" + + # Create schema with invalid source_key + # Bypass client-side validation by directly creating the config + from chromadb.api.types import ValueTypes, SparseVectorValueType, SparseVectorIndexType + + schema = Schema() + # Manually construct config with invalid source_key using model_construct to bypass validation + invalid_config = SparseVectorIndexConfig.model_construct( + embedding_function=None, + source_key="#embedding", # Invalid - should be rejected + bm25=None + ) + + schema.keys["test_sparse"] = ValueTypes( + sparse_vector=SparseVectorValueType( + sparse_vector_index=SparseVectorIndexType( + enabled=True, + config=invalid_config, + ) + ) + ) + + # Server should reject this + with pytest.raises(Exception) as exc_info: + client.create_collection(name=collection_name, schema=schema) + + # Verify server caught the invalid source_key + error_msg = str(exc_info.value) + assert "source_key" in error_msg.lower() or "#" in error_msg or "document" in error_msg.lower() diff --git a/chromadb/types.py b/chromadb/types.py index 1205213485c..be48337e314 100644 --- a/chromadb/types.py +++ b/chromadb/types.py @@ -174,6 +174,13 @@ def get_model_fields(self) -> Dict[Any, Any]: except AttributeError: return self.__fields__ # pydantic 1.x + def pretty_schema(self) -> str: + """Returns a pretty-printed version of the serialized schema.""" + if self.serialized_schema is None: + return "No schema" + import json + return json.dumps(self.serialized_schema, indent=2) + @classmethod @override def from_json(cls, json_map: Dict[str, Any]) -> Self: diff --git a/rust/types/src/validators.rs b/rust/types/src/validators.rs index b476271de17..be1a916654d 100644 --- a/rust/types/src/validators.rs +++ b/rust/types/src/validators.rs @@ -216,6 +216,14 @@ pub fn validate_schema(schema: &Schema) -> Result<(), ValidationError> { return Err(ValidationError::new("schema").with_message("Full text search / regular expression index cannot be enabled by default. It can only be enabled on #document field.".into())); } for (key, config) in &schema.keys { + // Validate that keys cannot start with # (except system keys) + if key.starts_with('#') && key != DOCUMENT_KEY && key != EMBEDDING_KEY { + return Err(ValidationError::new("schema").with_message( + format!("key cannot begin with '#'. Keys starting with '#' are reserved for system use: {key}") + .into(), + )); + } + if key == DOCUMENT_KEY && (config.boolean.is_some() || config.float.is_some() @@ -260,17 +268,27 @@ pub fn validate_schema(schema: &Schema) -> Result<(), ValidationError> { .with_message("Vector index can only source from #document".into())); } } - if config + if let Some(svit) = config .sparse_vector .as_ref() - .is_some_and(|vt| vt.sparse_vector_index.as_ref().is_some_and(|it| it.enabled)) + .and_then(|vt| vt.sparse_vector_index.as_ref()) { - sparse_index_keys.push(key); - if sparse_index_keys.len() > 1 { - return Err(ValidationError::new("schema").with_message( - format!("At most one sparse vector index is allowed for the collection: {sparse_index_keys:?}") - .into(), - )); + if svit.enabled { + sparse_index_keys.push(key); + if sparse_index_keys.len() > 1 { + return Err(ValidationError::new("schema").with_message( + format!("At most one sparse vector index is allowed for the collection: {sparse_index_keys:?}") + .into(), + )); + } + } + // Validate source_key for sparse vector index + if let Some(source_key) = &svit.config.source_key { + if source_key.starts_with('#') && source_key != DOCUMENT_KEY { + return Err(ValidationError::new("schema").with_message( + "source_key cannot begin with '#'. The only valid key starting with '#' is Key.DOCUMENT or '#document'.".into(), + )); + } } } if config