"""Common utilities for bindings."""
from __future__ import annotations
import logging
from dataclasses import dataclass
import numpy as np
from patch_denoise.denoise import (
hybrid_pca,
mp_pca,
nordic,
optimal_thresholding,
raw_svt,
adaptive_thresholding,
)
DENOISER_MAP = {
None: None,
"mp-pca": mp_pca,
"hybrid-pca": hybrid_pca,
"raw": raw_svt,
"optimal-fro": lambda *args, **kwargs: optimal_thresholding(
*args, loss="fro", **kwargs
),
"optimal-fro-noise": lambda *args, **kwargs: optimal_thresholding(
*args, loss="fro", **kwargs
),
"optimal-nuc": lambda *args, **kwargs: optimal_thresholding(
*args, loss="nuc", **kwargs
),
"optimal-ope": lambda *args, **kwargs: optimal_thresholding(
*args, loss="ope", **kwargs
),
"nordic": nordic,
"adaptive-qut": lambda *args, **kwargs: adaptive_thresholding(
*args, method="qut", **kwargs
),
}
_RECOMBINATION = {"w": "weighted", "c": "center", "a": "average"}
[docs]
@dataclass
class DenoiseParameters:
"""Denoise Parameters data structure."""
method: str = None
patch_shape: int | tuple[int, ...] = 11
patch_overlap: int | tuple[int, ...] = 0
recombination: str = "weighted" # "center" is also available
mask_threshold: int = 10
@property
def pretty_name(self):
"""Return a pretty name for the representation of parameters."""
if self.method:
name = self.method
for attr in [
"patch_shape",
"patch_overlap",
"recombination",
"mask_threshold",
]:
if getattr(self, attr):
name += f"_{getattr(self, attr)}"
else:
name = "noisy"
return name
@property
def pretty_par(self):
"""Get pretty representation of parameters."""
name = f"{self.patch_shape}_{self.patch_overlap}{self.recombination[0]}"
return name
[docs]
@classmethod
def get_str(cls, **kwargs):
"""Get full string representation from set of kwargs."""
return cls(**kwargs).pretty_name
[docs]
@classmethod
def from_str(self, config_str):
"""Create a DenoiseParameters from a string."""
if "noisy" in config_str:
return DenoiseParameters(
method=None,
patch_shape=None,
patch_overlap=None,
recombination=None,
mask_threshold=None,
)
else:
conf = config_str.split("_")
d = DenoiseParameters()
if conf:
d.method = conf.pop(0)
if conf:
d.patch_shape = int(conf.pop(0))
if conf:
d.patch_overlap = int(conf.pop(0))
if conf:
c = conf.pop(0)
d.recombination = c
if conf:
d.mask_threshold = int(conf.pop(0))
return d
def __str__(self):
"""Get string representation."""
return self.pretty_name
[docs]
def load_as_array(input):
"""Load a file as a numpy array, and return affine matrix if available."""
import nibabel as nib
if input is None:
return None, None
if input.suffix == ".npy":
return np.load(input), None
elif ".nii" in input.suffixes:
nii = nib.load(input)
return nii.get_fdata(dtype=np.float32), nii.affine
else:
raise ValueError("Unsupported file format. use numpy or nifti formats.")
[docs]
def save_array(data, affine, filename):
"""Save array to file, with affine matrix if required."""
import nibabel as nib
if filename is None:
return None
if ".nii" in filename.suffixes:
if affine is None:
affine = np.eye(len(data.shape))
nii_img = nib.Nifti1Image(data, affine)
nii_img.to_filename(filename)
elif filename.endswith(".npy"):
np.save(filename, data)
return filename
[docs]
def load_complex_nifti(mag_file, phase_file, filename=None): # pragma: no cover
"""Load two nifti image (magnitude and phase) to create a complex valued array.
Optionally, the result can be save as a .npy file
Parameters
----------
mag_file: str
The source magnitude file
phase_file: str
The source phase file
filename: str, default None
The output filename
"""
mag, mag_affine = load_as_array(mag_file)
phase, phase_affine = load_as_array(phase_file)
if not np.allclose(mag_affine, phase_affine):
logging.warning("Affine matrices for magnitude and phase are not the same")
logging.info("Phase data range is [%.2f %.2f]", np.min(phase), np.max(phase))
logging.info("Mag data range is [%.2f %.2f]", np.min(mag), np.max(mag))
img = mag * np.exp(1j * phase)
if filename is not None:
np.save(filename, img)
return img, mag_affine
[docs]
def compute_mask(array, convex=False):
"""Compute mask for array using the Otzu's method.
The time axis is assumed to be the last one.
The mask is computed slice-wise on the time average of the array.
Parameters
----------
array : numpy.ndarray
Array to compute mask for.
convex : bool, default False
If True, the mask is convex for each slice.
Returns
-------
numpy.ndarray
Mask for array.
"""
from skimage.filters import threshold_otsu
from skimage.morphology import convex_hull_image
mean = array.mean(axis=-1)
mask = np.zeros(mean.shape, dtype=bool)
for i in range(mean.shape[-1]):
mask[..., i] = mean[..., i] > threshold_otsu(mean[..., i])
if convex:
for i in range(mean.shape[-1]):
mask[..., i] = convex_hull_image(mask[..., i])
return mask