Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions spikeinterface_gui/basescatterview.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def _qt_make_layout(self):
tb = self.qt_widget.view_toolbar
self.combo_seg = QT.QComboBox()
tb.addWidget(self.combo_seg)
self.combo_seg.addItems([ f'Segment {segment_index}' for segment_index in range(self.controller.num_segments) ])
self.combo_seg.addItems([f'Segment {segment_index}' for segment_index in range(self.controller.num_segments)])
self.combo_seg.currentIndexChanged.connect(self._qt_change_segment)
add_stretch_to_qtoolbar(tb)
self.lasso_but = QT.QPushButton("select", checkable = True)
Expand Down Expand Up @@ -278,7 +278,7 @@ def _qt_refresh(self):
# make a copy of the color
color = QT.QColor(self.get_unit_color(unit_id))
color.setAlpha(int(self.settings['alpha']*255))
self.scatter.addPoints(x=spike_times, y=spike_data, pen=pg.mkPen(None), brush=color)
self.scatter.addPoints(x=spike_times, y=spike_data, pen=pg.mkPen(None), brush=color)

color = self.get_unit_color(unit_id)
curve = pg.PlotCurveItem(hist_count, hist_bins[:-1], fillLevel=None, fillOutline=True, brush=color, pen=color)
Expand Down
95 changes: 83 additions & 12 deletions spikeinterface_gui/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@

from spikeinterface.widgets.utils import get_unit_colors
from spikeinterface import compute_sparsity
from spikeinterface.core import get_template_extremum_channel
import spikeinterface.postprocessing
import spikeinterface.qualitymetrics
from spikeinterface.core import get_template_extremum_channel, BaseEvent
from spikeinterface.core.sorting_tools import spike_vector_to_indices
from spikeinterface.core.core_tools import check_json
from spikeinterface.curation import validate_curation_dict
from spikeinterface.curation.curation_model import CurationModel
from spikeinterface.widgets.utils import make_units_table_from_analyzer
Expand All @@ -33,10 +30,23 @@


class Controller():
def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save_on_compute=False,
curation=False, curation_data=None, label_definitions=None, with_traces=True,
displayed_unit_properties=None,
extra_unit_properties=None, skip_extensions=None, disable_save_settings_button=False):
def __init__(
self,
analyzer=None,
backend="qt",
parent=None,
verbose=False,
save_on_compute=False,
curation=False,
curation_data=None,
label_definitions=None,
with_traces=True,
displayed_unit_properties=None,
extra_unit_properties=None,
skip_extensions=None,
disable_save_settings_button=False,
events=None
):
self.views = []
skip_extensions = skip_extensions if skip_extensions is not None else []

Expand Down Expand Up @@ -220,6 +230,62 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save
self.pc_ext = pc_ext

self._potential_merges = None
# some direct attribute
self.num_segments = self.analyzer.get_num_segments()
self.sampling_frequency = self.analyzer.sampling_frequency

self.events = None
if events is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets move this events handling in a separate function no ?

if verbose:
print('\tLoading events')
self.events = {}
if isinstance(events, dict):
for key, val in events.items():
if not isinstance(val, dict):
if verbose:
print(f'\tSkipping event {key}: not a dict')
continue
if 'samples' not in val and 'times' not in val:
if verbose:
print(f'\tSkipping event {key}: missing samples or times')
continue
if 'times' in val:
samples_data = val['times']
convert_to_samples = True
else:
samples_data = val['samples']
convert_to_samples = False
if self.num_segments > 1:
if not len(samples_data) == self.num_segments:
if verbose:
print(f'\tSkipping event {key}: inconsistent number of samples')
continue
else:
# here we make sure samples is a list of list
if np.array(samples_data).ndim == 1:
samples_data = [samples_data]
if convert_to_samples:
self.events[key] = [np.array(self.time_to_sample_index(s)) for s in samples_data]
else:
self.events[key] = [np.array(s) for s in samples_data]
elif isinstance(events, BaseEvent):
event_names = events.channel_ids
self.events = {
event_name: [] for event_name in event_names
}
for event_name in event_names:
for segment_index in range(self.num_segments):
event_times_segment = events.get_event_times(
channel_id=event_name,
segment_index=segment_index
)
event_samples_segment = self.analyzer.time_to_sample_index(
event_times_segment
)
self.events[event_name].append(np.array(event_samples_segment))

if len(self.events) == 0:
self.events = None

t1 = time.perf_counter()
if verbose:
Expand All @@ -229,10 +295,6 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save

self._extremum_channel = get_template_extremum_channel(self.analyzer, peak_sign='neg', outputs='index')

# some direct attribute
self.num_segments = self.analyzer.get_num_segments()
self.sampling_frequency = self.analyzer.sampling_frequency

# spikeinterface handle colors in matplotlib style tuple values in range (0,1)
self.refresh_colors()

Expand Down Expand Up @@ -489,6 +551,13 @@ def time_to_sample_index(self, time):
else:
return int(time * self.sampling_frequency)

def get_events(self, event_name):
if self.events is None:
return None
if event_name not in self.events:
return None
return self.events[event_name][self.time_info['segment_index']]

def get_information_txt(self):
nseg = self.analyzer.get_num_segments()
nchan = self.analyzer.get_num_channels()
Expand Down Expand Up @@ -715,6 +784,8 @@ def set_channel_visibility(self, visible_channel_inds):
def has_extension(self, extension_name):
if extension_name == 'recording':
return self.analyzer.has_recording() or self.analyzer.has_temporary_recording()
elif extension_name == 'events':
return self.events is not None
else:
# extension needs to be loaded
if extension_name in self.skip_extensions:
Expand Down
Loading