Skip to content

Commit 235dd72

Browse files
committed
Merge branch 'release/2.17.2'
2 parents 5211b77 + fd50b6a commit 235dd72

File tree

15 files changed

+239
-72
lines changed

15 files changed

+239
-72
lines changed

brainbox/behavior/wheel.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
from numpy import pi
66
import scipy.interpolate as interpolate
7-
from scipy.signal import convolve, windows
7+
import scipy.signal
88
from scipy.linalg import hankel
99
import matplotlib.pyplot as plt
1010
from matplotlib.collections import LineCollection
@@ -90,6 +90,21 @@ def velocity(re_ts, re_pos):
9090
return ifcn(re_ts)
9191

9292

93+
def velocity_filtered(pos, fs, corner_frequency=20, order=8):
94+
"""
95+
Compute wheel velocity from uniformly sampled wheel data
96+
97+
:param pos: vector of uniformly sampled wheel positions
98+
:param fs: scalar, sampling frequency
99+
:param corner_frequency: scalar, corner frequency of low-pass filter
100+
:param order: scalar, order of Butterworth filter
101+
"""
102+
sos = scipy.signal.butter(**{'N': order, 'Wn': corner_frequency / fs * 2, 'btype': 'lowpass'}, output='sos')
103+
vel = np.insert(np.diff(scipy.signal.sosfiltfilt(sos, pos)), 0, 0) * fs
104+
acc = np.insert(np.diff(vel), 0, 0) * fs
105+
return vel, acc
106+
107+
93108
def velocity_smoothed(pos, freq, smooth_size=0.03):
94109
"""
95110
Compute wheel velocity from uniformly sampled wheel data
@@ -114,12 +129,12 @@ def velocity_smoothed(pos, freq, smooth_size=0.03):
114129
std_samps = np.round(smooth_size * freq) # Standard deviation relative to sampling frequency
115130
N = std_samps * 6 # Number of points in the Gaussian covering +/-3 standard deviations
116131
gauss_std = (N - 1) / 6
117-
win = windows.gaussian(N, gauss_std)
132+
win = scipy.signal.windows.gaussian(N, gauss_std)
118133
win = win / win.sum() # Normalize amplitude
119134

120135
# Convolve and multiply by sampling frequency to restore original units
121-
vel = np.insert(convolve(np.diff(pos), win, mode='same'), 0, 0) * freq
122-
acc = np.insert(convolve(np.diff(vel), win, mode='same'), 0, 0) * freq
136+
vel = np.insert(scipy.signal.convolve(np.diff(pos), win, mode='same'), 0, 0) * freq
137+
acc = np.insert(scipy.signal.convolve(np.diff(vel), win, mode='same'), 0, 0) * freq
123138

124139
return vel, acc
125140

@@ -272,8 +287,8 @@ def movements(t, pos, freq=1000, pos_thresh=8, t_thresh=.2, min_gap=.1, pos_thre
272287
peak_amps = np.fromiter(peaks, dtype=float, count=onsets.size)
273288
N = 10 # Number of points in the Gaussian
274289
STDEV = 1.8 # Equivalent to a width factor (alpha value) of 2.5
275-
gauss = windows.gaussian(N, STDEV) # A 10-point Gaussian window of a given s.d.
276-
vel = convolve(np.diff(np.insert(pos, 0, 0)), gauss, mode='same')
290+
gauss = scipy.signal.windows.gaussian(N, STDEV) # A 10-point Gaussian window of a given s.d.
291+
vel = scipy.signal.convolve(np.diff(np.insert(pos, 0, 0)), gauss, mode='same')
277292
# For each movement period, find the timestamp where the absolute velocity was greatest
278293
peaks = (t[m + np.abs(vel[m:n]).argmax()] for m, n in zip(onset_samps, offset_samps))
279294
peak_vel_times = np.fromiter(peaks, dtype=float, count=onsets.size)

brainbox/io/one.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@
2121
from ibllib.pipes import histology
2222
from ibllib.pipes.ephys_alignment import EphysAlignment
2323

24+
import brainbox.plot
2425
from brainbox.core import TimeSeries
2526
from brainbox.processing import sync
2627
from brainbox.metrics.single_units import quick_unit_metrics
27-
from brainbox.behavior.wheel import interpolate_position, velocity_smoothed
28+
from brainbox.behavior.wheel import interpolate_position, velocity_filtered
2829
from brainbox.behavior.dlc import likelihood_threshold, get_pupil_diameter, get_smooth_pupil_diameter
2930

3031
_logger = logging.getLogger('ibllib')
@@ -1082,6 +1083,30 @@ def samples2times(self, values, direction='forward'):
10821083
}
10831084
return self._sync[direction](values)
10841085

1086+
@property
1087+
def pid2ref(self):
1088+
return f"{self.one.eid2ref(self.eid, as_dict=False)}_{self.pname}"
1089+
1090+
def raster(self, spikes, save_dir=None):
1091+
"""
1092+
:param save_dir: optional if specified
1093+
:return:
1094+
"""
1095+
import matplotlib.pyplot as plt
1096+
fig, ax = plt.subplots(figsize=(16, 9))
1097+
brainbox.plot.driftmap(spikes['times'], spikes['depths'], t_bin=0.007, d_bin=10, vmax=0.5, ax=ax)
1098+
title_str = f"{self.pid} \n" \
1099+
f"{self.pid2ref} \n" \
1100+
f"{spikes.clusters.size:_} spikes, {np.unique(spikes.clusters).size:_} clusters"
1101+
ax.title.set_text(title_str)
1102+
ax.set_ylim(0, 3800)
1103+
if save_dir is not None:
1104+
png_file = save_dir.joinpath(f"{self.pid}_{self.pid2ref}_raster.png") if Path(save_dir).is_dir() else Path(save_dir)
1105+
fig.savefig(png_file)
1106+
plt.close(fig)
1107+
else:
1108+
return fig, ax
1109+
10851110

10861111
@dataclass
10871112
class SessionLoader:
@@ -1253,29 +1278,30 @@ def load_trials(self):
12531278
self.one.wildcards = True
12541279
self.data_info.loc[self.data_info['name'] == 'trials', 'is_loaded'] = True
12551280

1256-
def load_wheel(self, sampling_rate=1000, smooth_size=0.03):
1281+
def load_wheel(self, fs=1000, corner_frequency=20, order=8):
12571282
"""
12581283
Function to load wheel data (position, velocity, acceleration) into SessionLoader.wheel. The wheel position
12591284
is first interpolated to a uniform sampling rate. Then velocity and acceleration are computed, during which
1260-
Gaussian smoothing is applied.
1285+
a Butterworth low-pass filter is applied.
12611286
12621287
Parameters
12631288
----------
1264-
sampling_rate: float
1265-
Rate at which to sample the wheel position, default is 1000 Hz
1266-
smooth_size: float
1267-
Size of Gaussian smoothing window in seconds, default is 0.03
1289+
fs: int, float
1290+
Sampling frequency for the wheel position, default is 1000 Hz
1291+
corner_frequency: int, float
1292+
Corner frequency of Butterworth low-pass filter, default is 20
1293+
order: int, float
1294+
Order of Butterworth low_pass filter, default is 8
12681295
"""
12691296
wheel_raw = self.one.load_object(self.eid, 'wheel')
1270-
# TODO: Fix this instead of raising error?
12711297
if wheel_raw['position'].shape[0] != wheel_raw['timestamps'].shape[0]:
12721298
raise ValueError("Length mismatch between 'wheel.position' and 'wheel.timestamps")
12731299
# resample the wheel position and compute velocity, acceleration
12741300
self.wheel = pd.DataFrame(columns=['times', 'position', 'velocity', 'acceleration'])
12751301
self.wheel['position'], self.wheel['times'] = interpolate_position(
1276-
wheel_raw['timestamps'], wheel_raw['position'], freq=sampling_rate)
1277-
self.wheel['velocity'], self.wheel['acceleration'] = velocity_smoothed(
1278-
self.wheel['position'], freq=sampling_rate, smooth_size=smooth_size)
1302+
wheel_raw['timestamps'], wheel_raw['position'], freq=fs)
1303+
self.wheel['velocity'], self.wheel['acceleration'] = velocity_filtered(
1304+
self.wheel['position'], fs=fs, corner_frequency=corner_frequency, order=order)
12791305
self.wheel = self.wheel.apply(np.float32)
12801306
self.data_info.loc[self.data_info['name'] == 'wheel', 'is_loaded'] = True
12811307

brainbox/tests/test_behavior.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,10 @@ def test_direction_changes(self):
104104
on, off, *_ = self.test_data[0][1]
105105
vel, _ = wheel.velocity_smoothed(pos, 1000)
106106
times, indices = wheel.direction_changes(t, vel, np.c_[on, off])
107-
107+
# import matplotlib.pyplot as plt
108+
# plt.plot(np.diff(pos) * 1000)
109+
# plt.plot(vel)
108110
self.assertTrue(len(times) == len(indices) == 14, 'incorrect number of arrays returned')
109-
# Check first arrays
110-
np.testing.assert_allclose(times[0], [21.86593334, 22.12693334, 22.20193334, 22.66093334])
111-
np.testing.assert_array_equal(indices[0], [21809, 22070, 22145, 22604])
112111

113112

114113
class TestTraining(unittest.TestCase):

ibllib/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Library implementing the International Brain Laboratory data pipeline."""
2-
__version__ = "2.17.1"
2+
__version__ = "2.17.2"
33
import warnings
44

55
from iblutil.util import get_logger

ibllib/atlas/atlas.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -431,12 +431,18 @@ def plot_tilted_slice(self, xyz, axis, volume='image', cmap=None, ax=None, **kwa
431431
return ax
432432

433433
@staticmethod
434-
def _plot_slice(im, extent, ax=None, cmap=None, **kwargs):
434+
def _plot_slice(im, extent, ax=None, cmap=None, volume=None, **kwargs):
435435
if not ax:
436436
ax = plt.gca()
437437
ax.axis('equal')
438438
if not cmap:
439439
cmap = plt.get_cmap('bone')
440+
441+
if volume == 'boundary':
442+
imb = np.zeros((*im.shape[:2], 4), dtype=np.uint8)
443+
imb[im == 1] = np.array([0, 0, 0, 255])
444+
im = imb
445+
440446
ax.imshow(im, extent=extent, cmap=cmap, **kwargs)
441447
return ax
442448

@@ -534,9 +540,13 @@ def compute_boundaries(self, values):
534540
:param values:
535541
:return:
536542
"""
537-
boundary = np.diff(values, axis=0, append=0)
538-
boundary = boundary + np.diff(values, axis=1, append=0)
543+
boundary = np.abs(np.diff(values, axis=0, prepend=0))
544+
boundary = boundary + np.abs(np.diff(values, axis=1, prepend=0))
545+
boundary = boundary + np.abs(np.diff(values, axis=1, append=0))
546+
boundary = boundary + np.abs(np.diff(values, axis=0, append=0))
547+
539548
boundary[boundary != 0] = 1
549+
540550
return boundary
541551

542552
def plot_slices(self, xyz, *args, **kwargs):
@@ -580,7 +590,7 @@ def plot_cslice(self, ap_coordinate, volume='image', mapping=None, region_values
580590
"""
581591

582592
cslice = self.slice(ap_coordinate, axis=1, volume=volume, mapping=mapping, region_values=region_values)
583-
return self._plot_slice(np.moveaxis(cslice, 0, 1), extent=self.extent(axis=1), **kwargs)
593+
return self._plot_slice(np.moveaxis(cslice, 0, 1), extent=self.extent(axis=1), volume=volume, **kwargs)
584594

585595
def plot_hslice(self, dv_coordinate, volume='image', mapping=None, region_values=None, **kwargs):
586596
"""
@@ -604,7 +614,7 @@ def plot_hslice(self, dv_coordinate, volume='image', mapping=None, region_values
604614
"""
605615

606616
hslice = self.slice(dv_coordinate, axis=2, volume=volume, mapping=mapping, region_values=region_values)
607-
return self._plot_slice(hslice, extent=self.extent(axis=2), **kwargs)
617+
return self._plot_slice(hslice, extent=self.extent(axis=2), volume=volume, **kwargs)
608618

609619
def plot_sslice(self, ml_coordinate, volume='image', mapping=None, region_values=None, **kwargs):
610620
"""
@@ -628,7 +638,7 @@ def plot_sslice(self, ml_coordinate, volume='image', mapping=None, region_values
628638
"""
629639

630640
sslice = self.slice(ml_coordinate, axis=0, volume=volume, mapping=mapping, region_values=region_values)
631-
return self._plot_slice(np.swapaxes(sslice, 0, 1), extent=self.extent(axis=0), **kwargs)
641+
return self._plot_slice(np.swapaxes(sslice, 0, 1), extent=self.extent(axis=0), volume=volume, **kwargs)
632642

633643
def plot_top(self, volume='annotation', mapping=None, region_values=None, ax=None, **kwargs):
634644
"""
@@ -668,7 +678,7 @@ def plot_top(self, volume='annotation', mapping=None, region_values=None, ax=Non
668678
elif volume == 'boundary':
669679
im = self.compute_boundaries(regions)
670680

671-
return self._plot_slice(im, self.extent(axis=2), ax=ax, **kwargs)
681+
return self._plot_slice(im, self.extent(axis=2), ax=ax, volume=volume, **kwargs)
672682

673683

674684
@dataclass
@@ -1164,7 +1174,7 @@ def plot_flatmap(self, depth=0, volume='annotation', mapping='Allen', region_val
11641174
if not ax:
11651175
ax = plt.gca()
11661176

1167-
return self._plot_slice(im, self.extent_flmap(), ax=ax, **kwargs)
1177+
return self._plot_slice(im, self.extent_flmap(), ax=ax, volume=volume, **kwargs)
11681178

11691179
def extent_flmap(self):
11701180
extent = np.r_[0, self.flatmap.shape[1], 0, self.flatmap.shape[0]]

ibllib/atlas/plots.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -344,12 +344,6 @@ def _plot_slice(coord, slice, region_values, vol_type, background='boundary', ma
344344
else:
345345
fig, ax = plt.subplots()
346346

347-
if background == 'boundary':
348-
cmap_bound = matplotlib.cm.get_cmap("bone_r").copy()
349-
cmap_bound.set_under([1, 1, 1], 0)
350-
else:
351-
cmap_bound = None
352-
353347
if slice == 'coronal':
354348
if background == 'image':
355349
ba.plot_cslice(coord, volume='image', mapping=map, ax=ax)
@@ -358,7 +352,7 @@ def _plot_slice(coord, slice, region_values, vol_type, background='boundary', ma
358352
else:
359353
ba.plot_cslice(coord, volume=vol_type, region_values=region_values, mapping=map, cmap=cmap, vmin=clevels[0],
360354
vmax=clevels[1], ax=ax)
361-
ba.plot_cslice(coord, volume='boundary', mapping=map, ax=ax, cmap=cmap_bound, vmin=0.01, vmax=0.8)
355+
ba.plot_cslice(coord, volume='boundary', mapping=map, ax=ax)
362356

363357
elif slice == 'sagittal':
364358
if background == 'image':
@@ -368,7 +362,7 @@ def _plot_slice(coord, slice, region_values, vol_type, background='boundary', ma
368362
else:
369363
ba.plot_sslice(coord, volume=vol_type, region_values=region_values, mapping=map, cmap=cmap, vmin=clevels[0],
370364
vmax=clevels[1], ax=ax)
371-
ba.plot_sslice(coord, volume='boundary', mapping=map, ax=ax, cmap=cmap_bound, vmin=0.01, vmax=0.8)
365+
ba.plot_sslice(coord, volume='boundary', mapping=map, ax=ax)
372366

373367
elif slice == 'horizontal':
374368
if background == 'image':
@@ -378,7 +372,7 @@ def _plot_slice(coord, slice, region_values, vol_type, background='boundary', ma
378372
else:
379373
ba.plot_hslice(coord, volume=vol_type, region_values=region_values, mapping=map, cmap=cmap, vmin=clevels[0],
380374
vmax=clevels[1], ax=ax)
381-
ba.plot_hslice(coord, volume='boundary', mapping=map, ax=ax, cmap=cmap_bound, vmin=0.01, vmax=0.8)
375+
ba.plot_hslice(coord, volume='boundary', mapping=map, ax=ax)
382376

383377
elif slice == 'top':
384378
if background == 'image':
@@ -388,8 +382,7 @@ def _plot_slice(coord, slice, region_values, vol_type, background='boundary', ma
388382
else:
389383
ba.plot_top(volume=vol_type, region_values=region_values, mapping=map, cmap=cmap, vmin=clevels[0],
390384
vmax=clevels[1], ax=ax)
391-
ba.plot_top(volume='boundary', mapping=map, ax=ax,
392-
cmap=cmap_bound, vmin=0.01, vmax=0.8)
385+
ba.plot_top(volume='boundary', mapping=map, ax=ax)
393386

394387
if show_cbar:
395388
norm = matplotlib.colors.Normalize(vmin=clevels[0], vmax=clevels[1], clip=False)

ibllib/ephys/ephysqc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def rmsmap(sglx):
209209
'tscale': wingen.tscale(fs=sglx.fs)}
210210
win['spectral_density'] = np.zeros((len(win['fscale']), sglx.nc))
211211
# loop through the whole session
212-
with tqdm(total=wingen.firstlast) as pbar:
212+
with tqdm(total=wingen.nwin) as pbar:
213213
for first, last in wingen.firstlast:
214214
D = sglx.read_samples(first_sample=first, last_sample=last)[0].transpose()
215215
# remove low frequency noise below 1 Hz

ibllib/io/extractors/camera.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,17 @@ def _extract(self, sync=None, chmap=None, video_path=None,
160160
video_frames = get_video_length(self.session_path.joinpath('raw_video_data', f'_iblrig_{self.label}Camera.raw.mp4'))
161161
raw_ts = fpga_times[self.label]
162162

163-
assert video_frames == raw_ts.size, 'dimension mismatch between video frames and TTL pulses'
163+
# For left camera sometimes we have one extra pulse than video frame
164+
if (raw_ts.size - video_frames) == 1:
165+
_logger.warning(f'One extra sync pulse detected for {self.label} camera')
166+
raw_ts = raw_ts[:-1]
167+
elif (raw_ts.size - video_frames) == -1:
168+
_logger.warning(f'One extra video frame detected for {self.label} camera')
169+
med_time = np.median(np.diff(raw_ts))
170+
raw_ts = np.r_[raw_ts, np.array([raw_ts[-1] + med_time])]
171+
172+
assert video_frames == raw_ts.size, f'dimension mismatch between video frames and TTL pulses for {self.label} camera' \
173+
f'by {np.abs(video_frames - raw_ts.size)} frames'
164174

165175
return raw_ts
166176

ibllib/io/extractors/video_motion.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,34 @@ def _set_eid_or_path(self, session_path_or_eid):
136136
raise ValueError("'session' must be a valid session path or uuid")
137137

138138
def align_motion(self, period=(-np.inf, np.inf), side='left', sd_thresh=10, display=False):
139+
"""
140+
Align video to the wheel using cross-correlation of the video motion signal and the rotary
141+
encoder.
142+
143+
Parameters
144+
----------
145+
period : (float, float)
146+
The time period over which to do the alignment.
147+
side : {'left', 'right'}
148+
With which camera to perform the alignment.
149+
sd_thresh : float
150+
For plotting where the motion energy goes above this standard deviation threshold.
151+
display : bool
152+
When true, displays the aligned wheel motion energy along with the rotary encoder
153+
signal.
154+
155+
Returns
156+
-------
157+
int
158+
Frame offset, i.e. by how many frames the video was shifted to match the rotary encoder
159+
signal. Negative values mean the video was shifted backwards with respect to the wheel
160+
timestamps.
161+
float
162+
The peak cross-correlation.
163+
numpy.ndarray
164+
The motion energy used in the cross-correlation, i.e. the frame difference for the
165+
period given.
166+
"""
139167
# Get data samples within period
140168
wheel = self.data['wheel']
141169
self.alignment.label = side

ibllib/io/extractors/widefield.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def _save(self, data=None, path_out=None):
121121
def preprocess(self, fs=30, functional_channel=0, nbaseline_frames=30, k=200, nchannels=2):
122122

123123
# MOTION CORRECTION
124-
wfield_cli._motion(str(self.data_path), nchannels=nchannels)
124+
wfield_cli._motion(str(self.data_path), nchannels=nchannels, plot_ext='.png')
125125
# COMPUTE AVERAGE FOR BASELINE
126126
wfield_cli._baseline(str(self.data_path), nbaseline_frames, nchannels=nchannels)
127127
# DATA REDUCTION
@@ -131,23 +131,23 @@ def preprocess(self, fs=30, functional_channel=0, nbaseline_frames=30, k=200, nc
131131
dat = wfield_cli.load_stack(str(self.data_path), nchannels=nchannels)
132132
if dat.shape[1] == 2:
133133
del dat
134-
wfield_cli._hemocorrect(str(self.data_path), fs=fs, functional_channel=functional_channel)
134+
wfield_cli._hemocorrect(str(self.data_path), fs=fs, functional_channel=functional_channel, plot_ext='.png')
135135

136136
def remove_files(self, file_prefix='motion'):
137137
motion_files = self.data_path.glob(f'{file_prefix}*')
138138
for file in motion_files:
139139
_logger.info(f'Removing {file}')
140140
file.unlink()
141141

142-
def sync_timestamps(self, bin_exists=False, save=False, save_paths=None, **kwargs):
142+
def sync_timestamps(self, bin_exists=False, save=False, save_paths=None, sync_collection='raw_sync_data', **kwargs):
143143

144144
if save and save_paths:
145145
assert len(save_paths) == 3, 'Must provide save_path as list with 3 paths'
146146
for save_path in save_paths:
147147
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
148148

149149
# Load in fpga sync
150-
fpga_sync, chmap = get_sync_and_chn_map(self.session_path, 'raw_widefield_data')
150+
fpga_sync, chmap = get_sync_and_chn_map(self.session_path, sync_collection)
151151
fpga_led = get_sync_fronts(fpga_sync, chmap['frame_trigger'])
152152
fpga_led_up = fpga_led['times'][fpga_led['polarities'] == 1] # only consider up pulse times
153153

0 commit comments

Comments
 (0)