5353 "knn" ,
5454 "quality_score" ,
5555 ],
56+ "slay" : [
57+ "template_similarity" ,
58+ "slay_score" ,
59+ ],
5660}
5761
5862_required_extensions = {
6165 "snr" : ["templates" , "noise_levels" ],
6266 "template_similarity" : ["templates" , "template_similarity" ],
6367 "knn" : ["templates" , "spike_locations" , "spike_amplitudes" ],
68+ "slay_score" : ["correlograms" , "template_similarity" ],
6469}
6570
6671
8590 "censored_period_ms" : 0.3 ,
8691 },
8792 "quality_score" : {"firing_contamination_balance" : 1.5 , "refractory_period_ms" : 1.0 , "censored_period_ms" : 0.3 },
93+ "slay_score" : {"k1" : 0.25 , "k2" : 1 , "slay_threshold" : 0.5 },
8894}
8995
9096
@@ -114,6 +120,8 @@ def compute_merge_unit_groups(
114120 * "cross_contamination": the cross-contamination is not significant (`cc_thresh` and `p_value`)
115121 * "knn": the two units are close in the feature space
116122 * "quality_score": the unit "quality score" is increased after the merge
123+ * "slay_score": a combined score, factoring in a template similarity measure, a cross-correlation significance measure
124+ and a sliding refractory period violation measure, based on the SLAy algorithm.
117125
118126 The "quality score" factors in the increase in firing rate (**f**) due to the merge and a possible increase in
119127 contamination (**C**), wheighted by a factor **k** (`firing_contamination_balance`).
@@ -145,6 +153,9 @@ def compute_merge_unit_groups(
145153 * | "feature_neighbors": focused on finding unit pairs whose spikes are close in the feature space using kNN.
146154 | It uses the following steps: "num_spikes", "snr", "remove_contaminated", "unit_locations",
147155 | "knn", "quality_score"
156+ * | "slay": an approximate implementation of SLAy, original implementation at https://github.com/saikoukunt/SLAy.
157+ | The spikeinterface version uses `template_similarity`, rather than an auto-encoder.
158+ | It uses the following steps: "template_similarity", "slay_score"
148159
149160 If `preset` is None, you can specify the steps manually with the `steps` parameter.
150161 resolve_graph : bool, default: True
@@ -363,6 +374,14 @@ def compute_merge_unit_groups(
363374 )
364375 outs ["pairs_decreased_score" ] = pairs_decreased_score
365376
377+ elif step == "slay_score" :
378+
379+ M_ij = compute_slay_matrix (
380+ sorting_analyzer , params ["k1" ], params ["k2" ], templates_diff = outs ["templates_diff" ], pair_mask = pair_mask
381+ )
382+
383+ pair_mask = pair_mask & (M_ij > params ["slay_threshold" ])
384+
366385 # FINAL STEP : create the final list from pair_mask boolean matrix
367386 ind1 , ind2 = np .nonzero (pair_mask )
368387 merge_unit_groups = list (zip (unit_ids [ind1 ], unit_ids [ind2 ]))
@@ -550,6 +569,7 @@ def get_potential_auto_merge(
550569 * "cross_contamination": the cross-contamination is not significant (`cc_thresh` and `p_value`)
551570 * "knn": the two units are close in the feature space
552571 * "quality_score": the unit "quality score" is increased after the merge
572+ * "slay_score": a combined score, factoring in a template similarity measure, a cross-correlation significance measure and a sliding refractory period violation measure, based on the SLAy algorithm.
553573
554574 The "quality score" factors in the increase in firing rate (**f**) due to the merge and a possible increase in
555575 contamination (**C**), wheighted by a factor **k** (`firing_contamination_balance`).
@@ -566,7 +586,7 @@ def get_potential_auto_merge(
566586 ----------
567587 sorting_analyzer : SortingAnalyzer
568588 The SortingAnalyzer
569- preset : "similarity_correlograms" | "x_contaminations" | "temporal_splits" | "feature_neighbors" | None, default: "similarity_correlograms"
589+ preset : "similarity_correlograms" | "x_contaminations" | "temporal_splits" | "feature_neighbors" | "slay" | None, default: "similarity_correlograms"
570590 The preset to use for the auto-merge. Presets combine different steps into a recipe and focus on:
571591
572592 * | "similarity_correlograms": mainly focused on template similarity and correlograms.
@@ -581,6 +601,9 @@ def get_potential_auto_merge(
581601 * | "feature_neighbors": focused on finding unit pairs whose spikes are close in the feature space using kNN.
582602 | It uses the following steps: "num_spikes", "snr", "remove_contaminated", "unit_locations",
583603 | "knn", "quality_score"
604+ * | "slay": an approximate implementation of SLAy, original implementation at https://github.com/saikoukunt/SLAy.
605+ | The spikeinterface version uses `template_similarity`, rather than an auto-encoder.
606+ | It uses the following steps: "template_similarity", "slay_score"
584607
585608 If `preset` is None, you can specify the steps manually with the `steps` parameter.
586609 resolve_graph : bool, default: False
@@ -1525,3 +1548,251 @@ def estimate_cross_contamination(
15251548 )
15261549
15271550 return estimation , p_value
1551+
1552+
1553+ def compute_slay_matrix (
1554+ sorting_analyzer : SortingAnalyzer ,
1555+ k1 : float ,
1556+ k2 : float ,
1557+ templates_diff : np .ndarray | None ,
1558+ pair_mask : np .ndarray | None = None ,
1559+ ):
1560+ """
1561+ Computes the "merge decision metric" from the SLAy method, made from combining
1562+ a template similarity measure, a cross-correlation significance measure and a
1563+ sliding refractory period violation measure. A large M suggests that two
1564+ units should be merged.
1565+
1566+ Paramters
1567+ ---------
1568+ sorting_analyzer : SortingAnalyzer
1569+ The sorting analyzer object containing the spike sorting data
1570+ k1 : float
1571+ Coefficient determining the importance of the cross-correlation significance
1572+ k2 : float
1573+ Coefficient determining the importance of the sliding rp violation
1574+ templates_diff : np.ndarray | None
1575+ Pre-computed template similarity difference matrix. If None, it will be retrieved from the sorting_analyzer.
1576+ pair_mask : None | np.ndarray, default: None
1577+ A bool matrix describing which pairs are possible merges based on previous steps
1578+
1579+
1580+ References
1581+ ----------
1582+ Based on computation originally implemented in SLAy [Koukuntla]_.
1583+
1584+ Implementation is based on one of the original implementations written by Sai Koukuntla,
1585+ found at https://github.com/saikoukunt/SLAy.
1586+ """
1587+
1588+ num_units = sorting_analyzer .get_num_units ()
1589+
1590+ if pair_mask is None :
1591+ pair_mask = np .triu (np .arange (num_units ), 1 ) > 0
1592+
1593+ if templates_diff is not None :
1594+ sigma_ij = 1 - templates_diff
1595+ else :
1596+ sigma_ij = sorting_analyzer .get_extension ("template_similarity" ).get_data ()
1597+ rho_ij , eta_ij = compute_xcorr_and_rp (sorting_analyzer , pair_mask )
1598+
1599+ M_ij = sigma_ij + k1 * rho_ij - k2 * eta_ij
1600+
1601+ return M_ij
1602+
1603+
1604+ def compute_xcorr_and_rp (sorting_analyzer : SortingAnalyzer , pair_mask : np .ndarray ):
1605+ """
1606+ Computes a cross-correlation significance measure and a sliding refractory period violation
1607+ measure for all units in the `sorting_analyzer`.
1608+
1609+ Paramters
1610+ ---------
1611+ sorting_analyzer : SortingAnalyzer
1612+ The sorting analyzer object containing the spike sorting data
1613+ pair_mask : np.ndarray
1614+ A bool matrix describing which pairs are possible merges based on previous steps
1615+ """
1616+
1617+ correlograms_extension = sorting_analyzer .get_extension ("correlograms" )
1618+ ccgs , _ = correlograms_extension .get_data ()
1619+
1620+ # convert to seconds for SLAy functions
1621+ bin_size_ms = correlograms_extension .params ["bin_ms" ]
1622+
1623+ rho_ij = np .zeros ([len (sorting_analyzer .unit_ids ), len (sorting_analyzer .unit_ids )])
1624+ eta_ij = np .zeros ([len (sorting_analyzer .unit_ids ), len (sorting_analyzer .unit_ids )])
1625+
1626+ for unit_index_1 , _ in enumerate (sorting_analyzer .unit_ids ):
1627+ for unit_index_2 , _ in enumerate (sorting_analyzer .unit_ids ):
1628+
1629+ # Don't waste time computing the other metrics if units not candidates merges
1630+ if not pair_mask [unit_index_1 , unit_index_2 ]:
1631+ continue
1632+
1633+ xgram = ccgs [unit_index_1 , unit_index_2 , :]
1634+
1635+ rho_ij [unit_index_1 , unit_index_2 ] = _compute_xcorr_pair (
1636+ xgram , bin_size_s = bin_size_ms / 1000 , min_xcorr_rate = 0
1637+ )
1638+ eta_ij [unit_index_1 , unit_index_2 ] = _sliding_RP_viol_pair (xgram , bin_size_ms = bin_size_ms )
1639+
1640+ return rho_ij , eta_ij
1641+
1642+
1643+ def _compute_xcorr_pair (
1644+ xgram ,
1645+ bin_size_s : float ,
1646+ min_xcorr_rate : float ,
1647+ ) -> float :
1648+ """
1649+ Calculates a cross-correlation significance metric for a cluster pair.
1650+
1651+ Uses the wasserstein distance between an observed cross-correlogram and a null
1652+ distribution as an estimate of how significant the dependence between
1653+ two neurons is. Low spike count cross-correlograms have large wasserstein
1654+ distances from null by chance, so we first try to expand the window size. If
1655+ that fails to yield enough spikes, we apply a penalty to the metric.
1656+
1657+ Ported from https://github.com/saikoukunt/SLAy.
1658+
1659+ Parameters
1660+ ----------
1661+ xgram : np.array
1662+ The raw cross-correlogram for the cluster pair.
1663+ bin_size_s : float
1664+ The width in seconds of the bin size of the input ccgs.
1665+ min_xcorr_rate : float
1666+ The minimum ccg firing rate in Hz.
1667+
1668+ Returns
1669+ -------
1670+ sig : float
1671+ The calculated cross-correlation significance metric.
1672+ """
1673+
1674+ from scipy .signal import butter , find_peaks_cwt , sosfiltfilt
1675+ from scipy .stats import wasserstein_distance
1676+
1677+ # calculate low-pass filtered second derivative of ccg
1678+ fs = 1 / bin_size_s
1679+ cutoff_freq = 100
1680+ nyqist = fs / 2
1681+ cutoff = cutoff_freq / nyqist
1682+ peak_width = 0.002 / bin_size_s
1683+
1684+ xgram_2d = np .diff (xgram , 2 )
1685+ sos = butter (4 , cutoff , output = "sos" )
1686+ xgram_2d = sosfiltfilt (sos , xgram_2d )
1687+
1688+ if xgram .sum () == 0 :
1689+ return 0
1690+
1691+ # find negative peaks of second derivative of ccg, these are the edges of dips in ccg
1692+ peaks = find_peaks_cwt (- xgram_2d , peak_width , noise_perc = 90 ) + 1
1693+ # if no peaks are found, return a very low significance
1694+ if peaks .shape [0 ] == 0 :
1695+ return - 4
1696+ peaks = np .abs (peaks - xgram .shape [0 ] / 2 )
1697+ peaks = peaks [peaks > 0.5 * peak_width ]
1698+ min_peaks = np .sort (peaks )
1699+
1700+ # start with peaks closest to 0 and move to the next set of peaks if the event count is too low
1701+ window_width = min_peaks * 1.5
1702+ starts = np .maximum (xgram .shape [0 ] / 2 - window_width , 0 )
1703+ ends = np .minimum (xgram .shape [0 ] / 2 + window_width , xgram .shape [0 ] - 1 )
1704+ ind = 0
1705+ xgram_window = xgram [int (starts [0 ]) : int (ends [0 ] + 1 )]
1706+ xgram_sum = xgram_window .sum ()
1707+ window_size = xgram_window .shape [0 ] * bin_size_s
1708+ while (xgram_sum < (min_xcorr_rate * window_size * 10 )) and (ind < starts .shape [0 ]):
1709+ xgram_window = xgram [int (starts [ind ]) : int (ends [ind ] + 1 )]
1710+ xgram_sum = xgram_window .sum ()
1711+ window_size = xgram_window .shape [0 ] * bin_size_s
1712+ ind += 1
1713+ # use the whole ccg if peak finding fails
1714+ if ind == starts .shape [0 ]:
1715+ xgram_window = xgram
1716+
1717+ # TODO: was getting error messges when xgram_window was all zero. Why was this happening?
1718+ if np .abs (xgram_window ).sum () == 0 :
1719+ return 0
1720+
1721+ sig = (
1722+ wasserstein_distance (
1723+ np .arange (xgram_window .shape [0 ]) / xgram_window .shape [0 ],
1724+ np .arange (xgram_window .shape [0 ]) / xgram_window .shape [0 ],
1725+ xgram_window ,
1726+ np .ones_like (xgram_window ),
1727+ )
1728+ * 4
1729+ )
1730+
1731+ if xgram_window .sum () < (min_xcorr_rate * window_size ):
1732+ sig *= (xgram_window .sum () / (min_xcorr_rate * window_size )) ** 2
1733+
1734+ # if sig < 0.04 and xgram_window.sum() < (min_xcorr_rate * window_size):
1735+ if xgram_window .sum () < (min_xcorr_rate / 4 * window_size ):
1736+ sig = - 4 # don't merge if the event count is way too low
1737+
1738+ return sig
1739+
1740+
1741+ def _sliding_RP_viol_pair (
1742+ correlogram ,
1743+ bin_size_ms : float ,
1744+ accept_threshold : float = 0.15 ,
1745+ ) -> float :
1746+ """
1747+ Calculate the sliding refractory period violation confidence for a cluster.
1748+
1749+ Ported from https://github.com/saikoukunt/SLAy.
1750+
1751+ Parameters
1752+ ----------
1753+ correlogram : np.array
1754+ The auto-correlogram of the cluster.
1755+ bin_size_ms : float
1756+ The width in ms of the bin size of the input ccgs.
1757+ accept_threshold : float, default: 0.15
1758+ The minimum ccg firing rate in Hz.
1759+
1760+ Returns
1761+ -------
1762+ sig : float
1763+ The refractory period violation confidence for the cluster.
1764+ """
1765+ from scipy .signal import butter , sosfiltfilt
1766+ from scipy .stats import poisson
1767+
1768+ # create various refractory periods sizes to test (between 0 and 20x bin size)
1769+ all_refractory_periods = np .arange (0 , 21 * bin_size_ms , bin_size_ms ) / 1000
1770+ test_refractory_period_indices = np .array ([1 , 2 , 4 , 6 , 8 , 12 , 16 , 20 ], dtype = "int8" )
1771+ test_refractory_periods = [
1772+ all_refractory_periods [test_rp_index ] for test_rp_index in test_refractory_period_indices
1773+ ]
1774+
1775+ # calculate and avg halves of acg to ensure symmetry
1776+ # keep only second half of acg, refractory period violations are compared from the center of acg
1777+ half_len = int (correlogram .shape [0 ] / 2 )
1778+ correlogram = (correlogram [half_len :] + correlogram [:half_len ][::- 1 ]) / 2
1779+
1780+ acg_cumsum = np .cumsum (correlogram )
1781+ sum_res = acg_cumsum [test_refractory_period_indices - 1 ] # -1 bc 0th bin corresponds to 0-bin_size ms
1782+
1783+ # low-pass filter acg and use max as baseline event rate
1784+ order = 4 # Hz
1785+ cutoff_freq = 250 # Hz
1786+ fs = 1 / bin_size_ms * 1000
1787+ nyqist = fs / 2
1788+ cutoff = cutoff_freq / nyqist
1789+ sos = butter (order , cutoff , btype = "low" , output = "sos" )
1790+ smoothed_acg = sosfiltfilt (sos , correlogram )
1791+
1792+ bin_rate_max = np .max (smoothed_acg )
1793+ max_conts_max = np .array (test_refractory_periods ) / bin_size_ms * 1000 * (bin_rate_max * accept_threshold )
1794+ # compute confidence of less than acceptThresh contamination at each refractory period
1795+ confs = 1 - poisson .cdf (sum_res , max_conts_max )
1796+ rp_viol = 1 - confs .max ()
1797+
1798+ return rp_viol
0 commit comments