1313 StageNetTensorProcessor ,
1414 TensorProcessor ,
1515 TimeseriesProcessor ,
16+ DeepNestedSequenceProcessor ,
17+ DeepNestedFloatsProcessor ,
1618)
1719from .base_model import BaseModel
1820
19-
2021class EmbeddingModel (BaseModel ):
2122 """
2223 EmbeddingModel is responsible for creating embedding layers for different types of input data.
2324
2425 This model automatically creates appropriate embedding transformations based on the processor type:
2526
26- - SequenceProcessor: Creates nn.Embedding for categorical sequences (e.g., diagnosis codes)
27- Input: (batch, seq_len) with integer indices
28- Output: (batch, seq_len, embedding_dim)
27+ - SequenceProcessor: nn.Embedding
28+ Input: (batch, seq_len)
29+ Output: (batch, seq_len, embedding_dim)
2930
30- - TimeseriesProcessor: Creates nn.Linear for time series features
31- Input: (batch, seq_len, num_features )
32- Output: (batch, seq_len , embedding_dim)
31+ - NestedSequenceProcessor: nn.Embedding
32+ Input: (batch, num_visits, max_codes_per_visit )
33+ Output: (batch, num_visits, max_codes_per_visit , embedding_dim)
3334
34- - TensorProcessor: Creates nn.Linear for fixed-size numerical features
35- Input: (batch, feature_size )
36- Output: (batch, embedding_dim)
35+ - DeepNestedSequenceProcessor: nn.Embedding
36+ Input: (batch, num_groups, num_visits, max_codes_per_visit )
37+ Output: (batch, num_groups, num_visits, max_codes_per_visit , embedding_dim)
3738
38- - MultiHotProcessor: Creates nn.Linear for multi-hot encoded categorical features
39- Input: (batch, num_categories) binary tensor
40- Output : (batch, embedding_dim )
41- Note: Converts sparse categorical representations to dense embeddings
39+ - TimeseriesProcessor / NestedFloatsProcessor / DeepNestedFloatsProcessor / StageNetTensorProcessor:
40+ nn.Linear over the last dimension
41+ Input : (..., size )
42+ Output: (..., embedding_dim)
4243
43- - Other processors with size(): Creates nn.Linear if processor reports a positive size
44- Input: (batch, size)
45- Output: (batch, embedding_dim)
44+ - TensorProcessor: nn.Linear (size inferred from first sample)
4645
47- Attributes:
48- dataset (SampleDataset): The dataset containing input processors.
49- embedding_layers (nn.ModuleDict): A dictionary of embedding layers for each input field.
50- embedding_dim (int): The target embedding dimension for all features.
46+ - MultiHotProcessor: nn.Linear over multi-hot vector
5147 """
5248
5349 def __init__ (self , dataset : SampleDataset , embedding_dim : int = 128 ):
54- """
55- Initializes the EmbeddingModel with the given dataset and embedding dimension.
56-
57- Args:
58- dataset (SampleDataset): The dataset containing input processors.
59- embedding_dim (int): The dimension of the embedding space. Default is 128.
60- """
6150 super ().__init__ (dataset )
6251 self .embedding_dim = embedding_dim
6352 self .embedding_layers = nn .ModuleDict ()
53+
6454 for field_name , processor in self .dataset .input_processors .items ():
55+ # Deep categorical: use special module that collapses last dim to embedding_dim
56+
57+ # Regular categorical sequences -> nn.Embedding (adds embedding dim)
6558 if isinstance (
6659 processor ,
6760 (
6861 SequenceProcessor ,
6962 StageNetProcessor ,
7063 NestedSequenceProcessor ,
64+ DeepNestedSequenceProcessor
7165 ),
7266 ):
73- # Categorical codes -> use nn.Embedding
7467 vocab_size = len (processor .code_vocab )
75- # For NestedSequenceProcessor, don't use padding_idx
76- # because empty visits need non-zero embeddings
77- if isinstance (processor , NestedSequenceProcessor ):
68+
69+ # For NestedSequenceProcessor and DeepNestedSequenceProcessor, don't use padding_idx
70+ # because empty visits/groups need non-zero embeddings.
71+ if isinstance (processor , (NestedSequenceProcessor , DeepNestedSequenceProcessor )):
7872 self .embedding_layers [field_name ] = nn .Embedding (
7973 num_embeddings = vocab_size ,
8074 embedding_dim = embedding_dim ,
@@ -86,22 +80,25 @@ def __init__(self, dataset: SampleDataset, embedding_dim: int = 128):
8680 embedding_dim = embedding_dim ,
8781 padding_idx = 0 ,
8882 )
83+
84+ # Numeric features (including deep nested floats) -> nn.Linear over last dim
8985 elif isinstance (
9086 processor ,
9187 (
9288 TimeseriesProcessor ,
9389 StageNetTensorProcessor ,
9490 NestedFloatsProcessor ,
91+ DeepNestedFloatsProcessor ,
9592 ),
9693 ):
97- # Numeric features -> use nn.Linear
98- # Both processors have .size attribute
94+ # Assuming processor.size() returns the last-dim size
95+ in_features = processor .size ()
9996 self .embedding_layers [field_name ] = nn .Linear (
100- in_features = processor . size , out_features = embedding_dim
97+ in_features = in_features , out_features = embedding_dim
10198 )
99+
102100 elif isinstance (processor , TensorProcessor ):
103- # For tensor processor, we need to determine the input size
104- # from the first sample in the dataset
101+ # Infer size from first sample
105102 sample_tensor = None
106103 for sample in dataset .samples :
107104 if field_name in sample :
@@ -114,43 +111,51 @@ def __init__(self, dataset: SampleDataset, embedding_dim: int = 128):
114111 self .embedding_layers [field_name ] = nn .Linear (
115112 in_features = input_size , out_features = embedding_dim
116113 )
114+
117115 elif isinstance (processor , MultiHotProcessor ):
118- # MultiHotProcessor produces fixed-size binary vectors
119- # Use processor.size() to get the vocabulary size (num_categories)
120116 num_categories = processor .size ()
121117 self .embedding_layers [field_name ] = nn .Linear (
122118 in_features = num_categories , out_features = embedding_dim
123119 )
120+
124121 else :
125122 print (
126123 "Warning: No embedding created for field due to lack of compatible processor:" ,
127124 field_name ,
128125 )
129126
130- def forward (self , inputs : Dict [str , torch .Tensor ]) -> Dict [str , torch .Tensor ]:
131- """
132- Forward pass to compute embeddings for the input data.
133-
134- Args:
135- inputs (Dict[str, torch.Tensor]): A dictionary of input tensors.
136-
137- Returns:
138- Dict[str, torch.Tensor]: A dictionary of embedded tensors.
139- """
140- embedded = {}
127+ def forward (self ,
128+ inputs : Dict [str , torch .Tensor ],
129+ output_mask : bool = False
130+ ) -> Dict [str , torch .Tensor ] | tuple [Dict [str , torch .Tensor ], Dict [str , torch .Tensor ]]:
131+
132+ embedded : Dict [str , torch .Tensor ] = {}
133+ masks : Dict [str , torch .Tensor ] = {} if output_mask else None
134+
141135 for field_name , tensor in inputs .items ():
142- if field_name in self .embedding_layers :
143- tensor = tensor .to (self .device )
144- embedded [field_name ] = self .embedding_layers [field_name ](tensor )
145- else :
146- embedded [field_name ] = tensor # passthrough for continuous features
147- return embedded
136+ processor = self .dataset .input_processors .get (field_name , None )
137+
138+ if field_name not in self .embedding_layers :
139+ # No embedding layer -> passthrough
140+ embedded [field_name ] = tensor
141+ continue
142+
143+ tensor = tensor .to (self .device )
144+ embedded [field_name ] = self .embedding_layers [field_name ](tensor )
145+
146+ if output_mask :
147+ # Generate a mask for this field
148+ if hasattr (processor , "code_vocab" ):
149+ pad_idx = processor .code_vocab .get ("<pad>" , 0 )
150+ else :
151+ pad_idx = 0
152+
153+ masks [field_name ] = (tensor != pad_idx )
154+
155+ if output_mask :
156+ return embedded , masks
157+ else :
158+ return embedded
148159
149160 def __repr__ (self ) -> str :
150- """
151- Returns a string representation of the EmbeddingModel.
152-
153- Returns:
154- str: A string representation of the model.
155- """
156161 return f"EmbeddingModel(embedding_layers={ self .embedding_layers } )"
0 commit comments