Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions spikeinterface_gui/metricsview.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,27 +111,25 @@ def _qt_refresh(self):

scatter.setData(x=values2, y=values1)

visible_unit_ids = self.controller.get_visible_unit_ids()
visible_unit_ids = self.controller.get_visible_unit_indices()

for unit_ind, unit_id in self.controller.iter_visible_units():
color = self.get_unit_color(unit_id)
scatter.addPoints(x=[values2[unit_ind]], y=[values1[unit_ind]], pen=pg.mkPen(None), brush=color)
if (not np.isnan(values2[unit_ind])) and (not np.isnan(values1[unit_ind])):
scatter.addPoints(x=[values2[unit_ind]], y=[values1[unit_ind]], pen=pg.mkPen(None), brush=color)

# self.scatter.addPoints(x=scatter_x[unit_id], y=scatter_y[unit_id], pen=pg.mkPen(None), brush=color)
# self.scatter_select.setData(selected_scatter_x, selected_scatter_y)
elif c == r:
values1 = units_table[visible_metrics[r]].values
values1_no_nans = values1[~np.isnan(values1)]

count, bins = np.histogram(values1, bins=self.settings['num_bins'])
count, bins = np.histogram(values1_no_nans, bins=self.settings['num_bins'])
curve = pg.PlotCurveItem(bins, count, stepMode='center', fillLevel=0, brush=white_brush, pen=white_brush)
plot.addItem(curve)

for unit_ind, unit_id in self.controller.iter_visible_units():
x = values1[unit_ind]
color = self.get_unit_color(unit_id)
line = pg.InfiniteLine(pos=x, angle=90, movable=False, pen=color)
plot.addItem(line)
if not np.isnan(x):
line = pg.InfiniteLine(pos=x, angle=90, movable=False, pen=color)
plot.addItem(line)

def _qt_select_metrics(self):
if not self.tree_visible_metrics.isVisible():
Expand Down Expand Up @@ -212,6 +210,8 @@ def _panel_refresh(self):
col2 = visible_metrics[c]
values1 = units_table[col1].values
values2 = units_table[col2].values
values1_no_nans = values1[~np.isnan(values1)]
values2_no_nans = values2[~np.isnan(values2)]

plot = bpl.figure(
width=plot_size, height=plot_size,
Expand All @@ -227,7 +227,7 @@ def _panel_refresh(self):
plot.xaxis.axis_label = col1
plot.yaxis.axis_label = "Count"
# Create histogram
hist, edges = np.histogram(values1, bins=self.settings['num_bins'])
hist, edges = np.histogram(values1_no_nans, bins=self.settings['num_bins'])
if len(hist) > 0 and max(hist) > 0:
plot.quad(
top=hist, bottom=0, left=edges[:-1], right=edges[1:],
Expand All @@ -238,8 +238,9 @@ def _panel_refresh(self):
max_hist = max(hist)
for unit_ind, unit_id in self.controller.iter_visible_units():
x = values1[unit_ind]
color = self.get_unit_color(unit_id)
plot.line([x, x], [0, max_hist], line_width=2, color=color, alpha=0.8)
if not np.isnan(x):
color = self.get_unit_color(unit_id)
plot.line([x, x], [0, max_hist], line_width=2, color=color, alpha=0.8)
else:
# Off-diagonal - scatter plot
plot.xaxis.axis_label = col2
Expand All @@ -251,8 +252,8 @@ def _panel_refresh(self):

# Plot all points in light color first
all_source = ColumnDataSource({
'x': values2,
'y': values1,
'x': values2_no_nans,
'y': values1_no_nans,
'color': colors
})
plot.scatter('x', 'y', source=all_source, size=8, color='color', alpha=0.5)
Expand Down