diff --git a/.vscode/settings.json b/.vscode/settings.json index 3a284f0..991d0e2 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,6 +1,7 @@ { "python.testing.pytestArgs": [ - "tests" + "tests", + "-v" ], "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true, diff --git a/pytidb/table.py b/pytidb/table.py index 24c4d74..d717176 100644 --- a/pytidb/table.py +++ b/pytidb/table.py @@ -272,6 +272,45 @@ def insert(self, data: Union[T, dict]) -> T: db_session.refresh(data) return data + def save(self, data: Union[T, dict]) -> T: + if not isinstance(data, self._table_model) and not isinstance(data, dict): + raise ValueError( + f"Invalid data type: {type(data)}, expected {self._table_model}, dict" + ) + + # Convert dict to table model instance. + if isinstance(data, dict): + data = self._table_model(**data) + + # Auto embedding. + for field_name, config in self._auto_embedding_configs.items(): + # Skip if vector embeddings is provided. + if getattr(data, field_name) is not None: + continue + + # Skip if source field is not provided. + if not hasattr(data, config["source_field_name"]): + continue + + # Skip if source field is None or empty. + embedding_source = getattr(data, config["source_field_name"]) + if embedding_source is None or embedding_source == "": + setattr(data, field_name, None) + continue + + source_type = config.get("source_type", "text") + vector_embedding = config["embed_fn"].get_source_embedding( + embedding_source, + source_type=source_type, + ) + setattr(data, field_name, vector_embedding) + + with self._client.session() as db_session: + merged_data = db_session.merge(data) + db_session.flush() + db_session.refresh(merged_data) + return merged_data + def bulk_insert(self, data: List[Union[T, dict]]) -> List[T]: if not isinstance(data, list): raise ValueError( diff --git a/tests/conftest.py b/tests/conftest.py index fb00ba9..c8b90a3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ from dotenv import load_dotenv from pytidb import TiDBClient +from pytidb.embeddings import EmbeddingFunction logger = logging.getLogger(__name__) @@ -71,3 +72,8 @@ def isolated_client(env) -> Generator[TiDBClient, None, None]: client.disconnect() except Exception as e: logger.error(f"Failed to drop test database {db_name}: {e}") + + +@pytest.fixture(scope="session", autouse=True) +def text_embed(): + return EmbeddingFunction("openai/text-embedding-3-small", timeout=20) diff --git a/tests/test_auto_embedding.py b/tests/test_auto_embedding.py index f85ef1e..dfdc26f 100644 --- a/tests/test_auto_embedding.py +++ b/tests/test_auto_embedding.py @@ -4,98 +4,89 @@ from pytidb.schema import TableModel, Field -def test_auto_embedding(shared_client: TiDBClient): - text_embed_small = EmbeddingFunction("openai/text-embedding-3-small", timeout=20) - test_table_name = "test_auto_embedding" - - class Chunk(TableModel): - __tablename__ = test_table_name +def test_auto_embedding(shared_client: TiDBClient, text_embed: EmbeddingFunction): + class ChunkForAutoEmbedding(TableModel): + __tablename__ = "test_auto_embedding" id: int = Field(primary_key=True) text: str = Field() - text_vec: Optional[list[float]] = text_embed_small.VectorField( + text_vec: Optional[list[float]] = text_embed.VectorField( source_field="text", index=False, ) user_id: int = Field() + Chunk = ChunkForAutoEmbedding tbl = shared_client.create_table(schema=Chunk, if_exists="overwrite") - tbl.insert(Chunk(id=1, text="foo", user_id=1)) - tbl.bulk_insert( - [ - Chunk(id=2, text="bar", user_id=2), - Chunk(id=3, text="baz", user_id=3), - Chunk( - id=4, - text="", # Empty string will skip auto embedding. - user_id=4, - ), - ] - ) - chunks = tbl.query(filters=Chunk.user_id == 3).to_pydantic() - assert len(chunks) == 1 - assert chunks[0].text == "baz" - assert len(chunks[0].text_vec) == 1536 + # Test insert with auto embedding + chunk = tbl.insert(Chunk(id=1, text="foo", user_id=1)) + assert len(chunk.text_vec) == 1536 + + # Test dict insert with auto embedding + chunk = tbl.insert({"id": 2, "text": "bar", "user_id": 1}) + assert len(chunk.text_vec) == 1536 + + # Test bulk_insert with auto embedding (including empty text case) + chunks_via_model_instance = [ + Chunk(id=3, text="baz", user_id=2), + Chunk(id=4, text="", user_id=2), # Empty string will skip auto embedding. + ] + chunks_via_dict = [ + { + "id": 5, + "text": "qux", + "user_id": 3, + }, # Empty string will skip auto embedding. + {"id": 6, "text": "", "user_id": 3}, + ] + chunks = tbl.bulk_insert(chunks_via_model_instance + chunks_via_dict) + for chunk in chunks: + if chunk.text == "": + assert chunk.text_vec is None + else: + assert len(chunk.text_vec) == 1536 + # Test vector search with auto embedding results = tbl.search("bar").limit(1).to_pydantic(with_score=True) assert len(results) == 1 assert results[0].id == 2 assert results[0].text == "bar" assert results[0].similarity_score >= 0.9 - # Test dict insert - dict_chunk = tbl.insert({"id": 5, "text": "dict_test", "user_id": 5}) - assert dict_chunk.id == 5 - assert dict_chunk.text == "dict_test" - assert dict_chunk.user_id == 5 - assert len(dict_chunk.text_vec) == 1536 - - # Test dict bulk_insert - dict_chunks = tbl.bulk_insert( - [ - {"id": 6, "text": "dict_bulk_1", "user_id": 6}, - {"id": 7, "text": "dict_bulk_2", "user_id": 7}, - { - "id": 8, - "text": "", - "user_id": 8, - }, # Empty string will skip auto embedding - ] - ) - assert len(dict_chunks) == 3 - assert dict_chunks[0].id == 6 - assert dict_chunks[0].text == "dict_bulk_1" - assert len(dict_chunks[0].text_vec) == 1536 - assert dict_chunks[1].id == 7 - assert dict_chunks[1].text == "dict_bulk_2" - assert len(dict_chunks[1].text_vec) == 1536 - assert dict_chunks[2].id == 8 - assert dict_chunks[2].text == "" - assert dict_chunks[2].text_vec is None - - # Test mixed bulk_insert (dict and model instances) - mixed_chunks = tbl.bulk_insert( - [ - Chunk(id=9, text="model_instance", user_id=9), - {"id": 10, "text": "dict_mixed", "user_id": 10}, - ] - ) - assert len(mixed_chunks) == 2 - assert mixed_chunks[0].id == 9 - assert mixed_chunks[0].text == "model_instance" - assert len(mixed_chunks[0].text_vec) == 1536 - assert mixed_chunks[1].id == 10 - assert mixed_chunks[1].text == "dict_mixed" - assert len(mixed_chunks[1].text_vec) == 1536 - - # Update, + # Test update with auto embedding, from empty to non-empty string chunk = tbl.get(4) assert chunk.text == "" assert chunk.text_vec is None - tbl.update( - values={"text": "qux"}, - filters={"id": 4}, + tbl.update(values={"text": "another baz"}, filters={"id": 4}) + updated_chunk = tbl.get(4) + assert updated_chunk.text == "another baz" + assert len(updated_chunk.text_vec) == 1536 + + # Test update with auto embedding, from non-empty to empty string + tbl.update(values={"text": ""}, filters={"id": 4}) + updated_chunk = tbl.get(4) + assert updated_chunk.text == "" + assert updated_chunk.text_vec is None + + # Test save with auto embedding + saved_chunk = tbl.save(Chunk(id=7, text="save_test", user_id=4)) + assert saved_chunk.text == "save_test" + assert len(saved_chunk.text_vec) == 1536 + + # Test save with empty string - should skip auto embedding + save_empty = tbl.save(Chunk(id=8, text="", user_id=4)) + assert save_empty.text == "" + assert save_empty.text_vec is None + + # Test save with pre-existing vector field - should skip auto embedding + existing_vector = [0.1] * 1536 + save_with_vector = tbl.save( + Chunk(id=9, text="save_with_vector", text_vec=existing_vector, user_id=4) ) - chunk = tbl.get(4) - assert chunk.text == "qux" - assert chunk.text_vec is not None + assert len(save_with_vector.text_vec) == 1536 + + # Test save update from empty to text - should trigger auto embedding + saved_chunk = tbl.get(6) + saved_chunk.text = "another qux" + tbl.save(saved_chunk) + assert len(saved_chunk.text_vec) == 1536 diff --git a/tests/test_crud.py b/tests/test_data.py similarity index 68% rename from tests/test_crud.py rename to tests/test_data.py index e44cf14..1a9518d 100644 --- a/tests/test_crud.py +++ b/tests/test_data.py @@ -1,7 +1,8 @@ import logging -from typing import Any +from typing import Any, Optional import numpy as np +from pytidb import TiDBClient from pytidb.schema import TableModel, Field, VectorField @@ -10,7 +11,7 @@ def test_table_crud(shared_client): class Chunk(TableModel, table=True): - __tablename__ = "test_crud_table" + __tablename__ = "test_table_crud" id: int = Field(primary_key=True) text: str = Field(max_length=20) text_vec: Any = VectorField(dimensions=3, index=False) @@ -60,7 +61,7 @@ class Chunk(TableModel, table=True): def test_table_query(shared_client): class Chunk(TableModel): - __tablename__ = "test_query_table" + __tablename__ = "test_table_query" id: int = Field(primary_key=True) text: str = Field(max_length=20) text_vec: Any = VectorField(dimensions=3, index=False) @@ -117,3 +118,40 @@ class Chunk(TableModel): assert chunks[0]["id"] == 1 assert chunks[0]["text"] == "foo" assert np.array_equal(chunks[0]["text_vec"], [1, 2, 3]) + + +def test_table_save(shared_client: TiDBClient): + class RecordForSaveData(TableModel): + __tablename__ = "test_table_save" + id: int = Field(primary_key=True) + text: str = Field() + text_vec: Optional[list[float]] = VectorField(dimensions=3) + user_id: int = Field() + + Record = RecordForSaveData + + tbl = shared_client.create_table(schema=Record, if_exists="overwrite") + + # Test save - insert new record + new_record = Record(id=1, text="hello world", user_id=1) + saved_record = tbl.save(new_record) + assert saved_record.id == 1 + assert saved_record.text == "hello world" + + # Test save - update existing record + updated_record = Record(id=1, text="hello updated", user_id=1) + saved_record = tbl.save(updated_record) + assert saved_record.id == 1 + assert saved_record.text == "hello updated" + + # Test save with dict + dict_record = {"id": 2, "text": "dict insert", "user_id": 2} + saved_dict = tbl.save(dict_record) + assert saved_dict.id == 2 + assert saved_dict.text == "dict insert" + + # Test save update with dict + dict_update = {"id": 2, "text": "dict updated", "user_id": 2} + saved_dict = tbl.save(dict_update) + assert saved_dict.id == 2 + assert saved_dict.text == "dict updated"