Source code for patch_denoise.space_time.base

"""Base Structure for patch-based denoising on spatio-temporal dimension."""

import abc
import logging

import numpy as np
from tqdm.auto import tqdm

from .._docs import fill_doc


[docs] class PatchedArray: """A container for accessing custom view of array easily. Parameters ---------- array: np.ndarray patch_shape: tuple patch_overlap: tuple """ def __init__( self, array, patch_shape, patch_overlap, dtype=None, padding_mode="edge", **kwargs, ): if isinstance(array, tuple): array = np.zeros(array, dtype=dtype) self._arr = array self._ps = np.asarray(patch_shape) self._po = np.asarray(patch_overlap) self._po = patch_overlap dimensions = self._arr.ndim step = self._ps - self._po if np.any(step < 0): raise ValueError("overlap should be smaller than patch on every dimension.") if self._ps.size != dimensions or step.size != dimensions: raise ValueError( "self._ps and step must have the same number of dimensions as the " "input self._array." ) # Ensure patch size is not larger than self._array size along each axis self._ps = np.minimum(self._ps, self._arr.shape) # Calculate the shape and strides of the sliding view grid_shape = tuple( ((self._arr.shape[i] - self._ps[i]) // step[i] + 1) if self._ps[i] < self._arr.shape[i] else 1 for i in range(dimensions) ) shape = grid_shape + tuple(self._ps) strides = ( tuple( self._arr.strides[i] * step[i] if self._ps[i] < self._arr.shape[i] else 0 for i in range(dimensions) ) + self._arr.strides ) # Create the sliding view self.sliding_view = np.lib.stride_tricks.as_strided( self._arr, shape=shape, strides=strides ) self._grid_shape = grid_shape @property def n_patches(self): """Get number of patches.""" return np.prod(self._grid_shape)
[docs] def get_patch(self, idx): """Get patch at linear index ``idx``.""" return self.sliding_view[np.unravel_index(idx, self._grid_shape)]
[docs] def set_patch(self, idx, value): """Set patch at linear index ``idx`` with value.""" self.sliding_view[np.unravel_index(idx, self._grid_shape)]
[docs] def add2patch(self, idx, value): """Add to patch, in place.""" patch = self.get_patch(idx) # self.set_patch(idx, patch + value) patch += value
# def sync(self): # """Apply the padded value to the array back.""" # np.copyto( # self._array, # self._padded_array[ # tuple( # np.s_[: (s + 1 - ps) if (s - ps) else s] # for ps, s in zip(self._ps, self._padded_array.shape) # ) # ], # ) # def get(self): # """Return the regular array, after applying the padded values.""" # self.sync() # return self._array def __getattr__(self, name): """Get attribute of underlying array.""" return getattr(self._arr, name)
[docs] @fill_doc class BaseSpaceTimeDenoiser(abc.ABC): """ Base Class for Patch-based denoising methods for dynamical data. Parameters ---------- $patch_config """ def __init__(self, patch_shape, patch_overlap, recombination="weighted"): self.p_shape = patch_shape self.p_ovl = patch_overlap if recombination not in ["weighted", "average", "center"]: raise ValueError( "recombination must be one of 'weighted', 'average', 'center'" ) self.recombination = recombination self.input_denoising_kwargs = dict()
[docs] @fill_doc def denoise(self, input_data, mask=None, mask_threshold=50, progbar=None): """Denoise the input_data, according to mask. Patches are extracted sequentially and process by the implemented `_patch_processing` function. Only patches which have at least a voxel in the mask ROI are processed. Parameters ---------- $input_config $mask_config Returns ------- $denoise_return """ data_shape = input_data.shape p_s, p_o = self._get_patch_param(data_shape) input_data = PatchedArray(input_data, p_s, p_o) output_data = PatchedArray(data_shape, p_s, p_o, dtype=input_data.dtype) patch_weights = PatchedArray(data_shape, p_s, p_o, dtype=np.float32) rank_map = PatchedArray(data_shape, p_s, p_o, dtype=np.int32) noise_std_estimate = PatchedArray(data_shape, p_s, p_o, dtype=np.float32) # Create Default mask if mask is None: process_mask = np.full(data_shape, True) elif mask.shape == input_data.shape[:-1]: process_mask = np.broadcast_to(mask, input_data.shape) process_mask = PatchedArray( process_mask, p_s, p_o, padding_mode="constant", constant_values=0 ) center_pos = tuple(p // 2 for p in p_s) patch_space_size = np.prod(p_s[:-1]) # select only queue index where process_mask is valid. get_it = np.zeros(input_data.n_patches, dtype=bool) for i in range(len(get_it)): pm = process_mask.get_patch(i) if 100 * np.sum(pm) / pm.size > mask_threshold: get_it[i] = True select_patches = np.nonzero(get_it)[0] del get_it if progbar is None: progbar = tqdm(total=len(select_patches)) elif progbar is not False: progbar.reset(total=len(select_patches)) for i in select_patches: input_patch_casorati = input_data.get_patch(i).reshape(patch_space_size, -1) p_denoise, maxidx, noise_var = self._patch_processing( input_patch_casorati, patch_idx=i, **self.input_denoising_kwargs, ) p_denoise = np.reshape(p_denoise, p_s) if self.recombination == "center": output_data.get_patch(i)[center_pos] = p_denoise[center_pos] elif self.recombination == "weighted": theta = 1 / (2 + maxidx) output_data.add2patch(i, p_denoise * theta) patch_weights.add2patch(i, theta) elif self.recombination == "average": output_data.add2patch(i, p_denoise) patch_weights.add2patch(i, 1) else: raise ValueError( "recombination must be one of 'weighted', 'average', 'center'" ) if progbar: progbar.update() # Averaging the overlapping pixels. # this is only required for averaging recombinations. output_data = output_data._arr patch_weights = patch_weights._arr if self.recombination in ["average", "weighted"]: output_data /= patch_weights output_data[~process_mask._arr] = 0 return output_data, patch_weights, noise_std_estimate, rank_map
# if self.recombination == "center": # patch_center = ( # *(slice(ps // 2, ps // 2 + 1) for ps in patch_shape), # slice(None, None, None), # ) # patchs_weight = np.zeros(data_shape[:-1], np.float32) # noise_std_estimate = np.zeros(data_shape[:-1], dtype=np.float32) # # discard useless patches # patch_locs = get_patch_locs(patch_shape, patch_overlap, data_shape) # get_it = np.zeros(len(patch_locs), dtype=bool) # for i, patch_tl in enumerate(patch_locs): # patch_slice = tuple( # slice(tl, tl + ps) for tl, ps in zip(patch_tl, patch_shape) # ) # if 100 * np.sum(process_mask[patch_slice]) / patch_size > mask_threshold: # get_it[i] = True # logging.info(f"Denoise {100 * np.sum(get_it) / len(patch_locs):.2f}% patches") # patch_locs = np.ascontiguousarray(patch_locs[get_it]) # if progbar is None: # progbar = tqdm(total=len(patch_locs)) # elif progbar is not False: # progbar.reset(total=len(patch_locs)) # for patch_tl in patch_locs: # patch_slice = tuple( # slice(tl, tl + ps) for tl, ps in zip(patch_tl, patch_shape) # ) # process_mask[patch_slice] = 1 # # building the casoratti matrix # patch = np.reshape(input_data[patch_slice], (-1, input_data.shape[-1])) # # Replace all nan by mean value of patch. # # FIXME this behaviour should be documented # # And ideally chosen by the user. # patch[np.isnan(patch)] = np.mean(patch) # p_denoise, maxidx, noise_var = self._patch_processing( # patch, # patch_slice=patch_slice, # **self.input_denoising_kwargs, # ) # p_denoise = np.reshape(p_denoise, (*patch_shape, -1)) # patch_center_img = tuple( # ptl + ps // 2 for ptl, ps in zip(patch_tl, patch_shape) # ) # if self.recombination == "center": # output_data[patch_center_img] = p_denoise[patch_center] # noise_std_estimate[patch_center_img] += noise_var # elif self.recombination == "weighted": # theta = 1 / (2 + maxidx) # output_data[patch_slice] += p_denoise * theta # patchs_weight[patch_slice] += theta # elif self.recombination == "average": # output_data[patch_slice] += p_denoise # patchs_weight[patch_slice] += 1 # else: # raise ValueError( # "recombination must be one of 'weighted', 'average', 'center'" # ) # if not np.isnan(noise_var): # noise_std_estimate[patch_slice] += noise_var # # the top left corner of the patch is used as id for the patch. # rank_map[patch_center_img] = maxidx # if progbar: # progbar.update() # # Averaging the overlapping pixels. # # this is only required for averaging recombinations. # if self.recombination in ["average", "weighted"]: # output_data /= patchs_weight[..., None] # noise_std_estimate /= patchs_weight # output_data[~process_mask] = 0 # return output_data, patchs_weight, noise_std_estimate, rank_map
[docs] @abc.abstractmethod def _patch_processing(self, patch, patch_slice=None, **kwargs): """Process a patch. Implemented by child classes. """
[docs] def _get_patch_param(self, data_shape): """Return tuple for patch_shape and patch_overlap. It works from whatever the input format was (int or list). This method also ensure that the patch will provide tall and skinny matrices. """ pp = [None, None] for i, attr in enumerate(["p_shape", "p_ovl"]): p = getattr(self, attr) if isinstance(p, list): p = tuple(p) elif isinstance(p, (int, np.integer)): p = (p,) * (len(data_shape) - 1) if len(p) == len(data_shape) - 1: # add the time dimension p = (*p, data_shape[-1]) pp[i] = p if np.prod(pp[0][:-1]) < data_shape[-1]: logging.warning( f"the number of voxel in patch ({np.prod(pp[0])}) is smaller than the" f" last dimension ({data_shape[-1]}), this makes an ill-conditioned" "matrix for SVD.", stacklevel=2, ) return tuple(pp)