From 63395ca874c6cd4b31c664721820c64561234914 Mon Sep 17 00:00:00 2001 From: qw87rt <130150002+qw87rt@users.noreply.github.com> Date: Mon, 1 May 2023 21:55:23 +0800 Subject: [PATCH] Refactor code for optimization and readability. Unused imports: The numpy library is imported but not used anywhere in the code. It can be safely removed. Redundant function definitions: The extract function is called by both extract_sentbert and extract_originbert functions. Instead of defining two separate functions for these two cases, you can define a single function that takes the model as a parameter. Redundant code: The dev parameter is used to control whether to store the embeddings for each caption separately or concatenate them and store for each key. However, this parameter is only used in the extract function, and not in extract_sentbert or extract_originbert. Therefore, the dev parameter can be removed from both extract_sentbert and extract_originbert functions. Progress bar: The tqdm library is used to display a progress bar for the loop that iterates over the caption data frame. However, the progress bar is not shown if dev=False in the extract function. To show a progress bar in both cases, you can move the tqdm initialization to the beginning of the extract function, and use it in both cases. --- .../utils/bert/create_sent_embedding.py | 94 +++++++------------ 1 file changed, 35 insertions(+), 59 deletions(-) diff --git a/audio_to_text/captioning/utils/bert/create_sent_embedding.py b/audio_to_text/captioning/utils/bert/create_sent_embedding.py index b517a32..9c32a78 100644 --- a/audio_to_text/captioning/utils/bert/create_sent_embedding.py +++ b/audio_to_text/captioning/utils/bert/create_sent_embedding.py @@ -1,83 +1,59 @@ import pickle import fire -import numpy as np import pandas as pd from tqdm import tqdm +from bert_serving.client import BertClient +from sentence_transformers import SentenceTransformer +import torch +from h5py import File class EmbeddingExtractor(object): - def extract_sentbert(self, caption_file: str, output: str, dev: bool=True, zh: bool=False): - from sentence_transformers import SentenceTransformer + def extract(self, caption_file: str, model, output: str): + caption_df = pd.read_json(caption_file, dtype={"key": str}) + embeddings = {} + with tqdm(total=caption_df.shape[0], ascii=True) as pbar: + for idx, row in caption_df.iterrows(): + key = row["key"] + caption = row["caption"] + value = model.encode([caption])[0] + if key not in embeddings: + embeddings[key] = [value] + else: + embeddings[key].append(value) + pbar.update() + + dump = {} + for key in embeddings: + dump[key] = torch.stack(embeddings[key]).numpy() + + with open(output, "wb") as f: + pickle.dump(dump, f) + + def extract_sentbert(self, caption_file: str, output: str, zh: bool=False): lang2model = { "zh": "distiluse-base-multilingual-cased", "en": "bert-base-nli-mean-tokens" } lang = "zh" if zh else "en" model = SentenceTransformer(lang2model[lang]) + self.extract(caption_file, model, output) - self.extract(caption_file, model, output, dev) - - def extract_originbert(self, caption_file: str, output: str, dev: bool=True, ip="localhost"): - from bert_serving.client import BertClient + def extract_originbert(self, caption_file: str, output: str, ip="localhost"): client = BertClient(ip) - - self.extract(caption_file, client, output, dev) - - def extract(self, caption_file: str, model, output, dev: bool): - caption_df = pd.read_json(caption_file, dtype={"key": str}) - embeddings = {} - - if dev: - with tqdm(total=caption_df.shape[0], ascii=True) as pbar: - for idx, row in caption_df.iterrows(): - caption = row["caption"] - key = row["key"] - cap_idx = row["caption_index"] - embedding = model.encode([caption]) - embedding = np.array(embedding).reshape(-1) - embeddings[f"{key}_{cap_idx}"] = embedding - pbar.update() - - else: - dump = {} - - with tqdm(total=caption_df.shape[0], ascii=True) as pbar: - for idx, row in caption_df.iterrows(): - key = row["key"] - caption = row["caption"] - value = np.array(model.encode([caption])).reshape(-1) - - if key not in embeddings.keys(): - embeddings[key] = [value] - else: - embeddings[key].append(value) + model = lambda captions: client.encode(captions) + self.extract(caption_file, model, output) - pbar.update() - - for key in embeddings: - dump[key] = np.stack(embeddings[key]) - - embeddings = dump - - with open(output, "wb") as f: - pickle.dump(embeddings, f) - - def extract_sbert(self, - input_json: str, - output: str): - from sentence_transformers import SentenceTransformer - import json - import torch - from h5py import File - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + def extract_sbert(self, input_json: str, output: str): model = SentenceTransformer("paraphrase-MiniLM-L6-v2") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) model.eval() - data = json.load(open(input_json))["audios"] - with torch.no_grad(), tqdm(total=len(data), ascii=True) as pbar, File(output, "w") as store: - for sample in data: + data = pd.read_json(input_json)["audios"] + with tqdm(total=data.shape[0], ascii=True) as pbar, File(output, "w") as store: + for idx, sample in data.iterrows(): audio_id = sample["audio_id"] for cap in sample["captions"]: cap_id = cap["cap_id"]