Skip to content

Commit 3291976

Browse files
samuelgarciah-mayorquinalejoe91
authored
Fix compute matching v3 (#2182)
* some change to test * another change * another attempt * attempt merge * add condition * add auth * fix test and simpler implementation * small typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * avoid corner cose of doing the matching loop twice * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove n_jobs * Little docs cleanup * Remove internal n_jobs * Remove last internal n_jobs * Apply suggestions from code review * fix test * comment to test * docstring improvements * variable naming * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * new proposal for compute_matching * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Heberto Mayorquin <h.mayorquin@gmail.com> Co-authored-by: Alessio Buccino <alejoe9187@gmail.com>
1 parent fa0b034 commit 3291976

File tree

3 files changed

+111
-89
lines changed

3 files changed

+111
-89
lines changed

src/spikeinterface/comparison/comparisontools.py

Lines changed: 65 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -124,103 +124,105 @@ def get_optimized_compute_matching_matrix():
124124

125125
@numba.jit(nopython=True, nogil=True)
126126
def compute_matching_matrix(
127-
frames_spike_train1,
128-
frames_spike_train2,
127+
spike_frames_train1,
128+
spike_frames_train2,
129129
unit_indices1,
130130
unit_indices2,
131-
num_units_sorting1,
132-
num_units_sorting2,
131+
num_units_train1,
132+
num_units_train2,
133133
delta_frames,
134134
):
135135
"""
136136
Compute a matrix representing the matches between two spike trains.
137137
138138
Given two spike trains, this function finds matching spikes based on a temporal proximity criterion
139139
defined by `delta_frames`. The resulting matrix indicates the number of matches between units
140-
in `frames_spike_train1` and `frames_spike_train2`.
140+
in `spike_frames_train1` and `spike_frames_train2`.
141141
142142
Parameters
143143
----------
144-
frames_spike_train1 : ndarray
145-
Array of frames for the first spike train. Should be ordered in ascending order.
146-
frames_spike_train2 : ndarray
147-
Array of frames for the second spike train. Should be ordered in ascending order.
144+
spike_frames_train1 : ndarray
145+
An array of integer frame numbers corresponding to spike times for the first train. Must be in ascending order.
146+
spike_frames_train2 : ndarray
147+
An array of integer frame numbers corresponding to spike times for the second train. Must be in ascending order.
148148
unit_indices1 : ndarray
149-
Array indicating the unit indices corresponding to each spike in `frames_spike_train1`.
149+
An array of integers where `unit_indices1[i]` gives the unit index associated with the spike at `spike_frames_train1[i]`.
150150
unit_indices2 : ndarray
151-
Array indicating the unit indices corresponding to each spike in `frames_spike_train2`.
152-
num_units_sorting1 : int
153-
Total number of units in the first spike train.
154-
num_units_sorting2 : int
155-
Total number of units in the second spike train.
151+
An array of integers where `unit_indices2[i]` gives the unit index associated with the spike at `spike_frames_train2[i]`.
152+
num_units_train1 : int
153+
The total count of unique units in the first spike train.
154+
num_units_train2 : int
155+
The total count of unique units in the second spike train.
156156
delta_frames : int
157-
Maximum difference in frames between two spikes to consider them as a match.
157+
The inclusive upper limit on the frame difference for which two spikes are considered matching. That is
158+
if `abs(spike_frames_train1[i] - spike_frames_train2[j]) <= delta_frames` then the spikes at `spike_frames_train1[i]`
159+
and `spike_frames_train2[j]` are considered matching.
158160
159161
Returns
160162
-------
161163
matching_matrix : ndarray
162-
A matrix of shape (num_units_sorting1, num_units_sorting2) where each entry [i, j] represents
163-
the number of matching spikes between unit i of `frames_spike_train1` and unit j of `frames_spike_train2`.
164+
A 2D numpy array of shape `(num_units_train1, num_units_train2)`. Each element `[i, j]` represents
165+
the count of matching spike pairs between unit `i` from `spike_frames_train1` and unit `j` from `spike_frames_train2`.
166+
164167
165168
Notes
166169
-----
167170
This algorithm identifies matching spikes between two ordered spike trains.
168171
By iterating through each spike in the first train, it compares them against spikes in the second train,
169172
determining matches based on the two spikes frames being within `delta_frames` of each other.
170173
171-
To avoid redundant comparisons the algorithm maintains a reference, `lower_search_limit_in_second_train`,
174+
To avoid redundant comparisons the algorithm maintains a reference, `second_train_search_start `,
172175
which signifies the minimal index in the second spike train that might match the upcoming spike
173-
in the first train. This means that the start of the search moves forward in the second train as the
174-
matches between the two trains are found decreasing the number of comparisons needed.
176+
in the first train.
177+
178+
The logic can be summarized as follows:
179+
1. Iterate through each spike in the first train
180+
2. For each spike, find the first match in the second train.
181+
3. Save the index of the first match as the new `second_train_search_start `
182+
3. For each match, find as many matches as possible from the first match onwards.
175183
176-
An important condition here is thatthe same spike is not matched twice. This is managed by keeping track
177-
of the last matched frame for each unit pair in `previous_frame1_match` and `previous_frame2_match`
184+
An important condition here is that the same spike is not matched twice. This is managed by keeping track
185+
of the last matched frame for each unit pair in `last_match_frame1` and `last_match_frame2`
178186
179187
For more details on the rationale behind this approach, refer to the documentation of this module and/or
180-
the metrics section in SpikeForest documentation.
188+
the metrics section in SpikeForest documentation.
181189
"""
182190

183-
matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16)
191+
matching_matrix = np.zeros((num_units_train1, num_units_train2), dtype=np.uint16)
184192

185193
# Used to avoid the same spike matching twice
186-
previous_frame1_match = -np.ones_like(matching_matrix, dtype=np.int64)
187-
previous_frame2_match = -np.ones_like(matching_matrix, dtype=np.int64)
188-
189-
lower_search_limit_in_second_train = 0
190-
191-
for index1 in range(len(frames_spike_train1)):
192-
# Keeps track of which frame in the second spike train should be used as a search start for matches
193-
index2 = lower_search_limit_in_second_train
194-
frame1 = frames_spike_train1[index1]
195-
196-
# Determine next_frame1 if current frame is not the last frame
197-
not_in_the_last_loop = index1 < len(frames_spike_train1) - 1
198-
if not_in_the_last_loop:
199-
next_frame1 = frames_spike_train1[index1 + 1]
200-
201-
while index2 < len(frames_spike_train2):
202-
frame2 = frames_spike_train2[index2]
203-
not_a_match = abs(frame1 - frame2) > delta_frames
204-
if not_a_match:
205-
# Go to the next frame in the first train
194+
last_match_frame1 = -np.ones_like(matching_matrix, dtype=np.int64)
195+
last_match_frame2 = -np.ones_like(matching_matrix, dtype=np.int64)
196+
197+
num_spike_frames_train1 = len(spike_frames_train1)
198+
num_spike_frames_train2 = len(spike_frames_train2)
199+
200+
# Keeps track of which frame in the second spike train should be used as a search start for matches
201+
second_train_search_start = 0
202+
for index1 in range(num_spike_frames_train1):
203+
frame1 = spike_frames_train1[index1]
204+
205+
for index2 in range(second_train_search_start, num_spike_frames_train2):
206+
frame2 = spike_frames_train2[index2]
207+
if frame2 < frame1 - delta_frames:
208+
# no match move the left limit for the next loop
209+
second_train_search_start += 1
210+
continue
211+
elif frame2 > frame1 + delta_frames:
212+
# no match stop search in train2 and continue increment in train1
206213
break
214+
else:
215+
# match
216+
unit_index1, unit_index2 = unit_indices1[index1], unit_indices2[index2]
207217

208-
# Map the match to a matrix
209-
row, column = unit_indices1[index1], unit_indices2[index2]
210-
211-
# The same spike cannot be matched twice see the notes in the docstring for more info on this constraint
212-
if frame1 != previous_frame1_match[row, column] and frame2 != previous_frame2_match[row, column]:
213-
previous_frame1_match[row, column] = frame1
214-
previous_frame2_match[row, column] = frame2
215-
216-
matching_matrix[row, column] += 1
217-
218-
index2 += 1
218+
if (
219+
frame1 != last_match_frame1[unit_index1, unit_index2]
220+
and frame2 != last_match_frame2[unit_index1, unit_index2]
221+
):
222+
last_match_frame1[unit_index1, unit_index2] = frame1
223+
last_match_frame2[unit_index1, unit_index2] = frame2
219224

220-
# Advance the lower_search_limit_in_second_train if the next frame in the first train does not match
221-
not_a_match_with_next = abs(next_frame1 - frame2) > delta_frames
222-
if not_a_match_with_next:
223-
lower_search_limit_in_second_train = index2
225+
matching_matrix[unit_index1, unit_index2] += 1
224226

225227
return matching_matrix
226228

@@ -230,7 +232,7 @@ def compute_matching_matrix(
230232
return compute_matching_matrix
231233

232234

233-
def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None):
235+
def make_match_count_matrix(sorting1, sorting2, delta_frames):
234236
num_units_sorting1 = sorting1.get_num_units()
235237
num_units_sorting2 = sorting2.get_num_units()
236238
matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16)
@@ -275,7 +277,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None):
275277
return match_event_counts_df
276278

277279

278-
def make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1):
280+
def make_agreement_scores(sorting1, sorting2, delta_frames):
279281
"""
280282
Make the agreement matrix.
281283
No threshold (min_score) is applied at this step.
@@ -291,8 +293,6 @@ def make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1):
291293
The second sorting extractor
292294
delta_frames: int
293295
Number of frames to consider spikes coincident
294-
n_jobs: int
295-
Number of jobs to run in parallel
296296
297297
Returns
298298
-------
@@ -309,7 +309,7 @@ def make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1):
309309
event_counts1 = pd.Series(ev_counts1, index=unit1_ids)
310310
event_counts2 = pd.Series(ev_counts2, index=unit2_ids)
311311

312-
match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=n_jobs)
312+
match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames)
313313

314314
agreement_scores = make_agreement_scores_from_count(match_event_count, event_counts1, event_counts2)
315315

src/spikeinterface/comparison/paircomparisons.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,7 @@ def _do_agreement(self):
8484
self.event_counts2 = do_count_event(self.sorting2)
8585

8686
# matrix of event match count for each pair
87-
self.match_event_count = make_match_count_matrix(
88-
self.sorting1, self.sorting2, self.delta_frames, n_jobs=self.n_jobs
89-
)
87+
self.match_event_count = make_match_count_matrix(self.sorting1, self.sorting2, self.delta_frames)
9088

9189
# agreement matrix score for each pair
9290
self.agreement_scores = make_agreement_scores_from_count(

src/spikeinterface/comparison/tests/test_comparisontools.py

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,23 @@ def test_make_match_count_matrix_repeated_matching_but_no_double_counting():
135135
assert_array_equal(result.to_numpy(), expected_result)
136136

137137

138+
def test_make_match_count_matrix_test_proper_search_in_the_second_train():
139+
"Search exhaustively in the second train, but only within the delta_frames window, do not terminate search early"
140+
frames_spike_train1 = [500, 600, 800]
141+
frames_spike_train2 = [0, 100, 200, 300, 500, 800]
142+
unit_indices1 = [0, 0, 0]
143+
unit_indices2 = [0, 0, 0, 0, 0, 0]
144+
delta_frames = 20
145+
146+
sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2)
147+
148+
result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames)
149+
150+
expected_result = np.array([[2]])
151+
152+
assert_array_equal(result.to_numpy(), expected_result)
153+
154+
138155
def test_make_agreement_scores():
139156
delta_frames = 10
140157

@@ -150,15 +167,15 @@ def test_make_agreement_scores():
150167
[0, 0, 5],
151168
)
152169

153-
agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1)
170+
agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames)
154171
print(agreement_scores)
155172

156173
ok = np.array([[2 / 3, 0], [0, 1.0]], dtype="float64")
157174

158175
assert_array_equal(agreement_scores.values, ok)
159176

160177
# test if symetric
161-
agreement_scores2 = make_agreement_scores(sorting2, sorting1, delta_frames, n_jobs=1)
178+
agreement_scores2 = make_agreement_scores(sorting2, sorting1, delta_frames)
162179
assert_array_equal(agreement_scores, agreement_scores2.T)
163180

164181

@@ -178,7 +195,7 @@ def test_make_possible_match():
178195
[0, 0, 5],
179196
)
180197

181-
agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1)
198+
agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames)
182199

183200
possible_match_12, possible_match_21 = make_possible_match(agreement_scores, min_accuracy)
184201

@@ -207,7 +224,7 @@ def test_make_best_match():
207224
[0, 0, 5],
208225
)
209226

210-
agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1)
227+
agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames)
211228

212229
best_match_12, best_match_21 = make_best_match(agreement_scores, min_accuracy)
213230

@@ -236,7 +253,7 @@ def test_make_hungarian_match():
236253
[0, 0, 5],
237254
)
238255

239-
agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1)
256+
agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames)
240257

241258
hungarian_match_12, hungarian_match_21 = make_hungarian_match(agreement_scores, min_accuracy)
242259

@@ -344,8 +361,8 @@ def test_do_confusion_matrix():
344361

345362
event_counts1 = do_count_event(sorting1)
346363
event_counts2 = do_count_event(sorting2)
347-
match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=1)
348-
agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1)
364+
match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames)
365+
agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames)
349366
hungarian_match_12, hungarian_match_21 = make_hungarian_match(agreement_scores, min_accuracy)
350367

351368
confusion = do_confusion_matrix(event_counts1, event_counts2, hungarian_match_12, match_event_count)
@@ -363,8 +380,8 @@ def test_do_confusion_matrix():
363380

364381
event_counts1 = do_count_event(sorting1)
365382
event_counts2 = do_count_event(sorting2)
366-
match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=1)
367-
agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1)
383+
match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames)
384+
agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames)
368385
hungarian_match_12, hungarian_match_21 = make_hungarian_match(agreement_scores, min_accuracy)
369386

370387
confusion = do_confusion_matrix(event_counts1, event_counts2, hungarian_match_12, match_event_count)
@@ -391,8 +408,8 @@ def test_do_count_score_and_perf():
391408

392409
event_counts1 = do_count_event(sorting1)
393410
event_counts2 = do_count_event(sorting2)
394-
match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=1)
395-
agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1)
411+
match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames)
412+
agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames)
396413
hungarian_match_12, hungarian_match_21 = make_hungarian_match(agreement_scores, min_accuracy)
397414

398415
count_score = do_count_score(event_counts1, event_counts2, hungarian_match_12, match_event_count)
@@ -415,13 +432,20 @@ def test_do_count_score_and_perf():
415432

416433
if __name__ == "__main__":
417434
test_make_match_count_matrix()
418-
test_make_agreement_scores()
419-
420-
test_make_possible_match()
421-
test_make_best_match()
422-
test_make_hungarian_match()
423-
424-
test_do_score_labels()
425-
test_compare_spike_trains()
426-
test_do_confusion_matrix()
427-
test_do_count_score_and_perf()
435+
test_make_match_count_matrix_sorting_with_itself_simple()
436+
test_make_match_count_matrix_sorting_with_itself_longer()
437+
test_make_match_count_matrix_with_mismatched_sortings()
438+
test_make_match_count_matrix_no_double_matching()
439+
test_make_match_count_matrix_repeated_matching_but_no_double_counting()
440+
test_make_match_count_matrix_test_proper_search_in_the_second_train()
441+
442+
# test_make_agreement_scores()
443+
444+
# test_make_possible_match()
445+
# test_make_best_match()
446+
# test_make_hungarian_match()
447+
448+
# test_do_score_labels()
449+
# test_compare_spike_trains()
450+
# test_do_confusion_matrix()
451+
# test_do_count_score_and_perf()

0 commit comments

Comments
 (0)