Skip to content

Commit e90f32a

Browse files
author
Zecheng Zhang
authored
Add OpenAI embedding to text benchmark script (#367)
1 parent f4e5130 commit e90f32a

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1111
- Avoided for-loop in `EmbeddingEncoder` ([#366](https://github.com/pyg-team/pytorch-frame/pull/366))
1212
- Added `image_embedded` and one tabular image dataset ([#344](https://github.com/pyg-team/pytorch-frame/pull/344))
1313
- Added benchmarking suite for encoders ([#360](https://github.com/pyg-team/pytorch-frame/pull/360))
14-
- Added dataframe text benchmark script ([#354](https://github.com/pyg-team/pytorch-frame/pull/354))
14+
- Added dataframe text benchmark script ([#354](https://github.com/pyg-team/pytorch-frame/pull/354), [#367](https://github.com/pyg-team/pytorch-frame/pull/367))
1515
- Added `DataFrameTextBenchmark` dataset ([#349](https://github.com/pyg-team/pytorch-frame/pull/349))
1616
- Added support for empty `TensorFrame` ([#339](https://github.com/pyg-team/pytorch-frame/pull/339))
1717

benchmark/data_frame_text_benchmark.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,13 @@
8181
"google/electra-large-discriminator",
8282
"sentence-transformers/all-distilroberta-v1",
8383
"sentence-transformers/average_word_embeddings_glove.6B.300d",
84+
"sentence-transformers/all-roberta-large-v1",
85+
"text-embedding-3-large",
8486
],
8587
)
8688
parser.add_argument("--finetune", action="store_true")
8789
parser.add_argument('--result_path', type=str, default='')
90+
parser.add_argument("--api_key", type=str, default=None)
8891
args = parser.parse_args()
8992

9093
model_out_channels = {
@@ -188,6 +191,26 @@ def tokenize(self, sentences: list[str]) -> TextTokenizationOutputs:
188191
return_tensors="pt")
189192

190193

194+
class OpenAIEmbedding:
195+
def __init__(self, model: str, api_key: str):
196+
# Please run `pip install openai` to install the package
197+
from openai import OpenAI
198+
199+
self.client = OpenAI(api_key=api_key)
200+
self.model = model
201+
202+
def __call__(self, sentences: list[str]) -> Tensor:
203+
from openai import Embedding
204+
205+
items: list[Embedding] = self.client.embeddings.create(
206+
input=sentences, model=self.model).data
207+
assert len(items) == len(sentences)
208+
embeddings = [
209+
torch.FloatTensor(item.embedding).view(1, -1) for item in items
210+
]
211+
return torch.cat(embeddings, dim=0)
212+
213+
191214
def mean_pooling(last_hidden_state: Tensor, attention_mask: Tensor) -> Tensor:
192215
input_mask_expanded = (attention_mask.unsqueeze(-1).expand(
193216
last_hidden_state.size()).float())
@@ -360,7 +383,13 @@ def main_torch(
360383
path = osp.join(osp.dirname(osp.realpath(__file__)), "..", "data")
361384

362385
if not args.finetune:
363-
text_encoder = TextToEmbedding(model=args.text_model, device=device)
386+
if args.text_model == "text-embedding-3-large":
387+
assert isinstance(args.api_key, str)
388+
text_encoder = OpenAIEmbedding(model=args.text_model,
389+
api_key=args.api_key)
390+
else:
391+
text_encoder = TextToEmbedding(model=args.text_model,
392+
device=device)
364393
text_stype = torch_frame.text_embedded
365394
kwargs = {
366395
"text_stype":

0 commit comments

Comments
 (0)