Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
name: CI

on:
workflow_dispatch:
pull_request:
push:
paths-ignore:
Expand All @@ -14,20 +15,18 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.11']
python-version: ['3.9', '3.11']
steps:
- name: Checkout reposistory
uses: actions/checkout@v3
- name: Checkout repository
uses: actions/checkout@v4
with:
submodules: recursive

- name: Set up Python
uses: actions/setup-python@v3
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Display Python version
run: python -c "import sys; print(sys.version)"
cache: 'pip' # caching pip dependencies

- name: Install dependent packages
run: 'make deps'
Expand Down
7 changes: 3 additions & 4 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@ PY_TEST_GLOB ?= test_metrics.py
# install dependencies
.PHONY: deps
deps:
pip install -r requirements-nlp.txt
python -m pip install -r requirements.txt

# run the unit test cases
.PHONY: test
test:
@echo "Running tests in $(PY_TEST_DIR)/$(PY_TEST_GLOB)"
python -m unittest discover \
-s $(PY_TEST_DIR) -p '$(PY_TEST_GLOB)' -v
@echo "Running tests in $(PY_TEST_DIR)/"
python -m pytest $(PY_TEST_DIR)/

# clean derived objects
.PHONY: clean
Expand Down
60 changes: 41 additions & 19 deletions pyhealth/data/data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Any, Iterable
import operator
from dataclasses import dataclass, field
from datetime import datetime
Expand Down Expand Up @@ -35,13 +36,17 @@ def from_dict(cls, d: Dict[str, any]) -> "Event":
timestamp: datetime = d["timestamp"]
event_type: str = d["event_type"]
attr_dict: Dict[str, any] = {
k.split("/", 1)[1]: v
for k, v in d.items()
if k.split("/")[0] == event_type
k.split("/", 1)[1]: v for k, v in d.items() if k.split("/")[0] == event_type
}
return cls(event_type=event_type, timestamp=timestamp, attr_dict=attr_dict)

def __getitem__(self, key: str) -> any:
def to_dict(self) -> dict[str, Any]:
res = {f"{self.event_type}/{k}": v for k, v in self.attr_dict.items()}
res["timestamp"] = self.timestamp
res["event_type"] = self.event_type
return res

def __getitem__(self, key: str) -> Any:
"""Get an attribute by key.

Args:
Expand Down Expand Up @@ -108,17 +113,29 @@ def __init__(self, patient_id: str, data_source: pl.DataFrame) -> None:
"""
self.patient_id = patient_id
self.data_source = data_source.sort("timestamp")
self.event_type_partitions = self.data_source.partition_by("event_type", maintain_order=True, as_dict=True)
self.event_type_partitions = self.data_source.partition_by(
"event_type", maintain_order=True, as_dict=True
)

def _filter_by_time_range_regular(self, df: pl.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pl.DataFrame:
@classmethod
def from_events(cls, patient_id: str, events: Iterable[Event]) -> "Patient":
return cls(
patient_id=patient_id, data_source=pl.DataFrame(e.to_dict() for e in events)
)

def _filter_by_time_range_regular(
self, df: pl.DataFrame, start: Optional[datetime], end: Optional[datetime]
) -> pl.DataFrame:
"""Regular filtering by time. Time complexity: O(n)."""
if start is not None:
df = df.filter(pl.col("timestamp") >= start)
if end is not None:
df = df.filter(pl.col("timestamp") <= end)
return df

def _filter_by_time_range_fast(self, df: pl.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pl.DataFrame:
def _filter_by_time_range_fast(
self, df: pl.DataFrame, start: Optional[datetime], end: Optional[datetime]
) -> pl.DataFrame:
"""Fast filtering by time using binary search on sorted timestamps. Time complexity: O(log n)."""
if start is None and end is None:
return df
Expand All @@ -132,13 +149,17 @@ def _filter_by_time_range_fast(self, df: pl.DataFrame, start: Optional[datetime]
end_idx = np.searchsorted(ts_col, end, side="right")
return df.slice(start_idx, end_idx - start_idx)

def _filter_by_event_type_regular(self, df: pl.DataFrame, event_type: Optional[str]) -> pl.DataFrame:
def _filter_by_event_type_regular(
self, df: pl.DataFrame, event_type: Optional[str]
) -> pl.DataFrame:
"""Regular filtering by event type. Time complexity: O(n)."""
if event_type:
df = df.filter(pl.col("event_type") == event_type)
return df

def _filter_by_event_type_fast(self, df: pl.DataFrame, event_type: Optional[str]) -> pl.DataFrame:
def _filter_by_event_type_fast(
self, df: pl.DataFrame, event_type: Optional[str]
) -> pl.DataFrame:
"""Fast filtering by event type using pre-built event type index. Time complexity: O(1)."""
if event_type:
return self.event_type_partitions.get((event_type,), df[:0])
Expand All @@ -150,7 +171,7 @@ def get_events(
event_type: Optional[str] = None,
start: Optional[datetime] = None,
end: Optional[datetime] = None,
filters: Optional[List[tuple]] = None,
filters: Optional[List[tuple[str, str, Any]]] = None,
return_df: bool = False,
) -> Union[pl.DataFrame, List[Event]]:
"""Get events with optional type and time filters.
Expand All @@ -159,14 +180,14 @@ def get_events(
event_type (Optional[str]): Type of events to filter.
start (Optional[datetime]): Start time for filtering events.
end (Optional[datetime]): End time for filtering events.
return_df (bool): Whether to return a DataFrame or a list of
return_df (bool): Whether to return a DataFrame or a list of
Event objects.
filters (Optional[List[tuple]]): Additional filters as [(attr, op, value), ...], e.g.:
[("attr1", "!=", "abnormal"), ("attr2", "!=", 1)]. Filters are applied after type
[("attr1", "!=", "abnormal"), ("attr2", "!=", 1)]. Filters are applied after type
and time filters. The logic is "AND" between different filters.

Returns:
Union[pl.DataFrame, List[Event]]: Filtered events as a DataFrame
Union[pl.DataFrame, List[Event]]: Filtered events as a DataFrame
or a list of Event objects.
"""
# faster filtering (by default)
Expand All @@ -177,14 +198,15 @@ def get_events(
# df = self._filter_by_event_type_regular(self.data_source, event_type)
# df = self._filter_by_time_range_regular(df, start, end)

if filters:
assert event_type is not None, "event_type must be provided if filters are provided"
else:
filters = []
if filters and event_type is None:
raise ValueError("event_type must be provided if filters are provided")

exprs = []
for filt in filters:
for filt in filters or []:
if not (isinstance(filt, tuple) and len(filt) == 3):
raise ValueError(f"Invalid filter format: {filt} (must be tuple of (attr, op, value))")
raise ValueError(
f"Invalid filter format: {filt} (must be tuple of (attr, op, value))"
)
attr, op, val = filt
col_expr = pl.col(f"{event_type}/{attr}")
# Build operator expression
Expand Down
5 changes: 4 additions & 1 deletion pyhealth/datasets/configs/mimic3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ tables:
- "first_careunit"
- "dbsource"
- "last_careunit"
- "outtime"
- "outtime"

diagnoses_icd:
file_path: "DIAGNOSES_ICD.csv.gz"
Expand All @@ -54,6 +54,7 @@ tables:
- "dischtime"
timestamp: "dischtime"
attributes:
- "hadm_id"
- "icd9_code"
- "seq_num"

Expand All @@ -62,6 +63,7 @@ tables:
patient_id: "subject_id"
timestamp: "startdate"
attributes:
- "hadm_id"
- "drug"
- "drug_type"
- "drug_name_poe"
Expand All @@ -88,6 +90,7 @@ tables:
- "dischtime"
timestamp: "dischtime"
attributes:
- "hadm_id"
- "icd9_code"
- "seq_num"

Expand Down
3 changes: 3 additions & 0 deletions pyhealth/datasets/configs/mimic4_ehr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ tables:
- "dischtime"
timestamp: "dischtime"
attributes:
- "hadm_id"
- "icd_code"
- "icd_version"
- "seq_num"
Expand All @@ -63,6 +64,7 @@ tables:
- "dischtime"
timestamp: "dischtime"
attributes:
- "hadm_id"
- "icd_code"
- "icd_version"
- "seq_num"
Expand Down Expand Up @@ -93,6 +95,7 @@ tables:
- "category"
timestamp: "charttime"
attributes:
- "hadm_id"
- "itemid"
- "label"
- "fluid"
Expand Down
7 changes: 5 additions & 2 deletions pyhealth/datasets/mimic4.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
import warnings
from typing import Dict, List, Optional
from typing import List, Optional

import pandas as pd
import polars as pl
Expand Down Expand Up @@ -218,7 +218,7 @@ def __init__(
self.dataset_name = dataset_name
self.sub_datasets = {}
self.root = None
self.tables = None
self.tables = []
self.config = None
# Dev flag is only used in the MIMIC4Dataset class
# to ensure the same set of patients are used for all sub-datasets.
Expand All @@ -241,6 +241,7 @@ def __init__(
tables=ehr_tables,
config_path=ehr_config_path,
)
self.tables.extend(self.sub_datasets["ehr"].tables)
log_memory_usage("After EHR dataset initialization")

# Initialize Notes dataset if root is provided
Expand All @@ -251,6 +252,7 @@ def __init__(
tables=note_tables,
config_path=note_config_path,
)
self.tables.extend(self.sub_datasets["note"].tables)
log_memory_usage("After Note dataset initialization")

# Initialize CXR dataset if root is provided
Expand All @@ -261,6 +263,7 @@ def __init__(
tables=cxr_tables,
config_path=cxr_config_path,
)
self.tables.extend(self.sub_datasets["cxr"].tables)
log_memory_usage("After CXR dataset initialization")

# Combine data from all sub-datasets
Expand Down
62 changes: 0 additions & 62 deletions pyhealth/unittests/test_data/test_data.py

This file was deleted.

Loading