From 9debead2c2f8c2e40adf799d8748d115149aa2ed Mon Sep 17 00:00:00 2001 From: filipecosta90 Date: Thu, 13 Jun 2024 19:10:31 +0100 Subject: [PATCH] Enable specifying the max_optimization_threads on post_upload() for qdrant client --- engine/clients/qdrant/config.py | 21 +++++++++++++++++++++ engine/clients/qdrant/configure.py | 8 ++++++-- engine/clients/qdrant/upload.py | 12 ++++++++---- 3 files changed, 35 insertions(+), 6 deletions(-) diff --git a/engine/clients/qdrant/config.py b/engine/clients/qdrant/config.py index 31d34007..c9f1aade 100644 --- a/engine/clients/qdrant/config.py +++ b/engine/clients/qdrant/config.py @@ -1,3 +1,24 @@ import os +import random +import time QDRANT_COLLECTION_NAME = os.getenv("QDRANT_COLLECTION_NAME", "benchmark") +QDRANT_MAX_OPTIMIZATION_THREADS = os.getenv("QDRANT_MAX_OPTIMIZATION_THREADS", None) + + +def retry_with_exponential_backoff( + func, *args, max_retries=10, base_delay=1, max_delay=90, **kwargs +): + retries = 0 + while retries < max_retries: + try: + return func(*args, **kwargs) + except Exception as e: + delay = min(base_delay * 2**retries + random.uniform(0, 1), max_delay) + time.sleep(delay) + retries += 1 + print(f"received the following exception on try #{retries}: {e.__str__}") + if retries == max_retries: + raise e + else: + print("retrying...") diff --git a/engine/clients/qdrant/configure.py b/engine/clients/qdrant/configure.py index 668914b8..3d2afca1 100644 --- a/engine/clients/qdrant/configure.py +++ b/engine/clients/qdrant/configure.py @@ -4,7 +4,10 @@ from benchmark.dataset import Dataset from engine.base_client.configure import BaseConfigurator from engine.base_client.distances import Distance -from engine.clients.qdrant.config import QDRANT_COLLECTION_NAME +from engine.clients.qdrant.config import ( + QDRANT_COLLECTION_NAME, + retry_with_exponential_backoff, +) class QdrantConfigurator(BaseConfigurator): @@ -52,7 +55,8 @@ def recreate(self, dataset: Dataset, collection_params): ) } - self.client.recreate_collection( + retry_with_exponential_backoff( + self.client.recreate_collection, collection_name=QDRANT_COLLECTION_NAME, **vectors_config, **self.collection_params diff --git a/engine/clients/qdrant/upload.py b/engine/clients/qdrant/upload.py index a5c2dbbe..96a6f7c7 100644 --- a/engine/clients/qdrant/upload.py +++ b/engine/clients/qdrant/upload.py @@ -13,7 +13,10 @@ from dataset_reader.base_reader import Record from engine.base_client.upload import BaseUploader -from engine.clients.qdrant.config import QDRANT_COLLECTION_NAME +from engine.clients.qdrant.config import ( + QDRANT_COLLECTION_NAME, + QDRANT_MAX_OPTIMIZATION_THREADS, +) class QdrantUploader(BaseUploader): @@ -58,14 +61,15 @@ def upload_batch(cls, batch: List[Record]): @classmethod def post_upload(cls, _distance): + max_optimization_threads = QDRANT_MAX_OPTIMIZATION_THREADS + if max_optimization_threads is not None: + max_optimization_threads = int(max_optimization_threads) cls.client.update_collection( collection_name=QDRANT_COLLECTION_NAME, optimizer_config=OptimizersConfigDiff( - # indexing_threshold=10_000, - max_optimization_threads=1, + max_optimization_threads=max_optimization_threads, ), ) - cls.wait_collection_green() return {}