Skip to content

Commit 893678f

Browse files
Add light-weight MLP (#372)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 94bf7a3 commit 893678f

File tree

7 files changed

+166
-7
lines changed

7 files changed

+166
-7
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
88

99
### Added
1010

11+
- Added light-weight MLP ([#372](https://github.com/pyg-team/pytorch-frame/pull/372))
12+
1113
### Changed
1214

1315
### Deprecated
@@ -39,7 +41,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3941
### Fixed
4042

4143
- Fixed bug in empty `MultiNestedTensor` handling ([#369](https://github.com/pyg-team/pytorch-frame/pull/369))
42-
4344
- Fixed the split of `DataFrameTextBenchmark` ([#358](https://github.com/pyg-team/pytorch-frame/pull/358))
4445
- Fixed empty `MultiNestedTensor` col indexing ([#355](https://github.com/pyg-team/pytorch-frame/pull/355))
4546

benchmark/README.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ pip install lightgbm
1111

1212
Then run
1313
```bash
14-
# Specify the model from [TabNet, FTTransformer, ResNet, TabTransformer, Trompt
15-
# ExcelFormer, FTTransformerBucket, XGBoost, CatBoost, LightGBM]
14+
# Specify the model from [TabNet, FTTransformer, ResNet, MLP, TabTransformer,
15+
# Trompt, ExcelFormer, FTTransformerBucket, XGBoost, CatBoost, LightGBM]
1616
model_type=TabNet
1717

1818
# Specify the task type from [binary_classification, regression,
@@ -64,6 +64,7 @@ Experimental setting: 20 Optuna search trials. 50 epochs of training.
6464
| LightGBM | **0.931±0.000 (15s)** | 0.999±0.000 (1s) | 0.943±0.000 (23s) | 0.943±0.000 (14s) | **0.887±0.000 (5s)** | **0.972±0.000 (11s)** | **0.862±0.000 (6s)** | 0.774±0.000 (3s) | 0.979±0.000 (41s) | 0.732±0.000 (13s) | 0.787±0.000 (3s) | 0.951±0.000 (13s) | 0.999±0.000 (10s) | **0.927±0.000 (24s)** |
6565
| Trompt | 0.919±0.000 (9627s) | **1.000±0.000 (5341s)** | **0.945±0.000 (14679s)** | 0.942±0.001 (2752s) | 0.881±0.001 (2640s) | 0.964±0.001 (5173s) | 0.855±0.002 (4249s) | 0.778±0.002 (8789s) | 0.933±0.001 (9353s) | 0.686±0.008 (3105s) | 0.793±0.002 (8255s) | **0.952±0.001 (4876s)** | **1.000±0.000 (3558s)** | 0.916±0.001 (30002s) |
6666
| ResNet | 0.917±0.000 (615s) | **1.000±0.000 (71s)** | 0.937±0.001 (787s) | 0.938±0.002 (230s) | 0.865±0.001 (183s) | 0.960±0.001 (349s) | 0.828±0.001 (248s) | 0.768±0.002 (205s) | 0.925±0.002 (958s) | 0.665±0.006 (140s) | **0.794±0.002 (76s)** | 0.946±0.002 (145s) | **1.000±0.000 (93s)** | 0.911±0.001 (880s) |
67+
| MLP | 0.913±0.001 (112s) | **1.000±0.000 (45s)** | 0.934±0.001 (274s) | 0.938±0.001 (66s) | 0.863±0.002 (61s) | 0.953±0.000 (92s) | 0.830±0.001 (68s) | 0.769±0.002 (56s) | 0.903±0.002 (159s) | 0.666±0.015 (58s) | 0.789±0.001 (48s) | 0.940±0.002 (107s) | **1.000±0.000 (48s)** | 0.910±0.001 (149s)
6768
| FTTransformerBucket | 0.915±0.001 (690s) | **0.999±0.001 (354s)** | 0.936±0.002 (1705s) | 0.939±0.002 (484s) | 0.876±0.002 (321s) | 0.960±0.001 (746s) | 0.857±0.000 (549s) | 0.771±0.003 (654s) | 0.909±0.002 (1177s) | 0.636±0.012 (244s) | 0.788±0.002 (710s) | 0.950±0.001 (510s) | **0.999±0.000 (634s)** | 0.913±0.001 (1164s) |
6869
| ExcelFormer | 0.918±0.001 (1587s) | **1.000±0.000 (634s)** | 0.939±0.001 (1827s) | 0.939±0.002 (378s) | 0.878±0.003 (251s) | 0.969±0.000 (678s) | 0.833±0.011 (435s) | **0.780±0.002 (938s)** | 0.921±0.005 (1131s) | 0.649±0.008 (519s) | 0.794±0.003 (683s) | 0.950±0.001 (405s) | **0.999±0.000 (1169s)** | 0.919±0.001 (1798s) |
6970
| FTTransformer | 0.918±0.001 (871s) | **1.000±0.000 (571s)** | 0.940±0.001 (1371s) | 0.936±0.001 (458s) | 0.874±0.002 (200s) | 0.959±0.001 (622s) | 0.828±0.001 (339s) | 0.773±0.002 (521s) | 0.909±0.002 (1488s) | 0.635±0.011 (392s) | 0.790±0.001 (556s) | 0.949±0.002 (374s) | **1.000±0.000 (713s)** | 0.912±0.000 (1855s) |
@@ -81,6 +82,7 @@ Experimental setting: 20 Optuna search trials for XGBoost and CatBoost. 5 Optuna
8182
| LightGBM | **0.639±0.000 (49s)** | 0.955±0.000 (126s) | 0.652±0.000 (7s) | **0.986±0.000 (99s)** | **0.723±0.000 (16s)** | **0.997±0.000 (172s)** | 0.881±0.000 (83s) | 0.914±0.000 (86s) | **0.809±0.000 (76s)** |
8283
| Trompt | OOM | 0.950±0.000 (28212s) | **0.652±0.000 (5962s)** | 0.982±0.000 (19936s) | 0.716±0.000 (7110s) | 0.966±0.000 (106916s) | **0.882±0.000 (13644s)** | 0.883±0.000 (17863s) | 0.705±0.006 (11563s) |
8384
| ResNet | 0.637±0.000 (810s) | 0.948±0.000 (1051s) | 0.649±0.000 (185s) | 0.983±0.000 (239s) | 0.705±0.001 (226s) | 0.989±0.000 (1967s) | 0.871±0.001 (173s) | 0.890±0.001 (315s) | 0.719±0.001 (245s) |
85+
| MLP | 0.634±0.002 (392s) | 0.946±0.001 (2306s) | 0.650±0.000 (263s) | 0.978±0.000 (468s) | 0.699±0.001 (357s) | 0.991±0.000 (2491s) | 0.869±0.001 (449s) | 0.883±0.001 (695s) | 0.727±0.002 (368s) |
8486
| FTTransformerBucket | 0.637±0.000 (8032s) | 0.947±0.000 (6571s) | 0.649±0.001 (714s) | **0.986±0.000 (2138s)** | 0.651±0.060 (1473s) | 0.832±0.153 (8248s) | 0.866±0.001 (1531s) | 0.877±0.000 (2960s) | 0.688±0.001 (1983s) |
8587
| ExcelFormer | OOM | 0.948±0.000 (6278s) | **0.653±0.000 (515s)** | 0.982±0.000 (2691s) | 0.716±0.001 (1263s) | 0.985±0.000 (34917s) | 0.877±0.001 (3388s) | 0.886±0.001 (3237s) | 0.708±0.001 (2138s) |
8688
| FTTransformer | 0.632±0.001 (7669s) | 0.946±0.001 (4613s) | **0.652±0.000 (587s)** | 0.981±0.000 (3048s) | 0.704±0.001 (980s) | 0.984±0.001 (15615s) | 0.871±0.002 (1424s) | 0.878±0.002 (2933s) | 0.713±0.001 (1656s) |
@@ -105,6 +107,7 @@ Experimental setting: 20 Optuna search trials. 50 epochs of training.
105107
| LightGBM | 0.253±0.000 (38s) | 0.054±0.000 (24s) | **0.112±0.000 (10s)** | 0.302±0.000 (30s) | 0.325±0.000 (30s) | **0.384±0.000 (23s)** | 0.295±0.000 (15s) | **0.272±0.000 (26s)** | 0.877±0.000 (16s) | 0.011±0.000 (12s) | 0.702±0.000 (13s) | **0.863±0.000 (5s)** | **0.395±0.000 (40s)** |
106108
| Trompt | 0.261±0.003 (8390s) | **0.015±0.005 (3792s)** | 0.118±0.001 (3836s) | **0.262±0.001 (10037s)** | **0.323±0.001 (9255s)** | 0.418±0.003 (9071s) | 0.329±0.009 (2977s) | 0.312±0.002 (21967s) | OOM | **0.008±0.001 (1889s)** | 0.779±0.006 (775s) | 0.874±0.004 (3723s) | 0.424±0.005 (3185s) |
107109
| ResNet | 0.288±0.006 (220s) | 0.018±0.003 (187s) | 0.124±0.001 (135s) | 0.268±0.001 (330s) | 0.335±0.001 (471s) | 0.434±0.004 (345s) | 0.325±0.012 (178s) | 0.324±0.004 (365s) | 0.895±0.005 (142s) | 0.036±0.002 (172s) | 0.794±0.006 (120s) | 0.875±0.004 (122s) | 0.468±0.004 (303s) |
110+
| MLP | 0.300±0.002 (108s) | 0.141±0.015 (76s) | 0.125±0.001 (44s) | 0.272±0.002 (69s) | 0.348±0.001 (103s) | 0.435±0.002 (33s) | 0.331±0.008 (43s) | 0.380±0.004 (125s) | 0.893±0.002 (69s) | 0.017±0.001 (48s) | 0.784±0.007 (29s) | 0.881±0.005 (30s) | 0.467±0.003 (92s)
108111
| FTTransformerBucket | 0.325±0.008 (619s) | 0.096±0.005 (290s) | 0.360±0.354 (332s) | 0.284±0.005 (768s) | 0.342±0.004 (757s) | 0.441±0.003 (835s) | 0.345±0.007 (191s) | 0.339±0.003 (3321s) | OOM | 0.105±0.011 (199s) | 0.807±0.010 (156s) | 0.885±0.008 (820s) | 0.468±0.006 (706s) |
109112
| ExcelFormer | 0.302±0.003 (703s) | 0.099±0.003 (490s) | 0.145±0.003 (587s) | 0.382±0.011 (504s) | 0.344±0.002 (1096s) | 0.411±0.005 (469s) | 0.359±0.016 (207s) | 0.336±0.008 (5522s) | OOM | 0.192±0.014 (317s) | 0.794±0.005 (189s) | 0.890±0.003 (1186s) | 0.445±0.005 (550s) |
110113
| FTTransformer | 0.335±0.010 (338s) | 0.161±0.022 (370s) | 0.140±0.002 (244s) | 0.277±0.004 (516s) | 0.335±0.003 (973s) | 0.445±0.003 (599s) | 0.361±0.018 (286s) | 0.345±0.005 (2443s) | OOM | 0.106±0.012 (150s) | 0.826±0.005 (121s) | 0.896±0.007 (832s) | 0.461±0.003 (647s) |
@@ -122,6 +125,7 @@ Experimental setting: 20 Optuna search trials for XGBoost and CatBoost. 5 Optuna
122125
| LightGBM | **0.660±0.000 (199s)** | 0.015±0.000 (86s) | **0.085±0.000 (39s)** | 0.141±0.000 (35s) | **0.524±0.000 (148s)** | **0.895±0.000 (7s)** |
123126
| Trompt | OOM | **0.014±0.000 (19976s)** | 0.092±0.001 (4060s) | **0.140±0.000 (3487s)** | 0.537±0.000 (26520s) | 0.901±0.000 (2333s) |
124127
| ResNet | 0.676±0.000 (894s) | **0.016±0.000 (548s)** | 0.101±0.001 (176s) | 0.147±0.000 (503s) | 0.555±0.003 (1121s) | 0.903±0.000 (116s) |
128+
| MLP | 0.680±0.001 (907s) | **0.016±0.000 (1015s)** | 0.105±0.000 (254s) | **0.140±0.000 (313s)** | 0.558±0.001 (1756s) | 0.905±0.001 (240s) |
125129
| FTTransformerBucket | 0.738±0.029 (17223s) | 0.023±0.000 (2573s) | 0.113±0.002 (645s) | 0.147±0.000 (970s) | 0.545±0.000 (3009s) | 0.908±0.000 (360s) |
126130
| ExcelFormer | **0.667±0.000 (35946s)** | 0.064±0.019 (2355s) | 0.119±0.003 (594s) | 0.220±0.009 (1285s) | 0.563±0.002 (2772s) | 0.902±0.000 (288s) |
127131
| FTTransformer | 0.673±0.000 (18524s) | 0.056±0.003 (3348s) | 0.119±0.003 (396s) | **0.141±0.000 (1049s)** | 0.561±0.001 (2403s) | 0.907±0.002 (302s) |
@@ -150,6 +154,7 @@ Experimental setting: 20 Optuna search trials for XGBoost and CatBoost. 5 Optuna
150154
| LightGBM | Too slow\* | Too slow\* | Too slow\* |
151155
| Trompt | OOM | 0.373±0.004 (9114s) | OOM |
152156
| ResNet | **0.951±0.000 (419s)** | **0.378±0.001 (171s)**| 0.723±0.001 (257s) |
157+
| MLP | 0.947±0.001 (1133s) | 0.371±0.002 (462s) | 0.723±0.002 (495s) |
153158
| FTTransformerBucket | 0.879±0.006 (9104s) | 0.365±0.002 (1067s) | 0.722±0.001 (2366s) |
154159
| ExcelFormer | OOM | 0.375±0.004 (2168s) | **0.732±0.000 (4138s)**|
155160
| FTTransformer | 0.923±0.003 (14517s) | 0.357±0.001 (754s) | 0.724±0.004 (2621s) |

benchmark/data_frame_benchmark.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch_frame.gbdt import CatBoost, LightGBM, XGBoost
2020
from torch_frame.nn.encoder import EmbeddingEncoder, LinearBucketEncoder
2121
from torch_frame.nn.models import (
22+
MLP,
2223
ExcelFormer,
2324
FTTransformer,
2425
ResNet,
@@ -50,7 +51,7 @@
5051
help='Number of repeated training and eval on the best config.')
5152
parser.add_argument(
5253
'--model_type', type=str, default='TabNet', choices=[
53-
'TabNet', 'FTTransformer', 'ResNet', 'TabTransformer', 'Trompt',
54+
'TabNet', 'FTTransformer', 'ResNet', 'MLP', 'TabTransformer', 'Trompt',
5455
'ExcelFormer', 'FTTransformerBucket', 'XGBoost', 'CatBoost', 'LightGBM'
5556
])
5657
parser.add_argument('--seed', type=int, default=0)
@@ -153,6 +154,18 @@
153154
}
154155
model_cls = ResNet
155156
col_stats = dataset.col_stats
157+
elif args.model_type == 'MLP':
158+
model_search_space = {
159+
'channels': [64, 128, 256],
160+
'num_layers': [1, 2, 4],
161+
}
162+
train_search_space = {
163+
'batch_size': [256, 512],
164+
'base_lr': [0.0001, 0.001],
165+
'gamma_rate': [0.9, 0.95, 1.],
166+
}
167+
model_cls = MLP
168+
col_stats = dataset.col_stats
156169
elif args.model_type == 'TabTransformer':
157170
model_search_space = {
158171
'channels': [16, 32, 64, 128],

test/nn/models/test_mlp.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import pytest
2+
3+
from torch_frame.data.dataset import Dataset
4+
from torch_frame.datasets import FakeDataset
5+
from torch_frame.nn import MLP
6+
7+
8+
@pytest.mark.parametrize('batch_size', [0, 5])
9+
def test_mlp(batch_size):
10+
channels = 8
11+
out_channels = 1
12+
num_layers = 3
13+
dataset: Dataset = FakeDataset(num_rows=10, with_nan=False)
14+
dataset.materialize()
15+
tensor_frame = dataset.tensor_frame[:batch_size]
16+
# Feature-based embeddings
17+
model = MLP(
18+
channels=channels,
19+
out_channels=out_channels,
20+
num_layers=num_layers,
21+
col_stats=dataset.col_stats,
22+
col_names_dict=tensor_frame.col_names_dict,
23+
)
24+
model.reset_parameters()
25+
out = model(tensor_frame)
26+
assert out.shape == (batch_size, out_channels)

torch_frame/nn/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .tabnet import TabNet
66
from .resnet import ResNet
77
from .tab_transformer import TabTransformer
8+
from .mlp import MLP
89

910
__all__ = classes = [
1011
'Trompt',
@@ -13,4 +14,5 @@
1314
'TabNet',
1415
'ResNet',
1516
'TabTransformer',
17+
'MLP',
1618
]

torch_frame/nn/models/mlp.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
import torch
6+
from torch import Tensor
7+
from torch.nn import (
8+
BatchNorm1d,
9+
Dropout,
10+
LayerNorm,
11+
Linear,
12+
Module,
13+
ReLU,
14+
Sequential,
15+
)
16+
17+
import torch_frame
18+
from torch_frame import TensorFrame, stype
19+
from torch_frame.data.stats import StatType
20+
from torch_frame.nn.encoder.stype_encoder import (
21+
EmbeddingEncoder,
22+
LinearEncoder,
23+
StypeEncoder,
24+
)
25+
from torch_frame.nn.encoder.stypewise_encoder import StypeWiseFeatureEncoder
26+
27+
28+
class MLP(Module):
29+
r"""The light-weight MLP model that mean-pools column embeddings and
30+
applies MLP over it.
31+
32+
Args:
33+
channels (int): The number of channels in the backbone layers.
34+
out_channels (int): The number of output channels in the decoder.
35+
num_layers (int): The number of layers in the backbone.
36+
col_stats(dict[str,Dict[:class:`torch_frame.data.stats.StatType`,Any]]):
37+
A dictionary that maps column name into stats.
38+
Available as :obj:`dataset.col_stats`.
39+
col_names_dict (dict[:class:`torch_frame.stype`, List[str]]): A
40+
dictionary that maps stype to a list of column names. The column
41+
names are sorted based on the ordering that appear in
42+
:obj:`tensor_frame.feat_dict`. Available as
43+
:obj:`tensor_frame.col_names_dict`.
44+
stype_encoder_dict
45+
(dict[:class:`torch_frame.stype`,
46+
:class:`torch_frame.nn.encoder.StypeEncoder`], optional):
47+
A dictionary mapping stypes into their stype encoders.
48+
(default: :obj:`None`, will call :obj:`EmbeddingEncoder()`
49+
for categorical feature and :obj:`LinearEncoder()` for
50+
numerical feature)
51+
normalization (str, optional): The type of normalization to use.
52+
:obj:`batchnorm`, :obj:`layernorm`, or :obj:`None`.
53+
(default: :obj:`layernorm`)
54+
dropout_prob (float): The dropout probability (default: `0.2`).
55+
"""
56+
def __init__(
57+
self,
58+
channels: int,
59+
out_channels: int,
60+
num_layers: int,
61+
col_stats: dict[str, dict[StatType, Any]],
62+
col_names_dict: dict[torch_frame.stype, list[str]],
63+
stype_encoder_dict: dict[torch_frame.stype, StypeEncoder]
64+
| None = None,
65+
normalization: str | None = "layernorm",
66+
dropout_prob: float = 0.2,
67+
) -> None:
68+
super().__init__()
69+
70+
if stype_encoder_dict is None:
71+
stype_encoder_dict = {
72+
stype.categorical: EmbeddingEncoder(),
73+
stype.numerical: LinearEncoder(),
74+
}
75+
76+
self.encoder = StypeWiseFeatureEncoder(
77+
out_channels=channels,
78+
col_stats=col_stats,
79+
col_names_dict=col_names_dict,
80+
stype_encoder_dict=stype_encoder_dict,
81+
)
82+
83+
self.mlp = Sequential()
84+
norm_cls = LayerNorm if normalization == "layernorm" else BatchNorm1d
85+
for _ in range(num_layers - 1):
86+
self.mlp.append(Linear(channels, channels))
87+
self.mlp.append(norm_cls(channels))
88+
self.mlp.append(ReLU())
89+
self.mlp.append(Dropout(p=dropout_prob))
90+
self.mlp.append(Linear(channels, out_channels))
91+
92+
self.reset_parameters()
93+
94+
def reset_parameters(self) -> None:
95+
self.encoder.reset_parameters()
96+
for param in self.mlp:
97+
if (isinstance(param, Linear) or isinstance(param, BatchNorm1d)
98+
or isinstance(param, LayerNorm)):
99+
param.reset_parameters()
100+
101+
def forward(self, tf: TensorFrame) -> Tensor:
102+
r"""Transforming :class:`TensorFrame` object into output prediction.
103+
104+
Args:
105+
tf (TensorFrame): Input :class:`TensorFrame` object.
106+
107+
Returns:
108+
torch.Tensor: Output of shape [batch_size, out_channels].
109+
"""
110+
x, _ = self.encoder(tf)
111+
112+
x = torch.mean(x, dim=1)
113+
114+
out = self.mlp(x)
115+
return out

torch_frame/nn/models/resnet.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,6 @@ def __init__(
144144
dropout_prob: float = 0.2,
145145
) -> None:
146146
super().__init__()
147-
if num_layers <= 0:
148-
raise ValueError(
149-
f"num_layers must be a positive integer (got {num_layers})")
150147

151148
if stype_encoder_dict is None:
152149
stype_encoder_dict = {

0 commit comments

Comments
 (0)