Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
c8f03bd
Add PSTH based metrics
davcrom Jan 17, 2025
0eae5b4
Allow flexible input dtype
davcrom Jan 17, 2025
8ebb16d
Update n_unique_samples docstring
davcrom Jan 17, 2025
a9750f7
Normalize n_unique_samples by total samples
davcrom Jan 17, 2025
8d4e5cc
Spike detection for both positive and negative jumps
davcrom Jan 17, 2025
34ca19e
Add check for dead signals
davcrom Jan 17, 2025
b2782eb
Allow AR(n) models in ar_score
davcrom Mar 12, 2025
80a3166
Revert to original spike detection
davcrom Mar 12, 2025
9f7c092
ruff format
davcrom Mar 27, 2025
168445c
Add dt_violations metric
davcrom Apr 4, 2025
aff9a50
Add interleaved acquisition metric
davcrom Apr 4, 2025
058e887
Add sliding_deviance metric
davcrom Apr 4, 2025
e09eed3
Add f_unique_samples
davcrom Apr 4, 2025
e0e332c
Fix bug in remove spikes
davcrom Apr 4, 2025
787b07b
Duplicate n_spikes for dt and dy
davcrom Apr 4, 2025
51fa9c0
Suggest simple nan fill method with local median
davcrom Apr 4, 2025
5f519b9
Add sliding_expmax_violation
davcrom Apr 4, 2025
c4ddc8a
Add photobleaching_amp (WIP)
davcrom Apr 4, 2025
3c1fd7b
ruff format
davcrom Apr 4, 2025
c07651f
merge not present conflict?
grg2rsr Apr 8, 2025
9109b38
Remove always True conditional
davcrom Apr 8, 2025
287b11f
Fix bug in n_spikes_dt
davcrom Apr 8, 2025
db07a5a
Make expmax_violation safe for no outliers case
davcrom Apr 8, 2025
f1ab68a
Fix merge conflict
davcrom Apr 8, 2025
b596663
Housekeeping
davcrom Apr 18, 2025
8d41f34
Add separate functions for dy and dt spike detection
davcrom Apr 18, 2025
6cd8811
Replace global median with local median in remove_spikes
davcrom Apr 18, 2025
fa58dfa
Update detection of "spikes" caused by early sampling
davcrom Apr 18, 2025
d5222ac
Add deviance metric
davcrom Apr 18, 2025
23d9790
Update expmax violation metrics
davcrom Apr 18, 2025
ae5c8cb
Bugfix in model fitting
davcrom Apr 18, 2025
63b1f41
WIP: bleaching_amp metric
davcrom Apr 18, 2025
1f11178
Update low_freq_power_ratio to accept dt as kwarg
davcrom Apr 18, 2025
81d6c0b
Handle edge cases with too few data points in ar_score
davcrom Apr 18, 2025
a024876
WIP: add two methods for sliding_robust_zscore
davcrom Apr 18, 2025
1d7d9dc
Major update to qc_series sliding metric handling
davcrom Apr 18, 2025
50995f9
Ruff format
davcrom Apr 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 205 additions & 27 deletions src/iblphotometry/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,59 @@
import numpy as np
import pandas as pd
from scipy import stats
from scipy.signal import medfilt

from iblphotometry.processing import (
z,
Regression,
ExponDecay,
detect_spikes,
detect_spikes_dt,

Check failure on line 11 in src/iblphotometry/metrics.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

src/iblphotometry/metrics.py:11:5: F401 `iblphotometry.processing.detect_spikes_dt` imported but unused
detect_spikes_dy,
detect_outliers,
find_early_samples,
find_repeated_samples,
)
from iblphotometry.behavior import psth


def dt_violations(A: pd.DataFrame | pd.Series, atol: float = 1e-3) -> str:
t = A.index.values
dts = np.diff(t)
dt = np.median(dts)
n_violations = np.sum(np.abs(dts - dt) > atol)
## TODO: make ibllib wrappers to convert metrics to QC vals
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is indeed a huge todo. The structure that I see for this, is a separate definition of the ranges for the metric and the corresponding label. I would split this away from the function definition as we might change the interpretation for the metric, while the definition of it is standalone, to be decided. I also want to study the QC system for the ephys more first to see what works well there and what doesn't

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kick dt violations

# if n_violations == 0:
# outcome = QC.PASS
# elif n_violations <= 3:
# outcome = QC.WARNING
# elif n_violations <= 10:
# outcome = QC.CRITICAL
# else:
# outcome = QC.FAIL
# return outcome, n_violations
return n_violations


# def interleaved_acquisition(A: pd.DataFrame | pd.Series) -> bool:
# if sum(A['name'] == '') > 0:
# A = _fill_missing_channel_names(A)
# a = A['name'].values if isinstance(A, pd.DataFrame) else A.values
# even_check = np.all(a[::2] == a[0])
# odd_check = np.all(a[1::2] == a[1])
# return bool(even_check & odd_check)


def n_early_samples(A: pd.DataFrame | pd.Series, dt_tol: float = 0.001) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add docstring and definition what is an early sample

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to processing

return find_early_samples(A, dt_tol=dt_tol).sum()


def n_repeated_samples(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add docstring

A: pd.DataFrame,
dt_tol: float = 0.001,
) -> int:
return find_repeated_samples(A, dt_tol=dt_tol).sum()


def percentile_dist(A: pd.Series | np.ndarray, pc: tuple = (50, 95), axis=-1) -> float:
"""the distance between two percentiles in units of z. Captures the magnitude of transients.
Expand All @@ -33,6 +75,23 @@
return P[1] - P[0]


def deviance(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add docstring
unused argument w_len, also looks like you are using w_len for samples. In sliding operations, I followed a syntax as such

    y, t = F.values, F.index.values
    fs = 1 / np.median(np.diff(t)) if fs is None else fs
    w_size = int(w_len * fs)

A: pd.Series | np.ndarray,
w_len: int = 151,
) -> float:
a = A.values if isinstance(A, pd.Series) else A
return np.median(np.abs(a - np.median(a)) / np.median(a))


def sliding_deviance(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If possible I would argue for not having "sliding" variants of metrics but using a framework defined elsewhere such as sliding operations to use for evaluating metrics in a sliding manner.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kick it

A: pd.Series | np.ndarray,
w_len: int = 151,
) -> float:
a = A.values if isinstance(A, pd.Series) else A
running_median = medfilt(a, kernel_size=w_len)
return np.median(np.abs(a - running_median) / running_median)


def signal_asymmetry(A: pd.Series | np.ndarray, pc_comp: int = 95, axis=-1) -> float:
"""the ratio between the distance of two percentiles to the median. Proportional to the the signal to noise.
Expand Down Expand Up @@ -63,15 +122,32 @@


def n_unique_samples(A: pd.Series | np.ndarray) -> int:
"""number of unique samples in the signal. Low values indicate that the signal during acquisition was not within the range of the digitizer."""
"""
Number of unique samples in the signal. Low values indicate that the signal
was not within the range of the digitizer during acquisition.
"""
a = A.values if isinstance(A, pd.Series) else A
return np.unique(a).shape[0]


def n_spikes(A: pd.Series | np.ndarray, sd: int = 5):
def f_unique_samples(A: pd.Series | np.ndarray) -> int:
"""
Wrapper that converts n_unique_samples to a fraction of the total number
of samples.
"""
return n_unique_samples(A) / len(A)


# def n_spikes_dt(A: pd.Series | np.ndarray, sd: int = 5):
# """count the number of spike artifacts in the recording."""
# t = A.index.values if isinstance(A, pd.Series) else A
# return detect_spikes(t, sd=sd).shape[0]


def n_spikes_dy(A: pd.Series | np.ndarray, sd: int = 5):
"""count the number of spike artifacts in the recording."""
a = A.values if isinstance(A, pd.Series) else A
return detect_spikes(a, sd=sd).shape[0]
y = A.values if isinstance(A, pd.Series) else A
return detect_spikes_dy(y, sd=sd).shape[0]


def n_outliers(
Expand All @@ -84,14 +160,52 @@
return detect_outliers(a, w_size=w_size, alpha=alpha).shape[0]


def _expected_max_gauss(x):
"""
https://math.stackexchange.com/questions/89030/expectation-of-the-maximum-of-gaussian-random-variables
"""
return np.mean(x) + np.std(x) * np.sqrt(2 * np.log(len(x)))


def n_expmax_violations(A: pd.Series | np.ndarray) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please docstring

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge to expmax_violations

a = A.values if isinstance(A, pd.Series) else A
exp_max = _expected_max_gauss(a)
return sum(np.abs(a) > exp_max)


def expmax_violation(A: pd.Series | np.ndarray) -> float:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please docstring

a = A.values if isinstance(A, pd.Series) else A
exp_max = _expected_max_gauss(a)
n_violations = sum(np.abs(a) > exp_max)
if n_violations == 0:
return 0.0
else:
return np.sum(np.abs(a[np.abs(a) > exp_max]) - exp_max) / n_violations


def bleaching_tau(A: pd.Series) -> float:
"""overall tau of bleaching."""
y, t = A.values, A.index.values
reg = Regression(model=ExponDecay())
reg.fit(y, t)
reg.fit(t, y)
return reg.popt[1]


def bleaching_amp(A: pd.Series | np.ndarray) -> float:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this is a meaningful metric as we don't do a proper quantitative measurement of photon count or similar

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kick it

"""overall amplitude of bleaching."""
y = A.values if isinstance(A, pd.Series) else A
reg = Regression(model=LinearModel())

Check failure on line 197 in src/iblphotometry/metrics.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

src/iblphotometry/metrics.py:197:28: F821 Undefined name `LinearModel`
try:
reg.fit(np.arange(len(a)), y)

Check failure on line 199 in src/iblphotometry/metrics.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

src/iblphotometry/metrics.py:199:31: F821 Undefined name `a`
slope = reg.popt[0]
except:

Check failure on line 201 in src/iblphotometry/metrics.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E722)

src/iblphotometry/metrics.py:201:5: E722 Do not use bare `except`
slope = np.nan
return slope
# rolling_mean = A.rolling(window=1000).mean().dropna()
# ## TODO: mean rolling std, not across whole signal
# return (rolling_mean.iloc[0] - rolling_mean.iloc[-1]) / A.std()


def ttest_pre_post(
A: pd.Series,
trials: pd.DataFrame,
Expand Down Expand Up @@ -179,7 +293,39 @@
return np.any(res)


def low_freq_power_ratio(A: pd.Series, f_cutoff: float = 3.18) -> float:
def response_variability_ratio(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring

A: pd.Series, events: np.ndarray, window: tuple = (0, 1)
):
signal = A.values.squeeze()
assert signal.ndim == 1
tpts = A.index.values
dt = np.median(np.diff(tpts))
events = events[events + window[1] < tpts.max()]
event_inds = tpts.searchsorted(events)
i0s = event_inds - int(window[0] / dt)
i1s = event_inds + int(window[1] / dt)
responses = np.row_stack([signal[i0:i1] for i0, i1 in zip(i0s, i1s)])
responses = (responses.T - signal[event_inds]).T
return (responses).mean(axis=0).var() / (responses).var(axis=0).mean()


def response_magnitude(A: pd.Series, events: np.ndarray, window: tuple = (0, 1)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leftovers from pynapple, fixme

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kick it

signal = A.values.squeeze()
assert signal.ndim == 1
tpts = A.index.values
dt = np.median(np.diff(tpts))
events = events[events + window[1] < tpts.max()]
event_inds = tpts.searchsorted(events)
i0s = event_inds - int(window[0] / dt)
i1s = event_inds + int(window[1] / dt)
responses = np.row_stack([signal[i0:i1] for i0, i1 in zip(i0s, i1s)])
responses = (responses.T - signal[event_inds]).T
return np.abs(responses.mean(axis=0)).sum()


def low_freq_power_ratio(
A: pd.Series | np.ndarray, dt: float | None = None, f_cutoff: float = 3.18
) -> float:
"""
Fraction of the total signal power contained below a given cutoff frequency.
Expand All @@ -192,20 +338,24 @@
cutoff frequency, default value of 3.18 esitmated using the formula
1 / (2 * pi * tau) and an approximate tau_rise for GCaMP6f of 0.05s.
"""
signal = A.copy()
if isinstance(A, pd.Series):
signal = A.values
dt = np.median(np.diff(A.index))
else:
assert dt is not None
signal = A
assert signal.ndim == 1 # only 1D for now
# Get frequency bins
tpts = signal.index.values
dt = np.median(np.diff(tpts))
n_pts = len(signal)
freqs = np.fft.rfftfreq(n_pts, dt)
freqs = np.fft.rfftfreq(len(signal), dt)
# Compute power spectral density
psd = np.abs(np.fft.rfft(signal - signal.mean())) ** 2
# Return the ratio of power contained in low freqs
# Return proportion of total power in low freqs
return psd[freqs <= f_cutoff].sum() / psd.sum()


def spectral_entropy(A: pd.Series, eps: float = np.finfo('float').eps) -> float:
def spectral_entropy(
A: pd.Series | np.ndarray, eps: float = np.finfo('float').eps
) -> float:
"""
Compute the normalized entropy of the signal power spectral density and
return a metric (1 - entropy) that is low (0) if all frequency components
Expand All @@ -220,7 +370,7 @@
eps :
small number added to the PSD for numerical stability
"""
signal = A.copy()
signal = A.values if isinstance(A, pd.Series) else A
assert signal.ndim == 1 # only 1D for now
# Compute power spectral density
psd = np.abs(np.fft.rfft(signal - signal.mean())) ** 2
Expand All @@ -234,24 +384,52 @@
return 1 - norm_entropy


def ar_score(A: pd.Series) -> float:
def ar_score(A: pd.Series | np.ndarray, order: int = 2) -> float:
"""
R-squared from an AR(1) model fit to the signal as a measure of the temporal
R-squared from an AR(n) model fit to the signal as a measure of the temporal
structure present in the signal.
Parameters
----------
A :
the signal time series with signal values in the columns and sample
times in the index
A : pd.Series or np.ndarray
The signal time series with signal values in the columns and sample
times in the index.
order : int, optional
The order of the AR model. Default is 2.
Returns
-------
float
The R-squared value indicating the variance explained by the AR model.
Returns NaN if the signal is constant.
"""
# Pull signal out of pandas series
signal = A.values
assert signal.ndim == 1 # only 1D for now
X = signal[:-1]
y = signal[1:]
res = stats.linregress(X, y)
return res.rvalue**2
# Pull signal out of pandas Series if needed
signal = A.values if isinstance(A, pd.Series) else A
assert signal.ndim == 1, 'Signal must be 1-dimensional.'

# Handle constant signal case
if len(np.unique(signal)) == 1:
return np.nan

# Create design matrix X and target vector y based on AR order
X = np.column_stack([signal[i : len(signal) - order + i] for i in range(order)])
y = signal[order:]

try:
# Fit linear regression using least squares
_, residual, _, _ = np.linalg.lstsq(X, y)
except np.linalg.LinAlgError:
return np.nan

if residual:
# Calculate R-squared using residuals
ss_residual = residual[0]
ss_total = np.sum((y - np.mean(y)) ** 2)
r_squared = 1 - (ss_residual / ss_total)
else:
r_squared = np.nan

return r_squared


def noise_simulation(
Expand Down
Loading
Loading