Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ google-auth-httplib2
google-auth-oauthlib
ratelimit
backoff
kdbai-client
kdbai-client>=1.4.0
sentry-sdk[opentelemetry]
halo
sentence-transformers>=2.6.1
Expand Down
106 changes: 59 additions & 47 deletions src/vdf_io/export_vdf/kdbai_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
standardize_metric,
)


load_dotenv()


Expand All @@ -27,66 +26,85 @@ class ExportKDBAI(ExportVDB):
@classmethod
def make_parser(cls, subparsers):
parser_kdbai = subparsers.add_parser(
cls.DB_NAME_SLUG, help="Export data from KDB.AI"
cls.DB_NAME_SLUG,
help="Export data from KDB.AI",
)
parser_kdbai.add_argument(
"--kdbai_endpoint",
type=str,
help="KDB.AI cloud endpoint to connect.",
)
parser_kdbai.add_argument(
"-u",
"--url",
"--kdbai_api_key",
type=str,
help="KDB.AI cloud endpoint to connect",
help="KDB.AI cloud endpoint to connect.",
)
parser_kdbai.add_argument(
"-t", "--tables", type=str, help="KDB.AI tables to export (comma-separated)"
"--database_name",
type=str,
help="Name of the KDB.AI database to write into.",
)
parser_kdbai.add_argument(
"--tables_names",
type=str,
help="Names of the KDB.AI tables to export (comma-separated).",
)

@classmethod
def export_vdb(cls, args):
"""
Export data from KDBAI
"""
set_arg_from_input(
args,
"url",
"Enter the KDB.AI endpoint instance: ",
str,
env_var="KDBAI_ENDPOINT",
)
set_arg_from_password(
args, "kdbai_api_key", "Enter your KDB.AI API key: ", "KDBAI_API_KEY"
)
if args.get("kdbai_endpoint") is None:
set_arg_from_input(
args,
"kdbai_endpoint",
"Enter the KDB.AI endpoint instance: ",
str,
env_var="KDBAI_ENDPOINT",
)

if args.get("kdbai_api_key") is None:
set_arg_from_password(
args, "kdbai_api_key", "Enter your KDB.AI API key: ", "KDBAI_API_KEY"
)

kdbai_export = ExportKDBAI(args)
set_arg_from_input(
args,
"tables",
"Enter the name of table to export:",
str,
choices=kdbai_export.get_all_index_names(),
)
if args.get("tables", None) == "":
args["tables"] = ",".join(kdbai_export.get_all_index_names())

if args.get("tables_names") is None:
set_arg_from_input(
args,
"tables_names",
"Enter the name of table to export:",
str,
choices=kdbai_export.get_all_index_names(),
)

if args.get("tables_names", None) == "":
args["tables_names"] = ",".join(kdbai_export.get_all_index_names())
kdbai_export.get_data()
return kdbai_export

def __init__(self, args):
super().__init__(args)
api_key = args.get("kdbai_api_key")
endpoint = args.get("url")
self.session = kdbai.Session(api_key=api_key, endpoint=endpoint)
self.model = args.get("model_name")
endpoint = args.get("kdbai_endpoint")
session = kdbai.Session(api_key=api_key, endpoint=endpoint)
self.db = session.database("default")

def get_index_names(self):
if "tables" not in self.args or self.args["tables"] is None:
if "tables_names" not in self.args or self.args["tables_names"] is None:
return self.get_all_index_names()
return self.args["tables"].split(",")
return self.args["tables_names"].split(",")

def get_all_index_names(self):
return self.session.list()
return [name.name for name in self.db.tables]

def get_data(self):
if "tables" not in self.args or self.args["tables"] is None:
if "tables_names" not in self.args or self.args["tables_names"] is None:
table_names = self.get_all_index_names()
else:
table_names = self.args["tables"].split(",")
table_names = self.args["tables_names"].split(",")
index_metas: Dict[str, List[NamespaceMeta]] = {}
for table_name in tqdm(table_names, desc="Fetching indexes"):
index_metas[table_name] = self.export_table(table_name)
Expand All @@ -106,10 +124,9 @@ def get_data(self):
json_file.write(meta_json_text)

def export_table(self, table_name):
model = self.model
vectors_directory = self.create_vec_dir(table_name)

table = self.session.table(table_name)
table = self.db.table(table_name)
table_res = table.query()
save_path = f"{vectors_directory}/{table_name}.parquet"
table_res.to_parquet(save_path, index=False)
Expand All @@ -118,18 +135,13 @@ def export_table(self, table_name):
# vectors = table_res["vector"].apply(pd.Series)
# metadata = table_res.drop(columns=["vector"]).to_dict(orient="records")
# self.save_vectors_to_parquet(vectors, metadata, vectors_directory)
embedding_name = None
embedding_dims = None
embedding_dist = None
tab_schema = table.schema()

for i in range(len(tab_schema["columns"])):
if "vectorIndex" in tab_schema["columns"][i].keys():
embedding_name = tab_schema["columns"][i]["name"]
embedding_dims = tab_schema["columns"][i]["vectorIndex"]["dims"]
embedding_dist = standardize_metric(
tab_schema["columns"][i]["vectorIndex"]["metric"], self.DB_NAME_SLUG
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can remove the commented out code above as well

model = table.indexes[0]["type"]
embedding_name = table.indexes[0]["column"]
embedding_dims = table.indexes[0]["params"]["dims"]
embedding_dist = standardize_metric(
table.indexes[0]["params"]["metric"], self.DB_NAME_SLUG
)

namespace_meta = NamespaceMeta(
namespace="",
Expand Down
Loading