Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions examples/ehr_uncertainty_analysis_example/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Example: Model Uncertainty Analysis with Custom Data Wrapper for PyHealth

This example demonstrates how custom preprocessed sequential data (e.g., for EHR analysis) can be integrated with the PyHealth library.

## Files Included:

1. **`pyhealth_custom_dataset_wrapper.py`**:
* Defines `CustomSequentialEHRDataPyHealth`, a Python class inheriting from `pyhealth.datasets.SampleEHRDataset`.
* This class serves as a wrapper, taking lists of preprocessed sequence tensors and label tensors as input.
* It structures this data into the `samples` format expected by `SampleEHRDataset` and uses a `task_fn` to prepare individual samples.
* This demonstrates a method for making custom data formats compatible with PyHealth's data handling system.

2. **`uncertainty_wrapper_example.ipynb`**:
* A Jupyter notebook showcasing the usage of `CustomSequentialEHRDataPyHealth`.
* It first generates minimal synthetic sequential data (for illustrative purposes).
* It then instantiates the `CustomSequentialEHRDataPyHealth` wrapper with this data.
* Finally, it demonstrates accessing a sample from the PyHealth-compatible dataset.

## Purpose

The primary goal is to provide a simple example of integrating external, preprocessed sequential data into the PyHealth ecosystem using `SampleEHRDataset`. This wrapped dataset could then potentially be used with other PyHealth functionalities or in custom analysis pipelines (like model uncertainty studies) while maintaining compatibility with PyHealth data structures.
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from pyhealth.datasets import SampleEHRDataset
import torch

def basic_task_fn(patient_sample_from_list):
"""
Basic task function for SampleEHRDataset.
Processes a single sample dictionary from the pre-cached list
and converts relevant data back into PyTorch tensors.
"""
return {
"patient_id": patient_sample_from_list["patient_id"],
"sequence_data": torch.tensor(patient_sample_from_list["sequence_data"], dtype=torch.float32),
"label": torch.tensor([patient_sample_from_list["label"]], dtype=torch.float32),
}

class CustomSequentialEHRDataPyHealth(SampleEHRDataset):
"""
PyHealth Dataset wrapper for custom pre-processed sequential data.

This class demonstrates how to integrate pre-processed sequential data
(provided as lists of sequences and labels) into the PyHealth ecosystem
by inheriting from `pyhealth.datasets.SampleEHRDataset`.

It converts the input lists into the `samples` format expected by
`SampleEHRDataset` during initialization and uses a `task_fn`
(basic_task_fn) to process samples when they are accessed.

Args:
list_of_patient_sequences (list): List of PyTorch tensors, where each
tensor represents a patient's sequence (shape: [seq_len, feature_dim]).
list_of_patient_labels (list): List of PyTorch tensors, where each
tensor is a patient's label (e.g., tensor([0.]) or tensor([1.])).
root (str): Root directory path required by PyHealth datasets (e.g., ".").
dataset_name (str): Name for this dataset instance.
"""
def __init__(self, list_of_patient_sequences, list_of_patient_labels, root=".", dataset_name="custom_ehr_example"):
pyhealth_samples = []
if len(list_of_patient_sequences) != len(list_of_patient_labels):
raise ValueError("Sequences and labels lists must have the same length.")

for i in range(len(list_of_patient_labels)):
pyhealth_samples.append({
"patient_id": str(i),
"sequence_data": list_of_patient_sequences[i].tolist(),
"label": list_of_patient_labels[i].item()
})

super().__init__(
samples=pyhealth_samples,
task_fn=basic_task_fn,
dataset_name=dataset_name,
root=root
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "088fc0c5",
"metadata": {},
"source": [
"# Example: Wrapping Custom Sequential Data for PyHealth\n",
"\n",
"This notebook demonstrates how preprocessed sequential data (like synthetic EHR data) can be wrapped using a custom class inheriting from `pyhealth.datasets.SampleEHRDataset`.\n",
"\n",
"This allows custom data formats to be used within the PyHealth ecosystem, for example, as input to uncertainty quantification studies or other downstream tasks.\n",
"\n",
"**Focus:** The primary goal here is to show the data wrapping mechanism using `CustomSequentialEHRDataPyHealth` from the accompanying `.py` file. The actual model training or analysis part is omitted for brevity in this specific example, but this dataset could be used as input for such tasks."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e2bcfa6a",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import numpy as np\n",
"import os\n",
"from pyhealth_custom_dataset_wrapper import CustomSequentialEHRDataPyHealth"
]
},
{
"cell_type": "markdown",
"id": "6ba165cf",
"metadata": {},
"source": [
"## 1. Generate Minimal Synthetic Sequential Data\n",
"\n",
"First, we create some dummy sequential data (sequences and labels) in the format our wrapper expects: lists of PyTorch tensors. In a real scenario, this data would come from your preprocessing pipeline."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b2708a07",
"metadata": {},
"outputs": [],
"source": [
"num_patients = 20\n",
"input_dim = 3\n",
"\n",
"\n",
"patient_sequences = []\n",
"for _ in range(num_patients):\n",
" seq_len = np.random.randint(5, 15) \n",
" sequence = torch.randn(seq_len, input_dim) \n",
" patient_sequences.append(sequence)\n",
"\n",
"\n",
"patient_labels = [torch.tensor([float(np.random.rand() > 0.5)]) for _ in range(num_patients)]\n",
"\n",
"print(f\"Generated {len(patient_sequences)} synthetic patient sequences.\")\n",
"if patient_sequences:\n",
" print(f\"Example sequence shape: {patient_sequences[0].shape}\")\n",
" print(f\"Example label: {patient_labels[0]}\")"
]
},
{
"cell_type": "markdown",
"id": "0a9f3650",
"metadata": {},
"source": [
"## 2. Wrap Data using PyHealth Custom Dataset Wrapper\n",
"\n",
"Now, we instantiate our `CustomSequentialEHRDataPyHealth` class (which uses `pyhealth.datasets.SampleEHRDataset` internally) with the generated data lists."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2763a476",
"metadata": {},
"outputs": [],
"source": [
"print(\"\\n--- Wrapping Data with PyHealth Custom Dataset Class ---\")\n",
"try:\n",
" # Instantiate the PyHealth-compatible dataset wrapper\n",
" pyhealth_dataset = CustomSequentialEHRDataPyHealth(\n",
" list_of_patient_sequences=patient_sequences,\n",
" list_of_patient_labels=patient_labels,\n",
" root=\".\" # Use current directory for any potential caching by PyHealth\n",
" )\n",
" print(f\"Successfully wrapped data into CustomSequentialEHRDataPyHealth.\")\n",
" print(f\"Number of samples in PyHealth dataset: {len(pyhealth_dataset)}\")\n",
"\n",
" if len(pyhealth_dataset) > 0:\n",
" # Demonstrate getting a sample (processed by the task_fn)\n",
" first_sample = pyhealth_dataset[0] \n",
" print(\"\\nExample of first sample retrieved via the PyHealth wrapper:\")\n",
" print(f\" Patient ID: {first_sample['patient_id']}\")\n",
" print(f\" Sequence Data Shape: {first_sample['sequence_data'].shape}\")\n",
" print(f\" Label: {first_sample['label']}\")\n",
"\n",
"except Exception as e:\n",
" print(f\"An error occurred during dataset wrapping: {e}\")\n",
"\n",
"print(\"--- PyHealth Dataset Wrapper Demonstration Complete ---\")"
]
},
{
"cell_type": "markdown",
"id": "ff798c06",
"metadata": {},
"source": [
"## 3. Next Steps\n",
"\n",
"From here, the `pyhealth_dataset` object could potentially be used with PyHealth's data loaders or models if they are compatible with the output format defined in our `basic_task_fn`.\n",
"\n",
"Alternatively, this example primarily serves to show data integration. One could continue using standard PyTorch DataLoaders (created from the original `patient_sequences` and `patient_labels`) and custom models for downstream tasks like model training and uncertainty analysis, while having demonstrated compatibility with PyHealth's data structures."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "pyhealth_env",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.12.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}