@@ -124,103 +124,105 @@ def get_optimized_compute_matching_matrix():
124124
125125 @numba .jit (nopython = True , nogil = True )
126126 def compute_matching_matrix (
127- frames_spike_train1 ,
128- frames_spike_train2 ,
127+ spike_frames_train1 ,
128+ spike_frames_train2 ,
129129 unit_indices1 ,
130130 unit_indices2 ,
131- num_units_sorting1 ,
132- num_units_sorting2 ,
131+ num_units_train1 ,
132+ num_units_train2 ,
133133 delta_frames ,
134134 ):
135135 """
136136 Compute a matrix representing the matches between two spike trains.
137137
138138 Given two spike trains, this function finds matching spikes based on a temporal proximity criterion
139139 defined by `delta_frames`. The resulting matrix indicates the number of matches between units
140- in `frames_spike_train1 ` and `frames_spike_train2 `.
140+ in `spike_frames_train1 ` and `spike_frames_train2 `.
141141
142142 Parameters
143143 ----------
144- frames_spike_train1 : ndarray
145- Array of frames for the first spike train. Should be ordered in ascending order.
146- frames_spike_train2 : ndarray
147- Array of frames for the second spike train. Should be ordered in ascending order.
144+ spike_frames_train1 : ndarray
145+ An array of integer frame numbers corresponding to spike times for the first train. Must be in ascending order.
146+ spike_frames_train2 : ndarray
147+ An array of integer frame numbers corresponding to spike times for the second train. Must be in ascending order.
148148 unit_indices1 : ndarray
149- Array indicating the unit indices corresponding to each spike in `frames_spike_train1 `.
149+ An array of integers where `unit_indices1[i]` gives the unit index associated with the spike at `spike_frames_train1[i] `.
150150 unit_indices2 : ndarray
151- Array indicating the unit indices corresponding to each spike in `frames_spike_train2 `.
152- num_units_sorting1 : int
153- Total number of units in the first spike train.
154- num_units_sorting2 : int
155- Total number of units in the second spike train.
151+ An array of integers where `unit_indices2[i]` gives the unit index associated with the spike at `spike_frames_train2[i] `.
152+ num_units_train1 : int
153+ The total count of unique units in the first spike train.
154+ num_units_train2 : int
155+ The total count of unique units in the second spike train.
156156 delta_frames : int
157- Maximum difference in frames between two spikes to consider them as a match.
157+ The inclusive upper limit on the frame difference for which two spikes are considered matching. That is
158+ if `abs(spike_frames_train1[i] - spike_frames_train2[j]) <= delta_frames` then the spikes at `spike_frames_train1[i]`
159+ and `spike_frames_train2[j]` are considered matching.
158160
159161 Returns
160162 -------
161163 matching_matrix : ndarray
162- A matrix of shape (num_units_sorting1, num_units_sorting2) where each entry [i, j] represents
163- the number of matching spikes between unit i of `frames_spike_train1` and unit j of `frames_spike_train2`.
164+ A 2D numpy array of shape `(num_units_train1, num_units_train2)`. Each element `[i, j]` represents
165+ the count of matching spike pairs between unit `i` from `spike_frames_train1` and unit `j` from `spike_frames_train2`.
166+
164167
165168 Notes
166169 -----
167170 This algorithm identifies matching spikes between two ordered spike trains.
168171 By iterating through each spike in the first train, it compares them against spikes in the second train,
169172 determining matches based on the two spikes frames being within `delta_frames` of each other.
170173
171- To avoid redundant comparisons the algorithm maintains a reference, `lower_search_limit_in_second_train `,
174+ To avoid redundant comparisons the algorithm maintains a reference, `second_train_search_start `,
172175 which signifies the minimal index in the second spike train that might match the upcoming spike
173- in the first train. This means that the start of the search moves forward in the second train as the
174- matches between the two trains are found decreasing the number of comparisons needed.
176+ in the first train.
177+
178+ The logic can be summarized as follows:
179+ 1. Iterate through each spike in the first train
180+ 2. For each spike, find the first match in the second train.
181+ 3. Save the index of the first match as the new `second_train_search_start `
182+ 3. For each match, find as many matches as possible from the first match onwards.
175183
176- An important condition here is thatthe same spike is not matched twice. This is managed by keeping track
177- of the last matched frame for each unit pair in `previous_frame1_match ` and `previous_frame2_match `
184+ An important condition here is that the same spike is not matched twice. This is managed by keeping track
185+ of the last matched frame for each unit pair in `last_match_frame1 ` and `last_match_frame2 `
178186
179187 For more details on the rationale behind this approach, refer to the documentation of this module and/or
180- the metrics section in SpikeForest documentation.
188+ the metrics section in SpikeForest documentation.
181189 """
182190
183- matching_matrix = np .zeros ((num_units_sorting1 , num_units_sorting2 ), dtype = np .uint16 )
191+ matching_matrix = np .zeros ((num_units_train1 , num_units_train2 ), dtype = np .uint16 )
184192
185193 # Used to avoid the same spike matching twice
186- previous_frame1_match = - np .ones_like (matching_matrix , dtype = np .int64 )
187- previous_frame2_match = - np .ones_like (matching_matrix , dtype = np .int64 )
188-
189- lower_search_limit_in_second_train = 0
190-
191- for index1 in range (len (frames_spike_train1 )):
192- # Keeps track of which frame in the second spike train should be used as a search start for matches
193- index2 = lower_search_limit_in_second_train
194- frame1 = frames_spike_train1 [index1 ]
195-
196- # Determine next_frame1 if current frame is not the last frame
197- not_in_the_last_loop = index1 < len (frames_spike_train1 ) - 1
198- if not_in_the_last_loop :
199- next_frame1 = frames_spike_train1 [index1 + 1 ]
200-
201- while index2 < len (frames_spike_train2 ):
202- frame2 = frames_spike_train2 [index2 ]
203- not_a_match = abs (frame1 - frame2 ) > delta_frames
204- if not_a_match :
205- # Go to the next frame in the first train
194+ last_match_frame1 = - np .ones_like (matching_matrix , dtype = np .int64 )
195+ last_match_frame2 = - np .ones_like (matching_matrix , dtype = np .int64 )
196+
197+ num_spike_frames_train1 = len (spike_frames_train1 )
198+ num_spike_frames_train2 = len (spike_frames_train2 )
199+
200+ # Keeps track of which frame in the second spike train should be used as a search start for matches
201+ second_train_search_start = 0
202+ for index1 in range (num_spike_frames_train1 ):
203+ frame1 = spike_frames_train1 [index1 ]
204+
205+ for index2 in range (second_train_search_start , num_spike_frames_train2 ):
206+ frame2 = spike_frames_train2 [index2 ]
207+ if frame2 < frame1 - delta_frames :
208+ # no match move the left limit for the next loop
209+ second_train_search_start += 1
210+ continue
211+ elif frame2 > frame1 + delta_frames :
212+ # no match stop search in train2 and continue increment in train1
206213 break
214+ else :
215+ # match
216+ unit_index1 , unit_index2 = unit_indices1 [index1 ], unit_indices2 [index2 ]
207217
208- # Map the match to a matrix
209- row , column = unit_indices1 [index1 ], unit_indices2 [index2 ]
210-
211- # The same spike cannot be matched twice see the notes in the docstring for more info on this constraint
212- if frame1 != previous_frame1_match [row , column ] and frame2 != previous_frame2_match [row , column ]:
213- previous_frame1_match [row , column ] = frame1
214- previous_frame2_match [row , column ] = frame2
215-
216- matching_matrix [row , column ] += 1
217-
218- index2 += 1
218+ if (
219+ frame1 != last_match_frame1 [unit_index1 , unit_index2 ]
220+ and frame2 != last_match_frame2 [unit_index1 , unit_index2 ]
221+ ):
222+ last_match_frame1 [unit_index1 , unit_index2 ] = frame1
223+ last_match_frame2 [unit_index1 , unit_index2 ] = frame2
219224
220- # Advance the lower_search_limit_in_second_train if the next frame in the first train does not match
221- not_a_match_with_next = abs (next_frame1 - frame2 ) > delta_frames
222- if not_a_match_with_next :
223- lower_search_limit_in_second_train = index2
225+ matching_matrix [unit_index1 , unit_index2 ] += 1
224226
225227 return matching_matrix
226228
@@ -230,7 +232,7 @@ def compute_matching_matrix(
230232 return compute_matching_matrix
231233
232234
233- def make_match_count_matrix (sorting1 , sorting2 , delta_frames , n_jobs = None ):
235+ def make_match_count_matrix (sorting1 , sorting2 , delta_frames ):
234236 num_units_sorting1 = sorting1 .get_num_units ()
235237 num_units_sorting2 = sorting2 .get_num_units ()
236238 matching_matrix = np .zeros ((num_units_sorting1 , num_units_sorting2 ), dtype = np .uint16 )
@@ -275,7 +277,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None):
275277 return match_event_counts_df
276278
277279
278- def make_agreement_scores (sorting1 , sorting2 , delta_frames , n_jobs = 1 ):
280+ def make_agreement_scores (sorting1 , sorting2 , delta_frames ):
279281 """
280282 Make the agreement matrix.
281283 No threshold (min_score) is applied at this step.
@@ -291,8 +293,6 @@ def make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1):
291293 The second sorting extractor
292294 delta_frames: int
293295 Number of frames to consider spikes coincident
294- n_jobs: int
295- Number of jobs to run in parallel
296296
297297 Returns
298298 -------
@@ -309,7 +309,7 @@ def make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1):
309309 event_counts1 = pd .Series (ev_counts1 , index = unit1_ids )
310310 event_counts2 = pd .Series (ev_counts2 , index = unit2_ids )
311311
312- match_event_count = make_match_count_matrix (sorting1 , sorting2 , delta_frames , n_jobs = n_jobs )
312+ match_event_count = make_match_count_matrix (sorting1 , sorting2 , delta_frames )
313313
314314 agreement_scores = make_agreement_scores_from_count (match_event_count , event_counts1 , event_counts2 )
315315
0 commit comments