Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mne_bids_pipeline/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,6 +1367,12 @@
false-alarm rate increases dramatically.
"""

ica_use_icalabel = False
"""
Whether to use MNE-ICALabel to automatically label ICA components. Only available for
EEG data.
"""

# Rejection based on peak-to-peak amplitude
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
131 changes: 92 additions & 39 deletions mne_bids_pipeline/steps/preprocessing/_06a_run_ica.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pandas as pd
import numpy as np
import autoreject
from mne_icalabel import label_components

import mne
from mne.report import Report
Expand Down Expand Up @@ -78,6 +79,10 @@ def fit_ica(
algorithm = "infomax"
fit_params = dict(extended=True)

if cfg.ica_use_icalabel:
# The ICALabel network was trained on extended-Infomax ICA decompositions
assert algorithm in ["picard-extended_infomax", "extended_infomax"]

ica = ICA(
method=algorithm,
random_state=cfg.random_state,
Expand Down Expand Up @@ -184,7 +189,7 @@ def make_eog_epochs(
return eog_epochs


def detect_bad_components(
def detect_bad_components_mne(
*,
cfg,
which: Literal["eog", "ecg"],
Expand All @@ -195,7 +200,7 @@ def detect_bad_components(
session: Optional[str],
) -> Tuple[List[int], np.ndarray]:
artifact = which.upper()
msg = f"Performing automated {artifact} artifact detection …"
msg = f"Performing automated {artifact} artifact detection (MNE) …"
logger.info(**gen_log_kwargs(message=msg))

if which == "eog":
Expand Down Expand Up @@ -395,7 +400,18 @@ def run_ica(

# Set an EEG reference
if "eeg" in cfg.ch_types:
projection = True if cfg.eeg_reference == "average" else False
if cfg.ica_use_icalabel:
assert cfg.eeg_reference == "average"
projection = False # Avg. ref. needs to be applied for MNE-ICALabel
elif cfg.eeg_reference == "average":
projection = True
else:
projection = False

if not projection:
msg = "Applying average reference to EEG epochs used for ICA fitting."
logger.info(**gen_log_kwargs(message=msg))

epochs.set_eeg_reference(cfg.eeg_reference, projection=projection)

if cfg.ica_reject == "autoreject_local":
Expand Down Expand Up @@ -446,38 +462,56 @@ def run_ica(
if cfg.task is not None:
title += f", task-{cfg.task}"

# ECG and EOG component detection
if epochs_ecg:
ecg_ics, ecg_scores = detect_bad_components(
cfg=cfg,
which="ecg",
epochs=epochs_ecg,
ica=ica,
ch_names=None, # we currently don't allow for custom channels
subject=subject,
session=session,
)
else:
ecg_ics = ecg_scores = []
# Run MNE-ICALabel if requested.
if cfg.ica_use_icalabel:
msg = "Performing automated artifact detection (MNE-ICALabel) …"
logger.info(**gen_log_kwargs(message=msg))

if epochs_eog:
eog_ics, eog_scores = detect_bad_components(
cfg=cfg,
which="eog",
epochs=epochs_eog,
ica=ica,
ch_names=cfg.eog_channels,
subject=subject,
session=session,
)
label_results = label_components(inst=epochs, ica=ica, method="iclabel")
icalabel_ics = []
icalabel_labels = []
for idx, label in enumerate(label_results["labels"]):
if label not in ["brain", "other"]:
icalabel_ics.append(idx)
icalabel_labels.append(label)

msg = f"Detected {len(icalabel_ics)} artifact-related ICs in {len(epochs)} epochs."
logger.info(**gen_log_kwargs(message=msg))
ica.exclude = sorted(icalabel_ics)
else:
eog_ics = eog_scores = []
# Run MNE's built-in ECG and EOG component detection
if epochs_ecg:
ecg_ics, ecg_scores = detect_bad_components_mne(
cfg=cfg,
which="ecg",
epochs=epochs_ecg,
ica=ica,
ch_names=None, # we currently don't allow for custom channels
subject=subject,
session=session,
)
else:
ecg_ics = ecg_scores = []

if epochs_eog:
eog_ics, eog_scores = detect_bad_components_mne(
cfg=cfg,
which="eog",
epochs=epochs_eog,
ica=ica,
ch_names=cfg.eog_channels,
subject=subject,
session=session,
)
else:
eog_ics = eog_scores = []

ica.exclude = sorted(set(ecg_ics + eog_ics))

# Save ICA to disk.
# We also store the automatically identified ECG- and EOG-related ICs.
msg = "Saving ICA solution and detected artifacts to disk."
logger.info(**gen_log_kwargs(message=msg))
ica.exclude = sorted(set(ecg_ics + eog_ics))
ica.save(out_files["ica"], overwrite=True)
_update_for_splits(out_files, "ica")

Expand All @@ -492,15 +526,27 @@ def run_ica(
)
)

for component in ecg_ics:
row_idx = tsv_data["component"] == component
tsv_data.loc[row_idx, "status"] = "bad"
tsv_data.loc[row_idx, "status_description"] = "Auto-detected ECG artifact"

for component in eog_ics:
row_idx = tsv_data["component"] == component
tsv_data.loc[row_idx, "status"] = "bad"
tsv_data.loc[row_idx, "status_description"] = "Auto-detected EOG artifact"
if cfg.ica_use_icalabel:
for component, label in zip(icalabel_ics, icalabel_labels):
row_idx = tsv_data["component"] == component
tsv_data.loc[row_idx, "status"] = "bad"
tsv_data.loc[
row_idx, "status_description"
] = f"Auto-detected {label} (MNE-ICALabel)"
else:
for component in ecg_ics:
row_idx = tsv_data["component"] == component
tsv_data.loc[row_idx, "status"] = "bad"
tsv_data.loc[
row_idx, "status_description"
] = "Auto-detected ECG artifact (MNE)"

for component in eog_ics:
row_idx = tsv_data["component"] == component
tsv_data.loc[row_idx, "status"] = "bad"
tsv_data.loc[
row_idx, "status_description"
] = "Auto-detected EOG artifact (MNE)"

tsv_data.to_csv(out_files["components"], sep="\t", index=False)

Expand All @@ -510,10 +556,16 @@ def run_ica(
logger.info(**gen_log_kwargs(message=msg))

report = Report(info_fname=epochs, title=title, verbose=False)

ecg_evoked = None if epochs_ecg is None else epochs_ecg.average()
eog_evoked = None if epochs_eog is None else epochs_eog.average()
ecg_scores = None if len(ecg_scores) == 0 else ecg_scores
eog_scores = None if len(eog_scores) == 0 else eog_scores

if cfg.ica_use_icalabel:
# We didn't run MNE's scoring
ecg_scores = eog_scores = None
else:
ecg_scores = None if len(ecg_scores) == 0 else ecg_scores
eog_scores = None if len(eog_scores) == 0 else eog_scores

with _agg_backend():
if cfg.ica_reject == "autoreject_local":
Expand Down Expand Up @@ -588,6 +640,7 @@ def get_config(
ica_reject=config.ica_reject,
ica_eog_threshold=config.ica_eog_threshold,
ica_ctps_ecg_threshold=config.ica_ctps_ecg_threshold,
ica_use_icalabel=config.ica_use_icalabel,
autoreject_n_interpolate=config.autoreject_n_interpolate,
random_state=config.random_state,
ch_types=config.ch_types,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ dependencies = [
"autoreject",
"mne[hdf5] >=1.2",
"mne-bids[full]",
"mna-icalabel",
"filelock",
"setuptools >=65",
]
Expand Down