diff --git a/.circleci/config.yml b/.circleci/config.yml index 2e030bc88..3c7c802dc 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -50,8 +50,7 @@ jobs: pip install --upgrade --progress-bar off pip # TODO: Restore https://api.github.com/repos/mne-tools/mne-bids/zipball/main pending https://github.com/mne-tools/mne-bids/pull/1349/files#r1885104885 pip install --upgrade --progress-bar off "autoreject @ https://api.github.com/repos/autoreject/autoreject/zipball/master" "mne[hdf5] @ git+https://github.com/mne-tools/mne-python@main" "mne-bids[full] @ git+https://github.com/mne-tools/mne-bids@main" numba - pip install -ve .[tests] - pip install "PyQt6!=6.6.1" "PyQt6-Qt6!=6.6.1,!=6.6.2,!=6.6.3,!=6.7.0" + pip install -ve .[tests] onnxruntime "PyQt6!=6.6.1" "PyQt6-Qt6!=6.6.1,!=6.6.2,!=6.6.3,!=6.7.0" - run: name: Check Qt command: | diff --git a/docs/source/dev.md.inc b/docs/source/dev.md.inc index ca7e7f241..72c935734 100644 --- a/docs/source/dev.md.inc +++ b/docs/source/dev.md.inc @@ -2,6 +2,7 @@ ### :new: New features & enhancements +- Support for using `mne-icalabel` to automatically label ICA components. This requires the `mne-icalabel` package to be installed. (#1018 and #812 by @jschepers, @behinger, @hoechenberger, and @larsoner) - It is now possible to use separate MRIs for each session within a subject, as in longitudinal studies. This is achieved by creating separate "subject" folders for each subject-session combination, with the naming convention `sub-XXX_ses-YYY`, in the freesurfer `SUBJECTS_DIR`. (#987 by @drammock) - New config option [`allow_missing_sessions`][mne_bids_pipeline._config.allow_missing_sessions] allows to continue when not all sessions are present for all subjects. (#1000 by @drammock) - New config option [`mf_extra_kws`][mne_bids_pipeline._config.mf_extra_kws] passes additional keyword arguments to `mne.preprocessing.maxwell_filter`. (#1038 by @drammock) diff --git a/docs/source/settings/gen_settings.py b/docs/source/settings/gen_settings.py index cd5d264c1..310d2794f 100755 --- a/docs/source/settings/gen_settings.py +++ b/docs/source/settings/gen_settings.py @@ -111,12 +111,14 @@ "^" # The line starts, then is followed by r"(\w+): " # annotation syntax (name captured by the first group), "(?:" # then the rest of the line can be (in a non-capturing group): - ".+ = .+" # 1. a standard assignment - "|" # 2. or - r"Literal\[" # 3. the start of a multiline type annotation like "a: Literal[" - "|" # 4. or - r"\(" # 5. the start of a multiline 3.9+ type annotation like "a: (" - ")" # Then the end of our group + ".+ = .+" # 1. a standard assignment + "|" # 2. or + r"Literal\[" # 3. the start of a multiline type annotation like "a: Literal[" + "|" # 4. or + r"Annotated\[" # 5. the start of a multiline annotated type like "a: Annotated[" + "|" # 6. or + r"\(" # 7. the start of a multiline 3.9+ type annotation like "a: (" + ")" # Then the end of our group "$", # and immediately the end of the line. re.MULTILINE, ) diff --git a/mne_bids_pipeline/_config.py b/mne_bids_pipeline/_config.py index 9f4fe5a51..6852f94c5 100644 --- a/mne_bids_pipeline/_config.py +++ b/mne_bids_pipeline/_config.py @@ -13,6 +13,7 @@ DigMontageType, FloatArrayLike, PathLike, + UniqueSequence, ) # %% @@ -1430,6 +1431,11 @@ us so we can discuss. """ +ica_h_freq: float | None = None +""" +The cutoff frequency of the low-pass filter to apply before running ICA. +""" + ica_max_iterations: int = 500 """ Maximum number of iterations to decompose the data into independent @@ -1476,12 +1482,22 @@ `1` or `None` to not perform any decimation. """ +ica_use_ecg_detection: bool = True +""" +Whether to use the MNE ECG detection on the ICA components. +""" + ica_ecg_threshold: float = 0.1 """ The cross-trial phase statistics (CTPS) threshold parameter used for detecting ECG-related ICs. """ +ica_use_eog_detection: bool = True +""" +Whether to use the MNE EOG detection on the ICA components. +""" + ica_eog_threshold: float = 3.0 """ The threshold to use during automated EOG classification. Lower values mean @@ -1489,6 +1505,48 @@ false-alarm rate increases dramatically. """ +ica_use_icalabel: bool = False +""" +Whether to use MNE-ICALabel to automatically label ICA components. Only available for +EEG data. + +!!! info + Using MNE-ICALabel mandates that you also set: + ```python + eeg_reference = "average" + ica_l_freq = 1 + ica_h_freq = 100 + ``` + It will also apply the average reference to the data before running or applying ICA. + +!!! info + Using this requires `mne-icalabel` package, which in turn will require you to + install a suitable runtime (`onnxruntime` or `pytorch`). +""" + +ica_icalabel_include: Annotated[ + UniqueSequence[ + Literal[ + "brain", + "muscle artifact", + "eye blink", + "heart beat", + "line noise", + "channel noise", + "other", + ] + ], + Len(1, 7), +] = ("brain", "other") +""" +Which independent components (ICs) to keep based on the labels given by ICLabel. +Possible labels are: + +``` +["brain", "muscle artifact", "eye blink", "heart beat", "line noise", "channel noise", "other"] +``` +""" # noqa: E501 + # ### Amplitude-based artifact rejection # # ???+ info "Good Practice / Advice" diff --git a/mne_bids_pipeline/_config_import.py b/mne_bids_pipeline/_config_import.py index 3c50a2be8..333203697 100644 --- a/mne_bids_pipeline/_config_import.py +++ b/mne_bids_pipeline/_config_import.py @@ -426,6 +426,27 @@ def _check_config(config: SimpleNamespace, config_path: PathLike | None) -> None f"but got shape {destination.shape}" ) + # From: https://github.com/mne-tools/mne-bids-pipeline/pull/812 + # MNE-ICALabel + if config.ica_use_icalabel: + pre = "When using MNE-ICALabel, you must set" + if config.ica_l_freq != 1.0 or config.ica_h_freq != 100.0: + raise ValueError( + f"{pre} ica_l_freq=1 and h_freq=100, " + f"but got: ica_l_freq={config.ica_l_freq} and " + f"ica_h_freq={config.ica_h_freq}" + ) + if config.eeg_reference != "average": + raise ValueError( + f'{pre} eeg_reference="average", but got: ' + f"eeg_reference={config.eeg_reference}" + ) + if config.ica_algorithm not in ("picard-extended_infomax", "extended_infomax"): + raise ValueError( + f'{pre} ica_algorithm="picard-extended_infomax" or "extended_infomax", ' + f"but got: ica_algorithm={repr(config.ica_algorithm)}" + ) + def _default_factory(key: str, val: Any) -> Any: # convert a default to a default factory if needed, having an explicit @@ -435,6 +456,7 @@ def _default_factory(key: str, val: Any) -> Any: {"custom": (8, 24.0, 40)}, # decoding_csp_freqs ["evoked"], # inverse_targets [4, 8, 16], # autoreject_n_interpolate + ("brain", "other"), # ica_icalabel_include ] def default_factory() -> Any: diff --git a/mne_bids_pipeline/steps/preprocessing/_06a1_fit_ica.py b/mne_bids_pipeline/steps/preprocessing/_06a1_fit_ica.py index 18135f66e..77793c1fd 100644 --- a/mne_bids_pipeline/steps/preprocessing/_06a1_fit_ica.py +++ b/mne_bids_pipeline/steps/preprocessing/_06a1_fit_ica.py @@ -79,6 +79,14 @@ def run_ica( """Run ICA.""" import matplotlib.pyplot as plt + if cfg.ica_use_icalabel: + # The ICALabel network was trained on extended-Infomax ICA decompositions fit + # on data flltered between 1 and 100 Hz. + assert cfg.ica_algorithm in ["picard-extended_infomax", "extended_infomax"] + assert cfg.ica_l_freq == 1.0 + assert cfg.ica_h_freq == 100.0 + assert cfg.eeg_reference == "average" + raw_fnames = [in_files.pop(f"raw_run-{run}") for run in cfg.runs] out_files = dict() bids_basename = raw_fnames[0].copy().update(processing=None, split=None, run=None) @@ -105,20 +113,38 @@ def run_ica( # Sanity check – make sure we're using the correct data! if cfg.raw_resample_sfreq is not None: assert np.allclose(raw.info["sfreq"], cfg.raw_resample_sfreq) - if cfg.l_freq is not None: - assert np.allclose(raw.info["highpass"], cfg.l_freq) if idx == 0: - if cfg.ica_l_freq is None: + # We have to do some gymnastics here to permit for example 128 Hz-sampled + # data to be used with mne-icalabel, which wants data low-pass filtered + # at 100 Hz + h_freq = cfg.ica_h_freq + nyq = raw.info["sfreq"] / 2.0 + if h_freq is not None and h_freq >= nyq: msg = ( - f"Not applying high-pass filter (data is already filtered, " - f"cutoff: {raw.info['highpass']} Hz)." + f"Low-pass filter cutoff {h_freq} Hz is higher " + f"than Nyquist {nyq} Hz" ) + if cfg.ica_use_icalabel: + msg += ", setting to None for compatibility with MNE-ICALabel." + logger.warning(**gen_log_kwargs(message=msg)) + h_freq = None + else: + raise ValueError(msg) + msg = "" + if cfg.ica_l_freq is not None and h_freq is not None: + msg = ( + f"Applying band-pass filter with {cfg.ica_l_freq}-{h_freq} " + "Hz cutoffs" + ) + elif cfg.ica_l_freq is not None: + msg = f"Applying high-pass filter with {cfg.ica_l_freq} Hz cutoff" + elif h_freq is not None: + msg = f"Applying low-pass filter with {h_freq} Hz cutoff" + if cfg.ica_l_freq is not None or h_freq is not None: logger.info(**gen_log_kwargs(message=msg)) - else: - msg = f"Applying high-pass filter with {cfg.ica_l_freq} Hz cutoff …" - logger.info(**gen_log_kwargs(message=msg)) - raw.filter(l_freq=cfg.ica_l_freq, h_freq=None, n_jobs=1) + raw.filter(l_freq=cfg.ica_l_freq, h_freq=h_freq, n_jobs=1) + del nyq, h_freq # Only keep the subset of the mapping that applies to the current run event_id = event_name_to_code_map.copy() @@ -167,6 +193,8 @@ def run_ica( if "eeg" in cfg.ch_types: projection = True if cfg.eeg_reference == "average" else False epochs.set_eeg_reference(cfg.eeg_reference, projection=projection) + if cfg.ica_use_icalabel: + epochs.apply_proj() # Apply the reference projection ar_reject_log = ar_n_interpolate_ = None if cfg.ica_reject == "autoreject_local": @@ -333,16 +361,17 @@ def get_config( conditions=config.conditions, runs=get_runs(config=config, subject=subject), task_is_rest=config.task_is_rest, + ica_h_freq=config.ica_h_freq, ica_l_freq=config.ica_l_freq, ica_algorithm=config.ica_algorithm, ica_n_components=config.ica_n_components, ica_max_iterations=config.ica_max_iterations, ica_decim=config.ica_decim, ica_reject=config.ica_reject, + ica_use_icalabel=config.ica_use_icalabel, autoreject_n_interpolate=config.autoreject_n_interpolate, random_state=config.random_state, ch_types=config.ch_types, - l_freq=config.l_freq, epochs_decim=config.epochs_decim, raw_resample_sfreq=config.raw_resample_sfreq, event_repeated=config.event_repeated, diff --git a/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py b/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py index c2211476c..98b922e83 100644 --- a/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py +++ b/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py @@ -10,10 +10,12 @@ from types import SimpleNamespace from typing import Literal +import matplotlib.pyplot as plt import mne import numpy as np import pandas as pd from mne.preprocessing import create_ecg_epochs, create_eog_epochs +from mne.viz import plot_ica_components from mne_bids import BIDSPath from mne_bids_pipeline._config_utils import ( @@ -157,114 +159,148 @@ def find_ica_artifacts( epochs_ecg = None ecg_ics: list[int] = [] ecg_scores: FloatArrayT = np.zeros(0) - for ri, raw_fname in enumerate(raw_fnames): - # Have the channels needed to make ECG epochs - raw = mne.io.read_raw(raw_fname, preload=False) - # ECG epochs - if not ( - "ecg" in raw.get_channel_types() - or "meg" in cfg.ch_types - or "mag" in cfg.ch_types - ): - msg = ( - "No ECG or magnetometer channels are present, cannot " - "automate artifact detection for ECG." + if cfg.ica_use_ecg_detection: + for ri, raw_fname in enumerate(raw_fnames): + # Have the channels needed to make ECG epochs + raw = mne.io.read_raw(raw_fname, preload=False) + if cfg.ica_use_icalabel: + raw.set_eeg_reference("average", projection=True).apply_proj() + # ECG epochs + if not ( + "ecg" in raw.get_channel_types() + or "meg" in cfg.ch_types + or "mag" in cfg.ch_types + ): + msg = ( + "No ECG or magnetometer channels are present, cannot " + "automate artifact detection for ECG." + ) + logger.info(**gen_log_kwargs(message=msg)) + break + elif ri == 0: + msg = "Creating ECG epochs …" + logger.info(**gen_log_kwargs(message=msg)) + + # We want to extract a total of 5 min of data for ECG epochs generation + # (across all runs) + total_ecg_dur = 5 * 60 + ecg_dur_per_run = total_ecg_dur / len(raw_fnames) + t_mid = (raw.times[-1] + raw.times[0]) / 2 + raw = raw.crop( + tmin=max(t_mid - 1 / 2 * ecg_dur_per_run, 0), + tmax=min(t_mid + 1 / 2 * ecg_dur_per_run, raw.times[-1]), + ).load_data() + + these_ecg_epochs = create_ecg_epochs( + raw, + baseline=(None, -0.2), + tmin=-0.5, + tmax=0.5, ) - logger.info(**gen_log_kwargs(message=msg)) - break - elif ri == 0: - msg = "Creating ECG epochs …" - logger.info(**gen_log_kwargs(message=msg)) - - # We want to extract a total of 5 min of data for ECG epochs generation - # (across all runs) - total_ecg_dur = 5 * 60 - ecg_dur_per_run = total_ecg_dur / len(raw_fnames) - t_mid = (raw.times[-1] + raw.times[0]) / 2 - raw = raw.crop( - tmin=max(t_mid - 1 / 2 * ecg_dur_per_run, 0), - tmax=min(t_mid + 1 / 2 * ecg_dur_per_run, raw.times[-1]), - ).load_data() - - these_ecg_epochs = create_ecg_epochs( - raw, - baseline=(None, -0.2), - tmin=-0.5, - tmax=0.5, - ) - del raw # Free memory - if len(these_ecg_epochs): - if epochs.reject is not None: - these_ecg_epochs.drop_bad(reject=epochs.reject) + del raw # Free memory if len(these_ecg_epochs): - if epochs_ecg is None: - epochs_ecg = these_ecg_epochs - else: - epochs_ecg = mne.concatenate_epochs( - [epochs_ecg, these_ecg_epochs], on_mismatch="warn" - ) - del these_ecg_epochs - else: # did not break so had usable channels - ecg_ics, ecg_scores = detect_bad_components( - cfg=cfg, - which="ecg", - epochs=epochs_ecg, - ica=ica, - ch_names=None, # we currently don't allow for custom channels - subject=subject, - session=session, - ) + if epochs.reject is not None: + these_ecg_epochs.drop_bad(reject=epochs.reject) + if len(these_ecg_epochs): + if epochs_ecg is None: + epochs_ecg = these_ecg_epochs + else: + epochs_ecg = mne.concatenate_epochs( + [epochs_ecg, these_ecg_epochs], on_mismatch="warn" + ) + del these_ecg_epochs + else: # did not break so had usable channels + ecg_ics, ecg_scores = detect_bad_components( + cfg=cfg, + which="ecg", + epochs=epochs_ecg, + ica=ica, + ch_names=None, # we currently don't allow for custom channels + subject=subject, + session=session, + ) # EOG component detection epochs_eog = None eog_ics: list[int] = [] eog_scores: FloatArrayT = np.zeros(0) - for ri, raw_fname in enumerate(raw_fnames): - raw = mne.io.read_raw_fif(raw_fname, preload=True) - if cfg.eog_channels: - ch_names = cfg.eog_channels - assert all([ch_name in raw.ch_names for ch_name in ch_names]) - else: - eog_picks = mne.pick_types(raw.info, meg=False, eog=True) - ch_names = [raw.ch_names[pick] for pick in eog_picks] - if not ch_names: - msg = "No EOG channel is present, cannot automate IC detection for EOG." - logger.info(**gen_log_kwargs(message=msg)) - break - elif ri == 0: - msg = "Creating EOG epochs …" - logger.info(**gen_log_kwargs(message=msg)) - these_eog_epochs = create_eog_epochs( - raw, - ch_name=ch_names, - baseline=(None, -0.2), - ) - if len(these_eog_epochs): - if epochs.reject is not None: - these_eog_epochs.drop_bad(reject=epochs.reject) + if cfg.ica_use_eog_detection: + for ri, raw_fname in enumerate(raw_fnames): + raw = mne.io.read_raw_fif(raw_fname, preload=True) + if cfg.ica_use_icalabel: + raw.set_eeg_reference("average", projection=True).apply_proj() + if cfg.eog_channels: + ch_names = cfg.eog_channels + assert all([ch_name in raw.ch_names for ch_name in ch_names]) + else: + eog_picks = mne.pick_types(raw.info, meg=False, eog=True) + ch_names = [raw.ch_names[pick] for pick in eog_picks] + if not ch_names: + msg = "No EOG channel is present, cannot automate IC detection for EOG." + logger.info(**gen_log_kwargs(message=msg)) + break + elif ri == 0: + msg = "Creating EOG epochs …" + logger.info(**gen_log_kwargs(message=msg)) + these_eog_epochs = create_eog_epochs( + raw, + ch_name=ch_names, + baseline=(None, -0.2), + ) if len(these_eog_epochs): - if epochs_eog is None: - epochs_eog = these_eog_epochs - else: - epochs_eog = mne.concatenate_epochs( - [epochs_eog, these_eog_epochs], on_mismatch="warn" - ) - else: # did not break - eog_ics, eog_scores = detect_bad_components( - cfg=cfg, - which="eog", - epochs=epochs_eog, - ica=ica, - ch_names=cfg.eog_channels, - subject=subject, - session=session, + if epochs.reject is not None: + these_eog_epochs.drop_bad(reject=epochs.reject) + if len(these_eog_epochs): + if epochs_eog is None: + epochs_eog = these_eog_epochs + else: + epochs_eog = mne.concatenate_epochs( + [epochs_eog, these_eog_epochs], on_mismatch="warn" + ) + else: # did not break + eog_ics, eog_scores = detect_bad_components( + cfg=cfg, + which="eog", + epochs=epochs_eog, + ica=ica, + ch_names=cfg.eog_channels, + subject=subject, + session=session, + ) + + # Run MNE-ICALabel if requested. + icalabel_ics = [] + icalabel_labels = [] + icalabel_prob = [] + if cfg.ica_use_icalabel: + import mne_icalabel + + msg = "Performing automated artifact detection (MNE-ICALabel) …" + logger.info(**gen_log_kwargs(message=msg)) + + label_results = mne_icalabel.label_components( + inst=epochs, ica=ica, method="iclabel" + ) + for idx, (label, prob) in enumerate( + zip(label_results["labels"], label_results["y_pred_proba"]) + ): + if label not in cfg.ica_icalabel_include: + icalabel_ics.append(idx) + icalabel_labels.append(label) + icalabel_prob.append(prob) + + msg = ( + f"Detected {len(icalabel_ics)} artifact-related independent component(s) " + f"in {len(epochs)} epochs: {icalabel_labels}" ) + logger.info(**gen_log_kwargs(message=msg)) + + ica.exclude = sorted(set(ecg_ics + eog_ics + icalabel_ics)) # Save updated ICA to disk. # We also store the automatically identified ECG- and EOG-related ICs. msg = "Saving ICA solution and detected artifacts to disk." logger.info(**gen_log_kwargs(message=msg)) - ica.exclude = sorted(set(ecg_ics + eog_ics)) ica.save(out_files["ica"], overwrite=True) # Create TSV. @@ -278,15 +314,27 @@ def find_ica_artifacts( ) ) - for component in ecg_ics: - row_idx = tsv_data["component"] == component - tsv_data.loc[row_idx, "status"] = "bad" - tsv_data.loc[row_idx, "status_description"] = "Auto-detected ECG artifact" - - for component in eog_ics: - row_idx = tsv_data["component"] == component - tsv_data.loc[row_idx, "status"] = "bad" - tsv_data.loc[row_idx, "status_description"] = "Auto-detected EOG artifact" + if cfg.ica_use_icalabel: + for component, label in zip(icalabel_ics, icalabel_labels): + row_idx = tsv_data["component"] == component + tsv_data.loc[row_idx, "status"] = "bad" + tsv_data.loc[row_idx, "status_description"] = ( + f"Auto-detected {label} (MNE-ICALabel)" + ) + if cfg.ica_use_ecg_detection: + for component in ecg_ics: + row_idx = tsv_data["component"] == component + tsv_data.loc[row_idx, "status"] = "bad" + tsv_data.loc[row_idx, "status_description"] = ( + "Auto-detected ECG artifact (MNE)" + ) + if cfg.ica_use_eog_detection: + for component in eog_ics: + row_idx = tsv_data["component"] == component + tsv_data.loc[row_idx, "status"] = "bad" + tsv_data.loc[row_idx, "status_description"] = ( + "Auto-detected EOG artifact (MNE)" + ) tsv_data.to_csv(out_files_components, sep="\t", index=False) @@ -308,7 +356,8 @@ def find_ica_artifacts( del artifact_name, artifact_evoked - title = "ICA: components" + section = "ICA: components" + tags = ("ica",) with _open_report( cfg=cfg, exec_params=exec_params, @@ -316,10 +365,10 @@ def find_ica_artifacts( session=session, task=cfg.task, ) as report: - logger.info(**gen_log_kwargs(message=f'Adding "{title}" to report.')) + logger.info(**gen_log_kwargs(message=f'Adding "{section}" to report.')) report.add_ica( ica=ica, - title=title, + title=section, inst=epochs, ecg_evoked=ecg_evoked, eog_evoked=eog_evoked, @@ -327,9 +376,34 @@ def find_ica_artifacts( eog_scores=eog_scores if len(eog_scores) else None, replace=True, n_jobs=1, # avoid automatic parallelization - tags=("ica",), # the default but be explicit + tags=tags, # the default but be explicit ) + # Add a plot for each excluded IC together with the given label and the prob + if cfg.ica_use_icalabel and len(icalabel_ics): + msg = "Adding icalabel components to report." + logger.info(**gen_log_kwargs(message=msg)) + figs = list() + for ic, label, prob in zip(icalabel_ics, icalabel_labels, icalabel_prob): + fig = plot_ica_components(ica=ica, picks=ic) + fig.axes[0].text( + 0, + -0.15, + f"Label: {label} \n Probability: {prob:.3f}", + ha="center", + fontsize=8, + bbox={"facecolor": "orange", "alpha": 0.5, "pad": 5}, + ) + figs.append(fig) + report.add_figure( + fig=figs, + title="ICA components from icalabel", + section=section, + replace=True, + ) + for fig in figs: + plt.close(fig) + msg = 'Carefully review the extracted ICs and mark components "bad" in:' logger.info(**gen_log_kwargs(message=msg, emoji="🛑")) logger.info(**gen_log_kwargs(message=str(out_files_components), emoji="🛑")) @@ -350,8 +424,12 @@ def get_config( task_is_rest=config.task_is_rest, ica_l_freq=config.ica_l_freq, ica_reject=config.ica_reject, + ica_use_eog_detection=config.ica_use_eog_detection, ica_eog_threshold=config.ica_eog_threshold, + ica_use_ecg_detection=config.ica_use_ecg_detection, ica_ecg_threshold=config.ica_ecg_threshold, + ica_use_icalabel=config.ica_use_icalabel, + ica_icalabel_include=config.ica_icalabel_include, autoreject_n_interpolate=config.autoreject_n_interpolate, random_state=config.random_state, ch_types=config.ch_types, diff --git a/mne_bids_pipeline/steps/preprocessing/_08a_apply_ica.py b/mne_bids_pipeline/steps/preprocessing/_08a_apply_ica.py index be3b5ac44..1281604bf 100644 --- a/mne_bids_pipeline/steps/preprocessing/_08a_apply_ica.py +++ b/mne_bids_pipeline/steps/preprocessing/_08a_apply_ica.py @@ -138,6 +138,8 @@ def apply_ica_epochs( logger.info(**gen_log_kwargs(message=msg)) epochs = mne.read_epochs(in_files.pop("epochs"), preload=True) + if cfg.ica_use_icalabel: + epochs.set_eeg_reference("average", projection=True).apply_proj() # Now actually reject the components. msg = ( @@ -212,6 +214,8 @@ def apply_ica_raw( msg = f"Writing {out_files[in_key].basename} …" logger.info(**gen_log_kwargs(message=msg)) raw = mne.io.read_raw_fif(raw_fname, preload=True) + if cfg.ica_use_icalabel: + raw.set_eeg_reference("average", projection=True).apply_proj() ica.apply(raw) raw.save(out_files[in_key], overwrite=True, split_size=cfg._raw_split_size) _update_for_splits(out_files, in_key) @@ -245,6 +249,7 @@ def get_config( cfg = SimpleNamespace( baseline=config.baseline, ica_reject=config.ica_reject, + ica_use_icalabel=config.ica_use_icalabel, processing="filt" if config.regress_artifact is None else "regress", _epochs_split_size=config._epochs_split_size, **_import_data_kwargs(config=config, subject=subject), diff --git a/mne_bids_pipeline/tests/configs/config_ERP_CORE.py b/mne_bids_pipeline/tests/configs/config_ERP_CORE.py index e536450fa..c8ea4cd34 100644 --- a/mne_bids_pipeline/tests/configs/config_ERP_CORE.py +++ b/mne_bids_pipeline/tests/configs/config_ERP_CORE.py @@ -75,8 +75,12 @@ spatial_filter = None reject = "autoreject_local" autoreject_n_interpolate = [2, 4] -elif task == "N170": # test autoreject local before ICA +elif task == "N170": # test autoreject local before ICA, and MNE-ICALabel spatial_filter = "ica" + ica_algorithm = "picard-extended_infomax" + ica_use_icalabel = True + ica_h_freq = 100 + ica_l_freq = 1 ica_reject = "autoreject_local" reject = "autoreject_global" autoreject_n_interpolate = [2, 4] @@ -293,7 +297,6 @@ "O2", ] - ica_n_components = 30 - 1 for i in range(1, 180 + 1): orig_name = f"stimulus/{i}" @@ -316,6 +319,7 @@ conditions = ["stimulus/face/normal", "stimulus/car/normal"] contrasts = [("stimulus/face/normal", "stimulus/car/normal")] elif task == "P3": + ica_n_components = 30 - 1 # 29 channels rename_events = { "response/201": "response/correct", "response/202": "response/incorrect", diff --git a/mne_bids_pipeline/tests/test_documented.py b/mne_bids_pipeline/tests/test_documented.py index bc3eee1e3..4c5ade44e 100644 --- a/mne_bids_pipeline/tests/test_documented.py +++ b/mne_bids_pipeline/tests/test_documented.py @@ -66,7 +66,7 @@ def test_options_documented() -> None: why = f"Duplicate docs in {fname} and {other} for {val}" assert val not in in_doc[other], why in_doc[fname].add(val) - what = "docs/source/settings doc" + what = "docs/source/settings/*.md docs created by gen_settings.py" in_doc_all = set() for vals in in_doc.values(): in_doc_all.update(vals) diff --git a/mne_bids_pipeline/typing.py b/mne_bids_pipeline/typing.py index ebe2bcec3..fa873981c 100644 --- a/mne_bids_pipeline/typing.py +++ b/mne_bids_pipeline/typing.py @@ -2,7 +2,8 @@ import pathlib import sys -from typing import Annotated, Any, Literal, TypeAlias +from collections.abc import Hashable, Sequence +from typing import Annotated, Any, Literal, TypeAlias, TypeVar if sys.version_info < (3, 12): from typing_extensions import TypedDict @@ -13,7 +14,8 @@ import numpy as np from mne_bids import BIDSPath from numpy.typing import ArrayLike -from pydantic import PlainValidator +from pydantic import AfterValidator, Field, PlainValidator +from pydantic_core import PydanticCustomError PathLike = str | pathlib.Path @@ -82,3 +84,18 @@ def assert_dig_montage(val: mne.channels.DigMontage) -> mne.channels.DigMontage: mne.channels.DigMontage, PlainValidator(assert_dig_montage), ] + +T = TypeVar("T", bound=Hashable) + + +def _validate_unique_sequence(v: Sequence[T]) -> Sequence[T]: + if len(v) != len(set(v)): + raise PydanticCustomError("unique_sequence", "Sequence items must be unique") + return v + + +UniqueSequence = Annotated[ + Sequence[T], + AfterValidator(_validate_unique_sequence), + Field(json_schema_extra={"uniqueItems": True}), +] diff --git a/pyproject.toml b/pyproject.toml index e72787611..078b9ba6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ dependencies = [ "autoreject", "mne[hdf5] >=1.7", "mne-bids[full]", + "mne-icalabel", "filelock", "meegkit" ]