Skip to content
Merged
2 changes: 2 additions & 0 deletions doc/changes/devel/13364.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fail early with a clear error when non-finite values (NaN/Inf) are present
in PSD (Welch) and in ICA.fit, avoiding deep assertion failures by :newcontrib: `Emma Zhang`
7 changes: 7 additions & 0 deletions mne/preprocessing/ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,13 @@ def _pre_whiten(self, data):

def _fit(self, data, fit_type):
"""Aux function."""
if not np.isfinite(data).all():
raise ValueError(
"Input data contains non-finite values (NaN/Inf). "
"Please clean your data (e.g., high-pass filter, interpolate or drop "
"contaminated segments) before calling ICA.fit()."
)

random_state = check_random_state(self.random_state)
n_channels, n_samples = data.shape
self._compute_pre_whitener(data)
Expand Down
27 changes: 27 additions & 0 deletions mne/preprocessing/tests/test_ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -1746,3 +1746,30 @@ def test_ica_get_sources_concatenated():
# get sources
raw_sources = ica.get_sources(raw_concat) # but this only has 3 seconds of data
assert raw_concat.n_times == raw_sources.n_times # this will fail


@pytest.mark.filterwarnings(
"ignore:The data has not been high-pass filtered.:RuntimeWarning"
)
@pytest.mark.filterwarnings(
"ignore:invalid value encountered in subtract:RuntimeWarning"
)
def test_ica_rejects_nonfinite():
"""ICA.fit should fail early on NaN/Inf in the input data."""
info = create_info(["Fz", "Cz", "Pz", "Oz"], sfreq=100.0, ch_types="eeg")
rng = np.random.RandomState(1)
data = rng.randn(4, 1000)

# Case 1: NaN
raw = RawArray(data.copy(), info)
raw._data[0, 25] = np.nan
ica = ICA(n_components=2, random_state=0, method="fastica", max_iter="auto")
with pytest.raises(ValueError, match=r"Input data contains non[- ]?finite values"):
ica.fit(raw)

# Case 2: Inf
raw = RawArray(data.copy(), info)
raw._data[1, 50] = np.inf
ica = ICA(n_components=2, random_state=0, method="fastica", max_iter="auto")
with pytest.raises(ValueError, match=r"Input data contains non[- ]?finite values"):
ica.fit(raw)
17 changes: 16 additions & 1 deletion mne/time_frequency/psd.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,16 @@ def psd_array_welch(
del freq_mask
freqs = freqs[freq_sl]

# Hard error on Inf inside analyzed samples
step = max(n_per_seg - n_overlap, 1)
n_segments = 1 + (n_times - n_per_seg) // step if n_times >= n_per_seg else 0
analyzed_end = step * (n_segments - 1) + n_per_seg if n_segments > 0 else 0
if analyzed_end > 0 and np.isinf(x[..., :analyzed_end]).any():
raise ValueError(
"Input data contains non-finite values (Inf) in the analyzed time span. "
"Clean or drop bad segments before computing the PSD."
)

# Parallelize across first N-1 dimensions
logger.debug(
f"Spectogram using {n_fft}-point FFT on {n_per_seg} samples with "
Expand All @@ -221,7 +231,12 @@ def psd_array_welch(
good_mask = ~np.isnan(x)
# NaNs originate from annot, so must match for all channels. Note that we CANNOT
# use np.testing.assert_allclose() here; it is strict about shapes/broadcasting
assert np.allclose(good_mask, good_mask[[0]], equal_nan=True)
if not np.allclose(good_mask, good_mask[[0]], equal_nan=True):
raise ValueError(
"Input data contains NaN masks that are not aligned across channels; "
"make NaN spans consistent across channels or clean/drop bad segments."
)
# assert np.allclose(good_mask, good_mask[[0]], equal_nan=True)
t_onsets, t_offsets = _mask_to_onsets_offsets(good_mask[0])
x_splits = [x[..., t_ons:t_off] for t_ons, t_off in zip(t_onsets, t_offsets)]
# weights reflect the number of samples used from each span. For spans longer
Expand Down
22 changes: 22 additions & 0 deletions mne/time_frequency/tests/test_psd.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,25 @@ def test_psd_array_welch_n_jobs():
data = np.zeros((1, 2048))
psd_array_welch(data, 1024, n_jobs=1)
psd_array_welch(data, 1024, n_jobs=2)


def test_psd_raises_on_inf_in_analyzed_window_array():
"""psd_array_welch should fail if +Inf lies inside analyzed samples."""
n_samples, n_fft, n_overlap = 2048, 256, 128
rng = np.random.RandomState(0)
x = rng.randn(1, n_samples)
# Put +Inf inside the series; this falls within Welch windows
x[0, 800] = np.inf
with pytest.raises(ValueError, match="non[- ]?finite|NaN|Inf"):
psd_array_welch(x, float(n_fft), n_fft=n_fft, n_overlap=n_overlap)


def test_psd_raises_on_misaligned_nan_across_channels():
"""If NaNs are present but masks are NOT aligned across channels, raise."""
n_samples, n_fft, n_overlap = 2048, 256, 128
rng = np.random.RandomState(42)
x = rng.randn(2, n_samples)
# NaN only in ch0; ch1 has no NaN => masks not aligned -> should raise
x[0, 500] = np.nan
with pytest.raises(ValueError, match="aligned|not aligned|non[- ]?finite|NaN|Inf"):
psd_array_welch(x, float(n_fft), n_fft=n_fft, n_overlap=n_overlap)