@@ -401,8 +401,7 @@ def delete(self, ids: List[str]) -> None:
401401
402402 def query (
403403 self ,
404- query_vector : Optional [Union [List [float ], np .ndarray ]] = None ,
405- query_vectors : Optional [Union [np .ndarray , List [List [float ]]]] = None ,
404+ query_vectors : Optional [Union [np .ndarray , List [List [float ]], List [float ]]] = None ,
406405 query_contents : Optional [str ] = None ,
407406 top_k : int = 100 ,
408407 n_probes : int = 1 ,
@@ -423,16 +422,34 @@ def query(
423422
424423 # Determine the correct vector input
425424 vector_list = None
425+ is_single_query = False
426426
427- if query_vector is not None or query_contents is not None :
428- if isinstance (query_vector , np .ndarray ):
429- if query_vector .ndim != 1 :
430- raise ValueError ("Expected 1D NumPy array for `query_vector`." )
431- vector_list = query_vector .tolist ()
432- elif isinstance (query_vector , list ):
433- if query_vector and isinstance (query_vector [0 ], (list , np .ndarray )):
434- raise ValueError ("Received nested list in `query_vector`; did you mean to use `query_vectors`?" )
435- vector_list = list (map (float , query_vector )) # Ensure float type
427+ if query_vectors is not None :
428+ if isinstance (query_vectors , np .ndarray ):
429+ if query_vectors .ndim == 1 :
430+ # Single vector as 1D NumPy array
431+ is_single_query = True
432+ vector_list = query_vectors .tolist ()
433+ elif query_vectors .ndim == 2 :
434+ # Batch of vectors as 2D NumPy array
435+ vector_list = query_vectors .tolist ()
436+ else :
437+ raise ValueError ("Expected 1D or 2D NumPy array for `query_vectors`." )
438+ elif isinstance (query_vectors , list ):
439+ if not query_vectors :
440+ raise ValueError ("Empty list provided for `query_vectors`." )
441+ if isinstance (query_vectors [0 ], (list , np .ndarray )):
442+ # Batch of vectors as list of lists
443+ vector_list = [list (map (float , v )) if isinstance (v , list ) else v .tolist () for v in query_vectors ]
444+ else :
445+ # Single vector as flat list
446+ is_single_query = True
447+ vector_list = list (map (float , query_vectors ))
448+ else :
449+ raise ValueError ("Invalid type for `query_vectors`" )
450+
451+ if is_single_query or query_contents is not None :
452+ # Use QueryRequest for single vector or content-based query
436453 query_request = QueryRequest (
437454 index_key = self ._key_to_hex (),
438455 index_name = self ._index_name ,
@@ -444,16 +461,8 @@ def query(
444461 filters = filters ,
445462 include = include ,
446463 )
447-
448- elif query_vectors is not None :
449- if isinstance (query_vectors , list ):
450- query_vectors = np .array (query_vectors , dtype = np .float32 )
451- if isinstance (query_vectors , np .ndarray ):
452- if query_vectors .ndim != 2 :
453- raise ValueError ("Expected 2D NumPy array or list of lists for `query_vectors`." )
454- vector_list = query_vectors .tolist ()
455- else :
456- raise ValueError ("Invalid type for `query_vectors`" )
464+ else :
465+ # Use BatchQueryRequest for multiple vectors
457466 query_request = BatchQueryRequest (
458467 index_key = self ._key_to_hex (),
459468 index_name = self ._index_name ,
@@ -466,9 +475,6 @@ def query(
466475 include = include ,
467476 )
468477
469- elif query_contents is None :
470- raise ValueError ("You must provide `query_vector`, `query_vectors`, or `query_contents`." )
471-
472478 request = Request (query_request )
473479
474480 # Execute query via REST
0 commit comments