From 3319855d237fb427637c575a3564c0847f36d270 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 12 Mar 2026 16:15:28 +0000 Subject: [PATCH 1/4] Add valid periods shaded regions on scatter plots --- spikeinterface_gui/basescatterview.py | 56 +++++++++++++++++++++++---- spikeinterface_gui/controller.py | 6 +++ spikeinterface_gui/view_base.py | 12 ++++-- 3 files changed, 64 insertions(+), 10 deletions(-) diff --git a/spikeinterface_gui/basescatterview.py b/spikeinterface_gui/basescatterview.py index 44ea159c..9dee13a2 100644 --- a/spikeinterface_gui/basescatterview.py +++ b/spikeinterface_gui/basescatterview.py @@ -9,6 +9,7 @@ class BaseScatterView(ViewBase): _depend_on = None _settings = [ {'name': "auto_decimate", 'type': 'bool', 'value' : True}, + {'name': "display_valid_periods", 'type': 'bool', 'value' : True}, {'name': 'max_spikes_per_unit', 'type': 'int', 'value' : 5_000}, {'name': 'alpha', 'type': 'float', 'value' : 0.7, 'limits':(0, 1.), 'step':0.05}, {'name': 'scatter_size', 'type': 'float', 'value' : 2., 'step':0.5}, @@ -39,6 +40,8 @@ def __init__(self, spike_data, y_label, controller=None, parent=None, backend="q ViewBase.__init__(self, controller=controller, parent=parent, backend=backend) + self.valid_period_regions = [] + def get_unit_data(self, unit_id, segment_index=0): inds = self.controller.get_spike_indices(unit_id, segment_index=segment_index) @@ -298,8 +301,8 @@ def _qt_refresh(self, set_scatter_range=False): all_inds = [] ymins = [] ymaxs = [] - for unit_id in self.controller.get_visible_unit_ids(): - + visible_units = self.controller.get_visible_unit_ids() + for unit_id in visible_units: spike_times, spike_data, hist_count, hist_bins, ymin, ymax, inds = self.get_unit_data( unit_id, segment_index=segment_index @@ -336,9 +339,28 @@ def _qt_refresh(self, set_scatter_range=False): self.viewBox2.setXRange(0, self._max_count, padding = 0.0) # explicitly set the y-range of the histogram to match the spike data - spike_times, spike_data = self.get_selected_spikes_data(segment_index=self.combo_seg.currentIndex(), visible_inds=all_inds) + spike_times, spike_data = self.get_selected_spikes_data(segment_index=segment_index, visible_inds=all_inds) self.scatter_select.setData(spike_times, spike_data) + if self.settings["display_valid_periods"] and self.controller.valid_periods is not None: + for region in self.valid_period_regions: + self.plot.removeItem(region) + self.valid_period_regions = [] + for unit_id in visible_units: + valid_periods_unit = self.controller.valid_periods[segment_index][unit_id] + color = self.get_unit_color(unit_id, alpha=0.3) + pen_color = pg.mkColor(color) + for period in valid_periods_unit: + t_start = self.controller.sample_index_to_time(period[0]) + t_end = self.controller.sample_index_to_time(period[1]) + region = pg.LinearRegionItem([t_start, t_end], movable=False, brush=pen_color) + self.plot.addItem(region, ignoreBounds=True) + self.valid_period_regions.append(region) + else: + for region in self.valid_period_regions: + self.plot.removeItem(region) + self.valid_period_regions = [] + def _qt_on_time_info_updated(self): if self.combo_seg.currentIndex() != self.controller.get_time()[1]: self._block_auto_refresh_and_notify = True @@ -496,8 +518,6 @@ def _panel_make_layout(self): ), ) ) - # self.hist_lines = [] - self.noise_harea = [] self.plotted_inds = [] def _panel_refresh(self, set_scatter_range=False): @@ -569,6 +589,30 @@ def _panel_refresh(self, set_scatter_range=False): # handle selected spikes self._panel_update_selected_spikes() + # Add valid period regions + if self.settings["display_valid_periods"] and self.controller.valid_periods is not None: + for region in self.valid_period_regions: + self.scatter_fig.renderers.remove(region) + self.valid_period_regions = [] + for unit_id in visible_unit_ids: + valid_periods_unit = self.controller.valid_periods[segment_index][unit_id] + color = self.get_unit_color(unit_id) + color_shade = self.get_unit_color(unit_id, alpha=0.3) + for period in valid_periods_unit: + t_start = self.controller.sample_index_to_time(period[0]) + t_end = self.controller.sample_index_to_time(period[1]) + region = self.scatter_fig.varea( + x=[t_start, t_end], + y1=[-1_000_000, -1_000_000], + y2=[1_000_000, 1_000_000], + fill_color=color_shade + ) + self.valid_period_regions.append(region) + else: + for region in self.valid_period_regions: + self.scatter_fig.renderers.remove(region) + self.valid_period_regions = [] + # Defer Range updates to avoid nested document lock issues # def update_ranges(): if set_scatter_range or not self._first_refresh_done: @@ -578,8 +622,6 @@ def _panel_refresh(self, set_scatter_range=False): self.hist_fig.x_range.end = max_count self.hist_fig.xaxis.ticker = FixedTicker(ticks=[0, max_count // 2, max_count]) - # Schedule the update to run after the current event loop iteration - # pn.state.execute(update_ranges, schedule=True) def _panel_on_select_button(self, event): import panel as pn diff --git a/spikeinterface_gui/controller.py b/spikeinterface_gui/controller.py index 26b86fa1..5dcf324c 100644 --- a/spikeinterface_gui/controller.py +++ b/spikeinterface_gui/controller.py @@ -247,6 +247,12 @@ def __init__( pc_ext = analyzer.get_extension('principal_components') self.pc_ext = pc_ext + if analyzer.has_extension("valid_unit_periods"): + valid_periods_ext = analyzer.get_extension("valid_unit_periods") + self.valid_periods = valid_periods_ext.get_data(outputs="by_unit") + else: + self.valid_periods = None + self._potential_merges = None t1 = time.perf_counter() diff --git a/spikeinterface_gui/view_base.py b/spikeinterface_gui/view_base.py index 52aaebdc..db0ed502 100644 --- a/spikeinterface_gui/view_base.py +++ b/spikeinterface_gui/view_base.py @@ -1,6 +1,8 @@ import time from contextlib import contextmanager +import numpy as np + class ViewBase: id: str = None _supported_backend = [] @@ -159,7 +161,7 @@ def continue_from_user(self, warning_msg, action, *args): # Panel: asynchronous approach with callback self._panel_insert_warning_with_choice(warning_msg, action, *args) - def get_unit_color(self, unit_id): + def get_unit_color(self, unit_id, alpha=1.0): if self.backend == "qt": from .myqt import QT @@ -170,14 +172,18 @@ def get_unit_color(self, unit_id): color = self.controller.get_unit_color(unit_id) r, g, b, a = color qcolor = QT.QColor(int(r * 255), int(g * 255), int(b * 255)) + # only cache self.controller._cached_qcolors[unit_id] = qcolor - - return self.controller._cached_qcolors[unit_id] + else: + qcolor = self.controller._cached_qcolors[unit_id] + qcolor.setAlpha(int(alpha * 255)) + return qcolor elif self.backend == "panel": import matplotlib color = self.controller.get_unit_color(unit_id) + color = color[:3] + (np.float64(alpha),) html_color = matplotlib.colors.rgb2hex(color, keep_alpha=True) return html_color From 18997b4d5b28d4ec640dd23b1077c818c39bd39e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 12 Mar 2026 16:16:36 +0000 Subject: [PATCH 2/4] Add comment --- spikeinterface_gui/basescatterview.py | 1 + 1 file changed, 1 insertion(+) diff --git a/spikeinterface_gui/basescatterview.py b/spikeinterface_gui/basescatterview.py index 9dee13a2..c82d97de 100644 --- a/spikeinterface_gui/basescatterview.py +++ b/spikeinterface_gui/basescatterview.py @@ -601,6 +601,7 @@ def _panel_refresh(self, set_scatter_range=False): for period in valid_periods_unit: t_start = self.controller.sample_index_to_time(period[0]) t_end = self.controller.sample_index_to_time(period[1]) + # set y1/y2 to very large values to make sure they always display region = self.scatter_fig.varea( x=[t_start, t_end], y1=[-1_000_000, -1_000_000], From dc89ec74d387eb8ee8a1094bfb08227b755ceed8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 13 Mar 2026 11:01:14 +0000 Subject: [PATCH 3/4] add pen color in qt --- spikeinterface_gui/basescatterview.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spikeinterface_gui/basescatterview.py b/spikeinterface_gui/basescatterview.py index c82d97de..de1db685 100644 --- a/spikeinterface_gui/basescatterview.py +++ b/spikeinterface_gui/basescatterview.py @@ -353,7 +353,7 @@ def _qt_refresh(self, set_scatter_range=False): for period in valid_periods_unit: t_start = self.controller.sample_index_to_time(period[0]) t_end = self.controller.sample_index_to_time(period[1]) - region = pg.LinearRegionItem([t_start, t_end], movable=False, brush=pen_color) + region = pg.LinearRegionItem([t_start, t_end], movable=False, brush=pen_color, pen=pen_color) self.plot.addItem(region, ignoreBounds=True) self.valid_period_regions.append(region) else: From a7ac71d4882523eccd8f5c67951ba52c0d961c31 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 23 Mar 2026 16:50:55 +0100 Subject: [PATCH 4/4] change alpha, switch to quad and instantiate source in make_layout --- spikeinterface_gui/basescatterview.py | 44 +++++++++++++++------------ 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/spikeinterface_gui/basescatterview.py b/spikeinterface_gui/basescatterview.py index de1db685..595c46fc 100644 --- a/spikeinterface_gui/basescatterview.py +++ b/spikeinterface_gui/basescatterview.py @@ -348,7 +348,7 @@ def _qt_refresh(self, set_scatter_range=False): self.valid_period_regions = [] for unit_id in visible_units: valid_periods_unit = self.controller.valid_periods[segment_index][unit_id] - color = self.get_unit_color(unit_id, alpha=0.3) + color = self.get_unit_color(unit_id, alpha=0.2) pen_color = pg.mkColor(color) for period in valid_periods_unit: t_start = self.controller.sample_index_to_time(period[0]) @@ -474,6 +474,17 @@ def _panel_make_layout(self): # Add SelectionGeometry event handler to capture lasso vertices self.scatter_fig.on_event('selectiongeometry', self._on_panel_selection_geometry) + self.valid_periods_source = ColumnDataSource(data=dict( + left=[], right=[], top=[], bottom=[], fill_color=[] + )) + self.scatter_fig.quad( + left="left", right="right", + top="top", bottom="bottom", + fill_color="fill_color", + line_color=None, + source=self.valid_periods_source, + ) + self.hist_source = ColumnDataSource(data={"x": [], "y": []}) self.hist_data_source = ColumnDataSource(data=dict(x=[], y=[], color=[])) self.hist_fig = bpl.figure( @@ -589,30 +600,23 @@ def _panel_refresh(self, set_scatter_range=False): # handle selected spikes self._panel_update_selected_spikes() - # Add valid period regions + # Update valid period regions if self.settings["display_valid_periods"] and self.controller.valid_periods is not None: - for region in self.valid_period_regions: - self.scatter_fig.renderers.remove(region) - self.valid_period_regions = [] + lefts, rights, tops, bottoms, colors = [], [], [], [], [] for unit_id in visible_unit_ids: valid_periods_unit = self.controller.valid_periods[segment_index][unit_id] - color = self.get_unit_color(unit_id) - color_shade = self.get_unit_color(unit_id, alpha=0.3) + color_shade = self.get_unit_color(unit_id, alpha=0.2) for period in valid_periods_unit: - t_start = self.controller.sample_index_to_time(period[0]) - t_end = self.controller.sample_index_to_time(period[1]) - # set y1/y2 to very large values to make sure they always display - region = self.scatter_fig.varea( - x=[t_start, t_end], - y1=[-1_000_000, -1_000_000], - y2=[1_000_000, 1_000_000], - fill_color=color_shade - ) - self.valid_period_regions.append(region) + lefts.append(self.controller.sample_index_to_time(period[0])) + rights.append(self.controller.sample_index_to_time(period[1])) + tops.append(1_000_000) + bottoms.append(-1_000_000) + colors.append(color_shade) + self.valid_periods_source.data = dict( + left=lefts, right=rights, top=tops, bottom=bottoms, fill_color=colors + ) else: - for region in self.valid_period_regions: - self.scatter_fig.renderers.remove(region) - self.valid_period_regions = [] + self.valid_periods_source.data = dict(left=[], right=[], top=[], bottom=[], fill_color=[]) # Defer Range updates to avoid nested document lock issues # def update_ranges():