"""Base Structure for patch-based denoising on spatio-temporal dimension."""
import abc
import logging
import numpy as np
from tqdm.auto import tqdm
from .._docs import fill_doc
[docs]
class PatchedArray:
"""A container for accessing custom view of array easily.
Parameters
----------
array: np.ndarray
patch_shape: tuple
patch_overlap: tuple
"""
def __init__(
self,
array,
patch_shape,
patch_overlap,
dtype=None,
padding_mode="edge",
**kwargs,
):
if isinstance(array, tuple):
array = np.zeros(array, dtype=dtype)
self._arr = array
self._ps = np.asarray(patch_shape)
self._po = np.asarray(patch_overlap)
self._po = patch_overlap
dimensions = self._arr.ndim
step = self._ps - self._po
if np.any(step < 0):
raise ValueError("overlap should be smaller than patch on every dimension.")
if self._ps.size != dimensions or step.size != dimensions:
raise ValueError(
"self._ps and step must have the same number of dimensions as the "
"input self._array."
)
# Ensure patch size is not larger than self._array size along each axis
self._ps = np.minimum(self._ps, self._arr.shape)
# Calculate the shape and strides of the sliding view
grid_shape = tuple(
((self._arr.shape[i] - self._ps[i]) // step[i] + 1)
if self._ps[i] < self._arr.shape[i]
else 1
for i in range(dimensions)
)
shape = grid_shape + tuple(self._ps)
strides = (
tuple(
self._arr.strides[i] * step[i]
if self._ps[i] < self._arr.shape[i]
else 0
for i in range(dimensions)
)
+ self._arr.strides
)
# Create the sliding view
self.sliding_view = np.lib.stride_tricks.as_strided(
self._arr, shape=shape, strides=strides
)
self._grid_shape = grid_shape
@property
def n_patches(self):
"""Get number of patches."""
return np.prod(self._grid_shape)
[docs]
def get_patch(self, idx):
"""Get patch at linear index ``idx``."""
return self.sliding_view[np.unravel_index(idx, self._grid_shape)]
[docs]
def set_patch(self, idx, value):
"""Set patch at linear index ``idx`` with value."""
self.sliding_view[np.unravel_index(idx, self._grid_shape)]
[docs]
def add2patch(self, idx, value):
"""Add to patch, in place."""
patch = self.get_patch(idx)
# self.set_patch(idx, patch + value)
patch += value
# def sync(self):
# """Apply the padded value to the array back."""
# np.copyto(
# self._array,
# self._padded_array[
# tuple(
# np.s_[: (s + 1 - ps) if (s - ps) else s]
# for ps, s in zip(self._ps, self._padded_array.shape)
# )
# ],
# )
# def get(self):
# """Return the regular array, after applying the padded values."""
# self.sync()
# return self._array
def __getattr__(self, name):
"""Get attribute of underlying array."""
return getattr(self._arr, name)
[docs]
@fill_doc
class BaseSpaceTimeDenoiser(abc.ABC):
"""
Base Class for Patch-based denoising methods for dynamical data.
Parameters
----------
$patch_config
"""
def __init__(self, patch_shape, patch_overlap, recombination="weighted"):
self.p_shape = patch_shape
self.p_ovl = patch_overlap
if recombination not in ["weighted", "average", "center"]:
raise ValueError(
"recombination must be one of 'weighted', 'average', 'center'"
)
self.recombination = recombination
self.input_denoising_kwargs = dict()
[docs]
@fill_doc
def denoise(self, input_data, mask=None, mask_threshold=50, progbar=None):
"""Denoise the input_data, according to mask.
Patches are extracted sequentially and process by the implemented
`_patch_processing` function.
Only patches which have at least a voxel in the mask ROI are processed.
Parameters
----------
$input_config
$mask_config
Returns
-------
$denoise_return
"""
data_shape = input_data.shape
p_s, p_o = self._get_patch_param(data_shape)
input_data = PatchedArray(input_data, p_s, p_o)
output_data = PatchedArray(data_shape, p_s, p_o, dtype=input_data.dtype)
patch_weights = PatchedArray(data_shape, p_s, p_o, dtype=np.float32)
rank_map = PatchedArray(data_shape, p_s, p_o, dtype=np.int32)
noise_std_estimate = PatchedArray(data_shape, p_s, p_o, dtype=np.float32)
# Create Default mask
if mask is None:
process_mask = np.full(data_shape, True)
elif mask.shape == input_data.shape[:-1]:
process_mask = np.broadcast_to(mask, input_data.shape)
process_mask = PatchedArray(
process_mask, p_s, p_o, padding_mode="constant", constant_values=0
)
center_pos = tuple(p // 2 for p in p_s)
patch_space_size = np.prod(p_s[:-1])
# select only queue index where process_mask is valid.
get_it = np.zeros(input_data.n_patches, dtype=bool)
for i in range(len(get_it)):
pm = process_mask.get_patch(i)
if 100 * np.sum(pm) / pm.size > mask_threshold:
get_it[i] = True
select_patches = np.nonzero(get_it)[0]
del get_it
if progbar is None:
progbar = tqdm(total=len(select_patches))
elif progbar is not False:
progbar.reset(total=len(select_patches))
for i in select_patches:
input_patch_casorati = input_data.get_patch(i).reshape(patch_space_size, -1)
p_denoise, maxidx, noise_var = self._patch_processing(
input_patch_casorati,
patch_idx=i,
**self.input_denoising_kwargs,
)
p_denoise = np.reshape(p_denoise, p_s)
if self.recombination == "center":
output_data.get_patch(i)[center_pos] = p_denoise[center_pos]
elif self.recombination == "weighted":
theta = 1 / (2 + maxidx)
output_data.add2patch(i, p_denoise * theta)
patch_weights.add2patch(i, theta)
elif self.recombination == "average":
output_data.add2patch(i, p_denoise)
patch_weights.add2patch(i, 1)
else:
raise ValueError(
"recombination must be one of 'weighted', 'average', 'center'"
)
if progbar:
progbar.update()
# Averaging the overlapping pixels.
# this is only required for averaging recombinations.
output_data = output_data._arr
patch_weights = patch_weights._arr
if self.recombination in ["average", "weighted"]:
output_data /= patch_weights
output_data[~process_mask._arr] = 0
return output_data, patch_weights, noise_std_estimate, rank_map
# if self.recombination == "center":
# patch_center = (
# *(slice(ps // 2, ps // 2 + 1) for ps in patch_shape),
# slice(None, None, None),
# )
# patchs_weight = np.zeros(data_shape[:-1], np.float32)
# noise_std_estimate = np.zeros(data_shape[:-1], dtype=np.float32)
# # discard useless patches
# patch_locs = get_patch_locs(patch_shape, patch_overlap, data_shape)
# get_it = np.zeros(len(patch_locs), dtype=bool)
# for i, patch_tl in enumerate(patch_locs):
# patch_slice = tuple(
# slice(tl, tl + ps) for tl, ps in zip(patch_tl, patch_shape)
# )
# if 100 * np.sum(process_mask[patch_slice]) / patch_size > mask_threshold:
# get_it[i] = True
# logging.info(f"Denoise {100 * np.sum(get_it) / len(patch_locs):.2f}% patches")
# patch_locs = np.ascontiguousarray(patch_locs[get_it])
# if progbar is None:
# progbar = tqdm(total=len(patch_locs))
# elif progbar is not False:
# progbar.reset(total=len(patch_locs))
# for patch_tl in patch_locs:
# patch_slice = tuple(
# slice(tl, tl + ps) for tl, ps in zip(patch_tl, patch_shape)
# )
# process_mask[patch_slice] = 1
# # building the casoratti matrix
# patch = np.reshape(input_data[patch_slice], (-1, input_data.shape[-1]))
# # Replace all nan by mean value of patch.
# # FIXME this behaviour should be documented
# # And ideally chosen by the user.
# patch[np.isnan(patch)] = np.mean(patch)
# p_denoise, maxidx, noise_var = self._patch_processing(
# patch,
# patch_slice=patch_slice,
# **self.input_denoising_kwargs,
# )
# p_denoise = np.reshape(p_denoise, (*patch_shape, -1))
# patch_center_img = tuple(
# ptl + ps // 2 for ptl, ps in zip(patch_tl, patch_shape)
# )
# if self.recombination == "center":
# output_data[patch_center_img] = p_denoise[patch_center]
# noise_std_estimate[patch_center_img] += noise_var
# elif self.recombination == "weighted":
# theta = 1 / (2 + maxidx)
# output_data[patch_slice] += p_denoise * theta
# patchs_weight[patch_slice] += theta
# elif self.recombination == "average":
# output_data[patch_slice] += p_denoise
# patchs_weight[patch_slice] += 1
# else:
# raise ValueError(
# "recombination must be one of 'weighted', 'average', 'center'"
# )
# if not np.isnan(noise_var):
# noise_std_estimate[patch_slice] += noise_var
# # the top left corner of the patch is used as id for the patch.
# rank_map[patch_center_img] = maxidx
# if progbar:
# progbar.update()
# # Averaging the overlapping pixels.
# # this is only required for averaging recombinations.
# if self.recombination in ["average", "weighted"]:
# output_data /= patchs_weight[..., None]
# noise_std_estimate /= patchs_weight
# output_data[~process_mask] = 0
# return output_data, patchs_weight, noise_std_estimate, rank_map
[docs]
@abc.abstractmethod
def _patch_processing(self, patch, patch_slice=None, **kwargs):
"""Process a patch.
Implemented by child classes.
"""
[docs]
def _get_patch_param(self, data_shape):
"""Return tuple for patch_shape and patch_overlap.
It works from whatever the input format was (int or list).
This method also ensure that the patch will provide tall and skinny matrices.
"""
pp = [None, None]
for i, attr in enumerate(["p_shape", "p_ovl"]):
p = getattr(self, attr)
if isinstance(p, list):
p = tuple(p)
elif isinstance(p, (int, np.integer)):
p = (p,) * (len(data_shape) - 1)
if len(p) == len(data_shape) - 1:
# add the time dimension
p = (*p, data_shape[-1])
pp[i] = p
if np.prod(pp[0][:-1]) < data_shape[-1]:
logging.warning(
f"the number of voxel in patch ({np.prod(pp[0])}) is smaller than the"
f" last dimension ({data_shape[-1]}), this makes an ill-conditioned"
"matrix for SVD.",
stacklevel=2,
)
return tuple(pp)