Source code for snake.toolkit.plotting

"""Plotting utilities for the project."""

import matplotlib
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
from numpy.typing import NDArray
from mpl_toolkits.axes_grid1.axes_divider import Size, make_axes_locatable
from skimage.measure import find_contours
from matplotlib.cm import ScalarMappable


[docs] def get_coolgraywarm(thresh: float = 3, max: float = 7) -> matplotlib.colorbar.Colorbar: """Get a cool-warm colorbar, with gray inside the threshold.""" coolwarm = matplotlib.colormaps["coolwarm"].resampled(256) newcolors = coolwarm(np.linspace(0, 1, 256)) gray = np.array([0.8, 0.8, 0.8, 1]) minthresh = int(128 + (thresh / max) * 128) maxthresh = int(128 - (thresh / max) * 128) newcolors[minthresh:maxthresh, :] = gray cool_gray_warm = matplotlib.colors.ListedColormap(newcolors) return cool_gray_warm
# %%
[docs] def get_axis_properties( array_bg: NDArray, cuts: tuple[int, ...], width_inches: float, cbar: bool = True, arr_pad: int = 4, ) -> tuple[ NDArray, NDArray, tuple[tuple[slice, slice], ...], tuple[tuple[Any, Any, Any], ...], ]: """Generate mplt toolkit axes dividers.""" slices = (np.s_[cuts[0], :, :], np.s_[:, cuts[1], :], np.s_[:, :, cuts[2]]) bbox: list[tuple] = [(None, None), (None, None), (None, None)] for i in range(3): cut = array_bg[slices[i]] if cut.dtype != "bool": mask = abs(cut) > 0.5 * np.percentile(abs(cut), 95) else: mask = cut rows = np.any(mask, axis=1) cols = np.any(mask, axis=0) rmin, rmax = np.where(rows)[0][[0, -1]] cmin, cmax = np.where(cols)[0][[0, -1]] rmin = max(0, rmin - arr_pad) rmax = min(rmax + arr_pad, mask.shape[0]) cmin = max(0, cmin - arr_pad) cmax = min(cmax + arr_pad, mask.shape[1]) bbox[i] = (slice(rmin, rmax), slice(cmin, cmax)) hdiv, vdiv = _get_hdiv_vdiv(array_bg, bbox, slices, width_inches, cbar=cbar) return hdiv, vdiv, tuple(bbox), slices
def _get_hdiv_vdiv( array_bg: NDArray, bbox: tuple[tuple[slice]], slices: tuple[slice], width_inches: float, cbar: bool = False, ) -> tuple[NDArray, NDArray]: sizes = np.array([(bb.stop - bb.start) for b in bbox for bb in b]) sizes = tuple(array_bg[s][b].shape for s, b in zip(slices, bbox, strict=False)) alpha1 = sizes[1][1] / sizes[2][1] update_sizes = [[0, 0], [0, 0], [0, 0]] update_sizes[2][0] = sizes[2][0] update_sizes[2][1] = sizes[2][1] alpha1 = sizes[2][1] / sizes[1][1] update_sizes[1][0] = sizes[1][0] * alpha1 update_sizes[1][1] = sizes[1][1] * alpha1 alpha2 = (update_sizes[2][0] + update_sizes[1][0]) / sizes[0][0] update_sizes[0][0] = sizes[0][0] * alpha2 update_sizes[0][1] = sizes[0][1] * alpha2 aspect = update_sizes[0][0] / (update_sizes[0][1] + update_sizes[1][1]) split_lr = update_sizes[0][1] / (update_sizes[1][1] + update_sizes[0][1]) split_tb = update_sizes[1][0] / (update_sizes[1][0] + update_sizes[2][0]) hdiv = [ width_inches * split_lr, width_inches * (1 - split_lr), ] if cbar: hdiv.extend( [ 0.02 * hdiv[0], 0.02 * hdiv[0], ] ) np.array(hdiv) height_inches = width_inches * aspect vdiv = np.array([height_inches * split_tb, height_inches * (1 - split_tb)]) return hdiv, vdiv
[docs] def get_mask_cuts_mask(mask: NDArray) -> tuple[int, ...]: """Get the optimal cut that expose maximum number of voxel in mask.""" max_cuts = [0] * len(mask.shape) for i in range(len(max_cuts)): max_cuts[i] = int(np.argmax(np.sum(mask, axis=tuple(np.array([-2, -1]) + i)))) return tuple(max_cuts)
[docs] def plot_frames_activ( background: NDArray, z_score: NDArray, roi: NDArray | None, ax: plt.Axes, slices: tuple[Any, ...], bbox: tuple[Any, ...], z_thresh: float = 3, z_max: float = 11, bg_cmap: str = "gray", ) -> tuple[plt.Axes, matplotlib.image.AxesImage]: """Plot activation maps and background. Parameters ---------- background: 3D array z_score: 3D array roi: 3D array ax: plt.Axes """ bg = background[slices][bbox].squeeze() im = ax.imshow( bg, vmin=np.min(background), vmax=np.max(background), cmap=bg_cmap, origin="lower", aspect="equal", ) if z_score is not None: masked_z = z_score[slices][bbox].squeeze() masked_z[abs(masked_z) < z_thresh] = np.NaN im = ax.imshow( masked_z, alpha=1, cmap=get_coolgraywarm(z_thresh, max=z_max), vmin=-z_max, vmax=z_max, aspect="equal", interpolation="nearest", origin="lower", ) if roi is not None: roi_cut = roi[slices][bbox].squeeze() contours = find_contours(roi_cut) for c in contours: ax.plot(c[:, 1], c[:, 0] - 0.5, c="cyan", label="ground-truth", linewidth=1) ax.set_xticks([]) ax.set_yticks([]) return ax, im
[docs] def axis3dcut( fig: plt.Figure, ax: plt.Axes, background: NDArray, z_score: NDArray, gt_roi: NDArray | None = None, width_inches: float = 7, cbar: bool = True, cuts: tuple[int, ...] | None = None, bbox: tuple[tuple[Any, Any], ...] | None = None, slices: tuple[tuple[Any, Any, Any], ...] | None = None, bg_cmap: str = "gray", ) -> tuple[plt.Figure, plt.Axes, tuple[int, ...]]: """Display a 3D image with zscore and ground truth ROI.""" # ax.axis("off") if cuts is None and gt_roi is not None: cuts_ = get_mask_cuts_mask(gt_roi) gt_roi_ = gt_roi elif cuts is not None and gt_roi is not None: cuts_ = cuts gt_roi_ = gt_roi elif cuts is None and gt_roi is None: raise ValueError("Missing gt_roi to compute ideal cuts.") elif cuts is not None and gt_roi is None: cuts_ = cuts gt_roi_ = None if bbox is None and slices is None: hdiv, vdiv, bbox_, slices_ = get_axis_properties( background, cuts_, width_inches, cbar=cbar ) elif bbox is not None and slices is not None: hdiv, vdiv = _get_hdiv_vdiv(background, bbox, slices, width_inches, cbar=cbar) bbox_ = bbox slices_ = slices else: raise ValueError("Missing either bbox or slices.") divider = make_axes_locatable(ax) divider.set_horizontal([Size.Fixed(s) for s in hdiv]) divider.set_vertical([Size.Fixed(s) for s in vdiv]) axG: list[plt.Axes] = [None, None, None] for i, (nx, ny, ny1) in enumerate([(0, 0, 2), (1, 0, 1), (1, 1, 2)]): axG[i] = plt.Axes(fig, ax.get_position(original=True)) axG[i].set_axes_locator(divider.new_locator(nx=nx, ny=ny, ny1=ny1)) fig.add_axes(axG[i]) for i in range(3): plot_frames_activ( background, z_score, gt_roi_, axG[i], slices_[i], bbox_[i], bg_cmap=bg_cmap, ) if cbar: cax = type(ax)(fig, ax.get_position(original=True)) cax.set_axes_locator(divider.new_locator(nx=3, ny=0, ny1=-1)) if z_score is not None: im = ScalarMappable(norm="linear", cmap=get_coolgraywarm()) im.set_clim(-11, 11) matplotlib.colorbar.Colorbar(cax, im, orientation="vertical") cax.set_ylabel("z-scores", labelpad=-20) cax.set_yticks(np.concatenate([-np.arange(3, 12, 2), np.arange(3, 12, 2)])) else: # use the background image im = ScalarMappable(norm="linear", cmap=bg_cmap) im.set_clim(vmin=np.min(background), vmax=np.max(background)) matplotlib.colorbar.Colorbar(cax, im, orientation="vertical") fig.add_axes(cax) ax.set_axes_locator(divider.new_locator(nx=0, ny=0, ny1=-1, nx1=-1)) ax.set_zorder(10) ax.axis("off") # ax.set_xticks([]) # ax.set_yticks([]) return fig, ax, cuts_