Skip to content

Commit b934e93

Browse files
author
Zecheng Zhang
authored
Add image embedded and one tabular image dataset (#344)
1 parent e5a384b commit b934e93

File tree

15 files changed

+468
-4
lines changed

15 files changed

+468
-4
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
88

99
### Added
1010

11+
- Added `image_embedded` and one tabular image dataset ([#344](https://github.com/pyg-team/pytorch-frame/pull/344))
1112
- Added benchmarking suite for encoders ([#360](https://github.com/pyg-team/pytorch-frame/pull/360))
1213
- Added dataframe text benchmark script ([#354](https://github.com/pyg-team/pytorch-frame/pull/354))
1314
- Added `DataFrameTextBenchmark` dataset ([#349](https://github.com/pyg-team/pytorch-frame/pull/349))

examples/tabular_image.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
import argparse
2+
import os
3+
import os.path as osp
4+
5+
import torch
6+
import torch.nn.functional as F
7+
from PIL import Image
8+
from torch import Tensor
9+
from tqdm import tqdm
10+
from transformers import AutoImageProcessor, AutoModel
11+
12+
from torch_frame import stype
13+
from torch_frame.config import ImageEmbedder, ImageEmbedderConfig
14+
from torch_frame.data import DataLoader
15+
from torch_frame.datasets import DiamondImages
16+
from torch_frame.nn import (
17+
EmbeddingEncoder,
18+
FTTransformer,
19+
LinearEmbeddingEncoder,
20+
LinearEncoder,
21+
)
22+
23+
parser = argparse.ArgumentParser()
24+
parser.add_argument("--channels", type=int, default=256)
25+
parser.add_argument("--num_layers", type=int, default=4)
26+
parser.add_argument("--batch_size", type=int, default=512)
27+
parser.add_argument("--lr", type=float, default=0.0001)
28+
parser.add_argument("--epochs", type=int, default=30)
29+
parser.add_argument("--seed", type=int, default=0)
30+
parser.add_argument(
31+
"--model",
32+
type=str,
33+
default="google/vit-base-patch16-224-in21k",
34+
choices=[
35+
"microsoft/resnet-18",
36+
"google/vit-base-patch16-224-in21k",
37+
"microsoft/swin-base-patch4-window7-224-in22k",
38+
],
39+
)
40+
41+
args = parser.parse_args()
42+
43+
# Image Embedded
44+
# ================ ResNet ===================
45+
# Best Val Acc: 0.2864, Best Test Acc: 0.2789
46+
# ================== ViT ====================
47+
# Best Val Acc: 0.4173, Best Test Acc: 0.4110
48+
# ================= Swin ====================
49+
# Best Val Acc: 0.4345, Best Test Acc: 0.4274
50+
51+
52+
class ImageToEmbedding(ImageEmbedder):
53+
def __init__(self, model_name: str, device: torch.device):
54+
super().__init__()
55+
self.model_name = model_name
56+
self.preprocess = AutoImageProcessor.from_pretrained(model_name)
57+
self.model = AutoModel.from_pretrained(model_name).to(device)
58+
self.model.eval()
59+
self.device = device
60+
61+
def forward_embed(self, images: list[Image]) -> Tensor:
62+
inputs = self.preprocess(images, return_tensors="pt")
63+
inputs["pixel_values"] = inputs["pixel_values"].to(self.device)
64+
with torch.no_grad():
65+
res = self.model(**inputs).pooler_output.cpu().detach()
66+
if "resnet" in self.model_name:
67+
res = res.squeeze(dim=(2, 3))
68+
return res
69+
70+
71+
torch.manual_seed(args.seed)
72+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73+
74+
# Prepare datasets
75+
path = osp.join(osp.dirname(osp.realpath(__file__)), "..", "data",
76+
"diamond_images")
77+
os.makedirs(path, exist_ok=True)
78+
79+
col_to_image_embedder_cfg = ImageEmbedderConfig(
80+
image_embedder=ImageToEmbedding(args.model, device), batch_size=10)
81+
dataset = DiamondImages(path,
82+
col_to_image_embedder_cfg=col_to_image_embedder_cfg)
83+
84+
model_name = args.model.replace('/', '')
85+
filename = f"{model_name}_data.pt"
86+
dataset.materialize(path=osp.join(path, filename))
87+
dataset = dataset.shuffle()
88+
train_dataset, val_dataset, test_dataset = dataset[:0.8], dataset[
89+
0.8:0.9], dataset[0.9:]
90+
91+
train_tensor_frame = train_dataset.tensor_frame
92+
val_tensor_frame = val_dataset.tensor_frame
93+
test_tensor_frame = test_dataset.tensor_frame
94+
train_loader = DataLoader(train_tensor_frame, batch_size=args.batch_size,
95+
shuffle=True)
96+
val_loader = DataLoader(val_tensor_frame, batch_size=args.batch_size)
97+
test_loader = DataLoader(test_tensor_frame, batch_size=args.batch_size)
98+
99+
stype_encoder_dict = {
100+
stype.categorical: EmbeddingEncoder(),
101+
stype.numerical: LinearEncoder(),
102+
stype.image_embedded.parent: LinearEmbeddingEncoder(),
103+
}
104+
105+
model = FTTransformer(
106+
channels=args.channels,
107+
out_channels=dataset.num_classes,
108+
num_layers=args.num_layers,
109+
col_stats=dataset.col_stats,
110+
col_names_dict=train_tensor_frame.col_names_dict,
111+
stype_encoder_dict=stype_encoder_dict,
112+
).to(device)
113+
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
114+
115+
116+
def train(epoch: int) -> float:
117+
model.train()
118+
loss_accum = total_count = 0
119+
120+
for tf in tqdm(train_loader, desc=f"Epoch: {epoch}"):
121+
tf = tf.to(device)
122+
pred = model(tf)
123+
loss = F.cross_entropy(pred, tf.y)
124+
optimizer.zero_grad()
125+
loss.backward()
126+
loss_accum += float(loss) * len(tf.y)
127+
total_count += len(tf.y)
128+
optimizer.step()
129+
return loss_accum / total_count
130+
131+
132+
@torch.no_grad()
133+
def test(loader: DataLoader) -> float:
134+
model.eval()
135+
accum = total_count = 0
136+
137+
for tf in loader:
138+
tf = tf.to(device)
139+
pred = model(tf)
140+
pred_class = pred.argmax(dim=-1)
141+
accum += float((tf.y == pred_class).sum())
142+
total_count += len(tf.y)
143+
144+
accuracy = accum / total_count
145+
return accuracy
146+
147+
148+
metric = "Acc"
149+
best_val_metric = 0
150+
best_test_metric = 0
151+
152+
for epoch in range(1, args.epochs + 1):
153+
train_loss = train(epoch)
154+
train_metric = test(train_loader)
155+
val_metric = test(val_loader)
156+
test_metric = test(test_loader)
157+
if val_metric > best_val_metric:
158+
best_val_metric = val_metric
159+
best_test_metric = test_metric
160+
161+
print(f"Train Loss: {train_loss:.4f}, Train {metric}: {train_metric:.4f}, "
162+
f"Val {metric}: {val_metric:.4f}, Test {metric}: {test_metric:.4f}")
163+
164+
print(f"Best Val {metric}: {best_val_metric:.4f}, "
165+
f"Best Test {metric}: {best_test_metric:.4f}")

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ dependencies=[
3333
"torch",
3434
"tqdm",
3535
"pyarrow",
36+
"Pillow",
3637
]
3738

3839
[project.optional-dependencies]

test/data/test_dataset.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
import torch
55

66
import torch_frame
7+
from torch_frame.config.image_embedder import ImageEmbedderConfig
78
from torch_frame.config.text_embedder import TextEmbedderConfig
89
from torch_frame.data import DataFrameToTensorFrameConverter, Dataset
910
from torch_frame.data.dataset import canonicalize_col_to_pattern
1011
from torch_frame.data.stats import StatType
1112
from torch_frame.datasets import FakeDataset
13+
from torch_frame.testing.image_embedder import RandomImageEmbedder
1214
from torch_frame.testing.text_embedder import HashTextEmbedder
1315
from torch_frame.typing import TaskType
1416

@@ -93,11 +95,16 @@ def test_dataset_inductive_transform():
9395
-1).all()
9496

9597

96-
def test_materalization_and_converter():
98+
def test_materalization_and_converter(tmpdir):
99+
tmp_path = str(tmpdir.mkdir("image"))
97100
text_embedder_cfg = TextEmbedderConfig(
98101
text_embedder=HashTextEmbedder(1),
99102
batch_size=8,
100103
)
104+
image_embedder_cfg = ImageEmbedderConfig(
105+
image_embedder=RandomImageEmbedder(1),
106+
batch_size=8,
107+
)
101108
dataset_stypes = [
102109
torch_frame.categorical,
103110
torch_frame.numerical,
@@ -106,11 +113,14 @@ def test_materalization_and_converter():
106113
torch_frame.timestamp,
107114
torch_frame.text_embedded,
108115
torch_frame.embedding,
116+
torch_frame.image_embedded,
109117
]
110118
dataset = FakeDataset(
111119
num_rows=10,
112120
stypes=dataset_stypes,
113121
col_to_text_embedder_cfg=text_embedder_cfg,
122+
col_to_image_embedder_cfg=image_embedder_cfg,
123+
tmp_path=tmp_path,
114124
).materialize()
115125
expected_parent_feat_size: dict[torch_frame.stype, int] = dict()
116126
stype_num_cols: dict[torch_frame.stype, int] = dict()
@@ -144,6 +154,7 @@ def test_materalization_and_converter():
144154
col_to_sep=dataset.col_to_sep,
145155
col_to_time_format=dataset.col_to_time_format,
146156
col_to_text_embedder_cfg=dataset.col_to_text_embedder_cfg,
157+
col_to_image_embedder_cfg=dataset.col_to_image_embedder_cfg,
147158
)
148159
tf = convert_to_tensor_frame(dataset.df)
149160
assert tf.col_names_dict == convert_to_tensor_frame.col_names_dict

test/nn/encoder/test_stype_encoder.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch_frame
99
from torch_frame import NAStrategy, stype
1010
from torch_frame.config import ModelConfig
11+
from torch_frame.config.image_embedder import ImageEmbedderConfig
1112
from torch_frame.config.text_embedder import TextEmbedderConfig
1213
from torch_frame.config.text_tokenizer import TextTokenizerConfig
1314
from torch_frame.data.dataset import Dataset
@@ -29,6 +30,7 @@
2930
TimestampEncoder,
3031
)
3132
from torch_frame.nn.encoding import CyclicEncoding
33+
from torch_frame.testing.image_embedder import RandomImageEmbedder
3234
from torch_frame.testing.text_embedder import HashTextEmbedder
3335
from torch_frame.testing.text_tokenizer import (
3436
RandomTextModel,
@@ -435,6 +437,47 @@ def test_text_tokenized_encoder():
435437
tensor_frame.feat_dict[stype.text_tokenized][key].offset)
436438

437439

440+
def test_image_embedded_encoder(tmpdir):
441+
tmp_path = str(tmpdir.mkdir("image"))
442+
num_rows = 20
443+
out_channels = 5
444+
dataset = FakeDataset(
445+
num_rows=num_rows,
446+
stypes=[
447+
torch_frame.image_embedded,
448+
],
449+
tmp_path=tmp_path,
450+
col_to_image_embedder_cfg=ImageEmbedderConfig(
451+
image_embedder=RandomImageEmbedder(out_channels=out_channels),
452+
batch_size=None,
453+
),
454+
)
455+
dataset.materialize()
456+
tensor_frame = dataset.tensor_frame
457+
stats_list = [
458+
dataset.col_stats[col_name]
459+
for col_name in tensor_frame.col_names_dict[stype.embedding]
460+
]
461+
encoder = LinearEmbeddingEncoder(
462+
out_channels=out_channels,
463+
stats_list=stats_list,
464+
stype=stype.embedding,
465+
)
466+
feat_emb = tensor_frame.feat_dict[stype.embedding].clone()
467+
col_names = tensor_frame.col_names_dict[stype.embedding]
468+
x = encoder(feat_emb, col_names)
469+
# Make sure no in-place modification
470+
assert torch.allclose(feat_emb.values,
471+
tensor_frame.feat_dict[stype.embedding].values)
472+
assert torch.allclose(feat_emb.offset,
473+
tensor_frame.feat_dict[stype.embedding].offset)
474+
assert x.shape == (
475+
num_rows,
476+
len(tensor_frame.col_names_dict[stype.embedding]),
477+
out_channels,
478+
)
479+
480+
438481
def test_linear_model_encoder():
439482
num_rows = 20
440483
out_channels = 8

test/test_stype.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
def test_stype():
5-
assert len(torch_frame.stype) == 8
5+
assert len(torch_frame.stype) == 9
66
assert torch_frame.numerical == torch_frame.stype('numerical')
77
assert not torch_frame.numerical.is_text_stype
88
assert torch_frame.categorical == torch_frame.stype('categorical')
@@ -17,5 +17,7 @@ def test_stype():
1717
assert torch_frame.text_embedded.is_text_stype
1818
assert torch_frame.text_tokenized == torch_frame.stype('text_tokenized')
1919
assert torch_frame.text_tokenized.is_text_stype
20+
assert torch_frame.image_embedded == torch_frame.stype('image_embedded')
21+
assert torch_frame.image_embedded.is_image_stype
2022
assert torch_frame.embedding == torch_frame.stype('embedding')
2123
assert torch_frame.embedding.use_multi_embedding_tensor

torch_frame/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
multicategorical,
99
sequence_numerical,
1010
timestamp,
11+
image_embedded,
1112
embedding,
1213
)
1314
from .data import TensorFrame
@@ -30,6 +31,7 @@
3031
'multicategorical',
3132
'sequence_numerical',
3233
'timestamp',
34+
'image_embedded',
3335
'embedding',
3436
'TaskType',
3537
'Metric',

torch_frame/_stype.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class stype(Enum):
2525
sequence_numerical: Sequence of numerical values.
2626
embedding: Embedding columns.
2727
timestamp: Timestamp columns.
28+
image_embedded: Pre-computed embeddings of image columns.
2829
"""
2930
numerical = 'numerical'
3031
categorical = 'categorical'
@@ -33,12 +34,17 @@ class stype(Enum):
3334
multicategorical = 'multicategorical'
3435
sequence_numerical = 'sequence_numerical'
3536
timestamp = 'timestamp'
37+
image_embedded = 'image_embedded'
3638
embedding = 'embedding'
3739

3840
@property
3941
def is_text_stype(self) -> bool:
4042
return self in [stype.text_embedded, stype.text_tokenized]
4143

44+
@property
45+
def is_image_stype(self) -> bool:
46+
return self in [stype.image_embedded]
47+
4248
@property
4349
def use_multi_nested_tensor(self) -> bool:
4450
r"""This property indicates if the data of an stype is stored in
@@ -51,7 +57,9 @@ def use_multi_embedding_tensor(self) -> bool:
5157
r"""This property indicates if the data of an stype is stored in
5258
:class:`torch_frame.data.MultiNestedTensor`.
5359
"""
54-
return self in [stype.text_embedded, stype.embedding]
60+
return self in [
61+
stype.text_embedded, stype.image_embedded, stype.embedding
62+
]
5563

5664
@property
5765
def use_dict_multi_nested_tensor(self) -> bool:
@@ -79,6 +87,8 @@ def parent(self):
7987
"""
8088
if self == stype.text_embedded:
8189
return stype.embedding
90+
elif self == stype.image_embedded:
91+
return stype.embedding
8292
else:
8393
return self
8494

@@ -93,4 +103,5 @@ def __str__(self) -> str:
93103
multicategorical = stype('multicategorical')
94104
sequence_numerical = stype('sequence_numerical')
95105
timestamp = stype('timestamp')
106+
image_embedded = stype('image_embedded')
96107
embedding = stype('embedding')

torch_frame/config/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
from .text_embedder import TextEmbedderConfig
33
from .text_tokenizer import TextTokenizerConfig
44
from .model import ModelConfig
5+
from .image_embedder import ImageEmbedderConfig, ImageEmbedder
56

67
__all__ = classes = [
78
'TextEmbedderConfig',
89
'TextTokenizerConfig',
910
'ModelConfig',
11+
'ImageEmbedderConfig',
12+
'ImageEmbedder',
1013
]

0 commit comments

Comments
 (0)