Skip to content

Commit f4e5130

Browse files
authored
Avoid for-loop in EmbeddingEncoder (#366)
1 parent b934e93 commit f4e5130

File tree

2 files changed

+27
-20
lines changed

2 files changed

+27
-20
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
88

99
### Added
1010

11+
- Avoided for-loop in `EmbeddingEncoder` ([#366](https://github.com/pyg-team/pytorch-frame/pull/366))
1112
- Added `image_embedded` and one tabular image dataset ([#344](https://github.com/pyg-team/pytorch-frame/pull/344))
1213
- Added benchmarking suite for encoders ([#360](https://github.com/pyg-team/pytorch-frame/pull/360))
1314
- Added dataframe text benchmark script ([#354](https://github.com/pyg-team/pytorch-frame/pull/354))

torch_frame/nn/encoder/stype_encoder.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -284,39 +284,45 @@ def __init__(
284284

285285
def init_modules(self):
286286
super().init_modules()
287-
self.embs = ModuleList([])
287+
num_categories_list = [0]
288288
for stats in self.stats_list:
289289
num_categories = len(stats[StatType.COUNT][0])
290-
# 0-th category is for NaN.
291-
self.embs.append(
292-
Embedding(
293-
num_categories + 1,
294-
self.out_channels,
295-
padding_idx=0,
296-
))
290+
num_categories_list.append(num_categories)
291+
# Single embedding module that stores embeddings of all categories
292+
# across all categorical columns.
293+
# 0-th category is for NaN.
294+
self.emb = Embedding(
295+
sum(num_categories_list) + 1,
296+
self.out_channels,
297+
padding_idx=0,
298+
)
299+
# [num_cols, ]
300+
self.register_buffer(
301+
"offset",
302+
torch.cumsum(
303+
torch.tensor(num_categories_list[:-1], dtype=torch.long),
304+
dim=0))
297305
self.reset_parameters()
298306

299307
def reset_parameters(self):
300308
super().reset_parameters()
301-
for emb in self.embs:
302-
emb.reset_parameters()
309+
self.emb.reset_parameters()
303310

304311
def encode_forward(
305312
self,
306313
feat: Tensor,
307314
col_names: list[str] | None = None,
308315
) -> Tensor:
309-
# TODO: Make this more efficient.
310-
# Increment the index by one so that NaN index (-1) becomes 0
311-
# (padding_idx)
312316
# feat: [batch_size, num_cols]
313-
feat = feat + 1
314-
xs = []
315-
for i, emb in enumerate(self.embs):
316-
xs.append(emb(feat[:, i]))
317-
# [batch_size, num_cols, hidden_channels]
318-
x = torch.stack(xs, dim=1)
319-
return x
317+
# Get NaN mask
318+
na_mask = feat < 0
319+
# Increment the index by one not to conflict with the padding idx
320+
# Also add offset for each column to avoid embedding conflict
321+
feat = feat + self.offset + 1
322+
# Use 0th index for NaN
323+
feat[na_mask] = 0
324+
# [batch_size, num_cols, channels]
325+
return self.emb(feat)
320326

321327

322328
class MultiCategoricalEmbeddingEncoder(StypeEncoder):

0 commit comments

Comments
 (0)