Skip to content

Commit 11cd0ed

Browse files
authored
Merge pull request #440 from int-brain-lab/develop
Release 2.9.1
2 parents ce861a2 + 41f1df3 commit 11cd0ed

File tree

15 files changed

+520
-120
lines changed

15 files changed

+520
-120
lines changed

MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
include ibllib/atlas/allen_structure_tree.csv
22
include ibllib/atlas/beryl.npy
33
include ibllib/atlas/cosmos.npy
4+
include ibllib/atlas/mappings.pqt
45
include ibllib/io/extractors/extractor_types.json
56
include brainbox/tests/wheel_test.p
67
recursive-include brainbox/tests/fixtures *

brainbox/behavior/training.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ def compute_performance(trials, signed_contrast=None, block=None):
426426
block_idx = trials.probabilityLeft == block
427427

428428
if not np.any(block_idx):
429-
return np.nan * np.zeros(2)
429+
return np.nan * np.zeros(3)
430430

431431
contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True)
432432
rightward = trials.choice == -1
@@ -584,15 +584,15 @@ def plot_psychometric(trials, ax=None, title=None, **kwargs):
584584
signed_contrast = get_signed_contrast(trials)
585585
contrasts_fit = np.arange(-100, 100)
586586

587-
prob_right_50, contrasts, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.5)
587+
prob_right_50, contrasts_50, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.5)
588588
pars_50 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.5)
589589
prob_right_fit_50 = psy.erf_psycho_2gammas(pars_50, contrasts_fit)
590590

591-
prob_right_20, contrasts, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.2)
591+
prob_right_20, contrasts_20, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.2)
592592
pars_20 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.2)
593593
prob_right_fit_20 = psy.erf_psycho_2gammas(pars_20, contrasts_fit)
594594

595-
prob_right_80, contrasts, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.8)
595+
prob_right_80, contrasts_80, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.8)
596596
pars_80 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.8)
597597
prob_right_fit_80 = psy.erf_psycho_2gammas(pars_80, contrasts_fit)
598598

@@ -606,11 +606,11 @@ def plot_psychometric(trials, ax=None, title=None, **kwargs):
606606
# TODO error bars
607607

608608
fit_50 = ax.plot(contrasts_fit, prob_right_fit_50, color=cmap[1])
609-
data_50 = ax.scatter(contrasts, prob_right_50, color=cmap[1])
609+
data_50 = ax.scatter(contrasts_50, prob_right_50, color=cmap[1])
610610
fit_20 = ax.plot(contrasts_fit, prob_right_fit_20, color=cmap[0])
611-
data_20 = ax.scatter(contrasts, prob_right_20, color=cmap[0])
611+
data_20 = ax.scatter(contrasts_20, prob_right_20, color=cmap[0])
612612
fit_80 = ax.plot(contrasts_fit, prob_right_fit_80, color=cmap[2])
613-
data_80 = ax.scatter(contrasts, prob_right_80, color=cmap[2])
613+
data_80 = ax.scatter(contrasts_80, prob_right_80, color=cmap[2])
614614
ax.legend([fit_50[0], data_50, fit_20[0], data_20, fit_80[0], data_80],
615615
['p_left=0.5 fit', 'p_left=0.5 data', 'p_left=0.2 fit', 'p_left=0.2 data', 'p_left=0.8 fit', 'p_left=0.8 data'],
616616
loc='upper left')
@@ -631,9 +631,9 @@ def plot_reaction_time(trials, ax=None, title=None, **kwargs):
631631
"""
632632

633633
signed_contrast = get_signed_contrast(trials)
634-
reaction_50, contrasts, _ = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.5)
635-
reaction_20, contrasts, _ = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.2)
636-
reaction_80, contrasts, _ = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.8)
634+
reaction_50, contrasts_50, _ = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.5)
635+
reaction_20, contrasts_20, _ = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.2)
636+
reaction_80, contrasts_80, _ = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.8)
637637

638638
cmap = sns.diverging_palette(20, 220, n=3, center="dark")
639639

@@ -642,9 +642,9 @@ def plot_reaction_time(trials, ax=None, title=None, **kwargs):
642642
else:
643643
fig = plt.gcf()
644644

645-
data_50 = ax.plot(contrasts, reaction_50, '-o', color=cmap[1])
646-
data_20 = ax.plot(contrasts, reaction_20, '-o', color=cmap[0])
647-
data_80 = ax.plot(contrasts, reaction_80, '-o', color=cmap[2])
645+
data_50 = ax.plot(contrasts_50, reaction_50, '-o', color=cmap[1])
646+
data_20 = ax.plot(contrasts_20, reaction_20, '-o', color=cmap[0])
647+
data_80 = ax.plot(contrasts_80, reaction_80, '-o', color=cmap[2])
648648

649649
# TODO error bars
650650

Lines changed: 28 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,37 @@
11
"""
22
Get spikes, clusters and channels data
33
========================================
4-
Downloads and loads in spikes, clusters and channels data for a given session. Data is returned
4+
Downloads and loads in spikes, clusters and channels data for a given probe insertion.
55
6+
There could be several spike sorting collections, by default the loader will get the pykilosort collection
7+
8+
The channel locations can come from several sources, it will load the most advanced version of the histology available,
9+
regardless of the spike sorting version loaded. The steps are (from most advanced to fresh out of the imaging):
10+
- alf: the final version of channel locations, same as resolved with the difference that data has been written out to files
11+
- resolved: channel locations alignments have been agreed upon
12+
- aligned: channel locations have been aligned, but review or other alignments are pending, potentially not accurate
13+
- traced: the histology track has been recovered from microscopy, however the depths may not match, inacurate data
614
"""
7-
import brainbox.io.one as bbone
815

916
from one.api import ONE
17+
from ibllib.atlas import AllenAtlas
18+
from brainbox.io.one import SpikeSortingLoader
19+
20+
21+
one = ONE(base_url='https://openalyx.internationalbrainlab.org')
22+
ba = AllenAtlas()
23+
24+
insertions = one.alyx.rest('insertions', 'list')
25+
pid = insertions[0]['id']
26+
sl = SpikeSortingLoader(pid, one=one, atlas=ba)
27+
spikes, clusters, channels = sl.load_spike_sorting()
28+
clusters_labeled = SpikeSortingLoader.merge_clusters(spikes, clusters, channels)
29+
30+
# the histology property holds the provenance of the current channel locations
31+
print(sl.histology)
1032

11-
one = ONE(base_url='https://openalyx.internationalbrainlab.org', silent=True)
12-
13-
# Find eid of interest
14-
eid = one.search(subject='CSH_ZAD_029', date='2020-09-19')[0]
15-
16-
##################################################################################################
17-
# Example 1:
18-
# Download spikes, clusters and channels data for all available probes for this session.
19-
# The data for each probe is returned as a dict
20-
spikes, clusters, channels = bbone.load_spike_sorting_with_channel(eid, one=one)
21-
print(spikes.keys())
22-
print(spikes['probe01'].keys())
23-
24-
##################################################################################################
25-
# Example 2:
26-
# Download spikes, clusters and channels data for a single probe
27-
spikes, clusters, channels = bbone.load_spike_sorting_with_channel(eid, one=one, probe='probe01')
28-
print(spikes.keys())
29-
30-
##################################################################################################
31-
# Example 3:
32-
# The default spikes and clusters datasets that are downloaded are '
33-
# ['clusters.channels',
34-
# 'clusters.depths',
35-
# 'clusters.metrics',
36-
# 'spikes.clusters',
37-
# 'spikes.times']
38-
# If we also want to load for example, 'clusters.peakToTrough we can add a dataset_types argument
39-
40-
spikes, clusters, channels = bbone.load_spike_sorting_with_channel(eid, one=one, probe='probe01',
41-
dataset_types=['clusters.peakToTrough'])
42-
print(clusters['probe01'].keys())
33+
# available spike sorting collections for this probe insertion
34+
print(sl.collections)
4335

36+
# the collection that has been loaded
37+
print(sl.collection)

brainbox/io/one.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,8 @@ def channel_locations_interpolation(channels_aligned, channels=None, brain_regio
270270

271271
def _load_channel_locations_traj(eid, probe=None, one=None, revision=None, aligned=False,
272272
brain_atlas=None, return_source=False):
273+
if not hasattr(one, 'alyx'):
274+
return {}, None
273275
_logger.debug(f"trying to load from traj {probe}")
274276
channels = Bunch()
275277
brain_atlas = brain_atlas or AllenAtlas
@@ -416,6 +418,8 @@ def load_spike_sorting_fast(eid, one=None, probe=None, dataset_types=None, spike
416418
:param return_collection: (False) if True, will return the collection used to load
417419
:return: spikes, clusters, channels (dict of bunch, 1 bunch per probe)
418420
"""
421+
_logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting_fast will be removed in future versions.'
422+
'Use brainbox.io.one.SpikeSortingLoader instead')
419423
if collection is None:
420424
collection = _collection_filter_from_args(probe, spike_sorter)
421425
_logger.debug(f"load spike sorting with collection filter {collection}")
@@ -455,6 +459,8 @@ def load_spike_sorting(eid, one=None, probe=None, dataset_types=None, spike_sort
455459
:param return_collection:(bool - False) if True, returns the collection for loading the data
456460
:return: spikes, clusters (dict of bunch, 1 bunch per probe)
457461
"""
462+
_logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.'
463+
'Use brainbox.io.one.SpikeSortingLoader instead')
458464
collection = _collection_filter_from_args(probe, spike_sorter)
459465
_logger.debug(f"load spike sorting with collection filter {collection}")
460466
spikes, clusters = _load_spike_sorting(eid=eid, one=one, collection=collection, revision=revision,
@@ -506,6 +512,8 @@ def load_spike_sorting_with_channel(eid, one=None, probe=None, aligned=False, da
506512
'atlas_id', 'x', 'y', 'z'). Atlas IDs non-lateralized.
507513
"""
508514
# --- Get spikes and clusters data
515+
_logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.'
516+
'Use brainbox.io.one.SpikeSortingLoader instead')
509517
one = one or ONE()
510518
brain_atlas = brain_atlas or AllenAtlas()
511519
spikes, clusters, collection = load_spike_sorting(
@@ -862,12 +870,17 @@ def load_channels_from_insertion(ins, depths=None, one=None, ba=None):
862870

863871
@dataclass
864872
class SpikeSortingLoader:
865-
"""Class for loading spike sorting"""
866-
pid: str
873+
"""
874+
Object that will load spike sorting data for a given probe insertion.
875+
876+
877+
"""
867878
one: ONE
868-
atlas: None
869-
# the following properties are the outcome of the post init funciton
879+
atlas: None = None
880+
pid: str = None
870881
eid: str = ''
882+
pname: str = ''
883+
# the following properties are the outcome of the post init funciton
871884
session_path: Path = ''
872885
collections: list = None
873886
datasets: list = None # list of all datasets belonging to the sesion
@@ -878,7 +891,10 @@ class SpikeSortingLoader:
878891
spike_sorting_path: Path = None
879892

880893
def __post_init__(self):
881-
self.eid, self.pname = self.one.pid2eid(self.pid)
894+
if self.pid is not None:
895+
self.eid, self.pname = self.one.pid2eid(self.pid)
896+
if self.atlas is None:
897+
self.atlas = AllenAtlas()
882898
self.session_path = self.one.eid2path(self.eid)
883899
self.collections = self.one.list_collections(
884900
self.eid, filename='spikes*', collection=f"alf/{self.pname}*")
@@ -909,32 +925,61 @@ def _get_spike_sorting_collection(self, spike_sorter='pykilosort', revision=None
909925
return collection
910926

911927
def _download_spike_sorting_object(self, obj, spike_sorter='pykilosort', dataset_types=None):
928+
"""
929+
Downloads an ALF object
930+
:param obj: object name, str between 'spikes', 'clusters' or 'channels'
931+
:param spike_sorter: (defaults to 'pykilosort')
932+
:param dataset_types: list of extra dataset types
933+
:return:
934+
"""
912935
if len(self.collections) == 0:
913936
return {}, {}, {}
914937
self.collection = self._get_spike_sorting_collection(spike_sorter=spike_sorter)
938+
_logger.debug(f"loading spike sorting from {self.collection}")
915939
spike_attributes, cluster_attributes = self._get_attributes(dataset_types)
916940
attributes = {'spikes': spike_attributes, 'clusters': cluster_attributes, 'channels': None}
917941
self.files[obj] = self.one.load_object(self.eid, obj=obj, attribute=attributes[obj],
918942
collection=self.collection, download_only=True)
919943

920944
def download_spike_sorting(self, **kwargs):
921-
"""spike_sorter='pykilosort', dataset_types=None"""
945+
"""
946+
Downloads spikes, clusters and channels
947+
:param spike_sorter: (defaults to 'pykilosort')
948+
:param dataset_types: list of extra dataset types
949+
:return:
950+
"""
922951
for obj in ['spikes', 'clusters', 'channels']:
923952
self._download_spike_sorting_object(obj=obj, **kwargs)
924953
self.spike_sorting_path = self.files['spikes'][0].parent
925954

926955
def load_spike_sorting(self, **kwargs):
927-
"""spike_sorter='pykilosort', dataset_types=None"""
956+
"""
957+
Loads spikes, clusters and channels
958+
959+
There could be several spike sorting collections, by default the loader will get the pykilosort collection
960+
961+
The channel locations can come from several sources, it will load the most advanced version of the histology available,
962+
regardless of the spike sorting version loaded. The steps are (from most advanced to fresh out of the imaging):
963+
- alf: the final version of channel locations, same as resolved with the difference that data is on file
964+
- resolved: channel locations alignments have been agreed upon
965+
- aligned: channel locations have been aligned, but review or other alignments are pending, potentially not accurate
966+
- traced: the histology track has been recovered from microscopy, however the depths may not match, inacurate data
967+
968+
:param spike_sorter: (defaults to 'pykilosort')
969+
:param dataset_types: list of extra dataset types
970+
:return:
971+
"""
928972
if len(self.collections) == 0:
929973
return {}, {}, {}
930974
self.download_spike_sorting(**kwargs)
931975
channels = alfio.load_object(self.files['channels'], wildcards=self.one.wildcards)
932976
clusters = alfio.load_object(self.files['clusters'], wildcards=self.one.wildcards)
933977
spikes = alfio.load_object(self.files['spikes'], wildcards=self.one.wildcards)
934978
if 'brainLocationIds_ccf_2017' not in channels:
935-
channels, self.histology = _load_channel_locations_traj(
979+
_channels, self.histology = _load_channel_locations_traj(
936980
self.eid, probe=self.pname, one=self.one, brain_atlas=self.atlas, return_source=True)
937-
channels = channels[self.pname]
981+
if _channels:
982+
channels = _channels[self.pname]
938983
else:
939984
channels = _channels_alf2bunch(channels, brain_regions=self.atlas.regions)
940985
self.histology = 'alf'

ibllib/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "2.9.0"
1+
__version__ = "2.9.1"
22
import warnings
33

44
from ibllib.misc import logger_config

0 commit comments

Comments
 (0)