From 043ffb34e388c567b15d387c0820d581b2b5598e Mon Sep 17 00:00:00 2001 From: David Stansby Date: Thu, 5 May 2022 21:24:24 +0100 Subject: [PATCH] Refactor callback handling --- src/napari_matplotlib/base.py | 65 ++++++++++++++++++++++++++++-- src/napari_matplotlib/histogram.py | 34 +++++----------- src/napari_matplotlib/scatter.py | 30 ++------------ 3 files changed, 74 insertions(+), 55 deletions(-) diff --git a/src/napari_matplotlib/base.py b/src/napari_matplotlib/base.py index 9a14b747..6bfbd093 100644 --- a/src/napari_matplotlib/base.py +++ b/src/napari_matplotlib/base.py @@ -4,7 +4,6 @@ FigureCanvas, NavigationToolbar2QT, ) -from matplotlib.figure import Figure from qtpy.QtWidgets import QVBoxLayout, QWidget mpl.rc("axes", edgecolor="white") @@ -15,6 +14,7 @@ mpl.rc("xtick", color="white") mpl.rc("ytick", color="white") + __all__ = ["NapariMPLWidget"] @@ -23,8 +23,12 @@ class NapariMPLWidget(QWidget): Base widget that can be embedded as a napari widget and contains a Matplotlib canvas. - This creates a single Figure, and sub-classes should implement logic for - drawing on that Figure. + This creates a single FigureCanvas, which contains a single Figure. + + This class also handles callbacks to automatically update figures when + the layer selection or z-step is changed in the napari viewer. To take + advantage of this sub-classes should implement the ``clear()`` and + ``draw()`` methods. Attributes ---------- @@ -34,13 +38,14 @@ class NapariMPLWidget(QWidget): Matplotlib figure. canvas : matplotlib.backends.backend_qt5agg.FigureCanvas Matplotlib canvas. + layers : `list` + List of currently selected napari layers. """ def __init__(self, napari_viewer: napari.viewer.Viewer): super().__init__() self.viewer = napari_viewer - self.figure = Figure(figsize=(5, 3), tight_layout=True) self.canvas = FigureCanvas() self.canvas.figure.patch.set_facecolor("#262930") self.toolbar = NavigationToolbar2QT(self.canvas, self) @@ -49,9 +54,61 @@ def __init__(self, napari_viewer: napari.viewer.Viewer): self.layout().addWidget(self.toolbar) self.layout().addWidget(self.canvas) + self.setup_callbacks() + + @property + def n_selected_layers(self) -> int: + """ + Number of currently selected layers. + """ + return len(self.layers) + @property def current_z(self) -> int: """ Current z-step of the viewer. """ return self.viewer.dims.current_step[0] + + def setup_callbacks(self) -> None: + """ + Setup callbacks for: + - Layer selection changing + - z-step changing + """ + # z-step changed in viewer + self.viewer.dims.events.current_step.connect(self._draw) + # Layer selection changed in viewer + self.viewer.layers.selection.events.active.connect(self.update_layers) + + def update_layers(self, event: napari.utils.events.Event) -> None: + """ + Update the currently selected layers and re-draw. + """ + self.layers = list(self.viewer.layers.selection) + self._draw() + + def _draw(self) -> None: + """ + Clear current figure, check selected layers are correct, and draw new + figure if so. + """ + self.clear() + if self.n_selected_layers != self.n_layers_input: + return + self.draw() + self.canvas.draw() + + def clear(self) -> None: + """ + Clear any previously drawn figures. + + This is a no-op, and is intended for derived classes to override. + """ + + def draw(self) -> None: + """ + Re-draw any figures. + + This is a no-op, and is intended for derived classes to override. + """ diff --git a/src/napari_matplotlib/histogram.py b/src/napari_matplotlib/histogram.py index c9a9d820..2180bf08 100644 --- a/src/napari_matplotlib/histogram.py +++ b/src/napari_matplotlib/histogram.py @@ -1,49 +1,34 @@ -import napari import numpy as np from .base import NapariMPLWidget __all__ = ["HistogramWidget"] +import napari _COLORS = {"r": "tab:red", "g": "tab:green", "b": "tab:blue"} class HistogramWidget(NapariMPLWidget): """ - Widget to display a histogram of the currently selected layer. - - Attributes - ---------- - layer : `napari.layers.Layer` - Current layer being histogrammed. + Display a histogram of the currently selected layer. """ + n_layers_input = 1 + def __init__(self, napari_viewer: napari.viewer.Viewer): super().__init__(napari_viewer) self.axes = self.canvas.figure.subplots() - self.layer = self.viewer.layers[-1] - - self.viewer.dims.events.current_step.connect(self.hist_current_layer) - self.viewer.layers.selection.events.active.connect(self.update_layer) + self.update_layers(None) - self.hist_current_layer() - - def update_layer(self, event: napari.utils.events.Event) -> None: - """ - Update the currently selected layer. - """ - # Update current layer when selection changed in viewer - if event.value: - self.layer = event.value - self.hist_current_layer() + def clear(self) -> None: + self.axes.clear() - def hist_current_layer(self) -> None: + def draw(self) -> None: """ Clear the axes and histogram the currently selected layer/slice. """ - self.axes.clear() - layer = self.layer + layer = self.layers[0] bins = np.linspace(np.min(layer.data), np.max(layer.data), 100) if layer.data.ndim - layer.rgb == 3: @@ -67,4 +52,3 @@ def hist_current_layer(self) -> None: self.axes.hist(data.ravel(), bins=bins, label=layer.name) self.axes.legend() - self.canvas.draw() diff --git a/src/napari_matplotlib/scatter.py b/src/napari_matplotlib/scatter.py index 79575880..c3b12742 100644 --- a/src/napari_matplotlib/scatter.py +++ b/src/napari_matplotlib/scatter.py @@ -12,40 +12,19 @@ class ScatterWidget(NapariMPLWidget): If there are more than 500 data points, a 2D histogram is displayed instead of a scatter plot, to avoid too many scatter points. - - Attributes - ---------- - layers : list[`napari.layers.Layer`] - Current two layers being scattered. """ + n_layers_input = 2 + def __init__(self, napari_viewer: napari.viewer.Viewer): super().__init__(napari_viewer) self.axes = self.canvas.figure.subplots() - self.layers = self.viewer.layers[-2:] - - self.viewer.dims.events.current_step.connect( - self.scatter_current_layers - ) - self.viewer.layers.selection.events.changed.connect(self.update_layers) - - self.scatter_current_layers() - - def update_layers(self, event: napari.utils.events.Event) -> None: - """ - Update the currently selected layers. - """ - # Update current layer when selection changed in viewer - layers = self.viewer.layers.selection - if len(layers) == 2: - self.layers = list(layers) - self.scatter_current_layers() + self.update_layers(None) - def scatter_current_layers(self) -> None: + def draw(self) -> None: """ Clear the axes and scatter the currently selected layers. """ - self.axes.clear() data = [layer.data[self.current_z] for layer in self.layers] if data[0].size < 500: self.axes.scatter(data[0], data[1], alpha=0.5) @@ -58,4 +37,3 @@ def scatter_current_layers(self) -> None: ) self.axes.set_xlabel(self.layers[0].name) self.axes.set_ylabel(self.layers[1].name) - self.canvas.draw()