Skip to content

Commit d9c2f65

Browse files
Fix gaussian filter with time vector (#4268)
Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com>
1 parent d340536 commit d9c2f65

File tree

3 files changed

+27
-10
lines changed

3 files changed

+27
-10
lines changed

src/spikeinterface/core/zarrextractors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def __init__(
161161
if np.isnan(t_start):
162162
t_start = None
163163
time_kwargs["t_start"] = t_start
164-
time_kwargs["sampling_frequency"] = sampling_frequency
164+
time_kwargs["sampling_frequency"] = sampling_frequency
165165

166166
rec_segment = ZarrRecordingSegment(self._root, trace_name, **time_kwargs)
167167
self.add_recording_segment(rec_segment)

src/spikeinterface/preprocessing/filter_gaussian.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,29 +50,37 @@ def __init__(
5050
raise ValueError("At least one of `freq_min`,`freq_max` should be specified.")
5151

5252
for parent_segment in recording._recording_segments:
53-
self.add_recording_segment(GaussianFilterRecordingSegment(parent_segment, freq_min, freq_max, margin_sd))
53+
# Sampling frequency is taken from recording since segments may not have it set (in case of time_vector)
54+
self.add_recording_segment(
55+
GaussianFilterRecordingSegment(parent_segment, freq_min, freq_max, margin_sd, self.sampling_frequency)
56+
)
5457

5558
self._kwargs = {"recording": recording, "freq_min": freq_min, "freq_max": freq_max}
5659

5760

5861
class GaussianFilterRecordingSegment(BasePreprocessorSegment):
5962
def __init__(
60-
self, parent_recording_segment: BaseRecordingSegment, freq_min: float, freq_max: float, margin_sd: float = 5.0
63+
self,
64+
parent_recording_segment: BaseRecordingSegment,
65+
freq_min: float,
66+
freq_max: float,
67+
margin_sd: float = 5.0,
68+
parent_sampling_frequency: float = None,
6169
):
6270
BasePreprocessorSegment.__init__(self, parent_recording_segment)
6371

6472
self.freq_min = freq_min
6573
self.freq_max = freq_max
6674
self.cached_gaussian = dict()
6775

68-
sf = parent_recording_segment.sampling_frequency
76+
self.parent_sampling_frequency = parent_sampling_frequency
6977

7078
# Margin from widest gaussian
7179
sigmas = []
7280
if freq_min is not None:
73-
sigmas.append(sf / (2 * np.pi * freq_min))
81+
sigmas.append(self.parent_sampling_frequency / (2 * np.pi * freq_min))
7482
if freq_max is not None:
75-
sigmas.append(sf / (2 * np.pi * freq_max))
83+
sigmas.append(self.parent_sampling_frequency / (2 * np.pi * freq_max))
7684
self.margin = 1 + int(max(sigmas) * margin_sd)
7785

7886
def get_traces(
@@ -117,11 +125,12 @@ def _create_gaussian(self, N: int, cutoff_f: float):
117125
if cutoff_f in self.cached_gaussian and N in self.cached_gaussian[cutoff_f]:
118126
return self.cached_gaussian[cutoff_f][N]
119127

120-
sf = self.parent_recording_segment.sampling_frequency
121-
faxis = np.fft.fftfreq(N, d=1 / sf)
128+
faxis = np.fft.fftfreq(N, d=1 / self.parent_sampling_frequency)
122129

123-
if cutoff_f > sf / 8: # The Fourier transform of a Gaussian with a very low sigma isn't a Gaussian.
124-
sigma = sf / (2 * np.pi * cutoff_f)
130+
if (
131+
cutoff_f > self.parent_sampling_frequency / 8
132+
): # The Fourier transform of a Gaussian with a very low sigma isn't a Gaussian.
133+
sigma = self.parent_sampling_frequency / (2 * np.pi * cutoff_f)
125134
limit = int(round(5 * sigma)) + 1
126135
xaxis = np.arange(-limit, limit + 1) / sigma
127136
gaussian = normal_pdf(xaxis) / sigma

src/spikeinterface/preprocessing/tests/test_filter_gaussian.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@ def test_filter_gaussian(tmp_path):
3636
assert np.allclose(original_trace[60:-60], saved1_trace[60:-60], rtol=1e-3, atol=1e-3)
3737
assert np.allclose(original_trace[60:-60], saved2_trace[60:-60], rtol=1e-3, atol=1e-3)
3838

39+
# test filter gaussian with time_Vector
40+
for segment_index in range(recording.get_num_segments()):
41+
times = recording.get_times(segment_index) + (segment_index + 1) * 10.0
42+
recording.set_times(times, segment_index=segment_index)
43+
44+
rec_filtered_tv = gaussian_filter(recording)
45+
assert rec_filtered_tv.get_traces(segment_index=0, end_frame=100).shape == (100, 3)
46+
3947

4048
@pytest.mark.parametrize("freq_min", [None, 10, 50, 100])
4149
@pytest.mark.parametrize("freq_max", [None, 10, 50, 100])

0 commit comments

Comments
 (0)