@@ -284,39 +284,45 @@ def __init__(
284284
285285 def init_modules (self ):
286286 super ().init_modules ()
287- self . embs = ModuleList ([])
287+ num_categories_list = [ 0 ]
288288 for stats in self .stats_list :
289289 num_categories = len (stats [StatType .COUNT ][0 ])
290- # 0-th category is for NaN.
291- self .embs .append (
292- Embedding (
293- num_categories + 1 ,
294- self .out_channels ,
295- padding_idx = 0 ,
296- ))
290+ num_categories_list .append (num_categories )
291+ # Single embedding module that stores embeddings of all categories
292+ # across all categorical columns.
293+ # 0-th category is for NaN.
294+ self .emb = Embedding (
295+ sum (num_categories_list ) + 1 ,
296+ self .out_channels ,
297+ padding_idx = 0 ,
298+ )
299+ # [num_cols, ]
300+ self .register_buffer (
301+ "offset" ,
302+ torch .cumsum (
303+ torch .tensor (num_categories_list [:- 1 ], dtype = torch .long ),
304+ dim = 0 ))
297305 self .reset_parameters ()
298306
299307 def reset_parameters (self ):
300308 super ().reset_parameters ()
301- for emb in self .embs :
302- emb .reset_parameters ()
309+ self .emb .reset_parameters ()
303310
304311 def encode_forward (
305312 self ,
306313 feat : Tensor ,
307314 col_names : list [str ] | None = None ,
308315 ) -> Tensor :
309- # TODO: Make this more efficient.
310- # Increment the index by one so that NaN index (-1) becomes 0
311- # (padding_idx)
312316 # feat: [batch_size, num_cols]
313- feat = feat + 1
314- xs = []
315- for i , emb in enumerate (self .embs ):
316- xs .append (emb (feat [:, i ]))
317- # [batch_size, num_cols, hidden_channels]
318- x = torch .stack (xs , dim = 1 )
319- return x
317+ # Get NaN mask
318+ na_mask = feat < 0
319+ # Increment the index by one not to conflict with the padding idx
320+ # Also add offset for each column to avoid embedding conflict
321+ feat = feat + self .offset + 1
322+ # Use 0th index for NaN
323+ feat [na_mask ] = 0
324+ # [batch_size, num_cols, channels]
325+ return self .emb (feat )
320326
321327
322328class MultiCategoricalEmbeddingEncoder (StypeEncoder ):
0 commit comments