Skip to content

Commit 3052f3f

Browse files
kkopczynski-cyborgdupontcyborg
authored andcommitted
removed query_vector as a parameter of query
1 parent 8afd19d commit 3052f3f

File tree

3 files changed

+33
-27
lines changed

3 files changed

+33
-27
lines changed

cyborgdb/client/encrypted_index.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -401,8 +401,7 @@ def delete(self, ids: List[str]) -> None:
401401

402402
def query(
403403
self,
404-
query_vector: Optional[Union[List[float], np.ndarray]] = None,
405-
query_vectors: Optional[Union[np.ndarray, List[List[float]]]] = None,
404+
query_vectors: Optional[Union[np.ndarray, List[List[float]], List[float]]] = None,
406405
query_contents: Optional[str] = None,
407406
top_k: int = 100,
408407
n_probes: int = 1,
@@ -423,16 +422,34 @@ def query(
423422

424423
# Determine the correct vector input
425424
vector_list = None
425+
is_single_query = False
426426

427-
if query_vector is not None or query_contents is not None:
428-
if isinstance(query_vector, np.ndarray):
429-
if query_vector.ndim != 1:
430-
raise ValueError("Expected 1D NumPy array for `query_vector`.")
431-
vector_list = query_vector.tolist()
432-
elif isinstance(query_vector, list):
433-
if query_vector and isinstance(query_vector[0], (list, np.ndarray)):
434-
raise ValueError("Received nested list in `query_vector`; did you mean to use `query_vectors`?")
435-
vector_list = list(map(float, query_vector)) # Ensure float type
427+
if query_vectors is not None:
428+
if isinstance(query_vectors, np.ndarray):
429+
if query_vectors.ndim == 1:
430+
# Single vector as 1D NumPy array
431+
is_single_query = True
432+
vector_list = query_vectors.tolist()
433+
elif query_vectors.ndim == 2:
434+
# Batch of vectors as 2D NumPy array
435+
vector_list = query_vectors.tolist()
436+
else:
437+
raise ValueError("Expected 1D or 2D NumPy array for `query_vectors`.")
438+
elif isinstance(query_vectors, list):
439+
if not query_vectors:
440+
raise ValueError("Empty list provided for `query_vectors`.")
441+
if isinstance(query_vectors[0], (list, np.ndarray)):
442+
# Batch of vectors as list of lists
443+
vector_list = [list(map(float, v)) if isinstance(v, list) else v.tolist() for v in query_vectors]
444+
else:
445+
# Single vector as flat list
446+
is_single_query = True
447+
vector_list = list(map(float, query_vectors))
448+
else:
449+
raise ValueError("Invalid type for `query_vectors`")
450+
451+
if is_single_query or query_contents is not None:
452+
# Use QueryRequest for single vector or content-based query
436453
query_request = QueryRequest(
437454
index_key=self._key_to_hex(),
438455
index_name=self._index_name,
@@ -444,16 +461,8 @@ def query(
444461
filters=filters,
445462
include=include,
446463
)
447-
448-
elif query_vectors is not None:
449-
if isinstance(query_vectors, list):
450-
query_vectors = np.array(query_vectors, dtype=np.float32)
451-
if isinstance(query_vectors, np.ndarray):
452-
if query_vectors.ndim != 2:
453-
raise ValueError("Expected 2D NumPy array or list of lists for `query_vectors`.")
454-
vector_list = query_vectors.tolist()
455-
else:
456-
raise ValueError("Invalid type for `query_vectors`")
464+
else:
465+
# Use BatchQueryRequest for multiple vectors
457466
query_request = BatchQueryRequest(
458467
index_key=self._key_to_hex(),
459468
index_name=self._index_name,
@@ -466,9 +475,6 @@ def query(
466475
include=include,
467476
)
468477

469-
elif query_contents is None:
470-
raise ValueError("You must provide `query_vector`, `query_vectors`, or `query_contents`.")
471-
472478
request = Request(query_request)
473479

474480
# Execute query via REST

cyborgdb/integrations/langchain.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ def _execute_query(
458458
# Text query - generate embedding
459459
embedding = self.get_embeddings(query)
460460
results = self.index.query(
461-
query_vector=embedding,
461+
query_vectors=embedding,
462462
top_k=k,
463463
n_probes=n_probes,
464464
filters=filter,
@@ -467,7 +467,7 @@ def _execute_query(
467467
else:
468468
# Vector query
469469
results = self.index.query(
470-
query_vector=query,
470+
query_vectors=query,
471471
top_k=k,
472472
n_probes=n_probes,
473473
filters=filter,

tests/test_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_upsert_and_query(self):
5454

5555
# Query a vector
5656
query_vector = np.random.rand(dimension).astype(np.float32)
57-
results = self.index.query(query_vector=query_vector, top_k=10)
57+
results = self.index.query(query_vectors=query_vector, top_k=10)
5858

5959
# Check results
6060
self.assertEqual(len(results[0]), 10)

0 commit comments

Comments
 (0)