Skip to content

feat: support table.save() API #147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jul 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"python.testing.pytestArgs": [
"tests"
"tests",
"-v"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
Expand Down
39 changes: 39 additions & 0 deletions pytidb/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,45 @@ def insert(self, data: Union[T, dict]) -> T:
db_session.refresh(data)
return data

def save(self, data: Union[T, dict]) -> T:
if not isinstance(data, self._table_model) and not isinstance(data, dict):
raise ValueError(
f"Invalid data type: {type(data)}, expected {self._table_model}, dict"
)

# Convert dict to table model instance.
if isinstance(data, dict):
data = self._table_model(**data)

# Auto embedding.
for field_name, config in self._auto_embedding_configs.items():
# Skip if vector embeddings is provided.
if getattr(data, field_name) is not None:
continue

# Skip if source field is not provided.
if not hasattr(data, config["source_field_name"]):
continue

# Skip if source field is None or empty.
embedding_source = getattr(data, config["source_field_name"])
if embedding_source is None or embedding_source == "":
setattr(data, field_name, None)
continue

source_type = config.get("source_type", "text")
vector_embedding = config["embed_fn"].get_source_embedding(
embedding_source,
source_type=source_type,
)
setattr(data, field_name, vector_embedding)

with self._client.session() as db_session:
merged_data = db_session.merge(data)
db_session.flush()
db_session.refresh(merged_data)
return merged_data

def bulk_insert(self, data: List[Union[T, dict]]) -> List[T]:
if not isinstance(data, list):
raise ValueError(
Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dotenv import load_dotenv

from pytidb import TiDBClient
from pytidb.embeddings import EmbeddingFunction

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -71,3 +72,8 @@ def isolated_client(env) -> Generator[TiDBClient, None, None]:
client.disconnect()
except Exception as e:
logger.error(f"Failed to drop test database {db_name}: {e}")


@pytest.fixture(scope="session", autouse=True)
def text_embed():
return EmbeddingFunction("openai/text-embedding-3-small", timeout=20)
141 changes: 66 additions & 75 deletions tests/test_auto_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,98 +4,89 @@
from pytidb.schema import TableModel, Field


def test_auto_embedding(shared_client: TiDBClient):
text_embed_small = EmbeddingFunction("openai/text-embedding-3-small", timeout=20)
test_table_name = "test_auto_embedding"

class Chunk(TableModel):
__tablename__ = test_table_name
def test_auto_embedding(shared_client: TiDBClient, text_embed: EmbeddingFunction):
class ChunkForAutoEmbedding(TableModel):
__tablename__ = "test_auto_embedding"
id: int = Field(primary_key=True)
text: str = Field()
text_vec: Optional[list[float]] = text_embed_small.VectorField(
text_vec: Optional[list[float]] = text_embed.VectorField(
source_field="text",
index=False,
)
user_id: int = Field()

Chunk = ChunkForAutoEmbedding
tbl = shared_client.create_table(schema=Chunk, if_exists="overwrite")

tbl.insert(Chunk(id=1, text="foo", user_id=1))
tbl.bulk_insert(
[
Chunk(id=2, text="bar", user_id=2),
Chunk(id=3, text="baz", user_id=3),
Chunk(
id=4,
text="", # Empty string will skip auto embedding.
user_id=4,
),
]
)
chunks = tbl.query(filters=Chunk.user_id == 3).to_pydantic()
assert len(chunks) == 1
assert chunks[0].text == "baz"
assert len(chunks[0].text_vec) == 1536
# Test insert with auto embedding
chunk = tbl.insert(Chunk(id=1, text="foo", user_id=1))
assert len(chunk.text_vec) == 1536

# Test dict insert with auto embedding
chunk = tbl.insert({"id": 2, "text": "bar", "user_id": 1})
assert len(chunk.text_vec) == 1536

# Test bulk_insert with auto embedding (including empty text case)
chunks_via_model_instance = [
Chunk(id=3, text="baz", user_id=2),
Chunk(id=4, text="", user_id=2), # Empty string will skip auto embedding.
]
chunks_via_dict = [
{
"id": 5,
"text": "qux",
"user_id": 3,
}, # Empty string will skip auto embedding.
{"id": 6, "text": "", "user_id": 3},
]
chunks = tbl.bulk_insert(chunks_via_model_instance + chunks_via_dict)
for chunk in chunks:
if chunk.text == "":
assert chunk.text_vec is None
else:
assert len(chunk.text_vec) == 1536

# Test vector search with auto embedding
results = tbl.search("bar").limit(1).to_pydantic(with_score=True)
assert len(results) == 1
assert results[0].id == 2
assert results[0].text == "bar"
assert results[0].similarity_score >= 0.9

# Test dict insert
dict_chunk = tbl.insert({"id": 5, "text": "dict_test", "user_id": 5})
assert dict_chunk.id == 5
assert dict_chunk.text == "dict_test"
assert dict_chunk.user_id == 5
assert len(dict_chunk.text_vec) == 1536

# Test dict bulk_insert
dict_chunks = tbl.bulk_insert(
[
{"id": 6, "text": "dict_bulk_1", "user_id": 6},
{"id": 7, "text": "dict_bulk_2", "user_id": 7},
{
"id": 8,
"text": "",
"user_id": 8,
}, # Empty string will skip auto embedding
]
)
assert len(dict_chunks) == 3
assert dict_chunks[0].id == 6
assert dict_chunks[0].text == "dict_bulk_1"
assert len(dict_chunks[0].text_vec) == 1536
assert dict_chunks[1].id == 7
assert dict_chunks[1].text == "dict_bulk_2"
assert len(dict_chunks[1].text_vec) == 1536
assert dict_chunks[2].id == 8
assert dict_chunks[2].text == ""
assert dict_chunks[2].text_vec is None

# Test mixed bulk_insert (dict and model instances)
mixed_chunks = tbl.bulk_insert(
[
Chunk(id=9, text="model_instance", user_id=9),
{"id": 10, "text": "dict_mixed", "user_id": 10},
]
)
assert len(mixed_chunks) == 2
assert mixed_chunks[0].id == 9
assert mixed_chunks[0].text == "model_instance"
assert len(mixed_chunks[0].text_vec) == 1536
assert mixed_chunks[1].id == 10
assert mixed_chunks[1].text == "dict_mixed"
assert len(mixed_chunks[1].text_vec) == 1536

# Update,
# Test update with auto embedding, from empty to non-empty string
chunk = tbl.get(4)
assert chunk.text == ""
assert chunk.text_vec is None
tbl.update(
values={"text": "qux"},
filters={"id": 4},
tbl.update(values={"text": "another baz"}, filters={"id": 4})
updated_chunk = tbl.get(4)
assert updated_chunk.text == "another baz"
assert len(updated_chunk.text_vec) == 1536

# Test update with auto embedding, from non-empty to empty string
tbl.update(values={"text": ""}, filters={"id": 4})
updated_chunk = tbl.get(4)
assert updated_chunk.text == ""
assert updated_chunk.text_vec is None

# Test save with auto embedding
saved_chunk = tbl.save(Chunk(id=7, text="save_test", user_id=4))
assert saved_chunk.text == "save_test"
assert len(saved_chunk.text_vec) == 1536

# Test save with empty string - should skip auto embedding
save_empty = tbl.save(Chunk(id=8, text="", user_id=4))
assert save_empty.text == ""
assert save_empty.text_vec is None

# Test save with pre-existing vector field - should skip auto embedding
existing_vector = [0.1] * 1536
save_with_vector = tbl.save(
Chunk(id=9, text="save_with_vector", text_vec=existing_vector, user_id=4)
)
chunk = tbl.get(4)
assert chunk.text == "qux"
assert chunk.text_vec is not None
assert len(save_with_vector.text_vec) == 1536

# Test save update from empty to text - should trigger auto embedding
saved_chunk = tbl.get(6)
saved_chunk.text = "another qux"
tbl.save(saved_chunk)
assert len(saved_chunk.text_vec) == 1536
44 changes: 41 additions & 3 deletions tests/test_crud.py → tests/test_data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
from typing import Any
from typing import Any, Optional
import numpy as np

from pytidb import TiDBClient
from pytidb.schema import TableModel, Field, VectorField


Expand All @@ -10,7 +11,7 @@

def test_table_crud(shared_client):
class Chunk(TableModel, table=True):
__tablename__ = "test_crud_table"
__tablename__ = "test_table_crud"
id: int = Field(primary_key=True)
text: str = Field(max_length=20)
text_vec: Any = VectorField(dimensions=3, index=False)
Expand Down Expand Up @@ -60,7 +61,7 @@ class Chunk(TableModel, table=True):

def test_table_query(shared_client):
class Chunk(TableModel):
__tablename__ = "test_query_table"
__tablename__ = "test_table_query"
id: int = Field(primary_key=True)
text: str = Field(max_length=20)
text_vec: Any = VectorField(dimensions=3, index=False)
Expand Down Expand Up @@ -117,3 +118,40 @@ class Chunk(TableModel):
assert chunks[0]["id"] == 1
assert chunks[0]["text"] == "foo"
assert np.array_equal(chunks[0]["text_vec"], [1, 2, 3])


def test_table_save(shared_client: TiDBClient):
class RecordForSaveData(TableModel):
__tablename__ = "test_table_save"
id: int = Field(primary_key=True)
text: str = Field()
text_vec: Optional[list[float]] = VectorField(dimensions=3)
user_id: int = Field()

Record = RecordForSaveData

tbl = shared_client.create_table(schema=Record, if_exists="overwrite")

# Test save - insert new record
new_record = Record(id=1, text="hello world", user_id=1)
saved_record = tbl.save(new_record)
assert saved_record.id == 1
assert saved_record.text == "hello world"

# Test save - update existing record
updated_record = Record(id=1, text="hello updated", user_id=1)
saved_record = tbl.save(updated_record)
assert saved_record.id == 1
assert saved_record.text == "hello updated"

# Test save with dict
dict_record = {"id": 2, "text": "dict insert", "user_id": 2}
saved_dict = tbl.save(dict_record)
assert saved_dict.id == 2
assert saved_dict.text == "dict insert"

# Test save update with dict
dict_update = {"id": 2, "text": "dict updated", "user_id": 2}
saved_dict = tbl.save(dict_update)
assert saved_dict.id == 2
assert saved_dict.text == "dict updated"
Loading