Skip to content

Commit d539b2b

Browse files
committed
Add checking data_type function to baseclient/utils, check data_type in search_params before converting queries to bytes
1 parent 4582b91 commit d539b2b

File tree

4 files changed

+31
-15
lines changed

4 files changed

+31
-15
lines changed

engine/base_client/search.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
import numpy as np
88
import tqdm
99
import os
10+
from ml_dtypes import bfloat16
1011

1112
from dataset_reader.base_reader import Query
13+
from engine.base_client.utils import check_data_type
1214

1315
DEFAULT_TOP = 10
1416
MAX_QUERIES = int(os.getenv("MAX_QUERIES", -1))
@@ -66,6 +68,11 @@ def search_all(
6668
):
6769
parallel = self.search_params.get("parallel", 1)
6870
top = self.search_params.get("top", None)
71+
single_search_params = self.search_params.get("search_params", None)
72+
if single_search_params:
73+
data_type = check_data_type(single_search_params.get("data_type", "FLOAT32").upper())
74+
else:
75+
data_type = np.float32 # Default data type if not specified
6976
# setup_search may require initialized client
7077
self.init_client(
7178
self.host, distance, self.connection_params, self.search_params
@@ -78,7 +85,7 @@ def search_all(
7885
# Also, converts query vectors to bytes beforehand, preparing them for sending to client without affecting search time measurements
7986
queries_list = []
8087
for query in queries:
81-
query.vector = np.array(query.vector).astype(np.float32).tobytes()
88+
query.vector = np.array(query.vector).astype(data_type).tobytes()
8289
queries_list.append(query)
8390

8491
# Handle MAX_QUERIES environment variable

engine/base_client/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from typing import Any, Iterable
22

3+
from ml_dtypes import bfloat16
4+
import numpy as np
5+
36
from dataset_reader.base_reader import Record
47

58

@@ -18,3 +21,19 @@ def iter_batches(records: Iterable[Record], n: int) -> Iterable[Any]:
1821
ids, vectors, metadata = [], [], []
1922
if len(ids) > 0:
2023
yield [ids, vectors, metadata]
24+
25+
26+
def check_data_type(data_type: str):
27+
valid_data_types = ["FLOAT32", "FLOAT64", "FLOAT16", "BFLOAT16"]
28+
if data_type.upper() not in valid_data_types:
29+
raise ValueError(
30+
f"Invalid data type: {data_type}. Valid options are: {valid_data_types}"
31+
)
32+
if data_type == "FLOAT32":
33+
return np.float32
34+
if data_type == "FLOAT64":
35+
return np.float64
36+
if data_type == "FLOAT16":
37+
return np.float16
38+
if data_type == "BFLOAT16":
39+
return bfloat16

engine/clients/redis/search.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
from redis import Redis, RedisCluster
66
from redis.commands.search.query import Query
7+
from engine.base_client.utils import check_data_type
78
from engine.base_client.search import BaseSearcher
89
from engine.clients.redis.config import (
910
REDIS_PORT,
@@ -44,13 +45,7 @@ def init_client(cls, host, distance, connection_params: dict, search_params: dic
4445
cls.data_type = (
4546
cls.search_params["search_params"].get("data_type", "FLOAT32").upper()
4647
)
47-
cls.np_data_type = np.float32
48-
if cls.data_type == "FLOAT64":
49-
cls.np_data_type = np.float64
50-
if cls.data_type == "FLOAT16":
51-
cls.np_data_type = np.float16
52-
if cls.data_type == "BFLOAT16":
53-
cls.np_data_type = bfloat16
48+
cls.np_data_type = check_data_type(cls.data_type)
5449
cls._is_cluster = True if REDIS_CLUSTER else False
5550

5651
# In the case of CLUSTER API enabled we randomly select the starting primary shard

engine/clients/redis/upload.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
from redis import Redis, RedisCluster
99
from engine.base_client.upload import BaseUploader
10+
from engine.base_client.utils import check_data_type
1011
from engine.clients.redis.config import (
1112
REDIS_PORT,
1213
REDIS_AUTH,
@@ -42,13 +43,7 @@ def init_client(cls, host, distance, connection_params, upload_params):
4243
cls.upload_params = upload_params
4344
cls.algorithm = cls.upload_params.get("algorithm", "hnsw").upper()
4445
cls.data_type = cls.upload_params.get("data_type", "FLOAT32").upper()
45-
cls.np_data_type = np.float32
46-
if cls.data_type == "FLOAT64":
47-
cls.np_data_type = np.float64
48-
if cls.data_type == "FLOAT16":
49-
cls.np_data_type = np.float16
50-
if cls.data_type == "BFLOAT16":
51-
cls.np_data_type = bfloat16
46+
cls.np_data_type = check_data_type(cls.data_type)
5247
cls._is_cluster = True if REDIS_CLUSTER else False
5348

5449
@classmethod

0 commit comments

Comments
 (0)