Skip to content

Refactor callback handling #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 61 additions & 4 deletions src/napari_matplotlib/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
FigureCanvas,
NavigationToolbar2QT,
)
from matplotlib.figure import Figure
from qtpy.QtWidgets import QVBoxLayout, QWidget

mpl.rc("axes", edgecolor="white")
Expand All @@ -15,6 +14,7 @@

mpl.rc("xtick", color="white")
mpl.rc("ytick", color="white")

__all__ = ["NapariMPLWidget"]


Expand All @@ -23,8 +23,12 @@ class NapariMPLWidget(QWidget):
Base widget that can be embedded as a napari widget and contains a
Matplotlib canvas.

This creates a single Figure, and sub-classes should implement logic for
drawing on that Figure.
This creates a single FigureCanvas, which contains a single Figure.

This class also handles callbacks to automatically update figures when
the layer selection or z-step is changed in the napari viewer. To take
advantage of this sub-classes should implement the ``clear()`` and
``draw()`` methods.

Attributes
----------
Expand All @@ -34,13 +38,14 @@ class NapariMPLWidget(QWidget):
Matplotlib figure.
canvas : matplotlib.backends.backend_qt5agg.FigureCanvas
Matplotlib canvas.
layers : `list`
List of currently selected napari layers.
"""

def __init__(self, napari_viewer: napari.viewer.Viewer):
super().__init__()

self.viewer = napari_viewer
self.figure = Figure(figsize=(5, 3), tight_layout=True)
self.canvas = FigureCanvas()
self.canvas.figure.patch.set_facecolor("#262930")
self.toolbar = NavigationToolbar2QT(self.canvas, self)
Expand All @@ -49,9 +54,61 @@ def __init__(self, napari_viewer: napari.viewer.Viewer):
self.layout().addWidget(self.toolbar)
self.layout().addWidget(self.canvas)

self.setup_callbacks()

@property
def n_selected_layers(self) -> int:
"""
Number of currently selected layers.
"""
return len(self.layers)

@property
def current_z(self) -> int:
"""
Current z-step of the viewer.
"""
return self.viewer.dims.current_step[0]

def setup_callbacks(self) -> None:
"""
Setup callbacks for:
- Layer selection changing
- z-step changing
"""
# z-step changed in viewer
self.viewer.dims.events.current_step.connect(self._draw)
# Layer selection changed in viewer
self.viewer.layers.selection.events.active.connect(self.update_layers)

def update_layers(self, event: napari.utils.events.Event) -> None:
"""
Update the currently selected layers and re-draw.
"""
self.layers = list(self.viewer.layers.selection)
self._draw()

def _draw(self) -> None:
"""
Clear current figure, check selected layers are correct, and draw new
figure if so.
"""
self.clear()
if self.n_selected_layers != self.n_layers_input:
return
self.draw()
self.canvas.draw()

def clear(self) -> None:
"""
Clear any previously drawn figures.

This is a no-op, and is intended for derived classes to override.
"""

def draw(self) -> None:
"""
Re-draw any figures.

This is a no-op, and is intended for derived classes to override.
"""
34 changes: 9 additions & 25 deletions src/napari_matplotlib/histogram.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,34 @@
import napari
import numpy as np

from .base import NapariMPLWidget

__all__ = ["HistogramWidget"]

import napari

_COLORS = {"r": "tab:red", "g": "tab:green", "b": "tab:blue"}


class HistogramWidget(NapariMPLWidget):
"""
Widget to display a histogram of the currently selected layer.

Attributes
----------
layer : `napari.layers.Layer`
Current layer being histogrammed.
Display a histogram of the currently selected layer.
"""

n_layers_input = 1

def __init__(self, napari_viewer: napari.viewer.Viewer):
super().__init__(napari_viewer)
self.axes = self.canvas.figure.subplots()
self.layer = self.viewer.layers[-1]

self.viewer.dims.events.current_step.connect(self.hist_current_layer)
self.viewer.layers.selection.events.active.connect(self.update_layer)
self.update_layers(None)

self.hist_current_layer()

def update_layer(self, event: napari.utils.events.Event) -> None:
"""
Update the currently selected layer.
"""
# Update current layer when selection changed in viewer
if event.value:
self.layer = event.value
self.hist_current_layer()
def clear(self) -> None:
self.axes.clear()

def hist_current_layer(self) -> None:
def draw(self) -> None:
"""
Clear the axes and histogram the currently selected layer/slice.
"""
self.axes.clear()
layer = self.layer
layer = self.layers[0]
bins = np.linspace(np.min(layer.data), np.max(layer.data), 100)

if layer.data.ndim - layer.rgb == 3:
Expand All @@ -67,4 +52,3 @@ def hist_current_layer(self) -> None:
self.axes.hist(data.ravel(), bins=bins, label=layer.name)

self.axes.legend()
self.canvas.draw()
30 changes: 4 additions & 26 deletions src/napari_matplotlib/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,40 +12,19 @@ class ScatterWidget(NapariMPLWidget):

If there are more than 500 data points, a 2D histogram is displayed instead
of a scatter plot, to avoid too many scatter points.

Attributes
----------
layers : list[`napari.layers.Layer`]
Current two layers being scattered.
"""

n_layers_input = 2

def __init__(self, napari_viewer: napari.viewer.Viewer):
super().__init__(napari_viewer)
self.axes = self.canvas.figure.subplots()
self.layers = self.viewer.layers[-2:]

self.viewer.dims.events.current_step.connect(
self.scatter_current_layers
)
self.viewer.layers.selection.events.changed.connect(self.update_layers)

self.scatter_current_layers()

def update_layers(self, event: napari.utils.events.Event) -> None:
"""
Update the currently selected layers.
"""
# Update current layer when selection changed in viewer
layers = self.viewer.layers.selection
if len(layers) == 2:
self.layers = list(layers)
self.scatter_current_layers()
self.update_layers(None)

def scatter_current_layers(self) -> None:
def draw(self) -> None:
"""
Clear the axes and scatter the currently selected layers.
"""
self.axes.clear()
data = [layer.data[self.current_z] for layer in self.layers]
if data[0].size < 500:
self.axes.scatter(data[0], data[1], alpha=0.5)
Expand All @@ -58,4 +37,3 @@ def scatter_current_layers(self) -> None:
)
self.axes.set_xlabel(self.layers[0].name)
self.axes.set_ylabel(self.layers[1].name)
self.canvas.draw()