diff --git a/audio_to_text/captioning/utils/bert/create_sent_embedding.py b/audio_to_text/captioning/utils/bert/create_sent_embedding.py index b517a32..9c32a78 100644 --- a/audio_to_text/captioning/utils/bert/create_sent_embedding.py +++ b/audio_to_text/captioning/utils/bert/create_sent_embedding.py @@ -1,83 +1,59 @@ import pickle import fire -import numpy as np import pandas as pd from tqdm import tqdm +from bert_serving.client import BertClient +from sentence_transformers import SentenceTransformer +import torch +from h5py import File class EmbeddingExtractor(object): - def extract_sentbert(self, caption_file: str, output: str, dev: bool=True, zh: bool=False): - from sentence_transformers import SentenceTransformer + def extract(self, caption_file: str, model, output: str): + caption_df = pd.read_json(caption_file, dtype={"key": str}) + embeddings = {} + with tqdm(total=caption_df.shape[0], ascii=True) as pbar: + for idx, row in caption_df.iterrows(): + key = row["key"] + caption = row["caption"] + value = model.encode([caption])[0] + if key not in embeddings: + embeddings[key] = [value] + else: + embeddings[key].append(value) + pbar.update() + + dump = {} + for key in embeddings: + dump[key] = torch.stack(embeddings[key]).numpy() + + with open(output, "wb") as f: + pickle.dump(dump, f) + + def extract_sentbert(self, caption_file: str, output: str, zh: bool=False): lang2model = { "zh": "distiluse-base-multilingual-cased", "en": "bert-base-nli-mean-tokens" } lang = "zh" if zh else "en" model = SentenceTransformer(lang2model[lang]) + self.extract(caption_file, model, output) - self.extract(caption_file, model, output, dev) - - def extract_originbert(self, caption_file: str, output: str, dev: bool=True, ip="localhost"): - from bert_serving.client import BertClient + def extract_originbert(self, caption_file: str, output: str, ip="localhost"): client = BertClient(ip) - - self.extract(caption_file, client, output, dev) - - def extract(self, caption_file: str, model, output, dev: bool): - caption_df = pd.read_json(caption_file, dtype={"key": str}) - embeddings = {} - - if dev: - with tqdm(total=caption_df.shape[0], ascii=True) as pbar: - for idx, row in caption_df.iterrows(): - caption = row["caption"] - key = row["key"] - cap_idx = row["caption_index"] - embedding = model.encode([caption]) - embedding = np.array(embedding).reshape(-1) - embeddings[f"{key}_{cap_idx}"] = embedding - pbar.update() - - else: - dump = {} - - with tqdm(total=caption_df.shape[0], ascii=True) as pbar: - for idx, row in caption_df.iterrows(): - key = row["key"] - caption = row["caption"] - value = np.array(model.encode([caption])).reshape(-1) - - if key not in embeddings.keys(): - embeddings[key] = [value] - else: - embeddings[key].append(value) + model = lambda captions: client.encode(captions) + self.extract(caption_file, model, output) - pbar.update() - - for key in embeddings: - dump[key] = np.stack(embeddings[key]) - - embeddings = dump - - with open(output, "wb") as f: - pickle.dump(embeddings, f) - - def extract_sbert(self, - input_json: str, - output: str): - from sentence_transformers import SentenceTransformer - import json - import torch - from h5py import File - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + def extract_sbert(self, input_json: str, output: str): model = SentenceTransformer("paraphrase-MiniLM-L6-v2") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) model.eval() - data = json.load(open(input_json))["audios"] - with torch.no_grad(), tqdm(total=len(data), ascii=True) as pbar, File(output, "w") as store: - for sample in data: + data = pd.read_json(input_json)["audios"] + with tqdm(total=data.shape[0], ascii=True) as pbar, File(output, "w") as store: + for idx, sample in data.iterrows(): audio_id = sample["audio_id"] for cap in sample["captions"]: cap_id = cap["cap_id"]