Skip to content

Commit 3e80794

Browse files
authored
Merge pull request #802 from int-brain-lab/develop
2.38.0
2 parents 7b279d2 + 91ba20d commit 3e80794

32 files changed

+1254
-574
lines changed

.github/workflows/ibllib_ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ jobs:
3939
- name: Flake8
4040
run: |
4141
python -m flake8
42-
python -m flake8 --select D --ignore E ibllib/qc/camera.py
42+
python -m flake8 --select D --ignore E ibllib/qc/camera.py ibllib/qc/task_metrics.py
4343
- name: Brainbox tests
4444
run: |
4545
cd brainbox

brainbox/behavior/training.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -796,8 +796,11 @@ def compute_reaction_time(trials, stim_on_type='stimOn_times', stim_off_type='re
796796
block_idx = trials.probabilityLeft == block
797797

798798
contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True)
799-
reaction_time = np.vectorize(lambda x: np.nanmedian((trials[stim_off_type] - trials[stim_on_type])
800-
[(x == signed_contrast) & block_idx]))(contrasts)
799+
reaction_time = np.vectorize(
800+
lambda x: np.nanmedian((trials[stim_off_type] - trials[stim_on_type])[(x == signed_contrast) & block_idx]),
801+
otypes=[float]
802+
)(contrasts)
803+
801804
if compute_ci:
802805
ci = np.full((contrasts.size, 2), np.nan)
803806
for i, x in enumerate(contrasts):

brainbox/behavior/wheel.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ def velocity(re_ts, re_pos):
8888
for line in traceback.format_stack():
8989
print(line.strip())
9090

91-
msg = 'brainbox.behavior.wheel.velocity has been deprecated. Use velocity_filtered instead.'
92-
warnings.warn(msg, DeprecationWarning)
91+
msg = 'brainbox.behavior.wheel.velocity will soon be removed. Use velocity_filtered instead.'
92+
warnings.warn(msg, FutureWarning)
9393
logging.getLogger(__name__).warning(msg)
9494

9595
dp = np.diff(re_pos)
@@ -153,8 +153,8 @@ def velocity_smoothed(pos, freq, smooth_size=0.03):
153153
for line in traceback.format_stack():
154154
print(line.strip())
155155

156-
msg = 'brainbox.behavior.wheel.velocity_smoothed has been deprecated. Use velocity_filtered instead.'
157-
warnings.warn(msg, DeprecationWarning)
156+
msg = 'brainbox.behavior.wheel.velocity_smoothed will be removed. Use velocity_filtered instead.'
157+
warnings.warn(msg, FutureWarning)
158158
logging.getLogger(__name__).warning(msg)
159159

160160
# Define our smoothing window with an area of 1 so the units won't be changed
@@ -188,7 +188,7 @@ def last_movement_onset(t, vel, event_time):
188188
print(line.strip())
189189

190190
msg = 'brainbox.behavior.wheel.last_movement_onset has been deprecated. Use get_movement_onset instead.'
191-
warnings.warn(msg, DeprecationWarning)
191+
warnings.warn(msg, FutureWarning)
192192
logging.getLogger(__name__).warning(msg)
193193

194194
# Look back from timestamp

brainbox/io/one.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
SPIKES_ATTRIBUTES = ['clusters', 'times', 'amps', 'depths']
4242
CLUSTERS_ATTRIBUTES = ['channels', 'depths', 'metrics', 'uuids']
43+
WAVEFORMS_ATTRIBUTES = ['templates']
4344

4445

4546
def load_lfp(eid, one=None, dataset_types=None, **kwargs):
@@ -128,6 +129,10 @@ def _channels_alf2bunch(channels, brain_regions=None):
128129
'axial_um': channels['localCoordinates'][:, 1],
129130
'lateral_um': channels['localCoordinates'][:, 0],
130131
}
132+
# here if we have some extra keys, they will carry over to the next dictionary
133+
for k in channels:
134+
if k not in list(channels_.keys()) + ['mlapdv', 'brainLocationIds_ccf_2017', 'localCoordinates']:
135+
channels_[k] = channels[k]
131136
if brain_regions:
132137
channels_['acronym'] = brain_regions.get(channels_['atlas_id'])['acronym']
133138
return channels_
@@ -851,14 +856,14 @@ def _load_object(self, *args, **kwargs):
851856
@staticmethod
852857
def _get_attributes(dataset_types):
853858
"""returns attributes to load for spikes and clusters objects"""
854-
if dataset_types is None:
855-
return SPIKES_ATTRIBUTES, CLUSTERS_ATTRIBUTES
856-
else:
857-
spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp]
858-
cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl]
859-
spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes))
860-
cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes))
861-
return spike_attributes, cluster_attributes
859+
dataset_types = [] if dataset_types is None else dataset_types
860+
spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp]
861+
spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes))
862+
cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl]
863+
cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes))
864+
waveform_attributes = [cl.split('.')[1] for cl in dataset_types if 'waveforms.' in cl]
865+
waveform_attributes = list(set(WAVEFORMS_ATTRIBUTES + waveform_attributes))
866+
return {'spikes': spike_attributes, 'clusters': cluster_attributes, 'waveforms': waveform_attributes}
862867

863868
def _get_spike_sorting_collection(self, spike_sorter='pykilosort'):
864869
"""
@@ -891,14 +896,15 @@ def get_version(self, spike_sorter='pykilosort'):
891896
return dset[0]['version'] if len(dset) else 'unknown'
892897

893898
def download_spike_sorting_object(self, obj, spike_sorter='pykilosort', dataset_types=None, collection=None,
894-
missing='raise', **kwargs):
899+
attribute=None, missing='raise', **kwargs):
895900
"""
896901
Downloads an ALF object
897902
:param obj: object name, str between 'spikes', 'clusters' or 'channels'
898903
:param spike_sorter: (defaults to 'pykilosort')
899904
:param dataset_types: list of extra dataset types, for example ['spikes.samples']
900905
:param collection: string specifiying the collection, for example 'alf/probe01/pykilosort'
901906
:param kwargs: additional arguments to be passed to one.api.One.load_object
907+
:param attribute: list of attributes to load for the object
902908
:param missing: 'raise' (default) or 'ignore'
903909
:return:
904910
"""
@@ -907,8 +913,7 @@ def download_spike_sorting_object(self, obj, spike_sorter='pykilosort', dataset_
907913
self.collection = self._get_spike_sorting_collection(spike_sorter=spike_sorter)
908914
collection = collection or self.collection
909915
_logger.debug(f"loading spike sorting object {obj} from {collection}")
910-
spike_attributes, cluster_attributes = self._get_attributes(dataset_types)
911-
attributes = {'spikes': spike_attributes, 'clusters': cluster_attributes}
916+
attributes = self._get_attributes(dataset_types)
912917
try:
913918
self.files[obj] = self.one.load_object(
914919
self.eid, obj=obj, attribute=attributes.get(obj, None),
@@ -986,11 +991,10 @@ def load_channels(self, **kwargs):
986991
"""
987992
# we do not specify the spike sorter on purpose here: the electrode sites do not depend on the spike sorting
988993
self.download_spike_sorting_object(obj='electrodeSites', collection=f'alf/{self.pname}', missing='ignore')
989-
if 'electrodeSites' in self.files:
990-
channels = self._load_object(self.files['electrodeSites'], wildcards=self.one.wildcards)
991-
else: # otherwise, we try to load the channel object from the spike sorting folder - this may not contain histology
992-
self.download_spike_sorting_object(obj='channels', **kwargs)
993-
channels = self._load_object(self.files['channels'], wildcards=self.one.wildcards)
994+
self.download_spike_sorting_object(obj='channels', missing='ignore', **kwargs)
995+
channels = self._load_object(self.files['channels'], wildcards=self.one.wildcards)
996+
if 'electrodeSites' in self.files: # if common dict keys, electrodeSites prevails
997+
channels = channels | self._load_object(self.files['electrodeSites'], wildcards=self.one.wildcards)
994998
if 'brainLocationIds_ccf_2017' not in channels:
995999
_logger.debug(f"loading channels from alyx for {self.files['channels']}")
9961000
_channels, self.histology = _load_channel_locations_traj(
@@ -1000,7 +1004,7 @@ def load_channels(self, **kwargs):
10001004
else:
10011005
channels = _channels_alf2bunch(channels, brain_regions=self.atlas.regions)
10021006
self.histology = 'alf'
1003-
return channels
1007+
return Bunch(channels)
10041008

10051009
def load_spike_sorting(self, spike_sorter='pykilosort', revision=None, enforce_version=True, good_units=False, **kwargs):
10061010
"""

brainbox/processing.py

Lines changed: 20 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
1-
'''
2-
Processes data from one form into another, e.g. taking spike times and binning them into
3-
non-overlapping bins and convolving spike times with a gaussian kernel.
4-
'''
1+
"""Process data from one form into another.
2+
3+
For example, taking spike times and binning them into non-overlapping bins and convolving spike
4+
times with a gaussian kernel.
5+
"""
56

67
import numpy as np
78
import pandas as pd
89
from scipy import interpolate, sparse
910
from brainbox import core
10-
from iblutil.numerical import bincount2D as _bincount2D
11+
from iblutil.numerical import bincount2D
1112
from iblutil.util import Bunch
1213
import logging
13-
import warnings
14-
import traceback
1514

1615
_logger = logging.getLogger(__name__)
1716

@@ -118,35 +117,6 @@ def sync(dt, times=None, values=None, timeseries=None, offsets=None, interp='zer
118117
return syncd
119118

120119

121-
def bincount2D(x, y, xbin=0, ybin=0, xlim=None, ylim=None, weights=None):
122-
"""
123-
Computes a 2D histogram by aggregating values in a 2D array.
124-
125-
:param x: values to bin along the 2nd dimension (c-contiguous)
126-
:param y: values to bin along the 1st dimension
127-
:param xbin:
128-
scalar: bin size along 2nd dimension
129-
0: aggregate according to unique values
130-
array: aggregate according to exact values (count reduce operation)
131-
:param ybin:
132-
scalar: bin size along 1st dimension
133-
0: aggregate according to unique values
134-
array: aggregate according to exact values (count reduce operation)
135-
:param xlim: (optional) 2 values (array or list) that restrict range along 2nd dimension
136-
:param ylim: (optional) 2 values (array or list) that restrict range along 1st dimension
137-
:param weights: (optional) defaults to None, weights to apply to each value for aggregation
138-
:return: 3 numpy arrays MAP [ny,nx] image, xscale [nx], yscale [ny]
139-
"""
140-
for line in traceback.format_stack():
141-
print(line.strip())
142-
warning_text = """Deprecation warning: bincount2D() is now a part of iblutil.
143-
brainbox.processing.bincount2D is deprecated and will be removed in
144-
future versions. Please replace imports with iblutil.numerical.bincount2D."""
145-
_logger.warning(warning_text)
146-
warnings.warn(warning_text, DeprecationWarning)
147-
return _bincount2D(x, y, xbin, ybin, xlim, ylim, weights)
148-
149-
150120
def compute_cluster_average(spike_clusters, spike_var):
151121
"""
152122
Quickish way to compute the average of some quantity across spikes in each cluster given
@@ -197,7 +167,7 @@ def bin_spikes(spikes, binsize, interval_indices=False):
197167

198168

199169
def get_units_bunch(spks_b, *args):
200-
'''
170+
"""
201171
Returns a bunch, where the bunch keys are keys from `spks` with labels of spike information
202172
(e.g. unit IDs, times, features, etc.), and the values for each key are arrays with values for
203173
each unit: these arrays are ordered and can be indexed by unit id.
@@ -223,18 +193,18 @@ def get_units_bunch(spks_b, *args):
223193
--------
224194
1) Create a units bunch given a spikes bunch, and get the amps for unit #4 from the units
225195
bunch.
226-
>>> import brainbox as bb
227-
>>> import alf.io as aio
196+
>>> from brainbox import processing
197+
>>> import one.alf.io as alfio
228198
>>> import ibllib.ephys.spikes as e_spks
229199
(*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory):
230200
>>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out)
231-
>>> spks_b = aio.load_object(path_to_alf_out, 'spikes')
232-
>>> units_b = bb.processing.get_units_bunch(spks_b)
201+
>>> spks_b = alfio.load_object(path_to_alf_out, 'spikes')
202+
>>> units_b = processing.get_units_bunch(spks_b)
233203
# Get amplitudes for unit 4.
234204
>>> amps = units_b['amps']['4']
235205
236206
TODO add computation time estimate?
237-
'''
207+
"""
238208

239209
# Initialize `units`
240210
units_b = Bunch()
@@ -261,7 +231,7 @@ def get_units_bunch(spks_b, *args):
261231

262232

263233
def filter_units(units_b, t, **kwargs):
264-
'''
234+
"""
265235
Filters units according to some parameters. **kwargs are the keyword parameters used to filter
266236
the units.
267237
@@ -299,24 +269,24 @@ def filter_units(units_b, t, **kwargs):
299269
Examples
300270
--------
301271
1) Filter units according to the default parameters.
302-
>>> import brainbox as bb
303-
>>> import alf.io as aio
272+
>>> from brainbox import processing
273+
>>> import one.alf.io as alfio
304274
>>> import ibllib.ephys.spikes as e_spks
305275
(*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory):
306276
>>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out)
307277
# Get a spikes bunch, units bunch, and filter the units.
308-
>>> spks_b = aio.load_object(path_to_alf_out, 'spikes')
309-
>>> units_b = bb.processing.get_units_bunch(spks_b, ['times', 'amps', 'clusters'])
278+
>>> spks_b = alfio.load_object(path_to_alf_out, 'spikes')
279+
>>> units_b = processing.get_units_bunch(spks_b, ['times', 'amps', 'clusters'])
310280
>>> T = spks_b['times'][-1] - spks_b['times'][0]
311-
>>> filtered_units = bb.processing.filter_units(units_b, T)
281+
>>> filtered_units = processing.filter_units(units_b, T)
312282
313283
2) Filter units with no minimum amplitude, a minimum firing rate of 1 Hz, and a max false
314284
positive rate of 0.2, given a refractory period of 2 ms.
315-
>>> filtered_units = bb.processing.filter_units(units_b, T, min_amp=0, min_fr=1)
285+
>>> filtered_units = processing.filter_units(units_b, T, min_amp=0, min_fr=1)
316286
317287
TODO: `units_b` input arg could eventually be replaced by `clstrs_b` if the required metrics
318288
are in `clstrs_b['metrics']`
319-
'''
289+
"""
320290

321291
# Set params
322292
params = {'min_amp': 50e-6, 'min_fr': 0.5, 'max_fpr': 0.2, 'rp': 0.002} # defaults

brainbox/tests/test_behavior.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,12 @@ def test_get_movement_onset(self):
124124
with self.assertRaises(ValueError):
125125
wheel.get_movement_onset(intervals, np.random.permutation(self.trials['feedback_times']))
126126

127+
def test_velocity_deprecation(self):
128+
"""Ensure brainbox.behavior.wheel.velocity is removed."""
129+
from datetime import datetime
130+
self.assertTrue(datetime.today() < datetime(2024, 8, 1),
131+
'remove brainbox.behavior.wheel.velocity, velocity_smoothed and last_movement_onset')
132+
127133

128134
class TestTraining(unittest.TestCase):
129135
def setUp(self):

brainbox/tests/test_processing.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from brainbox import processing, core
22
import unittest
33
import numpy as np
4-
import datetime
54

65

76
class TestProcessing(unittest.TestCase):
@@ -63,15 +62,6 @@ def test_sync(self):
6362
self.assertTrue(times2.min() >= resamp2.times.min())
6463
self.assertTrue(times2.max() <= resamp2.times.max())
6564

66-
def test_bincount2D_deprecation(self):
67-
# Timer to remove bincount2D (now in iblutil)
68-
# Once this test fails:
69-
# - Remove the bincount2D method in processing.py
70-
# - Remove the import from iblutil at the top of that file
71-
# - Delete this test
72-
if datetime.datetime.now() > datetime.datetime(2024, 6, 30):
73-
raise NotImplementedError
74-
7565
def test_compute_cluster_averag(self):
7666
# Create fake data for 3 clusters
7767
clust1 = np.ones(40)
@@ -104,10 +94,6 @@ def test_compute_cluster_averag(self):
10494
self.assertTrue(np.all(count == (40, 40, 50)))
10595

10696

107-
def test_get_unit_bunches():
108-
pass
109-
110-
111-
if __name__ == "__main__":
97+
if __name__ == '__main__':
11298
np.random.seed(0)
11399
unittest.main(exit=False)

0 commit comments

Comments
 (0)