Skip to content

Commit c8fc421

Browse files
authored
[Bounty] Adacare PyHealth 2.0 (#613)
* Add DeepNestedSequenceProcessor and DeepNestedFloatsProcessor for handling deeply nested sequences * Add unit tests for DeepNestedSequenceProcessor and DeepNestedFloatsProcessor * Refactor EmbeddingModel to support DeepNestedSequenceProcessor and DeepNestedFloatsProcessor, and support mask outputs * Refactor AdaCare model to integrate new dataset and processor classes, update input handling, and forward propagation logic * Remove output mask handling for passthrough tensors in EmbeddingModel forward method * Add unit tests for AdaCare model including initialization, forward and backward passes, loss checks, and output shapes * Update mortality prediction example script for MIMIC-III using adacare * Remove property decorator from size function in StageNetTensorProcessor class
1 parent 10f5719 commit c8fc421

File tree

9 files changed

+4147
-328
lines changed

9 files changed

+4147
-328
lines changed

examples/mortality_mimic3_adacare.ipynb

Lines changed: 3040 additions & 0 deletions
Large diffs are not rendered by default.

examples/mortality_mimic3_adacare.py

Lines changed: 0 additions & 51 deletions
This file was deleted.

pyhealth/models/adacare.py

Lines changed: 94 additions & 214 deletions
Large diffs are not rendered by default.

pyhealth/models/embedding.py

Lines changed: 67 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -13,68 +13,62 @@
1313
StageNetTensorProcessor,
1414
TensorProcessor,
1515
TimeseriesProcessor,
16+
DeepNestedSequenceProcessor,
17+
DeepNestedFloatsProcessor,
1618
)
1719
from .base_model import BaseModel
1820

19-
2021
class EmbeddingModel(BaseModel):
2122
"""
2223
EmbeddingModel is responsible for creating embedding layers for different types of input data.
2324
2425
This model automatically creates appropriate embedding transformations based on the processor type:
2526
26-
- SequenceProcessor: Creates nn.Embedding for categorical sequences (e.g., diagnosis codes)
27-
Input: (batch, seq_len) with integer indices
28-
Output: (batch, seq_len, embedding_dim)
27+
- SequenceProcessor: nn.Embedding
28+
Input: (batch, seq_len)
29+
Output: (batch, seq_len, embedding_dim)
2930
30-
- TimeseriesProcessor: Creates nn.Linear for time series features
31-
Input: (batch, seq_len, num_features)
32-
Output: (batch, seq_len, embedding_dim)
31+
- NestedSequenceProcessor: nn.Embedding
32+
Input: (batch, num_visits, max_codes_per_visit)
33+
Output: (batch, num_visits, max_codes_per_visit, embedding_dim)
3334
34-
- TensorProcessor: Creates nn.Linear for fixed-size numerical features
35-
Input: (batch, feature_size)
36-
Output: (batch, embedding_dim)
35+
- DeepNestedSequenceProcessor: nn.Embedding
36+
Input: (batch, num_groups, num_visits, max_codes_per_visit)
37+
Output: (batch, num_groups, num_visits, max_codes_per_visit, embedding_dim)
3738
38-
- MultiHotProcessor: Creates nn.Linear for multi-hot encoded categorical features
39-
Input: (batch, num_categories) binary tensor
40-
Output: (batch, embedding_dim)
41-
Note: Converts sparse categorical representations to dense embeddings
39+
- TimeseriesProcessor / NestedFloatsProcessor / DeepNestedFloatsProcessor / StageNetTensorProcessor:
40+
nn.Linear over the last dimension
41+
Input: (..., size)
42+
Output: (..., embedding_dim)
4243
43-
- Other processors with size(): Creates nn.Linear if processor reports a positive size
44-
Input: (batch, size)
45-
Output: (batch, embedding_dim)
44+
- TensorProcessor: nn.Linear (size inferred from first sample)
4645
47-
Attributes:
48-
dataset (SampleDataset): The dataset containing input processors.
49-
embedding_layers (nn.ModuleDict): A dictionary of embedding layers for each input field.
50-
embedding_dim (int): The target embedding dimension for all features.
46+
- MultiHotProcessor: nn.Linear over multi-hot vector
5147
"""
5248

5349
def __init__(self, dataset: SampleDataset, embedding_dim: int = 128):
54-
"""
55-
Initializes the EmbeddingModel with the given dataset and embedding dimension.
56-
57-
Args:
58-
dataset (SampleDataset): The dataset containing input processors.
59-
embedding_dim (int): The dimension of the embedding space. Default is 128.
60-
"""
6150
super().__init__(dataset)
6251
self.embedding_dim = embedding_dim
6352
self.embedding_layers = nn.ModuleDict()
53+
6454
for field_name, processor in self.dataset.input_processors.items():
55+
# Deep categorical: use special module that collapses last dim to embedding_dim
56+
57+
# Regular categorical sequences -> nn.Embedding (adds embedding dim)
6558
if isinstance(
6659
processor,
6760
(
6861
SequenceProcessor,
6962
StageNetProcessor,
7063
NestedSequenceProcessor,
64+
DeepNestedSequenceProcessor
7165
),
7266
):
73-
# Categorical codes -> use nn.Embedding
7467
vocab_size = len(processor.code_vocab)
75-
# For NestedSequenceProcessor, don't use padding_idx
76-
# because empty visits need non-zero embeddings
77-
if isinstance(processor, NestedSequenceProcessor):
68+
69+
# For NestedSequenceProcessor and DeepNestedSequenceProcessor, don't use padding_idx
70+
# because empty visits/groups need non-zero embeddings.
71+
if isinstance(processor, (NestedSequenceProcessor, DeepNestedSequenceProcessor)):
7872
self.embedding_layers[field_name] = nn.Embedding(
7973
num_embeddings=vocab_size,
8074
embedding_dim=embedding_dim,
@@ -86,22 +80,25 @@ def __init__(self, dataset: SampleDataset, embedding_dim: int = 128):
8680
embedding_dim=embedding_dim,
8781
padding_idx=0,
8882
)
83+
84+
# Numeric features (including deep nested floats) -> nn.Linear over last dim
8985
elif isinstance(
9086
processor,
9187
(
9288
TimeseriesProcessor,
9389
StageNetTensorProcessor,
9490
NestedFloatsProcessor,
91+
DeepNestedFloatsProcessor,
9592
),
9693
):
97-
# Numeric features -> use nn.Linear
98-
# Both processors have .size attribute
94+
# Assuming processor.size() returns the last-dim size
95+
in_features = processor.size()
9996
self.embedding_layers[field_name] = nn.Linear(
100-
in_features=processor.size, out_features=embedding_dim
97+
in_features=in_features, out_features=embedding_dim
10198
)
99+
102100
elif isinstance(processor, TensorProcessor):
103-
# For tensor processor, we need to determine the input size
104-
# from the first sample in the dataset
101+
# Infer size from first sample
105102
sample_tensor = None
106103
for sample in dataset.samples:
107104
if field_name in sample:
@@ -114,43 +111,51 @@ def __init__(self, dataset: SampleDataset, embedding_dim: int = 128):
114111
self.embedding_layers[field_name] = nn.Linear(
115112
in_features=input_size, out_features=embedding_dim
116113
)
114+
117115
elif isinstance(processor, MultiHotProcessor):
118-
# MultiHotProcessor produces fixed-size binary vectors
119-
# Use processor.size() to get the vocabulary size (num_categories)
120116
num_categories = processor.size()
121117
self.embedding_layers[field_name] = nn.Linear(
122118
in_features=num_categories, out_features=embedding_dim
123119
)
120+
124121
else:
125122
print(
126123
"Warning: No embedding created for field due to lack of compatible processor:",
127124
field_name,
128125
)
129126

130-
def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
131-
"""
132-
Forward pass to compute embeddings for the input data.
133-
134-
Args:
135-
inputs (Dict[str, torch.Tensor]): A dictionary of input tensors.
136-
137-
Returns:
138-
Dict[str, torch.Tensor]: A dictionary of embedded tensors.
139-
"""
140-
embedded = {}
127+
def forward(self,
128+
inputs: Dict[str, torch.Tensor],
129+
output_mask: bool = False
130+
) -> Dict[str, torch.Tensor] | tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
131+
132+
embedded: Dict[str, torch.Tensor] = {}
133+
masks: Dict[str, torch.Tensor] = {} if output_mask else None
134+
141135
for field_name, tensor in inputs.items():
142-
if field_name in self.embedding_layers:
143-
tensor = tensor.to(self.device)
144-
embedded[field_name] = self.embedding_layers[field_name](tensor)
145-
else:
146-
embedded[field_name] = tensor # passthrough for continuous features
147-
return embedded
136+
processor = self.dataset.input_processors.get(field_name, None)
137+
138+
if field_name not in self.embedding_layers:
139+
# No embedding layer -> passthrough
140+
embedded[field_name] = tensor
141+
continue
142+
143+
tensor = tensor.to(self.device)
144+
embedded[field_name] = self.embedding_layers[field_name](tensor)
145+
146+
if output_mask:
147+
# Generate a mask for this field
148+
if hasattr(processor, "code_vocab"):
149+
pad_idx = processor.code_vocab.get("<pad>", 0)
150+
else:
151+
pad_idx = 0
152+
153+
masks[field_name] = (tensor != pad_idx)
154+
155+
if output_mask:
156+
return embedded, masks
157+
else:
158+
return embedded
148159

149160
def __repr__(self) -> str:
150-
"""
151-
Returns a string representation of the EmbeddingModel.
152-
153-
Returns:
154-
str: A string representation of the model.
155-
"""
156161
return f"EmbeddingModel(embedding_layers={self.embedding_layers})"

pyhealth/processors/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ def get_processor(name: str):
3030
NestedFloatsProcessor,
3131
NestedSequenceProcessor,
3232
)
33+
from .deep_nested_sequence_processor import (
34+
DeepNestedFloatsProcessor,
35+
DeepNestedSequenceProcessor,
36+
)
3337
from .raw_processor import RawProcessor
3438
from .sequence_processor import SequenceProcessor
3539
from .signal_processor import SignalProcessor

0 commit comments

Comments
 (0)