Skip to content

Commit 0f9a2aa

Browse files
Update add item (#34)
* add embedding as a parameter for add text * added test * update test * updated formating * update test * update test * update * update for linter
1 parent c8e8bcf commit 0f9a2aa

File tree

3 files changed

+350
-30
lines changed

3 files changed

+350
-30
lines changed

cyborgdb/integrations/langchain.py

Lines changed: 82 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,21 @@
88
pip install cyborgdb-py[langchain]
99
"""
1010

11-
import uuid
1211
import json
12+
import uuid
1313
import warnings
14+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
15+
1416
import numpy as np
15-
from typing import List, Dict, Any, Optional, Tuple, Union, Iterable
1617

1718
try:
18-
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
1919
from langchain_core.documents import Document
2020
from langchain_core.embeddings import Embeddings
21+
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
2122
from sentence_transformers import SentenceTransformer
2223

2324
# Import CyborgDB components
24-
from cyborgdb import Client, EncryptedIndex, IndexIVF, IndexIVFPQ, IndexIVFFlat
25+
from cyborgdb import Client, EncryptedIndex, IndexIVF, IndexIVFFlat, IndexIVFPQ
2526

2627
class CyborgVectorStore(VectorStore):
2728
"""
@@ -325,13 +326,16 @@ def add_texts(
325326
texts: Texts to add
326327
metadatas: Optional metadata for each text
327328
ids: Optional IDs for each text (generated if not provided)
328-
**kwargs: Additional arguments (unused)
329+
**kwargs: Additional arguments including:
330+
- embeddings: Optional pre-computed embeddings for each text. If provided,
331+
skips embedding generation. Should be a list of vectors or numpy array
332+
with shape (num_texts, dimension)
329333
330334
Returns:
331335
List of IDs for the added texts
332336
333337
Raises:
334-
ValueError: If lengths of texts, metadatas, or ids don't match
338+
ValueError: If lengths of texts, metadatas, ids, or embeddings don't match
335339
"""
336340
texts_list = list(texts)
337341
num_texts = len(texts_list)
@@ -351,25 +355,41 @@ def add_texts(
351355
if metadatas is not None and len(metadatas) != num_texts:
352356
raise ValueError("Length of metadatas must match length of texts")
353357

354-
# Generate embeddings
355-
embeddings = self.get_embeddings(texts_list)
358+
# Handle embeddings - either use provided vectors or generate them
359+
embeddings = kwargs.get("embeddings", None)
360+
if embeddings is not None:
361+
# Validate provided embeddings
362+
if isinstance(embeddings, np.ndarray):
363+
if embeddings.shape[0] != num_texts:
364+
raise ValueError(
365+
f"Number of embeddings ({embeddings.shape[0]}) must match number of texts ({num_texts})"
366+
)
367+
vectors = embeddings
368+
else:
369+
# Assume it's a list of lists
370+
if len(embeddings) != num_texts:
371+
raise ValueError(
372+
f"Number of embeddings ({len(embeddings)}) must match number of texts ({num_texts})"
373+
)
374+
vectors = np.array(embeddings, dtype=np.float32)
375+
else:
376+
# Generate embeddings from texts
377+
vectors = self.get_embeddings(texts_list)
356378

357379
# Build items for upsert
358380
items = []
359381
for i in range(num_texts):
360-
# Handle both numpy arrays and lists for embeddings
361-
if hasattr(embeddings, "shape"):
382+
# Handle both numpy arrays and lists for vectors
383+
if hasattr(vectors, "shape"):
362384
vector = (
363-
embeddings[i].tolist()
364-
if len(embeddings.shape) > 1
365-
else embeddings.tolist()
385+
vectors[i].tolist()
386+
if len(vectors.shape) > 1
387+
else vectors.tolist()
366388
)
367389
else:
368-
# embeddings is likely a list of lists or a list
390+
# vectors is likely a list of lists or a list
369391
vector = (
370-
embeddings[i]
371-
if isinstance(embeddings[i], list)
372-
else [embeddings[i]]
392+
vectors[i] if isinstance(vectors[i], list) else [vectors[i]]
373393
)
374394
item = {"id": id_list[i], "vector": vector}
375395

@@ -390,15 +410,21 @@ def add_texts(
390410
return id_list
391411

392412
def add_documents(
393-
self, documents: List[Document], ids: Optional[List[str]] = None, **kwargs
413+
self,
414+
documents: List[Document],
415+
ids: Optional[List[str]] = None,
416+
**kwargs,
394417
) -> List[str]:
395418
"""
396419
Add documents to the vector store.
397420
398421
Args:
399422
documents: Documents to add
400423
ids: Optional IDs for documents
401-
**kwargs: Additional arguments
424+
**kwargs: Additional arguments including:
425+
- embeddings: Optional pre-computed embeddings for each document. If provided,
426+
skips embedding generation. Should be a list of vectors or numpy array
427+
with shape (num_documents, dimension)
402428
403429
Returns:
404430
List of IDs for the added documents
@@ -718,17 +744,24 @@ async def aadd_texts(
718744
ids: Optional[List[str]] = None,
719745
**kwargs,
720746
) -> List[str]:
721-
"""Async version of add_texts."""
747+
"""Async version of add_texts with optional pre-computed embeddings."""
722748
import asyncio
723749

724750
return await asyncio.to_thread(
725-
self.add_texts, texts, metadatas=metadatas, ids=ids, **kwargs
751+
self.add_texts,
752+
texts,
753+
metadatas=metadatas,
754+
ids=ids,
755+
**kwargs,
726756
)
727757

728758
async def aadd_documents(
729-
self, documents: List[Document], ids: Optional[List[str]] = None, **kwargs
759+
self,
760+
documents: List[Document],
761+
ids: Optional[List[str]] = None,
762+
**kwargs,
730763
) -> List[str]:
731-
"""Async version of add_documents."""
764+
"""Async version of add_documents with optional pre-computed embeddings."""
732765
import asyncio
733766

734767
return await asyncio.to_thread(
@@ -793,6 +826,7 @@ def from_texts(
793826
texts: List[str],
794827
embedding: Union[str, Embeddings, SentenceTransformer],
795828
metadatas: Optional[List[Dict]] = None,
829+
embeddings: Optional[Union[List[List[float]], np.ndarray]] = None,
796830
**kwargs,
797831
) -> "CyborgVectorStore":
798832
"""
@@ -802,6 +836,9 @@ def from_texts(
802836
texts: List of texts to add
803837
embedding: Embedding model
804838
metadatas: Optional metadata for each text
839+
embeddings: Optional pre-computed embeddings for each text. If provided,
840+
skips embedding generation. Should be a list of vectors or numpy array
841+
with shape (num_texts, dimension)
805842
**kwargs: Additional arguments including:
806843
- index_name: Name for the index
807844
- index_key: 32-byte encryption key
@@ -841,6 +878,19 @@ def from_texts(
841878
if key in kwargs:
842879
index_config_params[key] = kwargs.pop(key)
843880

881+
# If embeddings are provided and dimension is not, infer it
882+
if embeddings is not None and dimension is None:
883+
if isinstance(embeddings, np.ndarray):
884+
dimension = (
885+
embeddings.shape[1]
886+
if len(embeddings.shape) > 1
887+
else len(embeddings)
888+
)
889+
elif isinstance(embeddings, list) and len(embeddings) > 0:
890+
dimension = (
891+
len(embeddings[0]) if isinstance(embeddings[0], list) else 1
892+
)
893+
844894
# Create vector store
845895
store = cls(
846896
index_name=index_name,
@@ -857,7 +907,10 @@ def from_texts(
857907

858908
# Add texts if provided
859909
if texts:
860-
store.add_texts(texts, metadatas, ids=ids)
910+
if embeddings is not None:
911+
store.add_texts(texts, metadatas, ids=ids, embeddings=embeddings)
912+
else:
913+
store.add_texts(texts, metadatas, ids=ids)
861914

862915
if not store.index.is_trained():
863916
warnings.warn("Not enough data to train index.")
@@ -877,7 +930,11 @@ def from_documents(
877930
Args:
878931
documents: List of documents to add
879932
embedding: Embedding model
880-
**kwargs: Additional arguments (see from_texts)
933+
**kwargs: Additional arguments including:
934+
- embeddings: Optional pre-computed embeddings for each document. If provided,
935+
skips embedding generation. Should be a list of vectors or numpy array
936+
with shape (num_documents, dimension)
937+
- Other arguments (see from_texts)
881938
882939
Returns:
883940
CyborgVectorStore instance

tests/test_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@ def test_upsert_and_query(self):
6363
query_vector = np.random.rand(dimension).astype(np.float32)
6464
results = self.index.query(query_vectors=query_vector, top_k=10)
6565

66-
# Check results
67-
self.assertEqual(len(results[0]), 10)
68-
self.assertTrue("id" in results[0][0])
69-
self.assertTrue("distance" in results[0][0])
66+
# Check results - results is a flat list, not nested
67+
self.assertEqual(len(results), 10)
68+
self.assertTrue("id" in results[0])
69+
self.assertTrue("distance" in results[0])
7070

7171
def test_health_check(self):
7272
"""Test the health check endpoint."""

0 commit comments

Comments
 (0)