import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, RadioButtons, Button
from matplotlib.lines import Line2D
from matplotlib.colors import Normalize, LogNorm
import itertools
[docs]
class NDImageViewer:
"""Interactive viewer for 2D, 3D, and 4D image data with real-time analysis tools.
This class provides an interactive matplotlib-based interface for viewing and analyzing
multi-dimensional image data. It supports 2D images, 3D image cubes (e.g., time series),
and 4D data (e.g., time series of grouped observations). Features include dynamic scaling
modes, cut profile extraction, and draggable cut tools for pixel-level analysis.
Parameters
----------
data : array-like
Input image data with shape (ny, nx) for 2D, (nt, ny, nx) for 3D, or
(nt, ng, ny, nx) for 4D, where nt=time, ng=group, ny/nx=spatial dimensions.
cmap : str, optional
Matplotlib colormap name for image display. Default is ``'magma'``.
Attributes
----------
data : ndarray
The input image data.
ndim : int
Number of dimensions in the data (2, 3, or 4).
i_image : int
Current image/frame index for 3D/4D data (default 0).
i_group : int
Current group index for 4D data (default 0).
scale_mode : str
Current scaling mode ('linear', 'log', 'asinh', or 'zscale').
cuts : list of dict
List of extracted cut profiles, each containing 'p1', 'p2', 'line', 'color',
and 'visible' keys.
active_cut : dict or None
Cut currently being drawn (awaiting two clicks).
fig : matplotlib.figure.Figure
The interactive figure object.
ax_img : matplotlib.axes.Axes
Main image display axes.
ax_prof : matplotlib.axes.Axes
Profile plot axes for cut data.
im : matplotlib.image.AxesImage
The displayed image object.
cbar : matplotlib.colorbar.Colorbar
Colorbar for the image display.
Notes
-----
- Data must be 2D, 3D, or 4D; other dimensions will raise a ValueError.
- Finite data values are automatically detected for scaling limits.
- The viewer is interactive: sliders update in real-time, and cuts can be drawn
and modified by clicking/dragging on the image.
Examples
--------
View a 2D image:
>>> import numpy as np
>>> data_2d = np.random.rand(256, 256)
>>> viewer = NDImageViewer(data=data_2d)
>>> viewer.show()
View a 3D time series:
>>> data_3d = np.random.rand(100, 256, 256) # 100 frames
>>> viewer = NDImageViewer(data=data_3d, cmap='viridis')
>>> viewer.show()
View 4D data and extract cuts:
>>> data_4d = np.random.rand(10, 8, 256, 256) # 10 times, 8 groups
>>> viewer = NDImageViewer(data=data_4d)
>>> viewer.show()
>>> # Click "Add cut" button, then click two points on the image to define a cut
"""
def __init__(self, data, cmap="magma"):
"""Initialize the NDImageViewer.
Parameters
----------
data : array-like
Image data (2D, 3D, or 4D).
cmap : str, optional
Colormap name. Default is ``'magma'``.
Raises
------
ValueError
If data is not 2D, 3D, or 4D.
"""
self.data = np.asarray(data)
self.cmap = cmap
self.ndim = self.data.ndim
if self.ndim not in (2, 3, 4):
raise ValueError("Data must be 2D, 3D, or 4D")
self.i_image = 0
self.i_group = 0
self.scale_mode = "linear"
self.cuts = []
self.active_cut = None
self.drag_mode = None
self.dragged_cut = None
self.dragged_endpoint = None
self.color_cycle = itertools.cycle(
["cyan", "orange", "lime", "magenta", "red", "yellow"]
)
self._setup_data()
self._setup_figure()
self._connect_events()
# ----------------------------------------------------
# Data helpers
# ----------------------------------------------------
def _setup_data(self):
"""Initialize data ranges and extract the first image for display.
Computes finite value statistics (min/max) and initializes vmin/vmax
for scaling. Sets up the first image frame to display.
"""
self.img0 = self._get_image()
self.ny, self.nx = self.img0.shape
finite = np.isfinite(self.data)
self.data_min = np.nanmin(self.data[finite])
self.data_max = np.nanmax(self.data[finite])
self.vmin = max(self.data_min, 1e-10) if self.data_min <= 0 else self.data_min
self.vmax = self.data_max
def _get_image(self):
"""Retrieve the current 2D image from the data based on current indices.
Returns
-------
ndarray
2D image array (shape: ny, nx).
"""
if self.ndim == 2:
return self.data
elif self.ndim == 3:
return self.data[self.i_image]
else:
return self.data[self.i_image, self.i_group]
# ----------------------------------------------------
# Figure & layout
# ----------------------------------------------------
def _setup_figure(self):
"""Create and configure the matplotlib figure and axes layout.
Sets up three main components:
- Left panel: main image display with colorbar
- Top-right: interactive controls (sliders, scale selector)
- Bottom-right: cut profile plot
"""
plt.rcParams.update({
"figure.facecolor": "#f2f2f2",
"axes.facecolor": "#ffffff",
"axes.edgecolor": "#aaaaaa",
"axes.linewidth": 0.8,
"font.size": 10,
})
self.fig = plt.figure(figsize=(12, 6))
# ---- Left square image panel ----
self.ax_img = self.fig.add_axes([0.05, 0.15, 0.38, 0.75])
self.ax_img.set_xticks([])
self.ax_img.set_yticks([])
self.ax_img.set_title("ND Image Viewer", pad=10)
self.im = self.ax_img.imshow(
self.img0,
cmap=self.cmap,
norm=Normalize(self.vmin, self.vmax),
origin="lower"
)
self.cbar = self.fig.colorbar(
self.im,
ax=self.ax_img,
fraction=0.05,
pad=0.04
)
# ---- Top-right controls ----
self._setup_controls()
# ---- Bottom-right profiles ----
self.ax_prof = self.fig.add_axes([0.55, 0.15, 0.40, 0.35])
self.ax_prof.set_xlabel("Pixel index")
self.ax_prof.set_ylabel("Value", labelpad=10)
def _setup_controls(self):
"""Create and arrange interactive control widgets.
Widgets include:
- vmin/vmax sliders for intensity scaling
- Scale mode selector (Linear/Log/Asinh/Zscale)
- Image and group sliders (for 3D/4D data)
- "Add cut" and "Remove cut" buttons
"""
px, pw = 0.55, 0.38
y = 0.85
dy = 0.055
def label(text, ypos):
self.fig.text(px, ypos, text, weight="semibold", color="#444")
label("Scaling", y)
y -= dy
# vmin slider
self.fig.text(px, y + 0.02, "vmin", fontsize=9)
self.ax_vmin = self.fig.add_axes([px + 0.06, y+0.015, pw - 0.2, 0.02])
self.s_vmin = Slider(self.ax_vmin, "", self.data_min, self.data_max, valinit=self.vmin)
y -= dy
# vmax slider
self.fig.text(px, y + 0.02, "vmax", fontsize=9)
self.ax_vmax = self.fig.add_axes([px + 0.06, y+0.015, pw - 0.2, 0.02])
self.s_vmax = Slider(self.ax_vmax, "", self.data_min, self.data_max, valinit=self.vmax)
# ------- Scale mode -------
self.fig.text(px + 0.32, 0.85, "Scale mode", weight="semibold", color="#444")
# Scale selector (2 columns: Linear/Log and Asinh/Zscale)
self.ax_scale = self.fig.add_axes([px + 0.32, 0.68, 0.08, 0.15])
self.scale_radio = RadioButtons(
self.ax_scale,
["Linear", "Log", "Asinh", "Zscale"],
active=0
)
for txt in self.scale_radio.labels:
txt.set_fontsize(9)
# ---- Navigation ----
y -= dy * 0.8
if self.ndim >= 3:
label("Navigation", y)
if self.ndim >= 3:
y -= dy
self.fig.text(px, y + 0.02, "Image", fontsize=9)
self.ax_img_idx = self.fig.add_axes([px + 0.06, y+0.015, pw - 0.2, 0.02])
self.s_img = Slider(self.ax_img_idx, "", 0, self.data.shape[0]-1,
valinit=0, valstep=1)
if self.ndim == 4:
y -= dy
self.fig.text(px, y + 0.02, "Group", fontsize=9)
self.ax_grp = self.fig.add_axes([px + 0.06, y+0.015, pw - 0.2, 0.02])
self.s_grp = Slider(self.ax_grp, "", 0, self.data.shape[1]-1,
valinit=0, valstep=1)
# ---- Cuts ----
y -= dy * 1.3
if self.ndim == 2:
y -= 0.105
if self.ndim == 3:
y -= 0.05
if self.ndim == 4:
y -= -0.005
label("Cuts", y)
#y -= dy
self.ax_add_cut = self.fig.add_axes([px + 0.14, y-0.01, 0.12, 0.045])
self.btn_add_cut = Button(self.ax_add_cut, "Add cut")
self.ax_remove_cut = self.fig.add_axes([px + 0.14 + 0.14, y-0.01, 0.12, 0.045])
self.btn_remove_cut = Button(self.ax_remove_cut, "Remove cut")
# ----------------------------------------------------
# Cut management
# ----------------------------------------------------
def _add_cut(self):
"""Initialize a new cut for the user to draw.
Sets ``self.active_cut`` to a new dict and assigns a color from the
color cycle. The user then clicks twice on the image to define the cut.
"""
self.active_cut = {
"p1": None,
"p2": None,
"line": None,
"color": next(self.color_cycle),
"visible": True
}
def _remove_cut(self):
"""Remove the active cut being drawn or the last finalized cut.
If a cut is being drawn (``self.active_cut`` is not None), cancel it.
Otherwise, pop the last cut from the ``self.cuts`` list and remove
its line from the image.
"""
if self.active_cut is not None:
# Remove active cut being drawn
self.active_cut = None
elif self.cuts:
# Remove the last cut
cut = self.cuts.pop()
if cut["line"] is not None:
cut["line"].remove()
self._update_profiles()
def _finalize_cut(self):
"""Finalize a cut after two points have been clicked.
Creates a Line2D artist on the image and appends the cut to
``self.cuts``. Resets ``self.active_cut`` to None and updates
the profile plot.
"""
cut = self.active_cut
line = Line2D(
[cut["p1"][0], cut["p2"][0]],
[cut["p1"][1], cut["p2"][1]],
lw=2,
color=cut["color"],
picker=5
)
self.ax_img.add_line(line)
cut["line"] = line
self.cuts.append(cut)
self.active_cut = None
self._update_profiles()
def _update_profiles(self):
"""Recompute and redraw cut profiles in the profile axes.
Clears the profile plot and redraws profiles for all visible cuts
by sampling intensity along each cut line.
"""
self.ax_prof.cla()
self.ax_prof.set_xlabel("Pixel index")
self.ax_prof.set_ylabel("Value", labelpad=10)
img = self._get_image()
for cut in self.cuts:
if not cut["visible"]:
continue
dist, prof = self._sample_cut(img, cut["p1"], cut["p2"])
self.ax_prof.plot(dist, prof, color=cut["color"], lw=1.8)
self.fig.canvas.draw_idle()
def _sample_cut(self, img, p1, p2, npts=None):
"""Sample pixel values along a line segment.
Linearly interpolates between two points and samples the image
at interpolated coordinates.
Parameters
----------
img : ndarray
2D image array.
p1 : array-like
Starting point [x, y].
p2 : array-like
Ending point [x, y].
npts : int, optional
Number of sample points. If None, uses the pixel length of the cut.
Returns
-------
dist : ndarray
Distance along the cut (in pixels).
values : ndarray
Sampled intensity values along the line.
"""
dx = p2[0] - p1[0]
dy = p2[1] - p1[1]
if npts is None:
npts = int(max(abs(dx), abs(dy))) + 1
npts = max(npts, 2)
x = np.linspace(p1[0], p2[0], npts)
y = np.linspace(p1[1], p2[1], npts)
xi = np.clip(np.round(x).astype(int), 0, self.nx - 1)
yi = np.clip(np.round(y).astype(int), 0, self.ny - 1)
dist = np.sqrt((x - p1[0]) ** 2 + (y - p1[1]) ** 2)
return dist, img[yi, xi]
def _get_distance_to_point(self, p1, p2, threshold=10):
"""Compute distance between two points and check if within threshold.
Parameters
----------
p1 : array-like
First point [x, y].
p2 : array-like
Second point [x, y].
threshold : float, optional
Distance threshold in pixels. Default is 10.
Returns
-------
endpoint : str or None
'p1' or 'p2' if within threshold, otherwise None.
distance : float
Computed distance, or None if not within threshold.
"""
dist_p1 = np.sqrt((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)
dist_p2 = np.sqrt((self.data.shape[-1] - p2[0])**2 + (self.data.shape[-2] - p2[1])**2)
if dist_p1 < threshold:
return "p1", dist_p1
elif dist_p2 < threshold:
return "p2", dist_p2
return None, None
def _get_distance_to_line(self, p1, p2, point, threshold=10):
"""Compute perpendicular distance from a point to a line segment.
Parameters
----------
p1 : array-like
Line segment start [x, y].
p2 : array-like
Line segment end [x, y].
point : array-like
Query point [x, y].
threshold : float, optional
Distance threshold in pixels. Default is 10.
Returns
-------
distance : float
Perpendicular distance to the line segment, or ``inf`` if the
perpendicular from the point does not intersect the segment.
"""
x1, y1 = p1
x2, y2 = p2
x0, y0 = point
# Distance from point to line segment
num = abs((y2 - y1) * x0 - (x2 - x1) * y0 + x2 * y1 - y2 * x1)
den = np.sqrt((y2 - y1)**2 + (x2 - x1)**2)
if den == 0:
return np.sqrt((x0 - x1)**2 + (y0 - y1)**2)
dist = num / den
# Check if point projects onto the segment
t = ((x0 - x1) * (x2 - x1) + (y0 - y1) * (y2 - y1)) / (den**2)
if 0 <= t <= 1 and dist < threshold:
return dist
return float('inf')
# ----------------------------------------------------
# Events
# ----------------------------------------------------
def _connect_events(self):
"""Connect matplotlib event handlers for interactivity.
Connects slider callbacks, button click handlers, and mouse event
handlers for image interaction.
"""
self.s_vmin.on_changed(self._update_image)
self.s_vmax.on_changed(self._update_image)
if self.ndim >= 3:
self.s_img.on_changed(self._update_image)
if self.ndim == 4:
self.s_grp.on_changed(self._update_image)
self.scale_radio.on_clicked(self._change_scale)
self.btn_add_cut.on_clicked(lambda _: self._add_cut())
self.btn_remove_cut.on_clicked(lambda _: self._remove_cut())
self.fig.canvas.mpl_connect("button_press_event", self._on_click)
self.fig.canvas.mpl_connect("motion_notify_event", self._on_drag)
self.fig.canvas.mpl_connect("button_release_event", self._on_release)
self.fig.canvas.mpl_connect("pick_event", self._on_pick)
def _update_image(self, val=None):
"""Update the displayed image and rescale it based on current settings.
Called when sliders change or scale mode is switched. Fetches the
current image, applies the selected scaling mode, updates the display,
and recomputes cut profiles.
Parameters
----------
val : float, optional
Value from slider (unused, present for callback compatibility).
"""
if self.ndim >= 3:
self.i_image = int(self.s_img.val)
if self.ndim == 4:
self.i_group = int(self.s_grp.val)
img = self._get_image()
if self.scale_mode == "log":
pos = img[img > 0]
norm = LogNorm(max(self.s_vmin.val, pos.min() if len(pos) > 0 else 1e-10), self.s_vmax.val)
elif self.scale_mode == "asinh":
from matplotlib.colors import SymLogNorm
norm = SymLogNorm(linthresh=0.03, vmin=self.s_vmin.val, vmax=self.s_vmax.val)
elif self.scale_mode == "zscale":
# Simple zscale-like normalization: use percentiles
vmin_z = np.percentile(img[np.isfinite(img)], 2)
vmax_z = np.percentile(img[np.isfinite(img)], 98)
norm = Normalize(vmin_z, vmax_z)
else:
norm = Normalize(self.s_vmin.val, self.s_vmax.val)
self.im.set_data(img)
self.im.set_norm(norm)
self.cbar.update_normal(self.im)
self._update_profiles()
def _change_scale(self, label):
"""Switch the scaling mode and update the image.
Parameters
----------
label : str
Scale mode name ('Linear', 'Log', 'Asinh', or 'Zscale').
"""
self.scale_mode = label.lower()
self._update_image()
def _on_click(self, event):
"""Handle mouse button press events on the image.
Supports:
- Drawing cuts (two clicks to define start and end points)
- Selecting cut endpoints or lines for dragging
Parameters
----------
event : matplotlib.backend_bases.MouseEvent
Mouse event with xdata, ydata, and inaxes attributes.
"""
if event.inaxes != self.ax_img:
return
# If drawing a new cut
if self.active_cut is not None:
if self.active_cut["p1"] is None:
self.active_cut["p1"] = [event.xdata, event.ydata]
else:
self.active_cut["p2"] = [event.xdata, event.ydata]
self._finalize_cut()
return
# Check if clicking near an existing cut endpoint or line
point = [event.xdata, event.ydata]
for i, cut in enumerate(self.cuts):
# Check distance to endpoints
endpoint, dist = self._get_distance_to_point(cut["p1"], point, threshold=10)
if endpoint is not None:
self.dragged_cut = i
self.dragged_endpoint = endpoint
self.drag_mode = "endpoint"
return
endpoint2, dist2 = self._get_distance_to_point(cut["p2"], point, threshold=10)
if endpoint2 is not None:
self.dragged_cut = i
self.dragged_endpoint = endpoint2
self.drag_mode = "endpoint"
return
# Check distance to line
line_dist = self._get_distance_to_line(cut["p1"], cut["p2"], point, threshold=10)
if line_dist < 10:
self.dragged_cut = i
self.drag_mode = "line"
self.drag_start = point
return
def _on_drag(self, event):
"""Handle mouse drag events for moving cuts or endpoints.
Allows dragging individual endpoints or translating entire cuts.
Parameters
----------
event : matplotlib.backend_bases.MouseEvent
Mouse motion event.
"""
if event.inaxes != self.ax_img or self.drag_mode is None or self.dragged_cut is None:
return
cut = self.cuts[self.dragged_cut]
point = [event.xdata, event.ydata]
if self.drag_mode == "endpoint":
if self.dragged_endpoint == "p1":
cut["p1"] = point
else:
cut["p2"] = point
# Update line
cut["line"].set_data([cut["p1"][0], cut["p2"][0]], [cut["p1"][1], cut["p2"][1]])
self._update_profiles()
elif self.drag_mode == "line":
# Translate both endpoints
dx = point[0] - self.drag_start[0]
dy = point[1] - self.drag_start[1]
cut["p1"] = [cut["p1"][0] + dx, cut["p1"][1] + dy]
cut["p2"] = [cut["p2"][0] + dx, cut["p2"][1] + dy]
cut["line"].set_data([cut["p1"][0], cut["p2"][0]], [cut["p1"][1], cut["p2"][1]])
self.drag_start = point
self._update_profiles()
def _on_release(self, event):
"""Handle mouse button release events to end dragging.
Parameters
----------
event : matplotlib.backend_bases.MouseEvent
Mouse release event.
"""
self.drag_mode = None
self.dragged_cut = None
self.dragged_endpoint = None
def _on_pick(self, event):
"""Handle pick events on cut lines (currently unused).
Parameters
----------
event : matplotlib.backend_bases.PickEvent
Pick event (fires when artist is clicked with appropriate picker tolerance).
"""
pass
# ----------------------------------------------------
[docs]
def show(self):
"""Display the interactive viewer window.
Calls ``plt.show()`` to render the figure and start the interactive loop.
"""
plt.show()