Source code for patch_denoise.space_time.lowrank

"""Low Rank methods."""

from types import MappingProxyType

import numpy as np
from scipy.linalg import svd
from scipy.optimize import minimize

from .base import BaseSpaceTimeDenoiser, PatchedArray
from .utils import (
    eig_analysis,
    eig_synthesis,
    marchenko_pastur_median,
    svd_analysis,
    svd_synthesis,
)
from .._docs import fill_doc

NUMBA_AVAILABLE = True
try:
    import numba as nb
except ImportError:
    NUMBA_AVAILABLE = False
    pass


[docs] @fill_doc class MPPCADenoiser(BaseSpaceTimeDenoiser): """Denoising using Marchenko-Pastur principal components analysis thresholding. Parameters ---------- $patch_config threshold_scale: float An extra factor multiplying the threshold. """ def __init__(self, patch_shape, patch_overlap, threshold_scale, **kwargs): super().__init__(patch_shape, patch_overlap, **kwargs) self.input_denoising_kwargs["threshold_scale"] = threshold_scale
[docs] def _patch_processing(self, patch, patch_idx=None, threshold_scale=1.0): """Process a patch with the MP-PCA method.""" p_center, eig_vals, eig_vec, p_tmean = eig_analysis(patch) maxidx = 0 meanvar = np.mean(eig_vals) meanvar *= 4 * np.sqrt((len(eig_vals) - maxidx + 1) / len(patch)) while maxidx < len(eig_vals) and meanvar < eig_vals[~maxidx] - eig_vals[0]: maxidx += 1 meanvar = np.mean(eig_vals[:-maxidx]) meanvar *= 4 * np.sqrt((len(eig_vec) - maxidx + 1) / len(patch)) var_noise = np.mean(eig_vals[: len(eig_vals) - maxidx]) maxidx = np.sum(eig_vals > (var_noise * threshold_scale ** 2)) if maxidx == 0: patch_new = np.zeros_like(patch) + p_tmean else: patch_new = eig_synthesis(p_center, eig_vec, p_tmean, maxidx) # Equation (3) of Manjon 2013 return patch_new, maxidx, var_noise
[docs] @fill_doc class HybridPCADenoiser(BaseSpaceTimeDenoiser): """Denoising using the Hybrid-PCA thresholding method. Parameters ---------- $patch_config """
[docs] @fill_doc def denoise( self, input_data, mask=None, mask_threshold=50, noise_std=1.0, progbar=None ): """Denoise using the Hybrid-PCA method. Along with the input data a noise std map or value should be provided. Parameters ---------- $input_config $mask_config $noise_std Returns ------- $denoise_return """ p_s, p_o = self._get_patch_param(input_data.shape) if isinstance(noise_std, (float, np.floating)): var_apriori = noise_std ** 2 * np.ones(input_data.shape[:-1]) else: var_apriori = noise_std ** 2 var_apriori = PatchedArray( np.broadcast_to(var_apriori[..., None], input_data.shape), p_s, p_o ) self.input_denoising_kwargs["var_apriori"] = var_apriori return super().denoise(input_data, mask, mask_threshold, progbar=progbar)
[docs] def _patch_processing(self, patch, patch_idx=None, var_apriori=None): """Process a patch with the Hybrid-PCA method.""" varest = np.mean(var_apriori.get_patch(patch_idx)) p_center, eig_vals, eig_vec, p_tmean = eig_analysis(patch) maxidx = 0 var_noise = np.mean(eig_vals) while var_noise > varest and maxidx < len(eig_vals) - 2: maxidx += 1 var_noise = np.mean(eig_vals[:-maxidx]) if maxidx == 0: # all eigen values are noise patch_new = np.zeros_like(patch) + p_tmean else: patch_new = eig_synthesis(p_center, eig_vec, p_tmean, maxidx) # Equation (3) of Manjon2013 return patch_new, maxidx, var_noise
[docs] @fill_doc class RawSVDDenoiser(BaseSpaceTimeDenoiser): """ Classical Patch wise singular value thresholding denoiser. Parameters ---------- $patch_config threshold_vlue: float threshold value for the singular values. """ def __init__( self, patch_shape, patch_overlap, threshold_value=1.0, recombination="weighted" ): self._threshold_val = threshold_value super().__init__(patch_shape, patch_overlap, recombination)
[docs] @fill_doc def denoise( self, input_data, mask=None, mask_threshold=50, threshold_scale=1.0, progbar=None, ): """Denoise the input_data, according to mask. Patches are extracted sequentially and process by the implemented `_patch_processing` function. Only patches which have at least a voxel in the mask ROI are processed. Parameters ---------- $input_config $mask_config threshold_scale: float Extra factor for the threshold of singular values. Returns ------- $denoise_return """ self._threshold = self._threshold_val * threshold_scale return super().denoise(input_data, mask, mask_threshold, progbar=progbar)
[docs] def _patch_processing(self, patch, patch_idx=None, **kwargs): """Process a patch with the simple SVT method.""" # Centering for better precision in SVD u_vec, s_values, v_vec, p_tmean = svd_analysis(patch) maxidx = np.sum(s_values > self._threshold) if maxidx == 0: p_new = np.zeros_like(patch) + p_tmean else: s_values[s_values < self._threshold] = 0 p_new = svd_synthesis(u_vec, s_values, v_vec, p_tmean, maxidx) # Equation (3) in Manjon 2013 return p_new, maxidx, np.nan
[docs] @fill_doc class NordicDenoiser(RawSVDDenoiser): """Denoising using the NORDIC method. Parameters ---------- $patch_config """
[docs] @fill_doc def denoise( self, input_data, mask=None, mask_threshold=50, noise_std=1.0, n_iter_threshold=10, progbar=None, ): """Denoise using the NORDIC method. Along with the input data a noise std map or value should be provided. Parameters ---------- $input_config $mask_config $noise_std Returns ------- $denoise_return """ patch_shape, _ = self._get_patch_param(input_data.shape) # compute the threshold using Monte-Carlo Simulations. max_sval = sum( max( svd( np.random.randn(np.prod(patch_shape), input_data.shape[-1]), compute_uv=False, ) ) for _ in range(n_iter_threshold) ) max_sval /= n_iter_threshold if isinstance(noise_std, np.ndarray): noise_std = np.mean(noise_std) if not isinstance(noise_std, (float, np.floating)): raise ValueError( "For NORDIC the noise level must be either an" + " array or a float specifying the std in the volume.", ) self._threshold = noise_std * max_sval return super(RawSVDDenoiser, self).denoise( input_data, mask, mask_threshold=mask_threshold, progbar=progbar )
# From MATLAB implementation def _opt_loss_x(y, beta): """Compute (8) of donoho2017.""" tmp = y ** 2 - beta - 1 return np.sqrt(0.5 * (tmp + np.sqrt((tmp ** 2) - (4 * beta)))) * ( y >= (1 + np.sqrt(beta)) ) def _opt_ope_shrink(singvals, beta=1): """Perform optimal threshold of singular values for operator norm.""" return np.maximum(_opt_loss_x(singvals, beta), 0) def _opt_nuc_shrink(singvals, beta=1): """Perform optimal threshold of singular values for nuclear norm.""" tmp = _opt_loss_x(singvals, beta) return ( np.maximum( 0, (tmp ** 4 - (np.sqrt(beta) * tmp * singvals) - beta), ) / ((tmp ** 2) * singvals) ) def _opt_fro_shrink(singvals, beta=1): """Perform optimal threshold of singular values for frobenius norm.""" return np.sqrt( np.maximum( (((singvals ** 2) - beta - 1) ** 2 - 4 * beta), 0, ) / singvals )
[docs] @fill_doc class OptimalSVDDenoiser(BaseSpaceTimeDenoiser): """ Optimal Shrinkage of singular values for a specific norm. Parameters ---------- $patch_config loss: str The loss determines the choice of the optimal thresholding function associated to it. The losses `"fro"`, `"nuc"` and `"op"` are supported, for the frobenius, nuclear and operator norm, respectively. """ _OPT_LOSS_SHRINK = MappingProxyType( { "fro": _opt_fro_shrink, "nuc": _opt_nuc_shrink, "ope": _opt_ope_shrink, } ) def __init__( self, patch_shape, patch_overlap, loss="fro", recombination="weighted", ): super().__init__(patch_shape, patch_overlap, recombination=recombination) self.input_denoising_kwargs[ "shrink_func" ] = OptimalSVDDenoiser._OPT_LOSS_SHRINK[loss]
[docs] @fill_doc def denoise( self, input_data, mask=None, mask_threshold=50, noise_std=None, eps_marshenko_pastur=1e-7, progbar=None, ): """ Optimal thresholing denoising method. Parameters ---------- $input_config $mask_config $noise_std loss: str The loss for which the optimal thresholding is performed. eps_marshenko_pastur: float The precision with which the optimal threshold is computed. Returns ------- $denoise_return Notes ----- Reimplementation of the original Matlab code [#]_ in python. References ---------- .. [#] Gavish, Matan, and David L. Donoho. \ "Optimal Shrinkage of Singular Values." IEEE Transactions on Information Theory 63, no. 4 (April 2017): 2137–52. https://doi.org/10.1109/TIT.2017.2653801. """ p_s, p_o = self._get_patch_param(input_data.shape) self.input_denoising_kwargs["mp_median"] = marchenko_pastur_median( beta=input_data.shape[-1] / np.prod(p_s), eps=eps_marshenko_pastur, ) if noise_std is None: self.input_denoising_kwargs["var_apriori"] = None else: if isinstance(noise_std, (float, np.floating)): var_apriori = noise_std ** 2 * np.ones(input_data.shape[:-1]) else: var_apriori = noise_std ** 2 var_apriori = PatchedArray( np.broadcast_to(var_apriori[..., None], input_data.shape), p_s, p_o ) self.input_denoising_kwargs["var_apriori"] = var_apriori return super().denoise(input_data, mask, mask_threshold, progbar=progbar)
[docs] def _patch_processing( self, patch, patch_idx=None, shrink_func=None, mp_median=None, var_apriori=None, ): u_vec, s_values, v_vec, p_tmean = svd_analysis(patch) if var_apriori is not None: sigma = np.mean(np.sqrt(var_apriori.get_patch(patch_idx))) else: sigma = np.median(s_values) / np.sqrt(patch.shape[1] * mp_median) scale_factor = np.sqrt(patch.shape[1]) * sigma thresh_s_values = scale_factor * shrink_func( s_values / scale_factor, beta=patch.shape[1] / patch.shape[0], ) thresh_s_values[np.isnan(thresh_s_values)] = 0 if np.any(thresh_s_values): maxidx = np.max(np.nonzero(thresh_s_values)) + 1 p_new = svd_synthesis(u_vec, thresh_s_values, v_vec, p_tmean, maxidx) else: maxidx = 0 p_new = np.zeros_like(patch) + p_tmean return p_new, maxidx, np.nan
def _sure_atn_cost(X, method, sing_vals, gamma, sigma=None, tau=None): """ Compute the SURE cost function. Parameters ---------- X: np.ndarray sing_vals : singular values of X gamma: float sigma: float tau: float """ n, p = np.shape(X) if method == "qut": gamma = np.exp(gamma) + 1 else: tau = np.exp(tau) sing_vals2 = sing_vals ** 2 n_vals = len(sing_vals) D = np.zeros((n_vals, n_vals), dtype=np.float32) dhat = sing_vals * np.maximum(1 - ((tau / sing_vals) ** gamma), 0) tmp = sing_vals * dhat for i in range(n_vals): diff2i = sing_vals2[i] - sing_vals2 diff2i[i] = np.inf D[i, :] = tmp[i] / diff2i gradd = (1 + (gamma - 1) * (tau / sing_vals) ** gamma) * (sing_vals >= tau) div = np.sum(gradd + abs(n - p) * dhat / sing_vals) + 2 * np.sum(D) rss = np.sum((dhat - sing_vals) ** 2) if method == "gsure": return rss / (1 - div / n / p) ** 2 return (sigma ** 2) * ((-n * p) + (2 * div)) + rss if NUMBA_AVAILABLE: s = nb.float32 d = nb.float64 sure_atn_cost = nb.njit( [ s(s[:, :], nb.types.unicode_type, s[:], s, s, s), s(s[:, :], nb.types.unicode_type, s[:], s, d, d), ], fastmath=True, )(_sure_atn_cost) def _atn_shrink(singvals, gamma, tau): """Adaptive trace norm shrinkage.""" return singvals * np.maximum(1 - (tau / singvals) ** gamma, 0) def _get_gamma_tau_qut(patch, sing_vals, stdest, gamma0, nbsim): """Estimate gamma and tau using the quantile method.""" maxd = np.ones(nbsim) for i in range(nbsim): maxd[i] = np.max( svd( np.random.randn(*patch.shape) * stdest, compute_uv=False, overwrite_a=True, ) ) # auto estimation of tau. tau = np.quantile(maxd, 1 - 1 / np.sqrt(np.log(max(*patch.shape)))) # single value for gamma not provided, estimating it. if not isinstance(gamma0, (float, np.floating)): def sure_gamma(gamma): return _sure_atn_cost( X=patch, method="qut", sing_vals=sing_vals, gamma=gamma, sigma=stdest, tau=tau, ) res_opti = minimize(sure_gamma, 0) gamma = np.exp(res_opti.x) + 1 else: gamma = gamma0 return gamma, tau def _get_gamma_tau(patch, sing_vals, stdest, method, gamma0, tau0): """Estimate gamma and tau.""" # estimation of tau def sure_tau(tau, *args): return _sure_atn_cost(*args, tau[0]) if tau0 is None: tau0 = np.log(np.median(sing_vals)) cost_glob = np.inf for g in gamma0: res_opti = minimize( lambda x: _sure_atn_cost( X=patch, method=method, gamma=g, # noqa: B023 sing_vals=sing_vals, sigma=stdest, tau=x, ), tau0, ) # get cost value. cost = _sure_atn_cost( X=patch, method=method, gamma=g, sing_vals=sing_vals, sigma=stdest, tau=res_opti.x, ) if cost < cost_glob: gamma = g tau = np.exp(res_opti.x) cost_glob = cost return gamma, tau
[docs] @fill_doc class AdaptiveDenoiser(BaseSpaceTimeDenoiser): """Adaptive Denoiser. Parameters ---------- $patch_config """ _SUPPORTED_METHOD = ["sure", "qut", "gsure"] def __init__( self, patch_shape, patch_overlap, method="SURE", recombination="weighted", nbsim=500, ): super().__init__(patch_shape, patch_overlap, recombination) if method.lower() not in self._SUPPORTED_METHOD: raise ValueError( f"Unsupported method: '{method}', use any of {self._SUPPORTED_METHOD}" ) self.input_denoising_kwargs["method"] = method.lower() self.input_denoising_kwargs["nbsim"] = nbsim
[docs] @fill_doc def denoise( self, input_data, mask=None, mask_threshold=50, tau0=None, noise_std=None, gamma0=None, progbar=None, ): """ Adaptive denoiser. Perform the denoising using the adaptive trace norm estimator. [#]_ Parameters ---------- $input_config $mask_config $noise_std References ---------- .. [#] J. Josse and S. Sardy, “Adaptive Shrinkage of singular values.” arXiv, Nov. 22, 2014. doi: 10.48550/arXiv.1310.6602. """ self.input_denoising_kwargs["gamma0"] = gamma0 self.input_denoising_kwargs["tau0"] = tau0 p_s, p_o = self._get_patch_param(input_data.shape) if isinstance(noise_std, (float, np.floating)): var_apriori = noise_std ** 2 * np.ones(input_data.shape[:-1]) else: var_apriori = noise_std ** 2 var_apriori = PatchedArray( np.broadcast_to(var_apriori[..., None], input_data.shape), p_s, p_o ) self.input_denoising_kwargs["var_apriori"] = var_apriori return super().denoise(input_data, mask, mask_threshold, progbar=progbar)
[docs] def _patch_processing( self, patch, patch_idx=None, gamma0=None, tau0=None, var_apriori=None, method=None, nbsim=None, ): stdest = np.sqrt(np.mean(var_apriori.get_patch(patch_idx))) u_vec, sing_vals, v_vec, p_tmean = svd_analysis(patch) if method == "qut": gamma, tau = _get_gamma_tau_qut(patch, sing_vals, stdest, gamma0, nbsim) else: gamma, tau = _get_gamma_tau(patch, sing_vals, stdest, method, gamma0, tau0) # end of parameter selection # Perform thresholding thresh_s_values = _atn_shrink(sing_vals, gamma=gamma, tau=tau) if np.any(thresh_s_values): maxidx = np.max(np.nonzero(thresh_s_values)) + 1 p_new = svd_synthesis(u_vec, thresh_s_values, v_vec, p_tmean, maxidx) else: maxidx = 0 p_new = np.zeros_like(patch) + p_tmean return p_new, maxidx, np.nan