Skip to content

Commit 1cb3e7d

Browse files
chrishalcrowalejoe91samuelgarcia
authored
Add SLAy auto-merge preset (#4190)
Co-authored-by: Alessio Buccino <alejoe9187@gmail.com> Co-authored-by: Garcia Samuel <sam.garcia.die@gmail.com>
1 parent d9c2f65 commit 1cb3e7d

File tree

3 files changed

+280
-4
lines changed

3 files changed

+280
-4
lines changed

doc/references.rst

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,10 @@ please following citations:
8787

8888
Curation Module
8989
---------------
90-
If you use the :code:`get_potential_auto_merge` method from the curation module, please cite [Llobet]_
90+
91+
If you use the default "similarity_correlograms" preset in the :code:`compute_merge_unit_groups` method from the curation module, please cite [Llobet]_
92+
93+
If you use the "slay" preset in the :code:`compute_merge_unit_groups` method, please cite [Koukuntla]_
9194

9295
If you use :code:`auto_label_units` or :code:`train_model`, please cite [Jain]_
9396

@@ -139,6 +142,8 @@ References
139142
140143
.. [Jia] `High-density extracellular probes reveal dendritic backpropagation and facilitate neuron classification. 2019 <https://journals.physiology.org/doi/full/10.1152/jn.00680.2018>`_
141144
145+
.. [Koukuntla] `SLAy-ing oversplitting errors in high-density electrophysiology spike sorting. 2025. <https://www.biorxiv.org/content/10.1101/2025.06.20.660590v1>`_
146+
142147
.. [Lee] `YASS: Yet another spike sorter. 2017. <https://www.biorxiv.org/content/10.1101/151928v1>`_
143148
144149
.. [Lemon] Methods for neuronal recording in conscious animals. IBRO Handbook Series. 1984.

src/spikeinterface/curation/auto_merge.py

Lines changed: 272 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@
5353
"knn",
5454
"quality_score",
5555
],
56+
"slay": [
57+
"template_similarity",
58+
"slay_score",
59+
],
5660
}
5761

5862
_required_extensions = {
@@ -61,6 +65,7 @@
6165
"snr": ["templates", "noise_levels"],
6266
"template_similarity": ["templates", "template_similarity"],
6367
"knn": ["templates", "spike_locations", "spike_amplitudes"],
68+
"slay_score": ["correlograms", "template_similarity"],
6469
}
6570

6671

@@ -85,6 +90,7 @@
8590
"censored_period_ms": 0.3,
8691
},
8792
"quality_score": {"firing_contamination_balance": 1.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3},
93+
"slay_score": {"k1": 0.25, "k2": 1, "slay_threshold": 0.5},
8894
}
8995

9096

@@ -114,6 +120,8 @@ def compute_merge_unit_groups(
114120
* "cross_contamination": the cross-contamination is not significant (`cc_thresh` and `p_value`)
115121
* "knn": the two units are close in the feature space
116122
* "quality_score": the unit "quality score" is increased after the merge
123+
* "slay_score": a combined score, factoring in a template similarity measure, a cross-correlation significance measure
124+
and a sliding refractory period violation measure, based on the SLAy algorithm.
117125
118126
The "quality score" factors in the increase in firing rate (**f**) due to the merge and a possible increase in
119127
contamination (**C**), wheighted by a factor **k** (`firing_contamination_balance`).
@@ -145,6 +153,9 @@ def compute_merge_unit_groups(
145153
* | "feature_neighbors": focused on finding unit pairs whose spikes are close in the feature space using kNN.
146154
| It uses the following steps: "num_spikes", "snr", "remove_contaminated", "unit_locations",
147155
| "knn", "quality_score"
156+
* | "slay": an approximate implementation of SLAy, original implementation at https://github.com/saikoukunt/SLAy.
157+
| The spikeinterface version uses `template_similarity`, rather than an auto-encoder.
158+
| It uses the following steps: "template_similarity", "slay_score"
148159
149160
If `preset` is None, you can specify the steps manually with the `steps` parameter.
150161
resolve_graph : bool, default: True
@@ -363,6 +374,14 @@ def compute_merge_unit_groups(
363374
)
364375
outs["pairs_decreased_score"] = pairs_decreased_score
365376

377+
elif step == "slay_score":
378+
379+
M_ij = compute_slay_matrix(
380+
sorting_analyzer, params["k1"], params["k2"], templates_diff=outs["templates_diff"], pair_mask=pair_mask
381+
)
382+
383+
pair_mask = pair_mask & (M_ij > params["slay_threshold"])
384+
366385
# FINAL STEP : create the final list from pair_mask boolean matrix
367386
ind1, ind2 = np.nonzero(pair_mask)
368387
merge_unit_groups = list(zip(unit_ids[ind1], unit_ids[ind2]))
@@ -550,6 +569,7 @@ def get_potential_auto_merge(
550569
* "cross_contamination": the cross-contamination is not significant (`cc_thresh` and `p_value`)
551570
* "knn": the two units are close in the feature space
552571
* "quality_score": the unit "quality score" is increased after the merge
572+
* "slay_score": a combined score, factoring in a template similarity measure, a cross-correlation significance measure and a sliding refractory period violation measure, based on the SLAy algorithm.
553573
554574
The "quality score" factors in the increase in firing rate (**f**) due to the merge and a possible increase in
555575
contamination (**C**), wheighted by a factor **k** (`firing_contamination_balance`).
@@ -566,7 +586,7 @@ def get_potential_auto_merge(
566586
----------
567587
sorting_analyzer : SortingAnalyzer
568588
The SortingAnalyzer
569-
preset : "similarity_correlograms" | "x_contaminations" | "temporal_splits" | "feature_neighbors" | None, default: "similarity_correlograms"
589+
preset : "similarity_correlograms" | "x_contaminations" | "temporal_splits" | "feature_neighbors" | "slay" | None, default: "similarity_correlograms"
570590
The preset to use for the auto-merge. Presets combine different steps into a recipe and focus on:
571591
572592
* | "similarity_correlograms": mainly focused on template similarity and correlograms.
@@ -581,6 +601,9 @@ def get_potential_auto_merge(
581601
* | "feature_neighbors": focused on finding unit pairs whose spikes are close in the feature space using kNN.
582602
| It uses the following steps: "num_spikes", "snr", "remove_contaminated", "unit_locations",
583603
| "knn", "quality_score"
604+
* | "slay": an approximate implementation of SLAy, original implementation at https://github.com/saikoukunt/SLAy.
605+
| The spikeinterface version uses `template_similarity`, rather than an auto-encoder.
606+
| It uses the following steps: "template_similarity", "slay_score"
584607
585608
If `preset` is None, you can specify the steps manually with the `steps` parameter.
586609
resolve_graph : bool, default: False
@@ -1525,3 +1548,251 @@ def estimate_cross_contamination(
15251548
)
15261549

15271550
return estimation, p_value
1551+
1552+
1553+
def compute_slay_matrix(
1554+
sorting_analyzer: SortingAnalyzer,
1555+
k1: float,
1556+
k2: float,
1557+
templates_diff: np.ndarray | None,
1558+
pair_mask: np.ndarray | None = None,
1559+
):
1560+
"""
1561+
Computes the "merge decision metric" from the SLAy method, made from combining
1562+
a template similarity measure, a cross-correlation significance measure and a
1563+
sliding refractory period violation measure. A large M suggests that two
1564+
units should be merged.
1565+
1566+
Paramters
1567+
---------
1568+
sorting_analyzer : SortingAnalyzer
1569+
The sorting analyzer object containing the spike sorting data
1570+
k1 : float
1571+
Coefficient determining the importance of the cross-correlation significance
1572+
k2 : float
1573+
Coefficient determining the importance of the sliding rp violation
1574+
templates_diff : np.ndarray | None
1575+
Pre-computed template similarity difference matrix. If None, it will be retrieved from the sorting_analyzer.
1576+
pair_mask : None | np.ndarray, default: None
1577+
A bool matrix describing which pairs are possible merges based on previous steps
1578+
1579+
1580+
References
1581+
----------
1582+
Based on computation originally implemented in SLAy [Koukuntla]_.
1583+
1584+
Implementation is based on one of the original implementations written by Sai Koukuntla,
1585+
found at https://github.com/saikoukunt/SLAy.
1586+
"""
1587+
1588+
num_units = sorting_analyzer.get_num_units()
1589+
1590+
if pair_mask is None:
1591+
pair_mask = np.triu(np.arange(num_units), 1) > 0
1592+
1593+
if templates_diff is not None:
1594+
sigma_ij = 1 - templates_diff
1595+
else:
1596+
sigma_ij = sorting_analyzer.get_extension("template_similarity").get_data()
1597+
rho_ij, eta_ij = compute_xcorr_and_rp(sorting_analyzer, pair_mask)
1598+
1599+
M_ij = sigma_ij + k1 * rho_ij - k2 * eta_ij
1600+
1601+
return M_ij
1602+
1603+
1604+
def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, pair_mask: np.ndarray):
1605+
"""
1606+
Computes a cross-correlation significance measure and a sliding refractory period violation
1607+
measure for all units in the `sorting_analyzer`.
1608+
1609+
Paramters
1610+
---------
1611+
sorting_analyzer : SortingAnalyzer
1612+
The sorting analyzer object containing the spike sorting data
1613+
pair_mask : np.ndarray
1614+
A bool matrix describing which pairs are possible merges based on previous steps
1615+
"""
1616+
1617+
correlograms_extension = sorting_analyzer.get_extension("correlograms")
1618+
ccgs, _ = correlograms_extension.get_data()
1619+
1620+
# convert to seconds for SLAy functions
1621+
bin_size_ms = correlograms_extension.params["bin_ms"]
1622+
1623+
rho_ij = np.zeros([len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)])
1624+
eta_ij = np.zeros([len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)])
1625+
1626+
for unit_index_1, _ in enumerate(sorting_analyzer.unit_ids):
1627+
for unit_index_2, _ in enumerate(sorting_analyzer.unit_ids):
1628+
1629+
# Don't waste time computing the other metrics if units not candidates merges
1630+
if not pair_mask[unit_index_1, unit_index_2]:
1631+
continue
1632+
1633+
xgram = ccgs[unit_index_1, unit_index_2, :]
1634+
1635+
rho_ij[unit_index_1, unit_index_2] = _compute_xcorr_pair(
1636+
xgram, bin_size_s=bin_size_ms / 1000, min_xcorr_rate=0
1637+
)
1638+
eta_ij[unit_index_1, unit_index_2] = _sliding_RP_viol_pair(xgram, bin_size_ms=bin_size_ms)
1639+
1640+
return rho_ij, eta_ij
1641+
1642+
1643+
def _compute_xcorr_pair(
1644+
xgram,
1645+
bin_size_s: float,
1646+
min_xcorr_rate: float,
1647+
) -> float:
1648+
"""
1649+
Calculates a cross-correlation significance metric for a cluster pair.
1650+
1651+
Uses the wasserstein distance between an observed cross-correlogram and a null
1652+
distribution as an estimate of how significant the dependence between
1653+
two neurons is. Low spike count cross-correlograms have large wasserstein
1654+
distances from null by chance, so we first try to expand the window size. If
1655+
that fails to yield enough spikes, we apply a penalty to the metric.
1656+
1657+
Ported from https://github.com/saikoukunt/SLAy.
1658+
1659+
Parameters
1660+
----------
1661+
xgram : np.array
1662+
The raw cross-correlogram for the cluster pair.
1663+
bin_size_s : float
1664+
The width in seconds of the bin size of the input ccgs.
1665+
min_xcorr_rate : float
1666+
The minimum ccg firing rate in Hz.
1667+
1668+
Returns
1669+
-------
1670+
sig : float
1671+
The calculated cross-correlation significance metric.
1672+
"""
1673+
1674+
from scipy.signal import butter, find_peaks_cwt, sosfiltfilt
1675+
from scipy.stats import wasserstein_distance
1676+
1677+
# calculate low-pass filtered second derivative of ccg
1678+
fs = 1 / bin_size_s
1679+
cutoff_freq = 100
1680+
nyqist = fs / 2
1681+
cutoff = cutoff_freq / nyqist
1682+
peak_width = 0.002 / bin_size_s
1683+
1684+
xgram_2d = np.diff(xgram, 2)
1685+
sos = butter(4, cutoff, output="sos")
1686+
xgram_2d = sosfiltfilt(sos, xgram_2d)
1687+
1688+
if xgram.sum() == 0:
1689+
return 0
1690+
1691+
# find negative peaks of second derivative of ccg, these are the edges of dips in ccg
1692+
peaks = find_peaks_cwt(-xgram_2d, peak_width, noise_perc=90) + 1
1693+
# if no peaks are found, return a very low significance
1694+
if peaks.shape[0] == 0:
1695+
return -4
1696+
peaks = np.abs(peaks - xgram.shape[0] / 2)
1697+
peaks = peaks[peaks > 0.5 * peak_width]
1698+
min_peaks = np.sort(peaks)
1699+
1700+
# start with peaks closest to 0 and move to the next set of peaks if the event count is too low
1701+
window_width = min_peaks * 1.5
1702+
starts = np.maximum(xgram.shape[0] / 2 - window_width, 0)
1703+
ends = np.minimum(xgram.shape[0] / 2 + window_width, xgram.shape[0] - 1)
1704+
ind = 0
1705+
xgram_window = xgram[int(starts[0]) : int(ends[0] + 1)]
1706+
xgram_sum = xgram_window.sum()
1707+
window_size = xgram_window.shape[0] * bin_size_s
1708+
while (xgram_sum < (min_xcorr_rate * window_size * 10)) and (ind < starts.shape[0]):
1709+
xgram_window = xgram[int(starts[ind]) : int(ends[ind] + 1)]
1710+
xgram_sum = xgram_window.sum()
1711+
window_size = xgram_window.shape[0] * bin_size_s
1712+
ind += 1
1713+
# use the whole ccg if peak finding fails
1714+
if ind == starts.shape[0]:
1715+
xgram_window = xgram
1716+
1717+
# TODO: was getting error messges when xgram_window was all zero. Why was this happening?
1718+
if np.abs(xgram_window).sum() == 0:
1719+
return 0
1720+
1721+
sig = (
1722+
wasserstein_distance(
1723+
np.arange(xgram_window.shape[0]) / xgram_window.shape[0],
1724+
np.arange(xgram_window.shape[0]) / xgram_window.shape[0],
1725+
xgram_window,
1726+
np.ones_like(xgram_window),
1727+
)
1728+
* 4
1729+
)
1730+
1731+
if xgram_window.sum() < (min_xcorr_rate * window_size):
1732+
sig *= (xgram_window.sum() / (min_xcorr_rate * window_size)) ** 2
1733+
1734+
# if sig < 0.04 and xgram_window.sum() < (min_xcorr_rate * window_size):
1735+
if xgram_window.sum() < (min_xcorr_rate / 4 * window_size):
1736+
sig = -4 # don't merge if the event count is way too low
1737+
1738+
return sig
1739+
1740+
1741+
def _sliding_RP_viol_pair(
1742+
correlogram,
1743+
bin_size_ms: float,
1744+
accept_threshold: float = 0.15,
1745+
) -> float:
1746+
"""
1747+
Calculate the sliding refractory period violation confidence for a cluster.
1748+
1749+
Ported from https://github.com/saikoukunt/SLAy.
1750+
1751+
Parameters
1752+
----------
1753+
correlogram : np.array
1754+
The auto-correlogram of the cluster.
1755+
bin_size_ms : float
1756+
The width in ms of the bin size of the input ccgs.
1757+
accept_threshold : float, default: 0.15
1758+
The minimum ccg firing rate in Hz.
1759+
1760+
Returns
1761+
-------
1762+
sig : float
1763+
The refractory period violation confidence for the cluster.
1764+
"""
1765+
from scipy.signal import butter, sosfiltfilt
1766+
from scipy.stats import poisson
1767+
1768+
# create various refractory periods sizes to test (between 0 and 20x bin size)
1769+
all_refractory_periods = np.arange(0, 21 * bin_size_ms, bin_size_ms) / 1000
1770+
test_refractory_period_indices = np.array([1, 2, 4, 6, 8, 12, 16, 20], dtype="int8")
1771+
test_refractory_periods = [
1772+
all_refractory_periods[test_rp_index] for test_rp_index in test_refractory_period_indices
1773+
]
1774+
1775+
# calculate and avg halves of acg to ensure symmetry
1776+
# keep only second half of acg, refractory period violations are compared from the center of acg
1777+
half_len = int(correlogram.shape[0] / 2)
1778+
correlogram = (correlogram[half_len:] + correlogram[:half_len][::-1]) / 2
1779+
1780+
acg_cumsum = np.cumsum(correlogram)
1781+
sum_res = acg_cumsum[test_refractory_period_indices - 1] # -1 bc 0th bin corresponds to 0-bin_size ms
1782+
1783+
# low-pass filter acg and use max as baseline event rate
1784+
order = 4 # Hz
1785+
cutoff_freq = 250 # Hz
1786+
fs = 1 / bin_size_ms * 1000
1787+
nyqist = fs / 2
1788+
cutoff = cutoff_freq / nyqist
1789+
sos = butter(order, cutoff, btype="low", output="sos")
1790+
smoothed_acg = sosfiltfilt(sos, correlogram)
1791+
1792+
bin_rate_max = np.max(smoothed_acg)
1793+
max_conts_max = np.array(test_refractory_periods) / bin_size_ms * 1000 * (bin_rate_max * accept_threshold)
1794+
# compute confidence of less than acceptThresh contamination at each refractory period
1795+
confs = 1 - poisson.cdf(sum_res, max_conts_max)
1796+
rp_viol = 1 - confs.max()
1797+
1798+
return rp_viol

src/spikeinterface/curation/tests/test_auto_merge.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
@pytest.mark.parametrize(
18-
"preset", ["x_contaminations", "feature_neighbors", "temporal_splits", "similarity_correlograms", None]
18+
"preset", ["x_contaminations", "feature_neighbors", "temporal_splits", "similarity_correlograms", "slay", None]
1919
)
2020
def test_compute_merge_unit_groups(sorting_analyzer_with_splits, preset):
2121

@@ -59,7 +59,7 @@ def test_compute_merge_unit_groups(sorting_analyzer_with_splits, preset):
5959

6060

6161
@pytest.mark.parametrize(
62-
"preset", ["x_contaminations", "feature_neighbors", "temporal_splits", "similarity_correlograms"]
62+
"preset", ["x_contaminations", "feature_neighbors", "temporal_splits", "similarity_correlograms", "slay"]
6363
)
6464
def test_compute_merge_unit_groups_multi_segment(sorting_analyzer_multi_segment_for_curation, preset):
6565
job_kwargs = dict(n_jobs=-1)

0 commit comments

Comments
 (0)