Skip to content

Commit 620e556

Browse files
authored
fix: improve (re)insertion speed (#80)
1 parent 8052840 commit 620e556

File tree

2 files changed

+24
-31
lines changed

2 files changed

+24
-31
lines changed

src/raglite/_database.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ class Document(SQLModel, table=True):
5050
metadata_: dict[str, Any] = Field(default_factory=dict, sa_column=Column("metadata", JSON))
5151

5252
# Add relationships so we can access document.chunks and document.evals.
53-
chunks: list["Chunk"] = Relationship(back_populates="document")
54-
evals: list["Eval"] = Relationship(back_populates="document")
53+
chunks: list["Chunk"] = Relationship(back_populates="document", cascade_delete=True)
54+
evals: list["Eval"] = Relationship(back_populates="document", cascade_delete=True)
5555

5656
@staticmethod
5757
def from_path(doc_path: Path, **kwargs: Any) -> "Document":
@@ -76,15 +76,15 @@ class Chunk(SQLModel, table=True):
7676

7777
# Table columns.
7878
id: ChunkId = Field(..., primary_key=True)
79-
document_id: DocumentId = Field(..., foreign_key="document.id", index=True)
79+
document_id: DocumentId = Field(..., foreign_key="document.id", index=True, ondelete="CASCADE")
8080
index: int = Field(..., index=True)
8181
headings: str
8282
body: str
8383
metadata_: dict[str, Any] = Field(default_factory=dict, sa_column=Column("metadata", JSON))
8484

8585
# Add relationships so we can access chunk.document and chunk.embeddings.
8686
document: Document = Relationship(back_populates="chunks")
87-
embeddings: list["ChunkEmbedding"] = Relationship(back_populates="chunk")
87+
embeddings: list["ChunkEmbedding"] = Relationship(back_populates="chunk", cascade_delete=True)
8888

8989
@staticmethod
9090
def from_body(
@@ -230,7 +230,7 @@ class ChunkEmbedding(SQLModel, table=True):
230230

231231
# Table columns.
232232
id: int = Field(..., primary_key=True)
233-
chunk_id: ChunkId = Field(..., foreign_key="chunk.id", index=True)
233+
chunk_id: ChunkId = Field(..., foreign_key="chunk.id", index=True, ondelete="CASCADE")
234234
embedding: FloatVector = Field(..., sa_column=Column(Embedding(dim=-1)))
235235

236236
# Add relationship so we can access embedding.chunk.
@@ -285,7 +285,7 @@ class Eval(SQLModel, table=True):
285285

286286
# Table columns.
287287
id: EvalId = Field(..., primary_key=True)
288-
document_id: DocumentId = Field(..., foreign_key="document.id", index=True)
288+
document_id: DocumentId = Field(..., foreign_key="document.id", index=True, ondelete="CASCADE")
289289
chunk_ids: list[ChunkId] = Field(default_factory=list, sa_column=Column(JSON))
290290
question: str
291291
contexts: list[str] = Field(default_factory=list, sa_column=Column(JSON))

src/raglite/_insert.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,16 @@ def insert_document(doc_path: Path, *, config: RAGLiteConfig | None = None) -> N
6969
"""Insert a document into the database and update the index."""
7070
# Use the default config if not provided.
7171
config = config or RAGLiteConfig()
72-
db_backend = make_url(config.db_url).get_backend_name()
7372
# Preprocess the document into chunks and chunk embeddings.
74-
with tqdm(total=5, unit="step", dynamic_ncols=True) as pbar:
73+
with tqdm(total=6, unit="step", dynamic_ncols=True) as pbar:
7574
pbar.set_description("Initializing database")
7675
engine = create_database_engine(config)
76+
document_record = Document.from_path(doc_path)
77+
with Session(engine) as session: # Exit early if the document is already in the database.
78+
if session.get(Document, document_record.id) is not None:
79+
pbar.update(6)
80+
pbar.close()
81+
return
7782
pbar.update(1)
7883
pbar.set_description("Converting to Markdown")
7984
doc = document_to_markdown(doc_path)
@@ -92,32 +97,20 @@ def insert_document(doc_path: Path, *, config: RAGLiteConfig | None = None) -> N
9297
max_size=config.chunk_max_size,
9398
)
9499
pbar.update(1)
95-
# Create and store the chunk records.
96-
with Session(engine) as session:
97-
# Add the document to the document table.
98-
document_record = Document.from_path(doc_path)
99-
if session.get(Document, document_record.id) is None:
100+
pbar.set_description("Updating database")
101+
with Session(engine) as session:
100102
session.add(document_record)
103+
for chunk_record, chunk_embedding_record_list in zip(
104+
*_create_chunk_records(document_record.id, chunks, chunk_embeddings, config),
105+
strict=True,
106+
):
107+
session.add(chunk_record)
108+
session.add_all(chunk_embedding_record_list)
101109
session.commit()
102-
# Create the chunk records to insert into the chunk table.
103-
chunk_records, chunk_embedding_records = _create_chunk_records(
104-
document_record.id, chunks, chunk_embeddings, config
105-
)
106-
# Store the chunk and chunk embedding records.
107-
for chunk_record, chunk_embedding_record_list in tqdm(
108-
zip(chunk_records, chunk_embedding_records, strict=True),
109-
desc="Inserting chunks",
110-
total=len(chunk_records),
111-
unit="chunk",
112-
dynamic_ncols=True,
113-
):
114-
if session.get(Chunk, chunk_record.id) is not None:
115-
continue
116-
session.add(chunk_record)
117-
session.add_all(chunk_embedding_record_list)
118-
session.commit()
110+
pbar.update(1)
111+
pbar.close()
119112
# Manually update the vector search chunk index for SQLite.
120-
if db_backend == "sqlite":
113+
if make_url(config.db_url).get_backend_name() == "sqlite":
121114
from pynndescent import NNDescent
122115

123116
with Session(engine) as session:

0 commit comments

Comments
 (0)