diff --git a/src/iblphotometry/metrics.py b/src/iblphotometry/metrics.py index 77f00d1..a7163e6 100644 --- a/src/iblphotometry/metrics.py +++ b/src/iblphotometry/metrics.py @@ -2,17 +2,59 @@ import numpy as np import pandas as pd from scipy import stats +from scipy.signal import medfilt from iblphotometry.processing import ( z, Regression, ExponDecay, - detect_spikes, + detect_spikes_dt, + detect_spikes_dy, detect_outliers, + find_early_samples, + find_repeated_samples, ) from iblphotometry.behavior import psth +def dt_violations(A: pd.DataFrame | pd.Series, atol: float = 1e-3) -> str: + t = A.index.values + dts = np.diff(t) + dt = np.median(dts) + n_violations = np.sum(np.abs(dts - dt) > atol) + ## TODO: make ibllib wrappers to convert metrics to QC vals + # if n_violations == 0: + # outcome = QC.PASS + # elif n_violations <= 3: + # outcome = QC.WARNING + # elif n_violations <= 10: + # outcome = QC.CRITICAL + # else: + # outcome = QC.FAIL + # return outcome, n_violations + return n_violations + + +# def interleaved_acquisition(A: pd.DataFrame | pd.Series) -> bool: +# if sum(A['name'] == '') > 0: +# A = _fill_missing_channel_names(A) +# a = A['name'].values if isinstance(A, pd.DataFrame) else A.values +# even_check = np.all(a[::2] == a[0]) +# odd_check = np.all(a[1::2] == a[1]) +# return bool(even_check & odd_check) + + +def n_early_samples(A: pd.DataFrame | pd.Series, dt_tol: float = 0.001) -> int: + return find_early_samples(A, dt_tol=dt_tol).sum() + + +def n_repeated_samples( + A: pd.DataFrame, + dt_tol: float = 0.001, +) -> int: + return find_repeated_samples(A, dt_tol=dt_tol).sum() + + def percentile_dist(A: pd.Series | np.ndarray, pc: tuple = (50, 95), axis=-1) -> float: """the distance between two percentiles in units of z. Captures the magnitude of transients. @@ -33,6 +75,23 @@ def percentile_dist(A: pd.Series | np.ndarray, pc: tuple = (50, 95), axis=-1) -> return P[1] - P[0] +def deviance( + A: pd.Series | np.ndarray, + w_len: int = 151, +) -> float: + a = A.values if isinstance(A, pd.Series) else A + return np.median(np.abs(a - np.median(a)) / np.median(a)) + + +def sliding_deviance( + A: pd.Series | np.ndarray, + w_len: int = 151, +) -> float: + a = A.values if isinstance(A, pd.Series) else A + running_median = medfilt(a, kernel_size=w_len) + return np.median(np.abs(a - running_median) / running_median) + + def signal_asymmetry(A: pd.Series | np.ndarray, pc_comp: int = 95, axis=-1) -> float: """the ratio between the distance of two percentiles to the median. Proportional to the the signal to noise. @@ -63,15 +122,32 @@ def signal_skew(A: pd.Series | np.ndarray, axis=-1) -> float: def n_unique_samples(A: pd.Series | np.ndarray) -> int: - """number of unique samples in the signal. Low values indicate that the signal during acquisition was not within the range of the digitizer.""" + """ + Number of unique samples in the signal. Low values indicate that the signal + was not within the range of the digitizer during acquisition. + """ a = A.values if isinstance(A, pd.Series) else A return np.unique(a).shape[0] -def n_spikes(A: pd.Series | np.ndarray, sd: int = 5): +def f_unique_samples(A: pd.Series | np.ndarray) -> int: + """ + Wrapper that converts n_unique_samples to a fraction of the total number + of samples. + """ + return n_unique_samples(A) / len(A) + + +# def n_spikes_dt(A: pd.Series | np.ndarray, sd: int = 5): +# """count the number of spike artifacts in the recording.""" +# t = A.index.values if isinstance(A, pd.Series) else A +# return detect_spikes(t, sd=sd).shape[0] + + +def n_spikes_dy(A: pd.Series | np.ndarray, sd: int = 5): """count the number of spike artifacts in the recording.""" - a = A.values if isinstance(A, pd.Series) else A - return detect_spikes(a, sd=sd).shape[0] + y = A.values if isinstance(A, pd.Series) else A + return detect_spikes_dy(y, sd=sd).shape[0] def n_outliers( @@ -84,14 +160,52 @@ def n_outliers( return detect_outliers(a, w_size=w_size, alpha=alpha).shape[0] +def _expected_max_gauss(x): + """ + https://math.stackexchange.com/questions/89030/expectation-of-the-maximum-of-gaussian-random-variables + """ + return np.mean(x) + np.std(x) * np.sqrt(2 * np.log(len(x))) + + +def n_expmax_violations(A: pd.Series | np.ndarray) -> int: + a = A.values if isinstance(A, pd.Series) else A + exp_max = _expected_max_gauss(a) + return sum(np.abs(a) > exp_max) + + +def expmax_violation(A: pd.Series | np.ndarray) -> float: + a = A.values if isinstance(A, pd.Series) else A + exp_max = _expected_max_gauss(a) + n_violations = sum(np.abs(a) > exp_max) + if n_violations == 0: + return 0.0 + else: + return np.sum(np.abs(a[np.abs(a) > exp_max]) - exp_max) / n_violations + + def bleaching_tau(A: pd.Series) -> float: """overall tau of bleaching.""" y, t = A.values, A.index.values reg = Regression(model=ExponDecay()) - reg.fit(y, t) + reg.fit(t, y) return reg.popt[1] +def bleaching_amp(A: pd.Series | np.ndarray) -> float: + """overall amplitude of bleaching.""" + y = A.values if isinstance(A, pd.Series) else A + reg = Regression(model=LinearModel()) + try: + reg.fit(np.arange(len(a)), y) + slope = reg.popt[0] + except: + slope = np.nan + return slope + # rolling_mean = A.rolling(window=1000).mean().dropna() + # ## TODO: mean rolling std, not across whole signal + # return (rolling_mean.iloc[0] - rolling_mean.iloc[-1]) / A.std() + + def ttest_pre_post( A: pd.Series, trials: pd.DataFrame, @@ -179,7 +293,39 @@ def has_responses( return np.any(res) -def low_freq_power_ratio(A: pd.Series, f_cutoff: float = 3.18) -> float: +def response_variability_ratio( + A: pd.Series, events: np.ndarray, window: tuple = (0, 1) +): + signal = A.values.squeeze() + assert signal.ndim == 1 + tpts = A.index.values + dt = np.median(np.diff(tpts)) + events = events[events + window[1] < tpts.max()] + event_inds = tpts.searchsorted(events) + i0s = event_inds - int(window[0] / dt) + i1s = event_inds + int(window[1] / dt) + responses = np.row_stack([signal[i0:i1] for i0, i1 in zip(i0s, i1s)]) + responses = (responses.T - signal[event_inds]).T + return (responses).mean(axis=0).var() / (responses).var(axis=0).mean() + + +def response_magnitude(A: pd.Series, events: np.ndarray, window: tuple = (0, 1)): + signal = A.values.squeeze() + assert signal.ndim == 1 + tpts = A.index.values + dt = np.median(np.diff(tpts)) + events = events[events + window[1] < tpts.max()] + event_inds = tpts.searchsorted(events) + i0s = event_inds - int(window[0] / dt) + i1s = event_inds + int(window[1] / dt) + responses = np.row_stack([signal[i0:i1] for i0, i1 in zip(i0s, i1s)]) + responses = (responses.T - signal[event_inds]).T + return np.abs(responses.mean(axis=0)).sum() + + +def low_freq_power_ratio( + A: pd.Series | np.ndarray, dt: float | None = None, f_cutoff: float = 3.18 +) -> float: """ Fraction of the total signal power contained below a given cutoff frequency. @@ -192,20 +338,24 @@ def low_freq_power_ratio(A: pd.Series, f_cutoff: float = 3.18) -> float: cutoff frequency, default value of 3.18 esitmated using the formula 1 / (2 * pi * tau) and an approximate tau_rise for GCaMP6f of 0.05s. """ - signal = A.copy() + if isinstance(A, pd.Series): + signal = A.values + dt = np.median(np.diff(A.index)) + else: + assert dt is not None + signal = A assert signal.ndim == 1 # only 1D for now # Get frequency bins - tpts = signal.index.values - dt = np.median(np.diff(tpts)) - n_pts = len(signal) - freqs = np.fft.rfftfreq(n_pts, dt) + freqs = np.fft.rfftfreq(len(signal), dt) # Compute power spectral density psd = np.abs(np.fft.rfft(signal - signal.mean())) ** 2 - # Return the ratio of power contained in low freqs + # Return proportion of total power in low freqs return psd[freqs <= f_cutoff].sum() / psd.sum() -def spectral_entropy(A: pd.Series, eps: float = np.finfo('float').eps) -> float: +def spectral_entropy( + A: pd.Series | np.ndarray, eps: float = np.finfo('float').eps +) -> float: """ Compute the normalized entropy of the signal power spectral density and return a metric (1 - entropy) that is low (0) if all frequency components @@ -220,7 +370,7 @@ def spectral_entropy(A: pd.Series, eps: float = np.finfo('float').eps) -> float: eps : small number added to the PSD for numerical stability """ - signal = A.copy() + signal = A.values if isinstance(A, pd.Series) else A assert signal.ndim == 1 # only 1D for now # Compute power spectral density psd = np.abs(np.fft.rfft(signal - signal.mean())) ** 2 @@ -234,24 +384,52 @@ def spectral_entropy(A: pd.Series, eps: float = np.finfo('float').eps) -> float: return 1 - norm_entropy -def ar_score(A: pd.Series) -> float: +def ar_score(A: pd.Series | np.ndarray, order: int = 2) -> float: """ - R-squared from an AR(1) model fit to the signal as a measure of the temporal + R-squared from an AR(n) model fit to the signal as a measure of the temporal structure present in the signal. Parameters ---------- - A : - the signal time series with signal values in the columns and sample - times in the index + A : pd.Series or np.ndarray + The signal time series with signal values in the columns and sample + times in the index. + order : int, optional + The order of the AR model. Default is 2. + + Returns + ------- + float + The R-squared value indicating the variance explained by the AR model. + Returns NaN if the signal is constant. """ - # Pull signal out of pandas series - signal = A.values - assert signal.ndim == 1 # only 1D for now - X = signal[:-1] - y = signal[1:] - res = stats.linregress(X, y) - return res.rvalue**2 + # Pull signal out of pandas Series if needed + signal = A.values if isinstance(A, pd.Series) else A + assert signal.ndim == 1, 'Signal must be 1-dimensional.' + + # Handle constant signal case + if len(np.unique(signal)) == 1: + return np.nan + + # Create design matrix X and target vector y based on AR order + X = np.column_stack([signal[i : len(signal) - order + i] for i in range(order)]) + y = signal[order:] + + try: + # Fit linear regression using least squares + _, residual, _, _ = np.linalg.lstsq(X, y) + except np.linalg.LinAlgError: + return np.nan + + if residual: + # Calculate R-squared using residuals + ss_residual = residual[0] + ss_total = np.sum((y - np.mean(y)) ** 2) + r_squared = 1 - (ss_residual / ss_total) + else: + r_squared = np.nan + + return r_squared def noise_simulation( diff --git a/src/iblphotometry/processing.py b/src/iblphotometry/processing.py index 52d3a2a..dc2e5d9 100644 --- a/src/iblphotometry/processing.py +++ b/src/iblphotometry/processing.py @@ -7,6 +7,7 @@ from iblutil.numerical import rcoeff from ibldsp.utils import WindowGenerator +from numpy.lib.stride_tricks import as_strided from scipy.optimize import minimize from scipy.stats.distributions import norm @@ -582,54 +583,124 @@ def fillnan_kde(y: np.ndarray, w: int = 25): def remove_outliers( - F: pd.Series, - w_len: float = 60, - alpha: float = 0.005, - w: int = 25, - fs=None, - max_it=100, + F: pd.Series, w_size: int = 1000, alpha: float = 0.005, w: int = 25 ): y, t = F.values, F.index.values - fs = 1 / np.median(np.diff(t)) if fs is None else fs - w_size = int(w_len * fs) - y = copy(y) outliers = detect_outliers(y, w_size=w_size, alpha=alpha) - j = 0 while len(outliers) > 0: y[outliers] = np.nan y = fillnan_kde(y, w=w) outliers = detect_outliers(y, w_size=w_size, alpha=alpha) - if j > max_it: - break - else: - j += 1 - return pd.Series(y, index=t) -def detect_spikes(t: np.ndarray, sd: int = 5): - dt = np.diff(t) - bad_inds = dt < np.average(dt) - sd * np.std(dt) - return np.where(bad_inds)[0] +def detect_spikes_dt(t: np.ndarray, atol: float = 0.001): + dts = np.diff(t) + dt = np.median(dts) + # bad_inds = dt < np.average(dt) - sd * np.std(dt) + # return np.where(bad_inds)[0] + return np.where(np.abs(dts - dt) > atol)[0] + + +def detect_spikes_dy(y: np.ndarray, sd: float = 5.0): + dy = np.abs(np.diff(y)) + # bad_inds = dt < np.average(dt) - sd * np.std(dt) + return np.where(dy > np.average(dy) + sd * np.std(dy))[0] -def remove_spikes(F: pd.Series, sd: int = 5, w: int = 25): +def remove_spikes(F: pd.Series, delta: str = 't', sd: int = 5, w: int = 25): y, t = F.values, F.index.values y = copy(y) - outliers = detect_spikes(y, sd=sd) + if delta == 't': + outliers = detect_spikes_dt(t, atol=0.001) + elif delta == 'y': + outliers = detect_spikes_dy(y, sd=sd) + else: + raise ValueError('delta must be "t" or "y"') y[outliers] = np.nan try: y = fillnan_kde(y, w=w) - except np.linalg.LinAlgError: - if np.all(pd.isna(y[outliers])): # all are NaN! - y[:] = 0 - warnings.warn('all values NaN, setting to zeros') # TODO logger + # except np.linalg.LinAlgError: + except: + i0s = (outliers - w).clip(0) + i1s = outliers + w + y[outliers] = [np.nanmedian(y[i0:i1]) for i0, i1 in zip(i0s, i1s)] + warnings.warn('KDE fillnan failed, using local median') # TODO logger + return pd.Series(y, index=t) + + +## TODO: consider this simple interpolation method that uses the local median +# def remove_spikes(F: pd.Series, sd: int = 5, w: int = 5): +# f = F.copy() +# y, t = f.values, f.index.values +# outliers = detect_spikes(y, sd=sd) +# outliers = np.unique(np.concatenate([outliers - 1, outliers])) +# i0s = (outliers - w).clip(0) +# i1s = outliers + w +# y[outliers] = [np.nanmedian(y[i0:i1]) for i0, i1 in zip(i0s, i1s)] +# return pd.Series(y, index=t) + + +def find_early_samples( + A: pd.DataFrame | pd.Series, dt_tol: float = 0.001 +) -> np.ndarray: + dt = np.median(np.diff(A.index)) + return dt - A.index.diff() > dt_tol + + +def _fill_missing_channel_names(A: np.ndarray) -> np.ndarray: + missing_inds = np.where(A == '')[0] + name_alternator = {'GCaMP': 'Isosbestic', 'Isosbestic': 'GCaMP', '': ''} + for i in missing_inds: + if i == 0: + A[i] = name_alternator[A[i + 1]] else: - y[outliers] = np.nanmedian(y) - warnings.warn('KDE fillnan failed, using global median') # TODO logger + A[i] = name_alternator[A[i - 1]] + return A - return pd.Series(y, index=t) + +def find_repeated_samples( + A: pd.DataFrame, + dt_tol: float = 0.001, +) -> int: + if any(A['name'] == ''): + A['name'] = _fill_missing_channel_names(A['name'].values) + else: + A + repeated_sample_mask = A['name'].iloc[1:].values == A['name'].iloc[:-1].values + repeated_samples = A.iloc[1:][repeated_sample_mask] + dt = np.median(np.diff(A.index)) + early_samples = A[find_early_samples(A, dt_tol=dt_tol)] + if not all([idx in early_samples.index for idx in repeated_samples.index]): + print('WARNING: repeated samples found without early sampling') + return repeated_sample_mask + + +def fix_repeated_sampling( + A: pd.DataFrame, dt_tol: float = 0.001, w_size: int = 10, roi: str | None = None +) -> int: + ## TODO: avoid this by explicitly handling multiple channels + assert roi is not None + # Drop first samples if channel labels are missing + A.loc[A['name'].replace({'': np.nan}).first_valid_index() :] + # Fix remaining missing channel labels + if any(A['name'] == ''): + A['name'] = _fill_missing_channel_names(A['name'].values) + repeated_sample_mask = find_repeated_samples(A, dt_tol=dt_tol) + name_alternator = {'GCaMP': 'Isosbestic', 'Isosbestic': 'GCaMP'} + for i in np.where(repeated_sample_mask)[0] + 1: + name = A.iloc[i]['name'] + value = A.iloc[i][roi] + i0, i1 = A.index[i - w_size], A.index[i] + same = A.loc[i0:i1].query('name == @name')[roi].mean() + other_name = name_alternator[name] + other = A.loc[i0:i1].query('name == @other_name')[roi].mean() + assert np.abs(value - same) > np.abs(value - other) + A.loc[A.index[i] :, 'name'] = [ + name_alternator[name] for name in A.loc[A.index[i] :, 'name'] + ] + return A """ @@ -743,7 +814,7 @@ def sliding_z(F: pd.Series, w_len: float, fs=None, weights=None): return pd.Series(d, index=t) -def sliding_mad(F: pd.Series, w_len: float = None, fs=None, overlap=90): +def sliding_mad(F: pd.DataFrame, w_len: float = None, fs=None, overlap=90): y, t = F.values, F.index.values fs = 1 / np.median(np.diff(t)) if fs is None else fs w_size = int(w_len * fs) @@ -756,3 +827,144 @@ def sliding_mad(F: pd.Series, w_len: float = None, fs=None, overlap=90): gain = np.nanmedian(np.abs(y)) / np.nanmedian(np.abs(rmswin), axis=0) gain = np.interp(t, trms, gain) return pd.Series(y * gain, index=t) + + +def sliding_robust_zscore(F: pd.Series, w_len: float, scale: bool = True) -> pd.Series: + """ + Compute a robust z-score for each data point in a pandas Series using a sliding window. + + For each data point at which a full window (centered around that point) + is available, compute the robust z-score: + + z = (x - median(window)) / MAD(window) + + where MAD is the median absolute deviation of the window. If scale=True, + the MAD is multiplied by 1.4826 to approximate the standard deviation under normality. + + Parameters + ---------- + F : pd.Series + Input time-series with a numeric index (time) and signal values. + w_len : float + Window length in seconds. + scale : bool, optional + Whether to scale the MAD by 1.4826. Default is True. + + Returns + ------- + pd.Series + A new Series of the same length as F containing the robust z-scores. + Data points near the boundaries without a full window are NaN. + """ + # Ensure the index is numeric (time in seconds) + times = F.index.values.astype(float) + dt = np.median(np.diff(times)) + + # Number of samples corresponding to w_len in seconds. + w_size = int(w_len // dt) + # Make window size odd so that a window can be centered + if w_size % 2 == 0: + w_size += 1 + half_win = w_size // 2 + + a = F.values # Underlying data + n = len(a) + + # We can only compute a full (centered) window where there's enough data on both sides. + # Valid center positions are indices half_win to n - half_win - 1. + n_valid = n - 2 * half_win + if n_valid <= 0: + raise ValueError('Window length is too long for the given series.') + + # Using step size of 1: each valid index gets its own window. + # Create a 2D view of the signal: + # windows shape: (n_valid, w_size) + windows = as_strided( + a[half_win : n - half_win], + shape=(n_valid, w_size), + strides=(a.strides[0], a.strides[0]), + ) + # However, the above would take contiguous blocks from a[half_win: n - half_win] only. + # To get a sliding window centered at each valid index, we need a trick: + # We'll use as_strided on the full array, starting at index 0, then select the valid windows: + windows_full = as_strided( + a, shape=(n - w_size + 1, w_size), strides=(a.strides[0], a.strides[0]) + ) + # The center of the k-th window in windows_full is at index k + half_win. + # We want windows centered at indices half_win, half_win+1, ..., n - half_win - 1. + # Thus, we select: + windows = windows_full[0 + 0 : 0 + n_valid] # shape (n_valid, w_size) + + # Compute the median for each window (row-wise). + medians = np.median(windows, axis=1) + # Compute the MAD for each window. + mads = np.median(np.abs(windows - medians[:, None]), axis=1) + if scale: + mads *= 1.4826 # Scale MAD to approximate standard deviation under a normal distribution. + + # Avoid division by zero: if MAD is zero, set those z-scores to 0. + safe_mads = np.where(mads == 0, np.nan, mads) + + # Compute robust z-scores for the center value of each window. + # The center value for the k-th window is at index: k + half_win in the original array. + centers = a[half_win : n - half_win] + z_scores_valid = (centers - medians) / safe_mads + + # Pre-allocate result (all values NaN) + robust_z = np.full(n, np.nan) + # Fill in the computed z-scores at valid indices. + valid_idx = np.arange(half_win, n - half_win) + robust_z[valid_idx] = z_scores_valid + + # Return as a Series with the original index + return pd.Series(robust_z, index=F.index) + + +def sliding_robust_zscore_rolling( + F: pd.Series, w_len: float, scale: bool = True +) -> pd.Series: + """ + Compute a robust z-score for each data point using a sliding window via pandas’ rolling(). + + For each point where a full, centered window is available, compute: + z = (x_center - median(window)) / MAD(window) + where MAD is the median absolute deviation and, if scale=True, MAD is scaled by 1.4826. + + Parameters + ---------- + F : pd.Series + Input time-series with a numeric index (time in seconds) and signal values. + w_len : float + The window length in seconds. + scale : bool, optional + If True, multiply MAD by 1.4826 (default is True). + + Returns + ------- + pd.Series + A Series containing the robust z-scores at the center of each window. + Points for which a full window cannot be computed will be NaN. + """ + # Get the sample interval from the index + times = F.index.values.astype(float) + dt = np.median(np.diff(times)) + # Compute window size in samples and ensure it is odd (so there's a unique center) + w_size = int(w_len / dt) + if w_size % 2 == 0: + w_size += 1 + + def robust_zscore(window): + # window is passed as a NumPy array (raw=True) + center = window[len(window) // 2] + med = np.median(window) + mad = np.median(np.abs(window - med)) + if mad == 0: + return np.nan + if scale: + mad *= 1.4826 + return (center - med) / mad + + # Use rolling window with center=True so that the result corresponds to the window center. + F_proc = F.rolling(window=w_size, center=True).apply(robust_zscore, raw=True) + # Return only valid (non-NaN) portions of the transformed signal + return F_proc.loc[F_proc.first_valid_index() : F_proc.last_valid_index()] diff --git a/src/iblphotometry/qc.py b/src/iblphotometry/qc.py index 2f43684..6f7ea12 100644 --- a/src/iblphotometry/qc.py +++ b/src/iblphotometry/qc.py @@ -5,11 +5,13 @@ import logging import numpy as np +from numpy.lib.stride_tricks import as_strided import pandas as pd from scipy.stats import linregress from iblphotometry.processing import make_sliding_window from iblphotometry.pipelines import run_pipeline +import iblphotometry.metrics as metrics logger = logging.getLogger() @@ -51,25 +53,57 @@ def sliding_metric( return pd.Series(m, index=t[inds + int(w_size / 2)]) +def _eval_metric_sliding( + F: pd.Series, + metric: Callable, + w_len: float = 60, + metric_kwargs: dict | None = None, +) -> pd.Series: + metric_kwargs = {} if metric_kwargs is None else metric_kwargs + dt = np.median(np.diff(F.index)) + w_size = int(w_len // dt) + step_size = int(w_size // 2) + n_windows = int((len(F) - w_size) // step_size + 1) + if n_windows <= 2: + return + a = F.values + windows = as_strided( + a, shape=(n_windows, w_size), strides=(step_size * a.strides[0], a.strides[0]) + ) + S_values = np.apply_along_axis( + lambda w: metric(w, **metric_kwargs), axis=1, arr=windows + ) + S_times = F.index.values[ + np.linspace(step_size, n_windows * step_size, n_windows).astype(int) + ] + return pd.Series(S_values, index=S_times) + + # eval pipleline will be here def eval_metric( F: pd.Series, metric: Callable, metric_kwargs: dict | None = None, sliding_kwargs: dict | None = None, + full_output=True, ): - m = metric(F, **metric_kwargs) if metric_kwargs is not None else metric(F) - - if sliding_kwargs is not None: - S = sliding_metric( - F, metric=metric, **sliding_kwargs, metric_kwargs=metric_kwargs - ) - r, p = linregress(S.index.values, S.values)[2:4] - else: - r = np.nan - p = np.nan - - return dict(value=m, rval=r, pval=p) + results_vals = ['value', 'sliding_values', 'sliding_timepoints', 'r', 'p'] + result = {k: np.nan for k in results_vals} + metric_func = getattr(metrics, metric) + result['value'] = ( + metric_func(F) if metric_kwargs is None else metric_func(F, **metric_kwargs) + ) + sliding_kwargs = {} if sliding_kwargs is None else sliding_kwargs + if sliding_kwargs: + S = _eval_metric_sliding(F, metric_func, sliding_kwargs['w_len'], metric_kwargs) + if S is None: + pass + else: + result['r'], result['p'] = linregress(S.index.values, S.values)[2:4] + if full_output: + result['sliding_values'] = S.values + result['sliding_timepoints'] = S.index.values + return result def qc_series( @@ -81,23 +115,45 @@ def qc_series( brain_region: str = None, # FIXME but left as is for now just to keep the logger happy ) -> dict: if isinstance(F, pd.DataFrame): - raise TypeError('F can not be a dataframe') + raise TypeError('F cannot be a dataframe') + + # if sliding_kwargs is None: # empty dicts indicate no sliding application + # sliding_kwargs = {metric:{} for metric in qc_metrics.keys()} + # elif ( + # isinstance(sliding_kwargs, dict) and + # not sorted(sliding_kwargs.keys()) == sorted(qc_metrics.keys()) + # ): # the same sliding kwargs will be applied to all metrics + # sliding_kwargs = {metric:sliding_kwargs for metric in qc_metrics.keys()} + # elif ( + # isinstance(sliding_kwargs, dict) and + # sorted(sliding_kwargs.keys()) == sorted(qc_metrics.keys()) + # ): # each metric has it's own sliding kwargs + # pass + # else: # is not None, a simple dict, or a nested dict + # raise TypeError( + # 'sliding_kwargs must be None, dict of kwargs, or nested dict with same keys as qc_metrics' + # ) + sliding_kwargs = {} if sliding_kwargs is None else sliding_kwargs # should cover all cases qc_results = {} - for metric, params in qc_metrics: - try: - if trials is not None: # if trials are passed - params['trials'] = trials - res = eval_metric(F, metric, params, sliding_kwargs) - qc_results[f'{metric.__name__}'] = res['value'] - if sliding_kwargs: - qc_results[f'{metric.__name__}_r'] = res['rval'] - qc_results[f'{metric.__name__}_p'] = res['pval'] - except Exception as e: - logger.warning( - f'{eid}, {brain_region}: metric {metric.__name__} failure: {type(e).__name__}:{e}' - ) + for metric, params in qc_metrics.items(): + # try: + if trials is not None: # if trials are passed + params['trials'] = trials + res = eval_metric( + F, metric, metric_kwargs=params, sliding_kwargs=sliding_kwargs[metric] + ) + qc_results[f'{metric}'] = res['value'] + if sliding_kwargs[metric]: + qc_results[f'_{metric}_values'] = res['sliding_values'] + qc_results[f'_{metric}_times'] = res['sliding_timepoints'] + qc_results[f'_{metric}_r'] = res['r'] + qc_results[f'_{metric}_p'] = res['p'] + # except Exception as e: + # logger.warning( + # f'{eid}, {brain_region}: metric {metric.__name__} failure: {type(e).__name__}:{e}' + # ) return qc_results