88 pip install cyborgdb-py[langchain]
99"""
1010
11- import uuid
1211import json
12+ import uuid
1313import warnings
14+ from typing import Any , Dict , Iterable , List , Optional , Tuple , Union
15+
1416import numpy as np
15- from typing import List , Dict , Any , Optional , Tuple , Union , Iterable
1617
1718try :
18- from langchain_core .vectorstores import VectorStore , VectorStoreRetriever
1919 from langchain_core .documents import Document
2020 from langchain_core .embeddings import Embeddings
21+ from langchain_core .vectorstores import VectorStore , VectorStoreRetriever
2122 from sentence_transformers import SentenceTransformer
2223
2324 # Import CyborgDB components
24- from cyborgdb import Client , EncryptedIndex , IndexIVF , IndexIVFPQ , IndexIVFFlat
25+ from cyborgdb import Client , EncryptedIndex , IndexIVF , IndexIVFFlat , IndexIVFPQ
2526
2627 class CyborgVectorStore (VectorStore ):
2728 """
@@ -325,13 +326,16 @@ def add_texts(
325326 texts: Texts to add
326327 metadatas: Optional metadata for each text
327328 ids: Optional IDs for each text (generated if not provided)
328- **kwargs: Additional arguments (unused)
329+ **kwargs: Additional arguments including:
330+ - embeddings: Optional pre-computed embeddings for each text. If provided,
331+ skips embedding generation. Should be a list of vectors or numpy array
332+ with shape (num_texts, dimension)
329333
330334 Returns:
331335 List of IDs for the added texts
332336
333337 Raises:
334- ValueError: If lengths of texts, metadatas, or ids don't match
338+ ValueError: If lengths of texts, metadatas, ids, or embeddings don't match
335339 """
336340 texts_list = list (texts )
337341 num_texts = len (texts_list )
@@ -351,25 +355,41 @@ def add_texts(
351355 if metadatas is not None and len (metadatas ) != num_texts :
352356 raise ValueError ("Length of metadatas must match length of texts" )
353357
354- # Generate embeddings
355- embeddings = self .get_embeddings (texts_list )
358+ # Handle embeddings - either use provided vectors or generate them
359+ embeddings = kwargs .get ("embeddings" , None )
360+ if embeddings is not None :
361+ # Validate provided embeddings
362+ if isinstance (embeddings , np .ndarray ):
363+ if embeddings .shape [0 ] != num_texts :
364+ raise ValueError (
365+ f"Number of embeddings ({ embeddings .shape [0 ]} ) must match number of texts ({ num_texts } )"
366+ )
367+ vectors = embeddings
368+ else :
369+ # Assume it's a list of lists
370+ if len (embeddings ) != num_texts :
371+ raise ValueError (
372+ f"Number of embeddings ({ len (embeddings )} ) must match number of texts ({ num_texts } )"
373+ )
374+ vectors = np .array (embeddings , dtype = np .float32 )
375+ else :
376+ # Generate embeddings from texts
377+ vectors = self .get_embeddings (texts_list )
356378
357379 # Build items for upsert
358380 items = []
359381 for i in range (num_texts ):
360- # Handle both numpy arrays and lists for embeddings
361- if hasattr (embeddings , "shape" ):
382+ # Handle both numpy arrays and lists for vectors
383+ if hasattr (vectors , "shape" ):
362384 vector = (
363- embeddings [i ].tolist ()
364- if len (embeddings .shape ) > 1
365- else embeddings .tolist ()
385+ vectors [i ].tolist ()
386+ if len (vectors .shape ) > 1
387+ else vectors .tolist ()
366388 )
367389 else :
368- # embeddings is likely a list of lists or a list
390+ # vectors is likely a list of lists or a list
369391 vector = (
370- embeddings [i ]
371- if isinstance (embeddings [i ], list )
372- else [embeddings [i ]]
392+ vectors [i ] if isinstance (vectors [i ], list ) else [vectors [i ]]
373393 )
374394 item = {"id" : id_list [i ], "vector" : vector }
375395
@@ -390,15 +410,21 @@ def add_texts(
390410 return id_list
391411
392412 def add_documents (
393- self , documents : List [Document ], ids : Optional [List [str ]] = None , ** kwargs
413+ self ,
414+ documents : List [Document ],
415+ ids : Optional [List [str ]] = None ,
416+ ** kwargs ,
394417 ) -> List [str ]:
395418 """
396419 Add documents to the vector store.
397420
398421 Args:
399422 documents: Documents to add
400423 ids: Optional IDs for documents
401- **kwargs: Additional arguments
424+ **kwargs: Additional arguments including:
425+ - embeddings: Optional pre-computed embeddings for each document. If provided,
426+ skips embedding generation. Should be a list of vectors or numpy array
427+ with shape (num_documents, dimension)
402428
403429 Returns:
404430 List of IDs for the added documents
@@ -718,17 +744,24 @@ async def aadd_texts(
718744 ids : Optional [List [str ]] = None ,
719745 ** kwargs ,
720746 ) -> List [str ]:
721- """Async version of add_texts."""
747+ """Async version of add_texts with optional pre-computed embeddings ."""
722748 import asyncio
723749
724750 return await asyncio .to_thread (
725- self .add_texts , texts , metadatas = metadatas , ids = ids , ** kwargs
751+ self .add_texts ,
752+ texts ,
753+ metadatas = metadatas ,
754+ ids = ids ,
755+ ** kwargs ,
726756 )
727757
728758 async def aadd_documents (
729- self , documents : List [Document ], ids : Optional [List [str ]] = None , ** kwargs
759+ self ,
760+ documents : List [Document ],
761+ ids : Optional [List [str ]] = None ,
762+ ** kwargs ,
730763 ) -> List [str ]:
731- """Async version of add_documents."""
764+ """Async version of add_documents with optional pre-computed embeddings ."""
732765 import asyncio
733766
734767 return await asyncio .to_thread (
@@ -793,6 +826,7 @@ def from_texts(
793826 texts : List [str ],
794827 embedding : Union [str , Embeddings , SentenceTransformer ],
795828 metadatas : Optional [List [Dict ]] = None ,
829+ embeddings : Optional [Union [List [List [float ]], np .ndarray ]] = None ,
796830 ** kwargs ,
797831 ) -> "CyborgVectorStore" :
798832 """
@@ -802,6 +836,9 @@ def from_texts(
802836 texts: List of texts to add
803837 embedding: Embedding model
804838 metadatas: Optional metadata for each text
839+ embeddings: Optional pre-computed embeddings for each text. If provided,
840+ skips embedding generation. Should be a list of vectors or numpy array
841+ with shape (num_texts, dimension)
805842 **kwargs: Additional arguments including:
806843 - index_name: Name for the index
807844 - index_key: 32-byte encryption key
@@ -841,6 +878,19 @@ def from_texts(
841878 if key in kwargs :
842879 index_config_params [key ] = kwargs .pop (key )
843880
881+ # If embeddings are provided and dimension is not, infer it
882+ if embeddings is not None and dimension is None :
883+ if isinstance (embeddings , np .ndarray ):
884+ dimension = (
885+ embeddings .shape [1 ]
886+ if len (embeddings .shape ) > 1
887+ else len (embeddings )
888+ )
889+ elif isinstance (embeddings , list ) and len (embeddings ) > 0 :
890+ dimension = (
891+ len (embeddings [0 ]) if isinstance (embeddings [0 ], list ) else 1
892+ )
893+
844894 # Create vector store
845895 store = cls (
846896 index_name = index_name ,
@@ -857,7 +907,10 @@ def from_texts(
857907
858908 # Add texts if provided
859909 if texts :
860- store .add_texts (texts , metadatas , ids = ids )
910+ if embeddings is not None :
911+ store .add_texts (texts , metadatas , ids = ids , embeddings = embeddings )
912+ else :
913+ store .add_texts (texts , metadatas , ids = ids )
861914
862915 if not store .index .is_trained ():
863916 warnings .warn ("Not enough data to train index." )
@@ -877,7 +930,11 @@ def from_documents(
877930 Args:
878931 documents: List of documents to add
879932 embedding: Embedding model
880- **kwargs: Additional arguments (see from_texts)
933+ **kwargs: Additional arguments including:
934+ - embeddings: Optional pre-computed embeddings for each document. If provided,
935+ skips embedding generation. Should be a list of vectors or numpy array
936+ with shape (num_documents, dimension)
937+ - Other arguments (see from_texts)
881938
882939 Returns:
883940 CyborgVectorStore instance
0 commit comments