"""Low Rank methods."""
from types import MappingProxyType
import numpy as np
from scipy.linalg import svd
from scipy.optimize import minimize
from .base import BaseSpaceTimeDenoiser, PatchedArray
from .utils import (
    eig_analysis,
    eig_synthesis,
    marchenko_pastur_median,
    svd_analysis,
    svd_synthesis,
)
from .._docs import fill_doc
NUMBA_AVAILABLE = True
try:
    import numba as nb
except ImportError:
    NUMBA_AVAILABLE = False
    pass
[docs]
@fill_doc
class MPPCADenoiser(BaseSpaceTimeDenoiser):
    """Denoising using Marchenko-Pastur principal components analysis thresholding.
    Parameters
    ----------
    $patch_config
    threshold_scale: float
        An extra factor multiplying the threshold.
    """
    def __init__(self, patch_shape, patch_overlap, threshold_scale, **kwargs):
        super().__init__(patch_shape, patch_overlap, **kwargs)
        self.input_denoising_kwargs["threshold_scale"] = threshold_scale
[docs]
    def _patch_processing(self, patch, patch_idx=None, threshold_scale=1.0):
        """Process a patch with the MP-PCA method."""
        p_center, eig_vals, eig_vec, p_tmean = eig_analysis(patch)
        maxidx = 0
        meanvar = np.mean(eig_vals)
        meanvar *= 4 * np.sqrt((len(eig_vals) - maxidx + 1) / len(patch))
        while maxidx < len(eig_vals) and meanvar < eig_vals[~maxidx] - eig_vals[0]:
            maxidx += 1
            meanvar = np.mean(eig_vals[:-maxidx])
            meanvar *= 4 * np.sqrt((len(eig_vec) - maxidx + 1) / len(patch))
        var_noise = np.mean(eig_vals[: len(eig_vals) - maxidx])
        maxidx = np.sum(eig_vals > (var_noise * threshold_scale ** 2))
        if maxidx == 0:
            patch_new = np.zeros_like(patch) + p_tmean
        else:
            patch_new = eig_synthesis(p_center, eig_vec, p_tmean, maxidx)
        # Equation (3) of Manjon 2013
        return patch_new, maxidx, var_noise 
 
[docs]
@fill_doc
class HybridPCADenoiser(BaseSpaceTimeDenoiser):
    """Denoising using the Hybrid-PCA thresholding method.
    Parameters
    ----------
    $patch_config
    """
[docs]
    @fill_doc
    def denoise(
        self, input_data, mask=None, mask_threshold=50, noise_std=1.0, progbar=None
    ):
        """Denoise using the Hybrid-PCA method.
        Along with the input data a noise std map or value should be provided.
        Parameters
        ----------
        $input_config
        $mask_config
        $noise_std
        Returns
        -------
        $denoise_return
        """
        p_s, p_o = self._get_patch_param(input_data.shape)
        if isinstance(noise_std, (float, np.floating)):
            var_apriori = noise_std ** 2 * np.ones(input_data.shape[:-1])
        else:
            var_apriori = noise_std ** 2
        var_apriori = PatchedArray(
            np.broadcast_to(var_apriori[..., None], input_data.shape), p_s, p_o
        )
        self.input_denoising_kwargs["var_apriori"] = var_apriori
        return super().denoise(input_data, mask, mask_threshold, progbar=progbar) 
[docs]
    def _patch_processing(self, patch, patch_idx=None, var_apriori=None):
        """Process a patch with the Hybrid-PCA method."""
        varest = np.mean(var_apriori.get_patch(patch_idx))
        p_center, eig_vals, eig_vec, p_tmean = eig_analysis(patch)
        maxidx = 0
        var_noise = np.mean(eig_vals)
        while var_noise > varest and maxidx < len(eig_vals) - 2:
            maxidx += 1
            var_noise = np.mean(eig_vals[:-maxidx])
        if maxidx == 0:  # all eigen values are noise
            patch_new = np.zeros_like(patch) + p_tmean
        else:
            patch_new = eig_synthesis(p_center, eig_vec, p_tmean, maxidx)
        # Equation (3) of Manjon2013
        return patch_new, maxidx, var_noise 
 
[docs]
@fill_doc
class RawSVDDenoiser(BaseSpaceTimeDenoiser):
    """
    Classical Patch wise singular value thresholding denoiser.
    Parameters
    ----------
    $patch_config
    threshold_vlue: float
        threshold value for the singular values.
    """
    def __init__(
        self, patch_shape, patch_overlap, threshold_value=1.0, recombination="weighted"
    ):
        self._threshold_val = threshold_value
        super().__init__(patch_shape, patch_overlap, recombination)
[docs]
    @fill_doc
    def denoise(
        self,
        input_data,
        mask=None,
        mask_threshold=50,
        threshold_scale=1.0,
        progbar=None,
    ):
        """Denoise the input_data, according to mask.
        Patches are extracted sequentially and process by the implemented
        `_patch_processing` function.
        Only patches which have at least a voxel in the mask ROI are processed.
        Parameters
        ----------
        $input_config
        $mask_config
        threshold_scale: float
            Extra factor for the threshold of singular values.
        Returns
        -------
        $denoise_return
        """
        self._threshold = self._threshold_val * threshold_scale
        return super().denoise(input_data, mask, mask_threshold, progbar=progbar) 
[docs]
    def _patch_processing(self, patch, patch_idx=None, **kwargs):
        """Process a patch with the simple SVT method."""
        # Centering for better precision in SVD
        u_vec, s_values, v_vec, p_tmean = svd_analysis(patch)
        maxidx = np.sum(s_values > self._threshold)
        if maxidx == 0:
            p_new = np.zeros_like(patch) + p_tmean
        else:
            s_values[s_values < self._threshold] = 0
            p_new = svd_synthesis(u_vec, s_values, v_vec, p_tmean, maxidx)
        # Equation (3) in Manjon 2013
        return p_new, maxidx, np.nan 
 
[docs]
@fill_doc
class NordicDenoiser(RawSVDDenoiser):
    """Denoising using the NORDIC method.
    Parameters
    ----------
    $patch_config
    """
[docs]
    @fill_doc
    def denoise(
        self,
        input_data,
        mask=None,
        mask_threshold=50,
        noise_std=1.0,
        n_iter_threshold=10,
        progbar=None,
    ):
        """Denoise using the NORDIC method.
        Along with the input data a noise std map or value should be provided.
        Parameters
        ----------
        $input_config
        $mask_config
        $noise_std
        Returns
        -------
        $denoise_return
        """
        patch_shape, _ = self._get_patch_param(input_data.shape)
        # compute the threshold using Monte-Carlo Simulations.
        max_sval = sum(
            max(
                svd(
                    np.random.randn(np.prod(patch_shape), input_data.shape[-1]),
                    compute_uv=False,
                )
            )
            for _ in range(n_iter_threshold)
        )
        max_sval /= n_iter_threshold
        if isinstance(noise_std, np.ndarray):
            noise_std = np.mean(noise_std)
        if not isinstance(noise_std, (float, np.floating)):
            raise ValueError(
                "For NORDIC the noise level must be either an"
                + " array or a float specifying the std in the volume.",
            )
        self._threshold = noise_std * max_sval
        return super(RawSVDDenoiser, self).denoise(
            input_data, mask, mask_threshold=mask_threshold, progbar=progbar
        ) 
 
# From MATLAB implementation
def _opt_loss_x(y, beta):
    """Compute (8) of donoho2017."""
    tmp = y ** 2 - beta - 1
    return np.sqrt(0.5 * (tmp + np.sqrt((tmp ** 2) - (4 * beta)))) * (
        y >= (1 + np.sqrt(beta))
    )
def _opt_ope_shrink(singvals, beta=1):
    """Perform optimal threshold of singular values for operator norm."""
    return np.maximum(_opt_loss_x(singvals, beta), 0)
def _opt_nuc_shrink(singvals, beta=1):
    """Perform optimal threshold of singular values for nuclear norm."""
    tmp = _opt_loss_x(singvals, beta)
    return (
        np.maximum(
            0,
            (tmp ** 4 - (np.sqrt(beta) * tmp * singvals) - beta),
        )
        / ((tmp ** 2) * singvals)
    )
def _opt_fro_shrink(singvals, beta=1):
    """Perform optimal threshold of singular values for frobenius norm."""
    return np.sqrt(
        np.maximum(
            (((singvals ** 2) - beta - 1) ** 2 - 4 * beta),
            0,
        )
        / singvals
    )
[docs]
@fill_doc
class OptimalSVDDenoiser(BaseSpaceTimeDenoiser):
    """
    Optimal Shrinkage of singular values for a specific norm.
    Parameters
    ----------
    $patch_config
    loss: str
        The loss determines the choice of the optimal thresholding function
        associated to it. The losses `"fro"`, `"nuc"` and `"op"` are supported,
        for the frobenius, nuclear and operator norm, respectively.
    """
    _OPT_LOSS_SHRINK = MappingProxyType(
        {
            "fro": _opt_fro_shrink,
            "nuc": _opt_nuc_shrink,
            "ope": _opt_ope_shrink,
        }
    )
    def __init__(
        self,
        patch_shape,
        patch_overlap,
        loss="fro",
        recombination="weighted",
    ):
        super().__init__(patch_shape, patch_overlap, recombination=recombination)
        self.input_denoising_kwargs[
            "shrink_func"
        ] = OptimalSVDDenoiser._OPT_LOSS_SHRINK[loss]
[docs]
    @fill_doc
    def denoise(
        self,
        input_data,
        mask=None,
        mask_threshold=50,
        noise_std=None,
        eps_marshenko_pastur=1e-7,
        progbar=None,
    ):
        """
        Optimal thresholing denoising method.
        Parameters
        ----------
        $input_config
        $mask_config
        $noise_std
        loss: str
            The loss for which the optimal thresholding is performed.
        eps_marshenko_pastur: float
            The precision with which the optimal threshold is computed.
        Returns
        -------
        $denoise_return
        Notes
        -----
        Reimplementation of the original Matlab code [#]_ in python.
        References
        ----------
        .. [#] Gavish, Matan, and David L. Donoho. \
            "Optimal Shrinkage of Singular Values."
            IEEE Transactions on Information Theory 63, no. 4 (April 2017): 2137–52.
            https://doi.org/10.1109/TIT.2017.2653801.
        """
        p_s, p_o = self._get_patch_param(input_data.shape)
        self.input_denoising_kwargs["mp_median"] = marchenko_pastur_median(
            beta=input_data.shape[-1] / np.prod(p_s),
            eps=eps_marshenko_pastur,
        )
        if noise_std is None:
            self.input_denoising_kwargs["var_apriori"] = None
        else:
            if isinstance(noise_std, (float, np.floating)):
                var_apriori = noise_std ** 2 * np.ones(input_data.shape[:-1])
            else:
                var_apriori = noise_std ** 2
            var_apriori = PatchedArray(
                np.broadcast_to(var_apriori[..., None], input_data.shape), p_s, p_o
            )
            self.input_denoising_kwargs["var_apriori"] = var_apriori
        return super().denoise(input_data, mask, mask_threshold, progbar=progbar) 
[docs]
    def _patch_processing(
        self,
        patch,
        patch_idx=None,
        shrink_func=None,
        mp_median=None,
        var_apriori=None,
    ):
        u_vec, s_values, v_vec, p_tmean = svd_analysis(patch)
        if var_apriori is not None:
            sigma = np.mean(np.sqrt(var_apriori.get_patch(patch_idx)))
        else:
            sigma = np.median(s_values) / np.sqrt(patch.shape[1] * mp_median)
        scale_factor = np.sqrt(patch.shape[1]) * sigma
        thresh_s_values = scale_factor * shrink_func(
            s_values / scale_factor,
            beta=patch.shape[1] / patch.shape[0],
        )
        thresh_s_values[np.isnan(thresh_s_values)] = 0
        if np.any(thresh_s_values):
            maxidx = np.max(np.nonzero(thresh_s_values)) + 1
            p_new = svd_synthesis(u_vec, thresh_s_values, v_vec, p_tmean, maxidx)
        else:
            maxidx = 0
            p_new = np.zeros_like(patch) + p_tmean
        return p_new, maxidx, np.nan 
 
def _sure_atn_cost(X, method, sing_vals, gamma, sigma=None, tau=None):
    """
    Compute the SURE cost function.
    Parameters
    ----------
    X: np.ndarray
    sing_vals : singular values of X
    gamma: float
    sigma: float
    tau: float
    """
    n, p = np.shape(X)
    if method == "qut":
        gamma = np.exp(gamma) + 1
    else:
        tau = np.exp(tau)
    sing_vals2 = sing_vals ** 2
    n_vals = len(sing_vals)
    D = np.zeros((n_vals, n_vals), dtype=np.float32)
    dhat = sing_vals * np.maximum(1 - ((tau / sing_vals) ** gamma), 0)
    tmp = sing_vals * dhat
    for i in range(n_vals):
        diff2i = sing_vals2[i] - sing_vals2
        diff2i[i] = np.inf
        D[i, :] = tmp[i] / diff2i
    gradd = (1 + (gamma - 1) * (tau / sing_vals) ** gamma) * (sing_vals >= tau)
    div = np.sum(gradd + abs(n - p) * dhat / sing_vals) + 2 * np.sum(D)
    rss = np.sum((dhat - sing_vals) ** 2)
    if method == "gsure":
        return rss / (1 - div / n / p) ** 2
    return (sigma ** 2) * ((-n * p) + (2 * div)) + rss
if NUMBA_AVAILABLE:
    s = nb.float32
    d = nb.float64
    sure_atn_cost = nb.njit(
        [
            s(s[:, :], nb.types.unicode_type, s[:], s, s, s),
            s(s[:, :], nb.types.unicode_type, s[:], s, d, d),
        ],
        fastmath=True,
    )(_sure_atn_cost)
def _atn_shrink(singvals, gamma, tau):
    """Adaptive trace norm shrinkage."""
    return singvals * np.maximum(1 - (tau / singvals) ** gamma, 0)
def _get_gamma_tau_qut(patch, sing_vals, stdest, gamma0, nbsim):
    """Estimate gamma and tau using the quantile method."""
    maxd = np.ones(nbsim)
    for i in range(nbsim):
        maxd[i] = np.max(
            svd(
                np.random.randn(*patch.shape) * stdest,
                compute_uv=False,
                overwrite_a=True,
            )
        )
    # auto estimation of tau.
    tau = np.quantile(maxd, 1 - 1 / np.sqrt(np.log(max(*patch.shape))))
    # single value for gamma not provided, estimating it.
    if not isinstance(gamma0, (float, np.floating)):
        def sure_gamma(gamma):
            return _sure_atn_cost(
                X=patch,
                method="qut",
                sing_vals=sing_vals,
                gamma=gamma,
                sigma=stdest,
                tau=tau,
            )
        res_opti = minimize(sure_gamma, 0)
        gamma = np.exp(res_opti.x) + 1
    else:
        gamma = gamma0
    return gamma, tau
def _get_gamma_tau(patch, sing_vals, stdest, method, gamma0, tau0):
    """Estimate gamma and tau."""
    # estimation of tau
    def sure_tau(tau, *args):
        return _sure_atn_cost(*args, tau[0])
    if tau0 is None:
        tau0 = np.log(np.median(sing_vals))
    cost_glob = np.inf
    for g in gamma0:
        res_opti = minimize(
            lambda x: _sure_atn_cost(
                X=patch,
                method=method,
                gamma=g,  # noqa: B023
                sing_vals=sing_vals,
                sigma=stdest,
                tau=x,
            ),
            tau0,
        )
        # get cost value.
        cost = _sure_atn_cost(
            X=patch,
            method=method,
            gamma=g,
            sing_vals=sing_vals,
            sigma=stdest,
            tau=res_opti.x,
        )
        if cost < cost_glob:
            gamma = g
            tau = np.exp(res_opti.x)
            cost_glob = cost
    return gamma, tau
[docs]
@fill_doc
class AdaptiveDenoiser(BaseSpaceTimeDenoiser):
    """Adaptive Denoiser.
    Parameters
    ----------
    $patch_config
    """
    _SUPPORTED_METHOD = ["sure", "qut", "gsure"]
    def __init__(
        self,
        patch_shape,
        patch_overlap,
        method="SURE",
        recombination="weighted",
        nbsim=500,
    ):
        super().__init__(patch_shape, patch_overlap, recombination)
        if method.lower() not in self._SUPPORTED_METHOD:
            raise ValueError(
                f"Unsupported method: '{method}', use any of {self._SUPPORTED_METHOD}"
            )
        self.input_denoising_kwargs["method"] = method.lower()
        self.input_denoising_kwargs["nbsim"] = nbsim
[docs]
    @fill_doc
    def denoise(
        self,
        input_data,
        mask=None,
        mask_threshold=50,
        tau0=None,
        noise_std=None,
        gamma0=None,
        progbar=None,
    ):
        """
        Adaptive denoiser.
        Perform the denoising using the adaptive trace norm estimator. [#]_
        Parameters
        ----------
        $input_config
        $mask_config
        $noise_std
        References
        ----------
        .. [#] J. Josse and S. Sardy, “Adaptive Shrinkage of singular values.”
            arXiv, Nov. 22, 2014.
            doi: 10.48550/arXiv.1310.6602.
        """
        self.input_denoising_kwargs["gamma0"] = gamma0
        self.input_denoising_kwargs["tau0"] = tau0
        p_s, p_o = self._get_patch_param(input_data.shape)
        if isinstance(noise_std, (float, np.floating)):
            var_apriori = noise_std ** 2 * np.ones(input_data.shape[:-1])
        else:
            var_apriori = noise_std ** 2
        var_apriori = PatchedArray(
            np.broadcast_to(var_apriori[..., None], input_data.shape), p_s, p_o
        )
        self.input_denoising_kwargs["var_apriori"] = var_apriori
        return super().denoise(input_data, mask, mask_threshold, progbar=progbar) 
[docs]
    def _patch_processing(
        self,
        patch,
        patch_idx=None,
        gamma0=None,
        tau0=None,
        var_apriori=None,
        method=None,
        nbsim=None,
    ):
        stdest = np.sqrt(np.mean(var_apriori.get_patch(patch_idx)))
        u_vec, sing_vals, v_vec, p_tmean = svd_analysis(patch)
        if method == "qut":
            gamma, tau = _get_gamma_tau_qut(patch, sing_vals, stdest, gamma0, nbsim)
        else:
            gamma, tau = _get_gamma_tau(patch, sing_vals, stdest, method, gamma0, tau0)
        # end of parameter selection
        # Perform thresholding
        thresh_s_values = _atn_shrink(sing_vals, gamma=gamma, tau=tau)
        if np.any(thresh_s_values):
            maxidx = np.max(np.nonzero(thresh_s_values)) + 1
            p_new = svd_synthesis(u_vec, thresh_s_values, v_vec, p_tmean, maxidx)
        else:
            maxidx = 0
            p_new = np.zeros_like(patch) + p_tmean
        return p_new, maxidx, np.nan