diff --git a/brainbox/io/one.py b/brainbox/io/one.py index c1c86726e..ebe9c3f74 100644 --- a/brainbox/io/one.py +++ b/brainbox/io/one.py @@ -1,4 +1,5 @@ """Functions for loading IBL ephys and trial data using the Open Neurophysiology Environment.""" + from dataclasses import dataclass, field import gc import logging @@ -29,6 +30,8 @@ from ibllib.pipes.ephys_alignment import EphysAlignment from ibllib.plots import vertical_lines, Density +from iblphotometry import fpio + import brainbox.plot from brainbox.io.spikeglx import Streamer from brainbox.ephys_plots import plot_brain_regions @@ -60,8 +63,7 @@ def load_lfp(eid, one=None, dataset_types=None, **kwargs): [one.load_dataset(eid, dset, download_only=True) for dset in dtypes] session_path = one.eid2path(eid) - efiles = [ef for ef in spikeglx.glob_ephys_files(session_path, bin_exists=False) - if ef.get('lf', None)] + efiles = [ef for ef in spikeglx.glob_ephys_files(session_path, bin_exists=False) if ef.get('lf', None)] return [spikeglx.Reader(ef['lf'], **kwargs) for ef in efiles] @@ -82,19 +84,21 @@ def _get_spike_sorting_collection(collections, pname): collection = next(filter(lambda c: c == f'alf/{pname}/pykilosort', collections), None) # otherwise, prefers the shortest collection = collection or next(iter(sorted(filter(lambda c: f'alf/{pname}' in c, collections), key=len)), None) - _logger.debug(f"selecting: {collection} to load amongst candidates: {collections}") + _logger.debug(f'selecting: {collection} to load amongst candidates: {collections}') return collection def _channels_alyx2bunch(chans): - channels = Bunch({ - 'atlas_id': np.array([ch['brain_region'] for ch in chans]), - 'x': np.array([ch['x'] for ch in chans]) / 1e6, - 'y': np.array([ch['y'] for ch in chans]) / 1e6, - 'z': np.array([ch['z'] for ch in chans]) / 1e6, - 'axial_um': np.array([ch['axial'] for ch in chans]), - 'lateral_um': np.array([ch['lateral'] for ch in chans]) - }) + channels = Bunch( + { + 'atlas_id': np.array([ch['brain_region'] for ch in chans]), + 'x': np.array([ch['x'] for ch in chans]) / 1e6, + 'y': np.array([ch['y'] for ch in chans]) / 1e6, + 'z': np.array([ch['z'] for ch in chans]) / 1e6, + 'axial_um': np.array([ch['axial'] for ch in chans]), + 'lateral_um': np.array([ch['lateral'] for ch in chans]), + } + ) return channels @@ -105,7 +109,7 @@ def _channels_traj2bunch(xyz_chans, brain_atlas): 'y': xyz_chans[:, 1], 'z': xyz_chans[:, 2], 'acronym': brain_regions['acronym'], - 'atlas_id': brain_regions['id'] + 'atlas_id': brain_regions['id'], } return channels @@ -115,7 +119,8 @@ def _channels_bunch2alf(channels): channels_ = { 'mlapdv': np.c_[channels['x'], channels['y'], channels['z']] * 1e6, 'brainLocationIds_ccf_2017': channels['atlas_id'], - 'localCoordinates': np.c_[channels['lateral_um'], channels['axial_um']]} + 'localCoordinates': np.c_[channels['lateral_um'], channels['axial_um']], + } return channels_ @@ -139,8 +144,9 @@ def _channels_alf2bunch(channels, brain_regions=None): return channels_ -def _load_spike_sorting(eid, one=None, collection=None, revision=None, return_channels=True, dataset_types=None, - brain_regions=None): +def _load_spike_sorting( + eid, one=None, collection=None, revision=None, return_channels=True, dataset_types=None, brain_regions=None +): """ Generic function to load spike sorting according data using ONE. @@ -184,7 +190,7 @@ def _load_spike_sorting(eid, one=None, collection=None, revision=None, return_ch # enumerate probes and load according to the name collections = one.list_collections(eid, filename='spikes*', collection=collection, revision=revision) if len(collections) == 0: - _logger.warning(f"eid {eid}: no collection found with collection filter: {collection}, revision: {revision}") + _logger.warning(f'eid {eid}: no collection found with collection filter: {collection}, revision: {revision}') pnames = list(set(c.split('/')[1] for c in collections)) spikes, clusters, channels = ({} for _ in range(3)) @@ -192,13 +198,14 @@ def _load_spike_sorting(eid, one=None, collection=None, revision=None, return_ch for pname in pnames: probe_collection = _get_spike_sorting_collection(collections, pname) - spikes[pname] = one.load_object(eid, collection=probe_collection, obj='spikes', - attribute=spike_attributes, namespace='') - clusters[pname] = one.load_object(eid, collection=probe_collection, obj='clusters', - attribute=cluster_attributes, namespace='') + spikes[pname] = one.load_object(eid, collection=probe_collection, obj='spikes', attribute=spike_attributes, namespace='') + clusters[pname] = one.load_object( + eid, collection=probe_collection, obj='clusters', attribute=cluster_attributes, namespace='' + ) if return_channels: channels = _load_channels_locations_from_disk( - eid, collection=collection, one=one, revision=revision, brain_regions=brain_regions) + eid, collection=collection, one=one, revision=revision, brain_regions=brain_regions + ) return spikes, clusters, channels else: return spikes, clusters @@ -220,7 +227,7 @@ def _load_channels_locations_from_disk(eid, collection=None, one=None, revision= channels = Bunch({}) collections = one.list_collections(eid, filename='channels*', collection=collection, revision=revision) if len(collections) == 0: - _logger.warning(f"eid {eid}: no collection found with collection filter: {collection}, revision: {revision}") + _logger.warning(f'eid {eid}: no collection found with collection filter: {collection}, revision: {revision}') probes = list(set([c.split('/')[1] for c in collections])) for probe in probes: probe_collection = _get_spike_sorting_collection(collections, probe) @@ -228,11 +235,12 @@ def _load_channels_locations_from_disk(eid, collection=None, one=None, revision= # if the spike sorter has not aligned data, try and get the alignment available if 'brainLocationIds_ccf_2017' not in channels[probe].keys(): aligned_channel_collections = one.list_collections( - eid, filename='channels.brainLocationIds_ccf_2017*', collection=probe_collection, revision=revision) + eid, filename='channels.brainLocationIds_ccf_2017*', collection=probe_collection, revision=revision + ) if len(aligned_channel_collections) == 0: - _logger.debug(f"no resolved alignment dataset found for {eid}/{probe}") + _logger.debug(f'no resolved alignment dataset found for {eid}/{probe}') continue - _logger.debug(f"looking for a resolved alignment dataset in {aligned_channel_collections}") + _logger.debug(f'looking for a resolved alignment dataset in {aligned_channel_collections}') ac_collection = _get_spike_sorting_collection(aligned_channel_collections, probe) channels_aligned = one.load_object(eid, 'channels', collection=ac_collection) channels[probe] = channel_locations_interpolation(channels_aligned, channels[probe]) @@ -274,8 +282,7 @@ def channel_locations_interpolation(channels_aligned, channels=None, brain_regio depths, ind, iinv = np.unique(channels['localCoordinates'][:, 1], return_index=True, return_inverse=True) channels['mlapdv'] = np.zeros((nch, 3)) for i in np.arange(3): - channels['mlapdv'][:, i] = np.interp( - depths, depth_aligned, channels_aligned['mlapdv'][ind_aligned, i])[iinv] + channels['mlapdv'][:, i] = np.interp(depths, depth_aligned, channels_aligned['mlapdv'][ind_aligned, i])[iinv] # the brain locations have to be interpolated by nearest neighbour fcn_interp = interp1d(depth_aligned, channels_aligned['brainLocationIds_ccf_2017'][ind_aligned], kind='nearest') channels['brainLocationIds_ccf_2017'] = fcn_interp(depths)[iinv].astype(np.int32) @@ -285,68 +292,62 @@ def channel_locations_interpolation(channels_aligned, channels=None, brain_regio return channels -def _load_channel_locations_traj(eid, probe=None, one=None, revision=None, aligned=False, - brain_atlas=None, return_source=False): +def _load_channel_locations_traj(eid, probe=None, one=None, revision=None, aligned=False, brain_atlas=None, return_source=False): if not hasattr(one, 'alyx'): return {}, None - _logger.debug(f"trying to load from traj {probe}") + _logger.debug(f'trying to load from traj {probe}') channels = Bunch() brain_atlas = brain_atlas or AllenAtlas # need to find the collection bruh insertion = one.alyx.rest('insertions', 'list', session=eid, name=probe)[0] collection = _collection_filter_from_args(probe=probe) - collections = one.list_collections(eid, filename='channels*', collection=collection, - revision=revision) + collections = one.list_collections(eid, filename='channels*', collection=collection, revision=revision) probe_collection = _get_spike_sorting_collection(collections, probe) chn_coords = one.load_dataset(eid, 'channels.localCoordinates', collection=probe_collection) depths = chn_coords[:, 1] - tracing = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \ - get('tracing_exists', False) - resolved = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \ - get('alignment_resolved', False) - counts = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \ - get('alignment_count', 0) + tracing = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}).get('tracing_exists', False) + resolved = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}).get('alignment_resolved', False) + counts = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}).get('alignment_count', 0) if tracing: xyz = np.array(insertion['json']['xyz_picks']) / 1e6 if resolved: - - _logger.debug(f'Channel locations for {eid}/{probe} have been resolved. ' - f'Channel and cluster locations obtained from ephys aligned histology ' - f'track.') - traj = one.alyx.rest('trajectories', 'list', session=eid, probe=probe, - provenance='Ephys aligned histology track')[0] + _logger.debug( + f'Channel locations for {eid}/{probe} have been resolved. ' + f'Channel and cluster locations obtained from ephys aligned histology ' + f'track.' + ) + traj = one.alyx.rest('trajectories', 'list', session=eid, probe=probe, provenance='Ephys aligned histology track')[0] align_key = insertion['json']['extended_qc']['alignment_stored'] feature = traj['json'][align_key][0] track = traj['json'][align_key][1] - ephysalign = EphysAlignment(xyz, depths, track_prev=track, - feature_prev=feature, - brain_atlas=brain_atlas, speedy=True) + ephysalign = EphysAlignment(xyz, depths, track_prev=track, feature_prev=feature, brain_atlas=brain_atlas, speedy=True) chans = ephysalign.get_channel_locations(feature, track) channels[probe] = _channels_traj2bunch(chans, brain_atlas) source = 'resolved' elif counts > 0 and aligned: - _logger.debug(f'Channel locations for {eid}/{probe} have not been ' - f'resolved. However, alignment flag set to True so channel and cluster' - f' locations will be obtained from latest available ephys aligned ' - f'histology track.') + _logger.debug( + f'Channel locations for {eid}/{probe} have not been ' + f'resolved. However, alignment flag set to True so channel and cluster' + f' locations will be obtained from latest available ephys aligned ' + f'histology track.' + ) # get the latest user aligned channels - traj = one.alyx.rest('trajectories', 'list', session=eid, probe=probe, - provenance='Ephys aligned histology track')[0] + traj = one.alyx.rest('trajectories', 'list', session=eid, probe=probe, provenance='Ephys aligned histology track')[0] align_key = insertion['json']['extended_qc']['alignment_stored'] feature = traj['json'][align_key][0] track = traj['json'][align_key][1] - ephysalign = EphysAlignment(xyz, depths, track_prev=track, - feature_prev=feature, - brain_atlas=brain_atlas, speedy=True) + ephysalign = EphysAlignment(xyz, depths, track_prev=track, feature_prev=feature, brain_atlas=brain_atlas, speedy=True) chans = ephysalign.get_channel_locations(feature, track) channels[probe] = _channels_traj2bunch(chans, brain_atlas) source = 'aligned' else: - _logger.debug(f'Channel locations for {eid}/{probe} have not been resolved. ' - f'Channel and cluster locations obtained from histology track.') + _logger.debug( + f'Channel locations for {eid}/{probe} have not been resolved. ' + f'Channel and cluster locations obtained from histology track.' + ) # get the channels from histology tracing xyz = xyz[np.argsort(xyz[:, 2]), :] chans = histology.interpolate_along_track(xyz, (depths + TIP_SIZE_UM) / 1e6) @@ -398,12 +399,12 @@ def load_channel_locations(eid, probe=None, one=None, aligned=False, brain_atlas else: eid = one.to_eid(eid) collection = _collection_filter_from_args(probe=probe) - channels = _load_channels_locations_from_disk(eid, one=one, collection=collection, - brain_regions=brain_atlas.regions) + channels = _load_channels_locations_from_disk(eid, one=one, collection=collection, brain_regions=brain_atlas.regions) incomplete_probes = [k for k in channels if 'x' not in channels[k]] for iprobe in incomplete_probes: - channels_, source = _load_channel_locations_traj(eid, probe=iprobe, one=one, aligned=aligned, - brain_atlas=brain_atlas, return_source=True) + channels_, source = _load_channel_locations_traj( + eid, probe=iprobe, one=one, aligned=aligned, brain_atlas=brain_atlas, return_source=True + ) if channels_ is not None: channels[iprobe] = channels_[iprobe] return channels @@ -449,7 +450,8 @@ def merge_clusters_channels(dic_clus, channels, keys_to_add_extra=None): else: _logger.warning( f'Probe {label}: merging channels and clusters for key "{key}" has {nch_key} on channels' - f' but expected {max(clu_ch)}. Data in new cluster key "{key}" is returned empty.') + f' but expected {max(clu_ch)}. Data in new cluster key "{key}" is returned empty.' + ) dic_clus[label][key] = [] except AssertionError: _logger.warning(f'Either clusters or channels does not have key {key}, could not merge') @@ -479,10 +481,9 @@ def load_passive_rfmap(eid, one=None): # Load in the receptive field mapping data rf_map = one.load_object(eid, obj='passiveRFM', collection='alf') - frames = np.fromfile(one.load_dataset(eid, '_iblrig_RFMapStim.raw.bin', - collection='raw_passive_data'), dtype="uint8") + frames = np.fromfile(one.load_dataset(eid, '_iblrig_RFMapStim.raw.bin', collection='raw_passive_data'), dtype='uint8') y_pix, x_pix = 15, 15 - frames = np.transpose(np.reshape(frames, [y_pix, x_pix, -1], order="F"), [2, 1, 0]) + frames = np.transpose(np.reshape(frames, [y_pix, x_pix, -1], order='F'), [2, 1, 0]) rf_map['frames'] = frames return rf_map @@ -553,13 +554,13 @@ def load_iti(trials): def load_channels_from_insertion(ins, depths=None, one=None, ba=None): - PROV_2_VAL = { 'Resolved': 90, 'Ephys aligned histology track': 70, 'Histology track': 50, 'Micro-manipulator': 30, - 'Planned': 10} + 'Planned': 10, + } one = one or ONE() ba = ba or atlas.AllenAtlas() @@ -573,21 +574,17 @@ def load_channels_from_insertion(ins, depths=None, one=None, ba=None): ins = atlas.Insertion.from_dict(traj) # Deepest coordinate first xyz = np.c_[ins.tip, ins.entry].T - xyz_channels = histology.interpolate_along_track(xyz, (depths + - TIP_SIZE_UM) / 1e6) + xyz_channels = histology.interpolate_along_track(xyz, (depths + TIP_SIZE_UM) / 1e6) else: xyz = np.array(ins['json']['xyz_picks']) / 1e6 if traj['provenance'] == 'Histology track': xyz = xyz[np.argsort(xyz[:, 2]), :] - xyz_channels = histology.interpolate_along_track(xyz, (depths + - TIP_SIZE_UM) / 1e6) + xyz_channels = histology.interpolate_along_track(xyz, (depths + TIP_SIZE_UM) / 1e6) else: align_key = ins['json']['extended_qc']['alignment_stored'] feature = traj['json'][align_key][0] track = traj['json'][align_key][1] - ephysalign = EphysAlignment(xyz, depths, track_prev=track, - feature_prev=feature, - brain_atlas=ba, speedy=True) + ephysalign = EphysAlignment(xyz, depths, track_prev=track, feature_prev=feature, brain_atlas=ba, speedy=True) xyz_channels = ephysalign.get_channel_locations(feature, track) return xyz_channels @@ -605,6 +602,7 @@ class SpikeSortingLoader: SpikeSortingLoader(session_path=session_path, pname='probe00') NB: When no ONE instance is passed, any datasets that are loaded will not be recorded. """ + one: One = None atlas: None = None pid: str = None @@ -613,7 +611,7 @@ class SpikeSortingLoader: session_path: ALFPath = '' # the following properties are the outcome of the post init function collections: list = None - datasets: list = None # list of all datasets belonging to the session + datasets: list = None # list of all datasets belonging to the session # the following properties are the outcome of a reading function files: dict = None raw_data_files: list = None # list of raw ap and lf files corresponding to the recording @@ -631,8 +629,10 @@ def __post_init__(self): self.eid, self.pname = self.one.pid2eid(self.pid) except NotImplementedError: if self.eid == '' or self.pname == '': - raise IOError("Cannot infer session id and probe name from pid. " - "You need to pass eid and pname explicitly when instantiating SpikeSortingLoader.") + raise IOError( + 'Cannot infer session id and probe name from pid. ' + 'You need to pass eid and pname explicitly when instantiating SpikeSortingLoader.' + ) self.session_path = self.one.eid2path(self.eid) # then eid / pname combination elif self.session_path is None or self.session_path == '': @@ -649,8 +649,7 @@ def __post_init__(self): self.one._cache['datasets'] = cache._make_datasets_df(self.session_path, hash_files=False) self.eid = str(self.session_path.relative_to(self.session_path.parents[2])) # populates default properties - self.collections = self.one.list_collections( - self.eid, filename='spikes*', collection=f"alf/{self.pname}*") + self.collections = self.one.list_collections(self.eid, filename='spikes*', collection=f'alf/{self.pname}*') self.datasets = self.one.list_datasets(self.eid) if self.atlas is None: self.atlas = AllenAtlas() @@ -691,7 +690,7 @@ def _get_spike_sorting_collection(self, spike_sorter=None): for sorter in list([spike_sorter, 'iblsorter', 'pykilosort']): if sorter is None: continue - if sorter == "": + if sorter == '': collection = next(filter(lambda c: c == f'alf/{self.pname}', self.collections), None) else: collection = next(filter(lambda c: c == f'alf/{self.pname}/{sorter}', self.collections), None) @@ -699,7 +698,7 @@ def _get_spike_sorting_collection(self, spike_sorter=None): return collection # if none is found amongst the defaults, prefers the shortest collection = collection or next(iter(sorted(filter(lambda c: f'alf/{self.pname}' in c, self.collections), key=len)), None) - _logger.debug(f"selecting: {collection} to load amongst candidates: {self.collections}") + _logger.debug(f'selecting: {collection} to load amongst candidates: {self.collections}') return collection def load_spike_sorting_object(self, obj, *args, revision=None, **kwargs): @@ -724,8 +723,17 @@ def get_version(self, spike_sorter=None): dset = self.one.alyx.rest('datasets', 'list', session=self.eid, collection=collection, name='spikes.times.npy') return dset[0]['version'] if len(dset) else 'unknown' - def download_spike_sorting_object(self, obj, spike_sorter=None, dataset_types=None, collection=None, - attribute=None, missing='raise', revision=None, **kwargs): + def download_spike_sorting_object( + self, + obj, + spike_sorter=None, + dataset_types=None, + collection=None, + attribute=None, + missing='raise', + revision=None, + **kwargs, + ): """ Downloads an ALF object :param obj: object name, str between 'spikes', 'clusters' or 'channels' @@ -745,12 +753,18 @@ def download_spike_sorting_object(self, obj, spike_sorter=None, dataset_types=No return {}, {}, {} self.collection = self._get_spike_sorting_collection(spike_sorter=spike_sorter) collection = collection or self.collection - _logger.debug(f"loading spike sorting object {obj} from {collection}") + _logger.debug(f'loading spike sorting object {obj} from {collection}') attributes = self._get_attributes(dataset_types) try: self.files[obj] = self.one.load_object( - self.eid, obj=obj, attribute=attributes.get(obj, None), - collection=collection, download_only=True, revision=revision, **kwargs) + self.eid, + obj=obj, + attribute=attributes.get(obj, None), + collection=collection, + download_only=True, + revision=revision, + **kwargs, + ) except ALFObjectNotFound as e: if missing == 'raise': raise e @@ -778,13 +792,15 @@ def download_raw_electrophysiology(self, band='ap'): for suffix in [f'*.{band}.ch', f'*.{band}.meta', f'*.{band}.cbin']: try: # FIXME: this will fail if multiple LFP segments are found - raw_data_files.append(self.one.load_dataset( - self.eid, - download_only=True, - collection=f'raw_ephys_data/{self.pname}', - dataset=suffix, - check_hash=False, - )) + raw_data_files.append( + self.one.load_dataset( + self.eid, + download_only=True, + collection=f'raw_ephys_data/{self.pname}', + dataset=suffix, + check_hash=False, + ) + ) except ALFObjectNotFound: _logger.debug(f"{self.session_path} can't locate raw data collection raw_ephys_data/{self.pname}, file {suffix}") self.raw_data_files = list(set(self.raw_data_files + raw_data_files)) @@ -804,7 +820,7 @@ def raw_electrophysiology(self, stream=True, band='ap', **kwargs): return Streamer(pid=self.pid, one=self.one, typ=band, **kwargs) else: raw_data_files = self.download_raw_electrophysiology(band=band) - cbin_file = next(filter(lambda f: re.match(rf".*\.{band}\..*cbin", f.name), raw_data_files), None) + cbin_file = next(filter(lambda f: re.match(rf'.*\.{band}\..*cbin', f.name), raw_data_files), None) if cbin_file is not None: return spikeglx.Reader(cbin_file) @@ -812,10 +828,14 @@ def download_raw_waveforms(self, **kwargs): """ Downloads raw waveforms extracted from sorting to local disk. """ - _logger.debug(f"loading waveforms from {self.collection}") + _logger.debug(f'loading waveforms from {self.collection}') return self.one.load_object( - id=self.eid, obj="waveforms", attribute=["traces", "templates", "table", "channels"], - collection=self._get_spike_sorting_collection("pykilosort"), download_only=True, **kwargs + id=self.eid, + obj='waveforms', + attribute=['traces', 'templates', 'table', 'channels'], + collection=self._get_spike_sorting_collection('pykilosort'), + download_only=True, + **kwargs, ) def raw_waveforms(self, **kwargs): @@ -846,9 +866,10 @@ def load_channels(self, **kwargs): channels = self._load_object(self.files['electrodeSites'], wildcards=self.one.wildcards) channels['rawInd'] = np.arange(channels[list(channels.keys())[0]].shape[0]) if 'brainLocationIds_ccf_2017' not in channels: - _logger.debug(f"loading channels from alyx for {self.files['channels']}") + _logger.debug(f'loading channels from alyx for {self.files["channels"]}') _channels, self.histology = _load_channel_locations_traj( - self.eid, probe=self.pname, one=self.one, brain_atlas=self.atlas, return_source=True, aligned=True) + self.eid, probe=self.pname, one=self.one, brain_atlas=self.atlas, return_source=True, aligned=True + ) if _channels: channels = _channels[self.pname] else: @@ -858,18 +879,19 @@ def load_channels(self, **kwargs): @staticmethod def filter_files_by_namespace(all_files, namespace): - # Create dict for each file with available namespaces, no namespce is stored under the key None namespace_files = defaultdict(dict) available_namespaces = [] for file in all_files: nspace = file.namespace or None available_namespaces.append(nspace) - namespace_files[f"{file.object}.{file.attribute}"][nspace] = file + namespace_files[f'{file.object}.{file.attribute}'][nspace] = file if namespace not in set(available_namespaces): - _logger.info(f'Could not find manual curation results for {namespace}, returning default' - f' non manually curated spikesorting data') + _logger.info( + f'Could not find manual curation results for {namespace}, returning default' + f' non manually curated spikesorting data' + ) # Return the files with the chosen namespace. files = [f.get(namespace, f.get(None, None)) for f in namespace_files.values()] @@ -877,8 +899,9 @@ def filter_files_by_namespace(all_files, namespace): files = [f for f in files if f] return files - def load_spike_sorting(self, spike_sorter='iblsorter', revision=None, enforce_version=False, good_units=False, - namespace=None, **kwargs): + def load_spike_sorting( + self, spike_sorter='iblsorter', revision=None, enforce_version=False, good_units=False, namespace=None, **kwargs + ): """ Loads spikes, clusters and channels @@ -908,8 +931,10 @@ def load_spike_sorting(self, spike_sorter='iblsorter', revision=None, enforce_ve self.revision = revision if good_units and namespace is not None: - _logger.info('Good units table does not exist for manually curated spike sorting. Pass in namespace with' - 'good_units=False and filter the spikes post hoc by the good clusters.') + _logger.info( + 'Good units table does not exist for manually curated spike sorting. Pass in namespace with' + 'good_units=False and filter the spikes post hoc by the good clusters.' + ) return [None] * 3 objects = ['passingSpikes', 'clusters', 'channels'] if good_units else None self.download_spike_sorting(spike_sorter=spike_sorter, revision=revision, objects=objects, **kwargs) @@ -934,17 +959,18 @@ def _assert_version_consistency(self): for k in ['spikes', 'clusters', 'channels', 'passingSpikes']: for fn in self.files.get(k, []): if self.spike_sorter: - assert fn.relative_to(self.session_path).parts[2] == self.spike_sorter, \ - f"You required strict version {self.spike_sorter}, {fn} does not match" + assert fn.relative_to(self.session_path).parts[2] == self.spike_sorter, ( + f'You required strict version {self.spike_sorter}, {fn} does not match' + ) if self.revision: - assert fn.revision == self.revision, \ - f"You required strict revision {self.revision}, {fn} does not match" + assert fn.revision == self.revision, f'You required strict revision {self.revision}, {fn} does not match' @staticmethod def compute_metrics(spikes, clusters=None): nc = clusters['channels'].size if clusters else np.unique(spikes['clusters']).size - metrics = pd.DataFrame(quick_unit_metrics( - spikes['clusters'], spikes['times'], spikes['amps'], spikes['depths'], cluster_ids=np.arange(nc))) + metrics = pd.DataFrame( + quick_unit_metrics(spikes['clusters'], spikes['times'], spikes['amps'], spikes['depths'], cluster_ids=np.arange(nc)) + ) return metrics @staticmethod @@ -969,7 +995,7 @@ def merge_clusters(spikes, clusters, channels, cache_dir=None, compute_metrics=F if metrics.shape[0] != nc: metrics = None if metrics is None or compute_metrics is True: - _logger.debug("recompute clusters metrics") + _logger.debug('recompute clusters metrics') metrics = SpikeSortingLoader.compute_metrics(spikes, clusters) if isinstance(cache_dir, Path): metrics.to_parquet(Path(cache_dir).joinpath('clusters.metrics.pqt')) @@ -992,12 +1018,15 @@ def _get_probe_info(self, revision=None): revision = revision if revision is not None else self.revision if self._sync is None: timestamps = self.one.load_dataset( - self.eid, dataset='_spikeglx_*.timestamps.npy', collection=f'raw_ephys_data/{self.pname}', revision=revision) + self.eid, dataset='_spikeglx_*.timestamps.npy', collection=f'raw_ephys_data/{self.pname}', revision=revision + ) _ = self.one.load_dataset( # this is not used here but we want to trigger the download for potential tasks - self.eid, dataset='_spikeglx_*.sync.npy', collection=f'raw_ephys_data/{self.pname}', revision=revision) + self.eid, dataset='_spikeglx_*.sync.npy', collection=f'raw_ephys_data/{self.pname}', revision=revision + ) try: - ap_meta = spikeglx.read_meta_data(self.one.load_dataset( - self.eid, dataset='_spikeglx_*.ap.meta', collection=f'raw_ephys_data/{self.pname}')) + ap_meta = spikeglx.read_meta_data( + self.one.load_dataset(self.eid, dataset='_spikeglx_*.ap.meta', collection=f'raw_ephys_data/{self.pname}') + ) fs = spikeglx._get_fs_from_meta(ap_meta) except ALFObjectNotFound: ap_meta = None @@ -1030,15 +1059,17 @@ def samples2times(self, values, direction='forward'): @property def pid2ref(self): - return f"{self.one.eid2ref(self.eid, as_dict=False)}_{self.pname}" + return f'{self.one.eid2ref(self.eid, as_dict=False)}_{self.pname}' def _default_plot_title(self, spikes): - title = f"{self.pid2ref}, {self.pid} \n" \ - f"{spikes['clusters'].size:_} spikes, {np.unique(spikes['clusters']).size:_} clusters" + title = ( + f'{self.pid2ref}, {self.pid} \n{spikes["clusters"].size:_} spikes, {np.unique(spikes["clusters"]).size:_} clusters' + ) return title - def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_series=None, - drift=None, title=None, **kwargs): + def raster( + self, spikes, channels, save_dir=None, br=None, label='raster', time_series=None, drift=None, title=None, **kwargs + ): """ :param spikes: spikes dictionary or Bunch :param channels: channels dictionary or Bunch. @@ -1052,13 +1083,14 @@ def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_ """ br = br or BrainRegions() time_series = time_series or {} - fig, axs = plt.subplots(2, 2, gridspec_kw={ - 'width_ratios': [.95, .05], 'height_ratios': [.1, .9]}, figsize=(16, 9), sharex='col') + fig, axs = plt.subplots( + 2, 2, gridspec_kw={'width_ratios': [0.95, 0.05], 'height_ratios': [0.1, 0.9]}, figsize=(16, 9), sharex='col' + ) axs[0, 1].set_axis_off() # axs[0, 0].set_xticks([]) if kwargs is None: # set default raster plot parameters - kwargs = {"t_bin": 0.007, "d_bin": 10, "vmax": 0.5} + kwargs = {'t_bin': 0.007, 'd_bin': 10, 'vmax': 0.5} brainbox.plot.driftmap(spikes['times'], spikes['depths'], ax=axs[1, 0], **kwargs) if title is None: title = self._default_plot_title(spikes) @@ -1066,8 +1098,14 @@ def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_ for k, ts in time_series.items(): vertical_lines(ts, ymin=0, ymax=3800, ax=axs[1, 0]) if 'atlas_id' in channels: - plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'], - brain_regions=br, display=True, ax=axs[1, 1], title=self.histology) + plot_brain_regions( + channels['atlas_id'], + channel_depths=channels['axial_um'], + brain_regions=br, + display=True, + ax=axs[1, 1], + title=self.histology, + ) axs[1, 0].set_ylim(0, 3800) axs[1, 0].set_xlim(spikes['times'][0], spikes['times'][-1]) fig.tight_layout() @@ -1077,28 +1115,33 @@ def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_ if 'drift' in self.files: drift = self._load_object(self.files['drift'], wildcards=self.one.wildcards) if isinstance(drift, dict): - axs[0, 0].plot(drift['times'], drift['um'], 'k', alpha=.5) + axs[0, 0].plot(drift['times'], drift['um'], 'k', alpha=0.5) axs[0, 0].set(ylim=[-15, 15]) if save_dir is not None: - png_file = save_dir.joinpath(f"{self.pid}_{self.pid2ref}_{label}.png") if Path(save_dir).is_dir() else Path(save_dir) + png_file = save_dir.joinpath(f'{self.pid}_{self.pid2ref}_{label}.png') if Path(save_dir).is_dir() else Path(save_dir) fig.savefig(png_file) plt.close(fig) gc.collect() else: return fig, axs - def plot_rawdata_snippet(self, sr, spikes, clusters, t0, - channels=None, - br: BrainRegions = None, - save_dir=None, - label='raster', - gain=-93, - title=None): - + def plot_rawdata_snippet( + self, + sr, + spikes, + clusters, + t0, + channels=None, + br: BrainRegions = None, + save_dir=None, + label='raster', + gain=-93, + title=None, + ): # compute the raw data offset and destripe, we take 400ms around t0 first_sample, last_sample = (int((t0 - 0.2) * sr.fs), int((t0 + 0.2) * sr.fs)) - raw = sr[first_sample:last_sample, :-sr.nsync].T + raw = sr[first_sample:last_sample, : -sr.nsync].T channel_labels = channels['labels'] if (channels is not None) and ('labels' in channels) else True destriped = ibldsp.voltage.destripe(raw, sr.fs, channel_labels=channel_labels) # filter out the spikes according to good/bad clusters and to the time slice @@ -1109,21 +1152,27 @@ def plot_rawdata_snippet(self, sr, spikes, clusters, t0, if title is None: title = self._default_plot_title(spikes) # display the raw data snippet with spikes overlaid - fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9), sharex='col') + fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [0.95, 0.05]}, figsize=(16, 9), sharex='col') Density(destriped, fs=sr.fs, taxis=1, gain=gain, ax=axs[0], t0=t0 - 0.2, unit='s') - axs[0].scatter(ss[sok] / sr.fs, sc[sok], color="green", alpha=0.5) - axs[0].scatter(ss[~sok] / sr.fs, sc[~sok], color="red", alpha=0.5) + axs[0].scatter(ss[sok] / sr.fs, sc[sok], color='green', alpha=0.5) + axs[0].scatter(ss[~sok] / sr.fs, sc[~sok], color='red', alpha=0.5) axs[0].set(title=title, xlim=[t0 - 0.035, t0 + 0.035]) # adds the channel locations if available if (channels is not None) and ('atlas_id' in channels): br = br or BrainRegions() - plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'], - brain_regions=br, display=True, ax=axs[1], title=self.histology) + plot_brain_regions( + channels['atlas_id'], + channel_depths=channels['axial_um'], + brain_regions=br, + display=True, + ax=axs[1], + title=self.histology, + ) axs[1].get_yaxis().set_visible(False) fig.tight_layout() if save_dir is not None: - png_file = save_dir.joinpath(f"{self.pid}_{self.pid2ref}_{label}.png") if Path(save_dir).is_dir() else Path(save_dir) + png_file = save_dir.joinpath(f'{self.pid}_{self.pid2ref}_{label}.png') if Path(save_dir).is_dir() else Path(save_dir) fig.savefig(png_file) plt.close(fig) gc.collect() @@ -1198,6 +1247,7 @@ class SessionLoader: functions: >>> sess_loader.load_wheel(sampling_rate=100) """ + one: One = None session_path: ALFPath = '' eid: str = '' @@ -1215,8 +1265,10 @@ def __post_init__(self): Checks for required inputs, sets session_path and eid, creates data_info table. """ if self.one is None: - raise ValueError("An input to one is required. If not connection to a database is desired, it can be " - "a fully local instance of One.") + raise ValueError( + 'An input to one is required. If not connection to a database is desired, it can be ' + 'a fully local instance of One.' + ) # If session path is given, takes precedence over eid if self.session_path is not None and self.session_path != '': self.eid = self.one.to_eid(self.session_path) @@ -1226,15 +1278,9 @@ def __post_init__(self): if self.eid is not None and self.eid != '': self.session_path = self.one.eid2path(self.eid) else: - raise ValueError("If no session path is given, eid is required.") - - data_names = [ - 'trials', - 'wheel', - 'pose', - 'motion_energy', - 'pupil' - ] + raise ValueError('If no session path is given, eid is required.') + + data_names = ['trials', 'wheel', 'pose', 'motion_energy', 'pupil'] self.data_info = pd.DataFrame(columns=['name', 'is_loaded'], data=zip(data_names, [False] * len(data_names))) def load_session_data(self, trials=True, wheel=True, pose=True, motion_energy=True, pupil=True, reload=False): @@ -1263,33 +1309,21 @@ def load_session_data(self, trials=True, wheel=True, pose=True, motion_energy=Tr Whether to reload data that has already been loaded into this SessionLoader object, default is False """ load_df = self.data_info.copy() - load_df['to_load'] = [ - trials, - wheel, - pose, - motion_energy, - pupil - ] - load_df['load_func'] = [ - self.load_trials, - self.load_wheel, - self.load_pose, - self.load_motion_energy, - self.load_pupil - ] + load_df['to_load'] = [trials, wheel, pose, motion_energy, pupil] + load_df['load_func'] = [self.load_trials, self.load_wheel, self.load_pose, self.load_motion_energy, self.load_pupil] for idx, row in load_df.iterrows(): if row['to_load'] is False: - _logger.debug(f"Not loading {row['name']} data, set to False.") + _logger.debug(f'Not loading {row["name"]} data, set to False.') elif row['is_loaded'] is True and reload is False: - _logger.debug(f"Not loading {row['name']} data, is already loaded and reload=False.") + _logger.debug(f'Not loading {row["name"]} data, is already loaded and reload=False.') else: try: - _logger.info(f"Loading {row['name']} data") + _logger.info(f'Loading {row["name"]} data') row['load_func']() self.data_info.loc[idx, 'is_loaded'] = True except BaseException as e: - _logger.warning(f"Could not load {row['name']} data.") + _logger.warning(f'Could not load {row["name"]} data.') _logger.debug(e) def _find_behaviour_collection(self, obj): @@ -1310,8 +1344,10 @@ def _find_behaviour_collection(self, obj): if len(set(collections)) == 1: return collections[0] else: - _logger.error(f'Multiple collections found {collections}. Specify collection when loading, ' - f'e.g sl.load_{obj}(collection="{collections[0]}")') + _logger.error( + f'Multiple collections found {collections}. Specify collection when loading, ' + f'e.g sl.load_{obj}(collection="{collections[0]}")' + ) raise ALFMultipleCollectionsFound def load_trials(self, collection=None): @@ -1329,7 +1365,8 @@ def load_trials(self, collection=None): # itiDuration frequently has a mismatched dimension, and we don't need it, exclude using regex self.one.wildcards = False self.trials = self.one.load_object( - self.eid, 'trials', collection=collection, attribute=r'(?!itiDuration).*', revision=self.revision or None).to_df() + self.eid, 'trials', collection=collection, attribute=r'(?!itiDuration).*', revision=self.revision or None + ).to_df() self.one.wildcards = True self.data_info.loc[self.data_info['name'] == 'trials', 'is_loaded'] = True @@ -1358,9 +1395,11 @@ def load_wheel(self, fs=1000, corner_frequency=20, order=8, collection=None): # resample the wheel position and compute velocity, acceleration self.wheel = pd.DataFrame(columns=['times', 'position', 'velocity', 'acceleration']) self.wheel['position'], self.wheel['times'] = interpolate_position( - wheel_raw['timestamps'], wheel_raw['position'], freq=fs) + wheel_raw['timestamps'], wheel_raw['position'], freq=fs + ) self.wheel['velocity'], self.wheel['acceleration'] = velocity_filtered( - self.wheel['position'], fs=fs, corner_frequency=corner_frequency, order=order) + self.wheel['position'], fs=fs, corner_frequency=corner_frequency, order=order + ) self.wheel = self.wheel.apply(np.float32) self.data_info.loc[self.data_info['name'] == 'wheel', 'is_loaded'] = True @@ -1386,7 +1425,8 @@ def load_pose(self, likelihood_thr=0.9, views=['left', 'right', 'body'], tracker self.pose = {} for view in views: pose_raw = self.one.load_object( - self.eid, f'{view}Camera', attribute=[tracker, 'times'], revision=self.revision or None) + self.eid, f'{view}Camera', attribute=[tracker, 'times'], revision=self.revision or None + ) # Double check if video timestamps are correct length or can be fixed times_fixed, dlc = self._check_video_timestamps(view, pose_raw['times'], pose_raw[tracker]) self.pose[f'{view}Camera'] = likelihood_threshold(dlc, likelihood_thr) @@ -1407,17 +1447,15 @@ def load_motion_energy(self, views=['left', 'right', 'body']): views: list List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'} """ - names = {'left': 'whiskerMotionEnergy', - 'right': 'whiskerMotionEnergy', - 'body': 'bodyMotionEnergy'} + names = {'left': 'whiskerMotionEnergy', 'right': 'whiskerMotionEnergy', 'body': 'bodyMotionEnergy'} # empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger self.motion_energy = {} for view in views: me_raw = self.one.load_object( - self.eid, f'{view}Camera', attribute=['ROIMotionEnergy', 'times'], revision=self.revision or None) + self.eid, f'{view}Camera', attribute=['ROIMotionEnergy', 'times'], revision=self.revision or None + ) # Double check if video timestamps are correct length or can be fixed - times_fixed, motion_energy = self._check_video_timestamps( - view, me_raw['times'], me_raw['ROIMotionEnergy']) + times_fixed, motion_energy = self._check_video_timestamps(view, me_raw['times'], me_raw['ROIMotionEnergy']) self.motion_energy[f'{view}Camera'] = pd.DataFrame(columns=[names[view]], data=motion_energy) self.motion_energy[f'{view}Camera'].insert(0, 'times', times_fixed) self.data_info.loc[self.data_info['name'] == 'motion_energy', 'is_loaded'] = True @@ -1428,7 +1466,7 @@ def load_licks(self): """ pass - def load_pupil(self, snr_thresh=5.): + def load_pupil(self, snr_thresh=5.0): """ Function to load raw and smoothed pupil diameter data from the left camera into SessionLoader.pupil. @@ -1448,8 +1486,7 @@ def load_pupil(self, snr_thresh=5.): # If unavailable compute on the fly else: _logger.info('Pupil diameter not available, trying to compute on the fly.') - if (self.data_info[self.data_info['name'] == 'pose']['is_loaded'].values[0] - and 'leftCamera' in self.pose.keys()): + if self.data_info[self.data_info['name'] == 'pose']['is_loaded'].values[0] and 'leftCamera' in self.pose.keys(): # If pose data is already loaded, we don't know if it was threshold at 0.9, so we need a little stunt copy_pose = self.pose['leftCamera'].copy() # Save the previously loaded pose data self.load_pose(views=['left'], likelihood_thr=0.9) # Load new with threshold 0.9 @@ -1463,16 +1500,18 @@ def load_pupil(self, snr_thresh=5.): try: self.pupil['pupilDiameter_smooth'] = get_smooth_pupil_diameter(self.pupil['pupilDiameter_raw'], 'left') except BaseException as e: - _logger.error("Loaded raw pupil diameter but computing smooth pupil diameter failed. " - "Saving all NaNs for pupilDiameter_smooth.") + _logger.error( + 'Loaded raw pupil diameter but computing smooth pupil diameter failed. ' + 'Saving all NaNs for pupilDiameter_smooth.' + ) _logger.debug(e) self.pupil['pupilDiameter_smooth'] = np.nan if not np.all(np.isnan(self.pupil['pupilDiameter_smooth'])): - good_idxs = np.where( - ~np.isnan(self.pupil['pupilDiameter_smooth']) & ~np.isnan(self.pupil['pupilDiameter_raw']))[0] - snr = (np.var(self.pupil['pupilDiameter_smooth'][good_idxs]) / - (np.var(self.pupil['pupilDiameter_smooth'][good_idxs] - self.pupil['pupilDiameter_raw'][good_idxs]))) + good_idxs = np.where(~np.isnan(self.pupil['pupilDiameter_smooth']) & ~np.isnan(self.pupil['pupilDiameter_raw']))[0] + snr = np.var(self.pupil['pupilDiameter_smooth'][good_idxs]) / ( + np.var(self.pupil['pupilDiameter_smooth'][good_idxs] - self.pupil['pupilDiameter_raw'][good_idxs]) + ) if snr < snr_thresh: self.pupil = pd.DataFrame() raise ValueError(f'Pupil diameter SNR ({snr:.2f}) below threshold SNR ({snr_thresh}), removing data.') @@ -1508,6 +1547,7 @@ class EphysSessionLoader(SessionLoader): To select for a specific probe >>> EphysSessionLoader(eid=eid, one=one, pid=pid) """ + def __init__(self, *args, pname=None, pid=None, **kwargs): """ Needs an active connection in order to get the list of insertions in the session @@ -1539,3 +1579,63 @@ def load_spike_sorting(self, pnames=None): @property def probes(self): return {k: self.ephys[k]['ssl'].pid for k in self.ephys} + + +class PhotometrySessionLoader(SessionLoader): + photometry: dict = field(default_factory=dict, repr=False) + + def __init__(self, *args, photometry_collection: str = 'photometry', **kwargs): + self.photometry_collection = photometry_collection + self.revision = kwargs.get('revision', None) + + # determine if loading by eid or session path + self.load_by_path = True if 'session_path' in kwargs else False + + super().__init__(*args, **kwargs) + + def load_session_data(self, **kwargs): + super().load_session_data(**kwargs) + self.load_photometry() + + def load_photometry( + self, + restrict_to_session: bool = True, + pre: int = 5, + post: int = 5, + ): + # session path precedence over eid + if self.load_by_path: + raw_dfs = fpio.from_session_path( + self.session_path, + collection=self.photometry_collection, + revision=self.revision, + ) + else: # load by eid + raw_dfs = fpio.from_eid( + self.eid, + self.one, + collection=self.photometry_collection, + revision=self.revision, + ) + + if restrict_to_session: + if isinstance(self.trials, pd.DataFrame) and (self.trials.shape[0] == 0): + self.load_trials() + t_start = self.trials.iloc[0]['intervals_0'] + t_stop = self.trials.iloc[-1]['intervals_1'] + + for band in raw_dfs.keys(): + df = raw_dfs[band] + ix = np.logical_and( + df.index.values > t_start - pre, + df.index.values < t_stop + post, + ) + raw_dfs[band] = df.loc[ix] + + # the above indexing can lead to unevenly shaped bands. + # Cut to shortest + n = np.min([df.shape[0] for _, df in raw_dfs.items()]) + for band in raw_dfs.keys(): + raw_dfs[band] = raw_dfs[band].iloc[:n] + + self.photometry = raw_dfs diff --git a/ibllib/pipes/dynamic_pipeline.py b/ibllib/pipes/dynamic_pipeline.py index a7f64e8e4..69b97ffbf 100644 --- a/ibllib/pipes/dynamic_pipeline.py +++ b/ibllib/pipes/dynamic_pipeline.py @@ -588,11 +588,35 @@ def make_pipeline(session_path, **pkwargs): **kwargs, **mscope_kwargs, parents=[tasks['MesoscopePreprocess']]) if 'neurophotometrics' in devices: - # {'collection': 'raw_photometry_data', 'datetime': '2024-09-18T16:43:55.207000', - # 'fibers': {'G0': {'location': 'NBM'}, 'G1': {'location': 'SI'}}, 'sync_channel': 1} - photometry_kwargs = devices['neurophotometrics'] - tasks['FibrePhotometrySync'] = type('FibrePhotometrySync', ( - ptasks.FibrePhotometrySync,), {})(**kwargs, **photometry_kwargs) + # note: devices['neurophotometrics'] is the acquisition_description + sync_mode = devices['neurophotometrics'].get('sync_mode', 'bpod') # default to bpod for downward compatibility + + # passive photometry + task_protocols = acquisition_description['tasks'] + assert len(task_protocols) == 1, 'chained protocols are not yet supported for photometry extraction' + protocol = task_protocols[0] + if 'passive' in protocol: + assert sync_mode == 'daqami', 'passive protocol syncing only supported for DAQ based syncing' + tasks['FibrePhotometryPassiveChoiceWorld'] = type( + 'FibrePhotometryPassiveChoiceWorld', (ptasks.FibrePhotometryPassiveChoiceWorld,), {} + )( + **kwargs, + ) + + match sync_mode: + case 'bpod': + # for synchronization with the BNC inputs of the neurophotometrics receiving the sync pulses + # from the individual bpods + tasks['FibrePhotometryBpodSync'] = type('FibrePhotometryBpodSync', (ptasks.FibrePhotometryBpodSync,), {})( + **kwargs, + ) + case 'daqami': + # for synchronization with the DAQami receiving the sync pulses from the individual bpods + # as well as the frame clock from the FP3002 + if 'passive' not in protocol: # excluding passive session + tasks['FibrePhotometryDAQSync'] = type('FibrePhotometryDAQSync', (ptasks.FibrePhotometryDAQSync,), {})( + **kwargs, + ) p = mtasks.Pipeline(session_path=session_path, **pkwargs) p.tasks = tasks diff --git a/ibllib/pipes/neurophotometrics.py b/ibllib/pipes/neurophotometrics.py index 18f558c59..69156ce07 100644 --- a/ibllib/pipes/neurophotometrics.py +++ b/ibllib/pipes/neurophotometrics.py @@ -1,186 +1,686 @@ -"""Extraction tasks for fibrephotometry""" - import logging +from pathlib import Path import numpy as np import pandas as pd +from typing import Tuple, Optional, List +import pickle import ibldsp.utils import ibllib.io.session_params from ibllib.pipes import base_tasks from iblutil.io import jsonable +from nptdms import TdmsFile + +from abc import abstractmethod +import iblphotometry +from iblphotometry import fpio + +from one.api import ONE +import json +from scipy.optimize import minimize + _logger = logging.getLogger('ibllib') -""" -Neurophotometrics FP3002 specific information. -The light source map refers to the available LEDs on the system. -The flags refers to the byte encoding of led states in the system. -""" -LIGHT_SOURCE_MAP = { - 'color': ['None', 'Violet', 'Blue', 'Green'], - 'wavelength': [0, 415, 470, 560], - 'name': ['None', 'Isosbestic', 'GCaMP', 'RCaMP'], -} - -LED_STATES = { - 'Condition': { - 0: 'No additional signal', - 1: 'Output 1 signal HIGH', - 2: 'Output 0 signal HIGH', - 3: 'Stimulation ON', - 4: 'GPIO Line 2 HIGH', - 5: 'GPIO Line 3 HIGH', - 6: 'Input 1 HIGH', - 7: 'Input 0 HIGH', - 8: 'Output 0 signal HIGH + Stimulation', - 9: 'Output 0 signal HIGH + Input 0 signal HIGH', - 10: 'Input 0 signal HIGH + Stimulation', - 11: 'Output 0 HIGH + Input 0 HIGH + Stimulation', - }, - 'No LED ON': {0: 0, 1: 8, 2: 16, 3: 32, 4: 64, 5: 128, 6: 256, 7: 512, 8: 48, 9: 528, 10: 544, 11: 560}, - 'L415': {0: 1, 1: 9, 2: 17, 3: 33, 4: 65, 5: 129, 6: 257, 7: 513, 8: 49, 9: 529, 10: 545, 11: 561}, - 'L470': {0: 2, 1: 10, 2: 18, 3: 34, 4: 66, 5: 130, 6: 258, 7: 514, 8: 50, 9: 530, 10: 546, 11: 562}, - 'L560': {0: 4, 1: 12, 2: 20, 3: 36, 4: 68, 5: 132, 6: 260, 7: 516, 8: 52, 9: 532, 10: 548, 11: 564} -} - - -def _channel_meta(light_source_map=None): +def _int2digital_channels(values: np.ndarray) -> np.ndarray: + """decoder for the digital channel values from the tdms file into a channel + based array (rows are temporal samples, columns are channels). + + essentially does: + + 0 -> 0000 + 1 -> 1000 + 2 -> 0100 + 3 -> 1100 + 4 -> 0010 + 5 -> 1010 + 6 -> 0110 + ... + + the order from binary representation is reversed so + columns index represents channel index + + Parameters + ---------- + values : np.ndarray + the input values from the tdms digital channel + + Returns + ------- + np.ndarray + a (n x 4) array """ - Return table of light source wavelengths and corresponding colour labels. + return np.array([list(f'{v:04b}'[::-1]) for v in values], dtype='int8') + + +def extract_timestamps_from_tdms_file( + tdms_filepath: Path, + save_path: Optional[Path] = None, + chunk_size=10000, +) -> dict: + """extractor for tdms files as written by the daqami software, configured for neurophotometrics + experiments: Frameclock is in an analog channel (AI?), DI1-4 are the bpod sync signals Parameters ---------- - light_source_map : dict - An optional map of light source wavelengths (nm) used and their corresponding colour name. + tdms_filepath : Path + path to TDMS file + save_path : Optional[Path], optional + if a path, save extracted timestamps from tdms file to this location, by default None + chunk_size : int, optional + if not None, process tdms data in chunks for decreased memory usage, by default 10000 Returns ------- - pandas.DataFrame - A sorted table of wavelength and colour name. + dict + a dict with the tdms channel names as keys and 'positive' the timestamps of the rising edges + 'negative' the falling edges """ - light_source_map = light_source_map or LIGHT_SOURCE_MAP - meta = pd.DataFrame.from_dict(light_source_map) - meta.index.rename('channel_id', inplace=True) - return meta + # + _logger.info(f'extracting timestamps from tdms file: {tdms_filepath}') + + # this should be 10kHz + tdms_file = TdmsFile.read(tdms_filepath) + groups = tdms_file.groups() + # this unfortunate hack is in here because there are a bunch of sessions + # where the frameclock is on DI0 + if len(groups) == 1: + has_analog_group = False + (digital_group,) = groups + if len(groups) == 2: + has_analog_group = True + analog_group, digital_group = groups + fs = digital_group.properties['ScanRate'] # this should be 10kHz + df = tdms_file.as_dataframe() -class FibrePhotometrySync(base_tasks.DynamicTask): + # inferring digital col name + (digital_col,) = [col for col in df.columns if 'Digital' in col] + vals = df[digital_col].values.astype('int8') + digital_channel_names = ['DI0', 'DI1', 'DI2', 'DI3'] + + # ini + timestamps = {} + for ch in digital_channel_names: + timestamps[ch] = dict(positive=[], negative=[]) + + # chunked loop for memory efficiency + if chunk_size is not None: + n_chunks = df.shape[0] // chunk_size + for i in range(n_chunks): + vals_ = vals[i * chunk_size: (i + 1) * chunk_size] + # data = np.array([list(f'{v:04b}'[::-1]) for v in vals_], dtype='int8') + data = _int2digital_channels(vals_) + + for j, name in enumerate(digital_channel_names): + ix = np.where(np.diff(data[:, j]) == 1)[0] + (chunk_size * i) + timestamps[name]['positive'].append(ix / fs) + ix = np.where(np.diff(data[:, j]) == -1)[0] + (chunk_size * i) + timestamps[name]['negative'].append(ix / fs) + + for ch in digital_channel_names: + timestamps[ch]['positive'] = np.concatenate(timestamps[ch]['positive']) + timestamps[ch]['negative'] = np.concatenate(timestamps[ch]['negative']) + else: + data = _int2digital_channels(vals) + for j, name in enumerate(digital_channel_names): + ix = np.where(np.diff(data[:, j]) == 1)[0] + timestamps[name]['positive'].append(ix / fs) + ix = np.where(np.diff(data[:, j]) == 1)[0] + timestamps[name]['negative'].append(ix / fs) + + if has_analog_group: + # frameclock data is recorded on an analog channel + for channel in analog_group.channels(): + timestamps[channel.name] = {} + signal = (channel.data > 2.5).astype('int32') # assumes 0-5V + timestamps[channel.name]['positive'] = np.where(np.diff(signal) == 1)[0] / fs + timestamps[channel.name]['negative'] = np.where(np.diff(signal) == -1)[0] / fs + + if save_path is not None: + _logger.info(f'saving extracted timestamps to: {save_path}') + with open(save_path, 'wb') as fH: + pickle.dump(timestamps, fH) + + return timestamps + + +def extract_timestamps_from_bpod_jsonable(file_jsonable: str | Path, sync_states_names: List[str]): + _, bpod_data = jsonable.load_task_jsonable(file_jsonable) + timestamps = [] + for sync_name in sync_states_names: + timestamps.append( + np.array( + [ + data['States timestamps'][sync_name][0][0] + data['Trial start timestamp'] - data['Bpod start timestamp'] + for data in bpod_data + if sync_name in data['States timestamps'] + ] + ) + ) + timestamps = np.sort(np.concatenate(timestamps)) + timestamps = timestamps[~np.isnan(timestamps)] + return timestamps + + +class FibrePhotometryBaseSync(base_tasks.DynamicTask): + # base clas for syncing fibre photometry + # derived classes are: FibrePhotometryBpodSync and FibrePhotometryDAQSync priority = 90 job_size = 'small' - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.device_collection = self.get_device_collection( - 'neurophotometrics', device_collection='raw_photometry_data') - # we will work with the first protocol here - for task in self.session_params['tasks']: - self.task_protocol = next(k for k in task) + def __init__( + self, + session_path: str | Path, + one: ONE, + task_protocol: str | None = None, + task_collection: str | None = None, + assert_matching_timestamps: bool = True, + sync_states_names: list[str] | None = None, + sync_channel: int | str | None = None, # if set, overwrites the value extracted from the experiment_description + **kwargs, + ): + super().__init__(session_path, one=one, **kwargs) + self.photometry_collection = kwargs.get('collection', 'raw_photometry_data') # raw_photometry_data + self.kwargs = kwargs + self.task_protocol = task_protocol + self.task_collection = task_collection + self.assert_matching_timestamps = assert_matching_timestamps + + if self.task_protocol is None: + # we will work with the first protocol here + for task in self.session_params['tasks']: + self.task_protocol = next(k for k in task) + break + + if self.task_collection is None: + # if not provided, infer self.task_collection = ibllib.io.session_params.get_task_collection(self.session_params, self.task_protocol) - break + + # configuring the sync: state names + if sync_states_names is None: + if 'habituation' in self.task_protocol: + self.sync_states_names = ['iti', 'reward'] + else: + self.sync_states_names = ['trial_start', 'reward', 'exit_state'] + else: + self.sync_states_names = sync_states_names + + # configuring the sync: channel + if sync_channel is None: + self.sync_channel = kwargs.get('sync_channel', self.session_params['devices']['neurophotometrics']['sync_channel']) + else: + self.sync_channel = sync_channel + + def _get_bpod_timestamps(self) -> np.ndarray: + # the timestamps for syncing, in the time of the bpod + + file_jsonable = self.session_path.joinpath(self.task_collection, '_iblrig_taskData.raw.jsonable') + timestamps_bpod = extract_timestamps_from_bpod_jsonable(file_jsonable, self.sync_states_names) + return timestamps_bpod + + def _get_valid_bounds(self): + file_jsonable = self.session_path.joinpath(self.task_collection, '_iblrig_taskData.raw.jsonable') + _, bpod_data = jsonable.load_task_jsonable(file_jsonable) + return [bpod_data[0]['Trial start timestamp'] - 2, bpod_data[-1]['Trial end timestamp'] + 2] + + @abstractmethod + def _get_neurophotometrics_timestamps(self) -> np.ndarray: + # this function needs to be implemented in the derived classes: + # for bpod based syncing, the timestamps are in the digial inputs file + # for daq based syncing, the timestamps are extracted from the tdms file + ... + + def _get_sync_function(self) -> Tuple[callable, list]: + # returns the synchronization function + # get the timestamps + timestamps_bpod = self._get_bpod_timestamps() + timestamps_nph = self._get_neurophotometrics_timestamps() + + # verify presence of sync timestamps + for source, timestamps in zip(['bpod', 'neurophotometrics'], [timestamps_bpod, timestamps_nph]): + assert len(timestamps) > 0, f'{source} sync timestamps are empty' + + sync_nph_to_bpod_fcn, drift_ppm, ix_nph, ix_bpod = ibldsp.utils.sync_timestamps( + timestamps_nph, timestamps_bpod, return_indices=True, linear=True + ) + if np.absolute(drift_ppm) > 20: + _logger.warning(f'sync with excessive drift: {drift_ppm}') + else: + _logger.info(f'synced with drift: {drift_ppm}') + + # assertion: 95% of timestamps in bpod need to be in timestamps of nph (but not the other way around) + if self.assert_matching_timestamps: + assert timestamps_bpod.shape[0] * 0.95 < ix_bpod.shape[0], 'less than 95% of bpod timestamps matched' + else: + if not (timestamps_bpod.shape[0] * 0.95 < ix_bpod.shape[0]): + _logger.warning( + f'less than 95% of bpod timestamps matched. \ + n_timestamps:{timestamps_bpod.shape[0]} matched:{ix_bpod.shape[0]}' + ) + + valid_bounds = self._get_valid_bounds() + return sync_nph_to_bpod_fcn, valid_bounds + + def load_data(self) -> pd.DataFrame: + # loads the raw photometry data + raw_photometry_folder = self.session_path / self.photometry_collection + photometry_df = fpio.from_neurophotometrics_file_to_photometry_df( + raw_photometry_folder / '_neurophotometrics_fpData.raw.pqt', + drop_first=False, + ) + return photometry_df + + def _run(self, **kwargs) -> Tuple[pd.DataFrame, pd.DataFrame]: + # 1) load photometry data + + # note: when loading daq based syncing, the SystemTimestamp column + # will be overridden with the timestamps from the tdms file + # the idea behind this is that the rest of the sync is then the same + # and handled by this base class + photometry_df = self.load_data() + + # 2) get the synchronization function + sync_nph_to_bpod_fcn, valid_bounds = self._get_sync_function() + + # 3) apply synchronization + photometry_df['times'] = sync_nph_to_bpod_fcn(photometry_df['times']) + photometry_df['valid'] = np.logical_and( + photometry_df['times'] >= valid_bounds[0], photometry_df['times'] <= valid_bounds[1] + ) + + # 4) write to disk + output_folder = self.session_path.joinpath('alf', 'photometry') + output_folder.mkdir(parents=True, exist_ok=True) + + # writing the synced photometry signal + photometry_filepath = self.session_path / 'alf' / 'photometry' / 'photometry.signal.pqt' + photometry_filepath.parent.mkdir(parents=True, exist_ok=True) + photometry_df.to_parquet(photometry_filepath) + + # writing the locations + rois = [] + for k, v in self.session_params['devices']['neurophotometrics']['fibers'].items(): + rois.append({'ROI': k, 'fiber': f'fiber_{v["location"]}', 'brain_region': v['location']}) + locations_df = pd.DataFrame(rois).set_index('ROI') + locations_filepath = self.session_path / 'alf' / 'photometry' / 'photometryROI.locations.pqt' + locations_filepath.parent.mkdir(parents=True, exist_ok=True) + locations_df.to_parquet(locations_filepath) + return photometry_filepath, locations_filepath + + +class FibrePhotometryBpodSync(FibrePhotometryBaseSync): + priority = 90 + job_size = 'small' + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) @property def signature(self): signature = { - 'input_files': [('_neurophotometrics_fpData.raw.pqt', self.device_collection, True, True), - ('_iblrig_taskData.raw.jsonable', self.task_collection, True, True), - ('_neurophotometrics_fpData.channels.csv', self.device_collection, True, True), - ('_neurophotometrics_fpData.digitalIntputs.pqt', self.device_collection, True)], - 'output_files': [('photometry.signal.pqt', 'alf/photometry', True), - ('photometryROI.locations.pqt', 'alf/photometry', True)] + 'input_files': [ + ('_neurophotometrics_fpData.raw.pqt', self.photometry_collection, True, True), + ('_iblrig_taskData.raw.jsonable', self.task_collection, True, True), + # ('_neurophotometrics_fpData.channels.csv', self.photometry_collection, True, True), + ('_neurophotometrics_fpData.digitalInputs.pqt', self.photometry_collection, True), + ], + 'output_files': [ + ('photometry.signal.pqt', 'alf/photometry', True), + ('photometryROI.locations.pqt', 'alf/photometry', True), + ], } return signature - def _sync_bpod_neurophotometrics(self): - """ - Perform the linear clock correction between bpod and neurophotometrics timestamps. - :return: interpolation function that outputs bpod timestamsp from neurophotometrics timestamps - """ - folder_raw_photometry = self.session_path.joinpath(self.device_collection) - df_digital_inputs = pd.read_parquet(folder_raw_photometry.joinpath('_neurophotometrics_fpData.digitalIntputs.pqt')) - # normally we should disregard the states and use the sync label. But bpod doesn't log TTL outs, - # only the states. This will change in the future but for now we are stuck with this. - if 'habituation' in self.task_protocol: - sync_states_names = ['iti', 'reward'] + def _get_neurophotometrics_timestamps(self) -> np.ndarray: + # for bpod based syncing, the timestamps for syncing are in the digital inputs file + raw_photometry_folder = self.session_path / self.photometry_collection + digital_inputs_filepath = raw_photometry_folder / '_neurophotometrics_fpData.digitalInputs.pqt' + digital_inputs_df = fpio.read_digital_inputs_file(digital_inputs_filepath, channel=self.sync_channel) + + # get the positive fronts + timestamps_nph = digital_inputs_df.groupby(['polarity', 'channel']).get_group((1, self.sync_channel))['times'].values + + # TODO replace this rudimentary spacer removal + # to implement: detect spacer / remove spacer methods + # timestamps_nph = timestamps_nph[15:] + return timestamps_nph + + +class FibrePhotometryDAQSync(FibrePhotometryBaseSync): + priority = 90 + job_size = 'small' + + def __init__( + self, + *args, + load_timestamps: bool = True, + # sync_channel: int | None = None, + frameclock_channel: int | None = None, + **kwargs, + ): + super().__init__(*args, **kwargs) + # setting up sync properties + frameclock_channel = ( + frameclock_channel or self.session_params['devices']['neurophotometrics']['sync_metadata']['frameclock_channel'] + ) + # downward compatibility - frameclock moved around, now is back on the AI7 + if frameclock_channel in ['0', 0]: + self.frameclock_channel_name = f'DI{frameclock_channel}' + elif frameclock_channel in ['7', 7]: + self.frameclock_channel_name = f'AI{frameclock_channel}' else: - sync_states_names = ['trial_start', 'reward', 'exit_state'] - # read in the raw behaviour data for syncing - file_jsonable = self.session_path.joinpath(self.task_collection, '_iblrig_taskData.raw.jsonable') - trials_table, bpod_data = jsonable.load_task_jsonable(file_jsonable) - # we get the timestamps of the states from the bpod data - tbpod = [] - for sname in sync_states_names: - tbpod.append(np.array( - [bd['States timestamps'][sname][0][0] + bd['Trial start timestamp'] for bd in bpod_data if - sname in bd['States timestamps']])) - tbpod = np.sort(np.concatenate(tbpod)) - tbpod = tbpod[~np.isnan(tbpod)] - # we get the timestamps for the photometry data - tph = df_digital_inputs['SystemTimestamp'].values[df_digital_inputs['Channel'] == self.kwargs['sync_channel']] - tph = tph[15:] # TODO: we may want to detect the spacers before removing it, especially for successive sessions - # sync the behaviour events to the photometry timestamps - fcn_nph_to_bpod_times, drift_ppm, iph, ibpod = ibldsp.utils.sync_timestamps( - tph, tbpod, return_indices=True, linear=True) - # then we check the alignment, should be less than the screen refresh rate - tcheck = fcn_nph_to_bpod_times(tph[iph]) - tbpod[ibpod] - _logger.info( - f'sync: n trials {len(bpod_data)}, n bpod sync {len(tbpod)}, n photometry {len(tph)}, n match {len(iph)}') - assert np.all(np.abs(tcheck) < 1 / 60), 'Sync issue detected, residual above 1/60s' - assert len(iph) / len(tbpod) > 0.95, 'Sync issue detected, less than 95% of the bpod events matched' - valid_bounds = [bpod_data[0]['Trial start timestamp'] - 2, bpod_data[-1]['Trial end timestamp'] + 2] - return fcn_nph_to_bpod_times, valid_bounds - - def _run(self, **kwargs): - """ - Extract photometry data from the raw neurophotometrics data in parquet - The extraction has 3 main steps: - 1. Synchronise the bpod and neurophotometrics timestamps. - 2. Extract the photometry data from the raw neurophotometrics data. - 3. Label the fibers correspondance with brain regions in a small table - :param kwargs: - :return: - """ - # 1) sync: we check the synchronisation, right now we only have bpod but soon the daq will be used - match list(self.session_params['sync'].keys())[0]: - case 'bpod': - fcn_nph_to_bpod_times, valid_bounds = self._sync_bpod_neurophotometrics() - case _: - raise NotImplementedError('Syncing with daq is not supported yet.') - # 2) reformat the raw data with wavelengths and meta-data - folder_raw_photometry = self.session_path.joinpath(self.device_collection) - fp_data = pd.read_parquet(folder_raw_photometry.joinpath('_neurophotometrics_fpData.raw.pqt')) - # Load channels and wavelength information - channel_meta_map = _channel_meta() - if (fn := folder_raw_photometry.joinpath('_neurophotometrics_fpData.channels.csv')).exists(): - led_states = pd.read_csv(fn) + self.frameclock_channel_name = frameclock_channel + + self.sync_channel = self.sync_channel or self.session_params['devices']['neurophotometrics']['sync_channel'] + + # whether or not to reextract from tdms or attempt to load from .pkl + self.load_timestamps = load_timestamps + + @property + def signature(self): + signature = { + 'input_files': [ + ('_neurophotometrics_fpData.raw.pqt', self.photometry_collection, True, True), + ('_iblrig_taskData.raw.jsonable', self.task_collection, True, True), + # ('_neurophotometrics_fpData.channels.csv', self.photometry_collection, True, True), + ('_mcc_DAQdata.raw.tdms', self.photometry_collection, True, True), + ], + 'output_files': [ + ('photometry.signal.pqt', 'alf/photometry', True), + ('photometryROI.locations.pqt', 'alf/photometry', True), + ], + } + return signature + + def load_data(self) -> pd.DataFrame: + # the point of this functions is to overwrite the SystemTimestamp column + # in the ibl_df with the values from the DAQ clock + # then syncing will work the same as for the bpod based syncing + photometry_df = super().load_data() + + # get daqami timestamps + # attempt to load + timestamps_filepath = self.session_path / self.photometry_collection / '_mcc_DAQdata.pkl' + if self.load_timestamps and timestamps_filepath.exists(): + with open(timestamps_filepath, 'rb') as fH: + self.timestamps = pickle.load(fH) + else: # extract timestamps: + tdms_filepath = self.session_path / self.photometry_collection / '_mcc_DAQdata.raw.tdms' + self.timestamps = extract_timestamps_from_tdms_file(tdms_filepath, save_path=timestamps_filepath) + + # timestamps of the frameclock in DAQ time + frame_timestamps = self.timestamps[self.frameclock_channel_name]['positive'] + + # compare number of frame timestamps + # and put them in the photometry_df SystemTimestamp column + # based on the different scenarios + frame_times_adjusted = False # for debugging reasons + + # they are the same, all is well + if photometry_df.shape[0] == frame_timestamps.shape[0]: + photometry_df['times'] = frame_timestamps + _logger.info(f'timestamps are of equal size {photometry_df.shape[0]}') + frame_times_adjusted = True + + # there are more timestamps recorded by DAQ than + # frames recorded by bonsai + elif photometry_df.shape[0] < frame_timestamps.shape[0]: + _logger.info(f'# bonsai frames: {photometry_df.shape[0]}, # daq timestamps: {frame_timestamps.shape[0]}') + # there is exactly one more timestamp recorded by the daq + # (probably bonsai drops the last incomplete frame) + if photometry_df.shape[0] == frame_timestamps.shape[0] - 1: + photometry_df['times'] = frame_timestamps[:-1] + # there are two more frames recorded by the DAQ than by + # bonsai - this is observed. TODO understand when this happens + elif photometry_df.shape[0] == frame_timestamps.shape[0] - 2: + photometry_df['times'] = frame_timestamps[:-2] + # there are more frames recorded by the DAQ than that + # this indicates and issue - + elif photometry_df.shape[0] < frame_timestamps.shape[0] - 2: + raise ValueError('more timestamps for frames recorded by the daqami than frames were recorded by bonsai.') + frame_times_adjusted = True + + # there are more frames recorded by bonsai than by the DAQ + # this happens when the user stops the daqami recording before stopping the bonsai + # or when daqami crashes + elif photometry_df.shape[0] > frame_timestamps.shape[0]: + # we drop all excess frames + _logger.warning( + f'#frames bonsai: {photometry_df.shape[0]} > #frames daqami {frame_timestamps.shape[0]}, dropping excess' + ) + n_frames_daqami = frame_timestamps.shape[0] + photometry_df = photometry_df.iloc[:n_frames_daqami] + photometry_df.loc[:, 'SystemTimestamp'] = frame_timestamps + frame_times_adjusted = True + + if not frame_times_adjusted: + raise ValueError('timestamp issue that hasnt been caught') + + return photometry_df + + def _get_neurophotometrics_timestamps(self) -> np.ndarray: + # get the sync channel and the corresponding timestamps + timestamps_nph = self.timestamps[f'DI{self.sync_channel}']['positive'] + + # TODO replace this rudimentary spacer removal + # to implement: detect spacer / remove spacer methods + # timestamps_nph = timestamps_nph[15: ] + return timestamps_nph + + +class FibrePhotometryPassiveChoiceWorld(base_tasks.BehaviourTask): + priority = 90 + job_size = 'small' + + def __init__( + self, + session_path: str | Path, + one: ONE, + load_timestamps: bool = True, + **kwargs, + ): + super().__init__(session_path, one=one, **kwargs) + self.photometry_collection = kwargs.get('collection', 'raw_photometry_data') + self.kwargs = kwargs + self.load_timestamps = load_timestamps + + def _run(self, **kwargs) -> Tuple[pd.DataFrame, pd.DataFrame]: + # load the fixtures - from the relative delays between trials, an "absolute" time vector is + # created that is used for the synchronization + fixtures_path = ( + Path(iblphotometry.__file__).parent.parent + / 'iblphotometry_tests' + / 'fixtures' + / 'passiveChoiceWorld_trials_fixtures.pqt' + ) + + # getting the task_settings + with open(self.session_path / self.collection / '_iblrig_taskSettings.raw.json', 'r') as fH: + task_settings = json.load(fH) + + # getting the fixtures + fixtures_df = pd.read_parquet(fixtures_path).groupby('session_id').get_group(task_settings['SESSION_TEMPLATE_ID']) + + # the fixtures table contains delays between the individual stimuli + # in order to get their onset times, we need to do an adjusted cumsum of the intervals + # adjusted by: the length of each stimulus, plus the overhead time to load it and play it + # e.g. state machine time, bonsai delay etc. + + # stimulus durations + stim_durations = dict( + T=task_settings['GO_TONE_DURATION'], + N=task_settings['WHITE_NOISE_DURATION'], + G=0.3, # visual stimulus duration is hardcoded to 300ms + V=0.1, # V=0.1102 from a a session # to be replaced later down + ) + for s in fixtures_df['stim_type'].unique(): + fixtures_df.loc[fixtures_df['stim_type'] == s, 'delay'] = stim_durations[s] + + # the audio go cue times - recorded in the time of the mic clock + # this is assumed to be precise so we can use it to fit the unknown overhead + # time for each stim class + go_cue_times_mic = np.load(self.session_path / self.collection / '_iblmic_audioOnsetGoCue.times_mic.npy') + + # adding the delays + def obj_fun(x, go_cue_times_mic, fixtures_df): + # fit overhead + for s in ['T', 'N', 'G', 'V']: + if s == 'T' or s == 'N': + fixtures_df.loc[fixtures_df['stim_type'] == s, 'overhead'] = x[0] + if s == 'G': + fixtures_df.loc[fixtures_df['stim_type'] == s, 'overhead'] = x[1] + if s == 'V': + fixtures_df.loc[fixtures_df['stim_type'] == s, 'overhead'] = x[2] + + fixtures_df['t_rel'] = np.cumsum( + fixtures_df['stim_delay'].values + np.roll(fixtures_df['delay'].values, 1) + fixtures_df['overhead'].values, + ) + + go_cue_times_rel = fixtures_df.groupby('stim_type').get_group('T')['t_rel'].values + err = np.sum((np.diff(go_cue_times_rel) - np.diff(go_cue_times_mic)) ** 2) + return err + + # fitting the overheads + fixtures_df['overhead'] = 0.0 + bounds = ((0, np.inf), (0, np.inf), (0, np.inf)) + pfit = minimize(obj_fun, (0.0, 0.0, 0.0), args=(go_cue_times_mic, fixtures_df), bounds=bounds) + overheads = dict(zip(['T', 'N', 'G', 'V'], [pfit.x[0], pfit.x[0], pfit.x[1], pfit.x[2]])) + + # creating the relative time vector for each stimulus + for s in fixtures_df['stim_type'].unique(): + fixtures_df.loc[fixtures_df['stim_type'] == s, 'overhead'] = overheads[s] + fixtures_df['t_rel'] = np.cumsum( + fixtures_df['stim_delay'].values + np.roll(fixtures_df['delay'].values, 1) + fixtures_df['overhead'].values + ) + + # we now sync the valve times from the relative time and the neurophotometrics time + valve_times_rel = fixtures_df.groupby('stim_type').get_group('V')['t_rel'].values + + # getting the valve timestamps from the DAQ + timestamps_filepath = self.session_path / self.photometry_collection / '_mcc_DAQdata.pkl' + if self.load_timestamps and timestamps_filepath.exists(): + with open(timestamps_filepath, 'rb') as fH: + self.timestamps = pickle.load(fH) + else: # extract timestamps: + tdms_filepath = self.session_path / self.photometry_collection / '_mcc_DAQdata.raw.tdms' + self.timestamps = extract_timestamps_from_tdms_file(tdms_filepath, save_path=timestamps_filepath) + + sync_channel = self.session_params['devices']['neurophotometrics']['sync_channel'] + valve_times_daq = self.timestamps[f'DI{sync_channel}']['positive'] + + sync_fun_rel_to_daq, drift_ppm, ix_rel, ix_daq = ibldsp.utils.sync_timestamps( + valve_times_rel, valve_times_daq, return_indices=True, linear=True + ) + assert ix_rel.shape[0] == 40, 'not all bpod valve onset events are accepted by the sync function' + if np.absolute(drift_ppm) > 20: + _logger.warning(f'sync with excessive drift: {drift_ppm}') else: - led_states = pd.DataFrame(LED_STATES) - led_states = led_states.set_index('Condition') - # Extract signal columns into 2D array - rois = list(self.kwargs['fibers'].keys()) - out_df = fp_data.filter(items=rois, axis=1).sort_index(axis=1) - out_df['times'] = fcn_nph_to_bpod_times(fp_data['SystemTimestamp']) - out_df['valid'] = np.logical_and(out_df['times'] >= valid_bounds[0], out_df['times'] <= valid_bounds[1]) - out_df['wavelength'] = np.nan - out_df['name'] = '' - out_df['color'] = '' - # Extract channel index - states = fp_data.get('LedState', fp_data.get('Flags', None)) - for state in states.unique(): - ir, ic = np.where(led_states == state) - if ic.size == 0: - continue - for cn in ['name', 'color', 'wavelength']: - out_df.loc[states == state, cn] = channel_meta_map.iloc[ic[0]][cn] - # 3) label the brain regions + _logger.info(f'synced with drift: {drift_ppm}') + + # loads the raw photometry data + raw_photometry_folder = self.session_path / self.photometry_collection + photometry_df = fpio.from_neurophotometrics_file_to_photometry_df( + raw_photometry_folder / '_neurophotometrics_fpData.raw.pqt', + drop_first=False, + ) + + # load the photometry data and replace the timestamp column + # with the values from the frameclock timestamps as recorded by the DAQ + frameclock_channel_name = self.session_params['devices']['neurophotometrics']['sync_metadata']['frameclock_channel'] + frame_timestamps = self.timestamps[frameclock_channel_name]['positive'] + + # compare number of frame timestamps + # and put them in the photometry_df SystemTimestamp column + # based on the different scenarios + frame_times_adjusted = False # for debugging reasons + + # they are the same, all is well + if photometry_df.shape[0] == frame_timestamps.shape[0]: + photometry_df['times'] = frame_timestamps + _logger.info(f'timestamps are of equal size {photometry_df.shape[0]}') + frame_times_adjusted = True + + # there are more timestamps recorded by DAQ than + # frames recorded by bonsai + elif photometry_df.shape[0] < frame_timestamps.shape[0]: + _logger.info(f'# bonsai frames: {photometry_df.shape[0]}, # daq timestamps: {frame_timestamps.shape[0]}') + # there is exactly one more timestamp recorded by the daq + # (probably bonsai drops the last incomplete frame) + if photometry_df.shape[0] == frame_timestamps.shape[0] - 1: + photometry_df['times'] = frame_timestamps[:-1] + # there are two more frames recorded by the DAQ than by + # bonsai - this is observed. TODO understand when this happens + elif photometry_df.shape[0] == frame_timestamps.shape[0] - 2: + photometry_df['times'] = frame_timestamps[:-2] + # there are more frames recorded by the DAQ than that + # this indicates and issue - + elif photometry_df.shape[0] < frame_timestamps.shape[0] - 2: + raise ValueError('more timestamps for frames recorded by the daqami than frames were recorded by bonsai.') + frame_times_adjusted = True + + # there are more frames recorded by bonsai than by the DAQ + # this happens when the user stops the daqami recording before stopping the bonsai + # or when daqami crashes + elif photometry_df.shape[0] > frame_timestamps.shape[0]: + # we drop all excess frames + _logger.warning( + f'#frames bonsai: {photometry_df.shape[0]} > #frames daqami {frame_timestamps.shape[0]}, dropping excess' + ) + n_frames_daqami = frame_timestamps.shape[0] + photometry_df = photometry_df.iloc[:n_frames_daqami] + photometry_df.loc[:, 'SystemTimestamp'] = frame_timestamps + frame_times_adjusted = True + + if not frame_times_adjusted: + raise ValueError('timestamp issue that hasnt been caught') + + # write to disk + # the photometry signal + photometry_filepath = self.session_path / 'alf' / 'photometry' / 'photometry.signal.pqt' + photometry_filepath.parent.mkdir(parents=True, exist_ok=True) + photometry_df.to_parquet(photometry_filepath) + + # writing the locations rois = [] - c = 0 - for k, v in self.kwargs['fibers'].items(): - rois.append({'ROI': k, 'fiber': f'fiber{c:02d}', 'brain_region': v['location']}) - df_rois = pd.DataFrame(rois).set_index('ROI') - # to finish we write the dataframes to disk - out_path = self.session_path.joinpath('alf', 'photometry') - out_path.mkdir(parents=True, exist_ok=True) - out_df.to_parquet(file_signal := out_path.joinpath('photometry.signal.pqt')) - df_rois.to_parquet(file_locations := out_path.joinpath('photometryROI.locations.pqt')) - return file_signal, file_locations + for k, v in self.session_params['devices']['neurophotometrics']['fibers'].items(): + rois.append({'ROI': k, 'fiber': f'fiber_{v["location"]}', 'brain_region': v['location']}) + locations_df = pd.DataFrame(rois).set_index('ROI') + locations_filepath = self.session_path / 'alf' / 'photometry' / 'photometryROI.locations.pqt' + locations_filepath.parent.mkdir(parents=True, exist_ok=True) + locations_df.to_parquet(locations_filepath) + + # writing the passive events table + # get the valve open duration + timestamps_filepath = self.session_path / self.photometry_collection / '_mcc_DAQdata.pkl' + if self.load_timestamps and timestamps_filepath.exists(): + with open(timestamps_filepath, 'rb') as fH: + self.timestamps = pickle.load(fH) + else: # extract timestamps: + tdms_filepath = self.session_path / self.photometry_collection / '_mcc_DAQdata.raw.tdms' + self.timestamps = extract_timestamps_from_tdms_file(tdms_filepath, save_path=timestamps_filepath) + + ttl_durations = self.timestamps[f'DI{sync_channel}']['negative'] - self.timestamps[f'DI{sync_channel}']['positive'] + valve_open_dur = np.median(ttl_durations[ix_daq]) + passiveStims_df = pd.DataFrame( + dict( + valveOn=fixtures_df.groupby('stim_type').get_group('V')['t_rel'], + valveOff=fixtures_df.groupby('stim_type').get_group('V')['t_rel'] + valve_open_dur, + toneOn=fixtures_df.groupby('stim_type').get_group('T')['t_rel'], + toneOff=fixtures_df.groupby('stim_type').get_group('T')['t_rel'] + task_settings['GO_TONE_DURATION'], + noiseOn=fixtures_df.groupby('stim_type').get_group('N')['t_rel'], + noiseOff=fixtures_df.groupby('stim_type').get_group('N')['t_rel'] + task_settings['WHITE_NOISE_DURATION'], + ) + ) + # convert all times from fixture time (=rel) to daq time + passiveStims_df.iloc[:, :] = sync_fun_rel_to_daq(passiveStims_df.values) + passiveStims_filepath = self.session_path / 'alf' / self.collection / '_ibl_passiveStims.table.pqt' + passiveStims_filepath.parent.mkdir(exist_ok=True, parents=True) + passiveStims_df.reset_index().to_parquet(passiveStims_filepath) + + return photometry_filepath, locations_filepath, passiveStims_filepath diff --git a/ibllib/tests/fixtures/neurophotometrics/_ibl_experiment.description.yaml b/ibllib/tests/fixtures/neurophotometrics/_ibl_experiment.description.yaml new file mode 100644 index 000000000..8a39783cb --- /dev/null +++ b/ibllib/tests/fixtures/neurophotometrics/_ibl_experiment.description.yaml @@ -0,0 +1,35 @@ +devices: + cameras: + left: + collection: raw_video_data + sync_label: audio + microphone: + microphone: + collection: raw_task_data_00 + sync_label: audio + neurophotometrics: + collection: raw_photometry_data + datetime: '2025-05-26T15:08:40.237101' + fibers: + G0: + location: VTA + sync_channel: 2 + sync_metadata: + acquisition_software: daqami + collection: raw_photometry_data + frameclock_channel: 7 + sync_mode: daqami +procedures: +- Fiber photometry +projects: +- ibl_fibrephotometry +- practice +sync: + bpod: + acquisition_software: pybpod + collection: raw_task_data_00 + extension: .jsonable +tasks: +- _iblrig_tasks_advancedChoiceWorld: + collection: raw_task_data_00 +version: 1.0.0 diff --git a/ibllib/tests/test_neurophotometrics.py b/ibllib/tests/test_neurophotometrics.py new file mode 100644 index 000000000..6f29ce509 --- /dev/null +++ b/ibllib/tests/test_neurophotometrics.py @@ -0,0 +1,39 @@ +"""Tests for ibllib.pipes.mesoscope_tasks.""" + +import unittest +import tempfile +# from pathlib import Path +# import iblphotometry_tests +# from ibllib.pipes.neurophotometrics import FibrePhotometryBpodSync +# from ibllib.io import session_params + +# Mock suit2p which is imported in MesoscopePreprocess +# attrs = {'default_ops.return_value': {}} +# sys.modules['suite2p'] = mock.MagicMock(**attrs) + +# from iblscripts.ci.tests import base + + +class TestNeurophotometricsExtractor(unittest.TestCase): + """ + this class tests + that the correct extractor is run based on the experiment description file + this requires the setup to have + + """ + + def setUp(self) -> None: + self.tmp_folder = tempfile.TemporaryDirectory() + # self.session_folder = Path(self.tmp_folder.name) / 'subject' / '2020-01-01' / '001' + # self.raw_photometry_folder = self.session_folder / 'raw_photometry_data' + # self.raw_photometry_folder.mkdir(parents=True) + + # def test_bpod_extractor(self): + # session_folder = Path(iblphotometry_tests.__file__).parent / 'data' / 'neurophotometrics' / 'raw_bpod_session' + # assert session_folder.exists() + # self.experiment_description = session_params.read_params(session_folder) + # FibrePhotometryBpodSync() + + # def test_daqami_extractor(self): + # path = Path(__file__).parent / 'fixtures' / 'neurophotometrics' / '_ibl_experiment_description_bpod.yaml' + # self.experiment_description = session_params.read_params(path) diff --git a/requirements.txt b/requirements.txt index ebe13e951..b907c2ad7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,3 +33,4 @@ psychofit slidingRP>=1.1.1 # steinmetz lab refractory period metrics pyqt5 ibl-style +ibl-photometry>=0.1.2 \ No newline at end of file