|
81 | 81 | "google/electra-large-discriminator", |
82 | 82 | "sentence-transformers/all-distilroberta-v1", |
83 | 83 | "sentence-transformers/average_word_embeddings_glove.6B.300d", |
| 84 | + "sentence-transformers/all-roberta-large-v1", |
| 85 | + "text-embedding-3-large", |
84 | 86 | ], |
85 | 87 | ) |
86 | 88 | parser.add_argument("--finetune", action="store_true") |
87 | 89 | parser.add_argument('--result_path', type=str, default='') |
| 90 | +parser.add_argument("--api_key", type=str, default=None) |
88 | 91 | args = parser.parse_args() |
89 | 92 |
|
90 | 93 | model_out_channels = { |
@@ -188,6 +191,26 @@ def tokenize(self, sentences: list[str]) -> TextTokenizationOutputs: |
188 | 191 | return_tensors="pt") |
189 | 192 |
|
190 | 193 |
|
| 194 | +class OpenAIEmbedding: |
| 195 | + def __init__(self, model: str, api_key: str): |
| 196 | + # Please run `pip install openai` to install the package |
| 197 | + from openai import OpenAI |
| 198 | + |
| 199 | + self.client = OpenAI(api_key=api_key) |
| 200 | + self.model = model |
| 201 | + |
| 202 | + def __call__(self, sentences: list[str]) -> Tensor: |
| 203 | + from openai import Embedding |
| 204 | + |
| 205 | + items: list[Embedding] = self.client.embeddings.create( |
| 206 | + input=sentences, model=self.model).data |
| 207 | + assert len(items) == len(sentences) |
| 208 | + embeddings = [ |
| 209 | + torch.FloatTensor(item.embedding).view(1, -1) for item in items |
| 210 | + ] |
| 211 | + return torch.cat(embeddings, dim=0) |
| 212 | + |
| 213 | + |
191 | 214 | def mean_pooling(last_hidden_state: Tensor, attention_mask: Tensor) -> Tensor: |
192 | 215 | input_mask_expanded = (attention_mask.unsqueeze(-1).expand( |
193 | 216 | last_hidden_state.size()).float()) |
@@ -360,7 +383,13 @@ def main_torch( |
360 | 383 | path = osp.join(osp.dirname(osp.realpath(__file__)), "..", "data") |
361 | 384 |
|
362 | 385 | if not args.finetune: |
363 | | - text_encoder = TextToEmbedding(model=args.text_model, device=device) |
| 386 | + if args.text_model == "text-embedding-3-large": |
| 387 | + assert isinstance(args.api_key, str) |
| 388 | + text_encoder = OpenAIEmbedding(model=args.text_model, |
| 389 | + api_key=args.api_key) |
| 390 | + else: |
| 391 | + text_encoder = TextToEmbedding(model=args.text_model, |
| 392 | + device=device) |
364 | 393 | text_stype = torch_frame.text_embedded |
365 | 394 | kwargs = { |
366 | 395 | "text_stype": |
|
0 commit comments