Skip to content

Commit b700ca1

Browse files
committed
Merge branch 'develop' into docs
2 parents 2f61fb4 + e1eb551 commit b700ca1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+1414
-5778
lines changed

brainbox/behavior/wheel.py

Lines changed: 0 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
"""
22
Set of functions to handle wheel data.
33
"""
4-
import logging
5-
import warnings
6-
import traceback
7-
84
import numpy as np
95
from numpy import pi
106
from iblutil.numerical import between_sorted
@@ -68,42 +64,6 @@ def interpolate_position(re_ts, re_pos, freq=1000, kind='linear', fill_gaps=None
6864
return yinterp, t
6965

7066

71-
def velocity(re_ts, re_pos):
72-
"""
73-
(DEPRECATED) Compute wheel velocity from non-uniformly sampled wheel data. Returns the velocity
74-
at the same samples locations as the position through interpolation.
75-
76-
Parameters
77-
----------
78-
re_ts : array_like
79-
Array of timestamps
80-
re_pos: array_like
81-
Array of unwrapped wheel positions
82-
83-
Returns
84-
-------
85-
np.ndarray
86-
numpy array of velocities
87-
"""
88-
for line in traceback.format_stack():
89-
print(line.strip())
90-
91-
msg = 'brainbox.behavior.wheel.velocity will soon be removed. Use velocity_filtered instead.'
92-
warnings.warn(msg, FutureWarning)
93-
logging.getLogger(__name__).warning(msg)
94-
95-
dp = np.diff(re_pos)
96-
dt = np.diff(re_ts)
97-
# Compute raw velocity
98-
vel = dp / dt
99-
# Compute velocity time scale
100-
tv = re_ts[:-1] + dt / 2
101-
# interpolate over original time scale
102-
if tv.size > 1:
103-
ifcn = interpolate.interp1d(tv, vel, fill_value="extrapolate")
104-
return ifcn(re_ts)
105-
106-
10767
def velocity_filtered(pos, fs, corner_frequency=20, order=8):
10868
"""
10969
Compute wheel velocity from uniformly sampled wheel data.
@@ -130,83 +90,6 @@ def velocity_filtered(pos, fs, corner_frequency=20, order=8):
13090
return vel, acc
13191

13292

133-
def velocity_smoothed(pos, freq, smooth_size=0.03):
134-
"""
135-
(DEPRECATED) Compute wheel velocity from uniformly sampled wheel data.
136-
137-
Parameters
138-
----------
139-
pos : array_like
140-
Array of wheel positions
141-
smooth_size : float
142-
Size of Gaussian smoothing window in seconds
143-
freq : float
144-
Sampling frequency of the data
145-
146-
Returns
147-
-------
148-
vel : np.ndarray
149-
Array of velocity values
150-
acc : np.ndarray
151-
Array of acceleration values
152-
"""
153-
for line in traceback.format_stack():
154-
print(line.strip())
155-
156-
msg = 'brainbox.behavior.wheel.velocity_smoothed will be removed. Use velocity_filtered instead.'
157-
warnings.warn(msg, FutureWarning)
158-
logging.getLogger(__name__).warning(msg)
159-
160-
# Define our smoothing window with an area of 1 so the units won't be changed
161-
std_samps = np.round(smooth_size * freq) # Standard deviation relative to sampling frequency
162-
N = std_samps * 6 # Number of points in the Gaussian covering +/-3 standard deviations
163-
gauss_std = (N - 1) / 6
164-
win = scipy.signal.windows.gaussian(N, gauss_std)
165-
win = win / win.sum() # Normalize amplitude
166-
167-
# Convolve and multiply by sampling frequency to restore original units
168-
vel = np.insert(scipy.signal.convolve(np.diff(pos), win, mode='same'), 0, 0) * freq
169-
acc = np.insert(scipy.signal.convolve(np.diff(vel), win, mode='same'), 0, 0) * freq
170-
171-
return vel, acc
172-
173-
174-
def last_movement_onset(t, vel, event_time):
175-
"""
176-
(DEPRECATED) Find the time at which movement started, given an event timestamp that occurred during the
177-
movement.
178-
179-
Movement start is defined as the first sample after the velocity has been zero for at least 50ms.
180-
Wheel inputs should be evenly sampled.
181-
182-
:param t: numpy array of wheel timestamps in seconds
183-
:param vel: numpy array of wheel velocities
184-
:param event_time: timestamp anywhere during movement of interest, e.g. peak velocity
185-
:return: timestamp of movement onset
186-
"""
187-
for line in traceback.format_stack():
188-
print(line.strip())
189-
190-
msg = 'brainbox.behavior.wheel.last_movement_onset has been deprecated. Use get_movement_onset instead.'
191-
warnings.warn(msg, FutureWarning)
192-
logging.getLogger(__name__).warning(msg)
193-
194-
# Look back from timestamp
195-
threshold = 50e-3
196-
mask = t < event_time
197-
times = t[mask]
198-
vel = vel[mask]
199-
t = None # Initialize
200-
for i, t in enumerate(times[::-1]):
201-
i = times.size - i
202-
idx = np.min(np.where((t - times) < threshold))
203-
if np.max(np.abs(vel[idx:i])) < 0.5:
204-
break
205-
206-
# Return timestamp
207-
return t
208-
209-
21093
def get_movement_onset(intervals, event_times):
21194
"""
21295
Find the time at which movement started, given an event timestamp that occurred during the

brainbox/io/one.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import spikeglx
2121

2222
import ibldsp.voltage
23+
from ibldsp.waveform_extraction import WaveformsLoader
2324
from iblutil.util import Bunch
2425
from iblatlas.atlas import AllenAtlas, BrainRegions
2526
from iblatlas import atlas
@@ -975,6 +976,21 @@ def raw_electrophysiology(self, stream=True, band='ap', **kwargs):
975976
if cbin_file is not None:
976977
return spikeglx.Reader(cbin_file)
977978

979+
def download_raw_waveforms(self, **kwargs):
980+
"""
981+
Downloads raw waveforms extracted from sorting to local disk.
982+
"""
983+
_logger.debug(f"loading waveforms from {self.collection}")
984+
return self.one.load_object(
985+
self.eid, "waveforms",
986+
attribute=["traces", "templates", "table", "channels"],
987+
collection=self._get_spike_sorting_collection("pykilosort"), download_only=True, **kwargs
988+
)
989+
990+
def raw_waveforms(self, **kwargs):
991+
wf_paths = self.download_raw_waveforms(**kwargs)
992+
return WaveformsLoader(wf_paths[0].parent, wfs_dtype=np.float16)
993+
978994
def load_channels(self, **kwargs):
979995
"""
980996
Loads channels
@@ -1318,6 +1334,7 @@ class SessionLoader:
13181334
one: One = None
13191335
session_path: Path = ''
13201336
eid: str = ''
1337+
revision: str = ''
13211338
data_info: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
13221339
trials: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
13231340
wheel: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
@@ -1445,7 +1462,7 @@ def load_trials(self, collection=None):
14451462
# itiDuration frequently has a mismatched dimension, and we don't need it, exclude using regex
14461463
self.one.wildcards = False
14471464
self.trials = self.one.load_object(
1448-
self.eid, 'trials', collection=collection, attribute=r'(?!itiDuration).*').to_df()
1465+
self.eid, 'trials', collection=collection, attribute=r'(?!itiDuration).*', revision=self.revision or None).to_df()
14491466
self.one.wildcards = True
14501467
self.data_info.loc[self.data_info['name'] == 'trials', 'is_loaded'] = True
14511468

@@ -1468,7 +1485,7 @@ def load_wheel(self, fs=1000, corner_frequency=20, order=8, collection=None):
14681485
"""
14691486
if not collection:
14701487
collection = self._find_behaviour_collection('wheel')
1471-
wheel_raw = self.one.load_object(self.eid, 'wheel', collection=collection)
1488+
wheel_raw = self.one.load_object(self.eid, 'wheel', collection=collection, revision=self.revision or None)
14721489
if wheel_raw['position'].shape[0] != wheel_raw['timestamps'].shape[0]:
14731490
raise ValueError("Length mismatch between 'wheel.position' and 'wheel.timestamps")
14741491
# resample the wheel position and compute velocity, acceleration
@@ -1498,7 +1515,7 @@ def load_pose(self, likelihood_thr=0.9, views=['left', 'right', 'body']):
14981515
# empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger
14991516
self.pose = {}
15001517
for view in views:
1501-
pose_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['dlc', 'times'])
1518+
pose_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['dlc', 'times'], revision=self.revision or None)
15021519
# Double check if video timestamps are correct length or can be fixed
15031520
times_fixed, dlc = self._check_video_timestamps(view, pose_raw['times'], pose_raw['dlc'])
15041521
self.pose[f'{view}Camera'] = likelihood_threshold(dlc, likelihood_thr)
@@ -1525,7 +1542,8 @@ def load_motion_energy(self, views=['left', 'right', 'body']):
15251542
# empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger
15261543
self.motion_energy = {}
15271544
for view in views:
1528-
me_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['ROIMotionEnergy', 'times'])
1545+
me_raw = self.one.load_object(
1546+
self.eid, f'{view}Camera', attribute=['ROIMotionEnergy', 'times'], revision=self.revision or None)
15291547
# Double check if video timestamps are correct length or can be fixed
15301548
times_fixed, motion_energy = self._check_video_timestamps(
15311549
view, me_raw['times'], me_raw['ROIMotionEnergy'])
@@ -1550,7 +1568,7 @@ def load_pupil(self, snr_thresh=5.):
15501568
will be considered unusable and will be discarded.
15511569
"""
15521570
# Try to load from features
1553-
feat_raw = self.one.load_object(self.eid, 'leftCamera', attribute=['times', 'features'])
1571+
feat_raw = self.one.load_object(self.eid, 'leftCamera', attribute=['times', 'features'], revision=self.revision or None)
15541572
if 'features' in feat_raw.keys():
15551573
times_fixed, feats = self._check_video_timestamps('left', feat_raw['times'], feat_raw['features'])
15561574
self.pupil = feats.copy()

brainbox/tests/test_behavior.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,6 @@ def test_get_movement_onset(self):
124124
with self.assertRaises(ValueError):
125125
wheel.get_movement_onset(intervals, np.random.permutation(self.trials['feedback_times']))
126126

127-
def test_velocity_deprecation(self):
128-
"""Ensure brainbox.behavior.wheel.velocity is removed."""
129-
from datetime import datetime
130-
self.assertTrue(datetime.today() < datetime(2024, 8, 1),
131-
'remove brainbox.behavior.wheel.velocity, velocity_smoothed and last_movement_onset')
132-
133127

134128
class TestTraining(unittest.TestCase):
135129
def setUp(self):

ibllib/ephys/ephysqc.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -205,57 +205,65 @@ def run(self, update: bool = False, overwrite: bool = True, stream: bool = None,
205205
return qc_files
206206

207207

208-
def rmsmap(sglx):
208+
def rmsmap(sglx, spectra=True, nmod=1):
209209
"""
210210
Computes RMS map in time domain and spectra for each channel of Neuropixel probe
211211
212212
:param sglx: Open spikeglx reader
213+
:param spectra: Whether to compute the spectra
214+
:param nmod: take every nmod windows, in cases where we don't want to compute over the whole signal
213215
:return: a dictionary with amplitudes in channeltime space, channelfrequency space, time
214216
and frequency scales
215217
"""
216218
rms_win_length_samples = 2 ** np.ceil(np.log2(sglx.fs * RMS_WIN_LENGTH_SECS))
217219
# the window generator will generates window indices
218220
wingen = utils.WindowGenerator(ns=sglx.ns, nswin=rms_win_length_samples, overlap=0)
221+
nwin = np.ceil(wingen.nwin / nmod).astype(int)
219222
# pre-allocate output dictionary of numpy arrays
220-
win = {'TRMS': np.zeros((wingen.nwin, sglx.nc)),
221-
'nsamples': np.zeros((wingen.nwin,)),
223+
win = {'TRMS': np.zeros((nwin, sglx.nc)),
224+
'nsamples': np.zeros((nwin,)),
222225
'fscale': fourier.fscale(WELCH_WIN_LENGTH_SAMPLES, 1 / sglx.fs, one_sided=True),
223-
'tscale': wingen.tscale(fs=sglx.fs)}
226+
'tscale': wingen.tscale(fs=sglx.fs)[::nmod]}
224227
win['spectral_density'] = np.zeros((len(win['fscale']), sglx.nc))
225228
# loop through the whole session
226229
with tqdm(total=wingen.nwin) as pbar:
227-
for first, last in wingen.firstlast:
230+
for iwindow, (first, last) in enumerate(wingen.firstlast):
231+
if np.mod(iwindow, nmod) != 0:
232+
continue
233+
228234
D = sglx.read_samples(first_sample=first, last_sample=last)[0].transpose()
229235
# remove low frequency noise below 1 Hz
230236
D = fourier.hp(D, 1 / sglx.fs, [0, 1])
231-
iw = wingen.iw
237+
iw = np.floor(wingen.iw / nmod).astype(int)
232238
win['TRMS'][iw, :] = utils.rms(D)
233239
win['nsamples'][iw] = D.shape[1]
234-
# the last window may be smaller than what is needed for welch
235-
if last - first < WELCH_WIN_LENGTH_SAMPLES:
236-
continue
237-
# compute a smoothed spectrum using welch method
238-
_, w = signal.welch(
239-
D, fs=sglx.fs, window='hann', nperseg=WELCH_WIN_LENGTH_SAMPLES,
240-
detrend='constant', return_onesided=True, scaling='density', axis=-1
241-
)
242-
win['spectral_density'] += w.T
240+
if spectra:
241+
# the last window may be smaller than what is needed for welch
242+
if last - first < WELCH_WIN_LENGTH_SAMPLES:
243+
continue
244+
# compute a smoothed spectrum using welch method
245+
_, w = signal.welch(
246+
D, fs=sglx.fs, window='hann', nperseg=WELCH_WIN_LENGTH_SAMPLES,
247+
detrend='constant', return_onesided=True, scaling='density', axis=-1
248+
)
249+
win['spectral_density'] += w.T
243250
# print at least every 20 windows
244251
if (iw % min(20, max(int(np.floor(wingen.nwin / 75)), 1))) == 0:
245252
pbar.update(iw)
246253
sglx.close()
247254
return win
248255

249256

250-
def extract_rmsmap(sglx, out_folder=None, overwrite=False):
257+
def extract_rmsmap(sglx, out_folder=None, overwrite=False, spectra=True, nmod=1):
251258
"""
252259
Wrapper for rmsmap that outputs _ibl_ephysRmsMap and _ibl_ephysSpectra ALF files
253260
254261
:param sglx: Open spikeglx Reader with data for which to compute rmsmap
255262
:param out_folder: folder in which to store output ALF files. Default uses the folder in which
256263
the `fbin` file lives.
257264
:param overwrite: do not re-extract if all ALF files already exist
258-
:param label: string or list of strings that will be appended to the filename before extension
265+
:param spectra: Whether to compute the spectral density across the signal
266+
:param nmod: take every nmod windows, in cases where we don't want to compute over the whole signal
259267
:return: None
260268
"""
261269
if out_folder is None:
@@ -271,18 +279,19 @@ def extract_rmsmap(sglx, out_folder=None, overwrite=False):
271279
_logger.warning(f'RMS map already exists for .{sglx.type} data in {out_folder}, skipping. Use overwrite option.')
272280
return files_time + files_freq
273281
# crunch numbers
274-
rms = rmsmap(sglx)
282+
rms = rmsmap(sglx, spectra=spectra, nmod=nmod)
275283
# output ALF files, single precision with the optional label as suffix before extension
276284
if not out_folder.exists():
277285
out_folder.mkdir()
278286
tdict = {'rms': rms['TRMS'].astype(np.single), 'timestamps': rms['tscale'].astype(np.single)}
279-
fdict = {'power': rms['spectral_density'].astype(np.single),
280-
'freqs': rms['fscale'].astype(np.single)}
281287
out_time = alfio.save_object_npy(
282288
out_folder, object=alf_object_time, dico=tdict, namespace='iblqc')
283-
out_freq = alfio.save_object_npy(
284-
out_folder, object=alf_object_freq, dico=fdict, namespace='iblqc')
285-
return out_time + out_freq
289+
if spectra:
290+
fdict = {'power': rms['spectral_density'].astype(np.single),
291+
'freqs': rms['fscale'].astype(np.single)}
292+
out_freq = alfio.save_object_npy(
293+
out_folder, object=alf_object_freq, dico=fdict, namespace='iblqc')
294+
return out_time + out_freq if spectra else out_time
286295

287296

288297
def raw_qc_session(session_path, overwrite=False):

0 commit comments

Comments
 (0)