"""Nipype Bindings, Provide bindings To apply patch based denoising."""
import os
import nibabel as nib
import numpy as np
from nipype.interfaces.base import (
BaseInterfaceInputSpec,
File,
SimpleInterface,
TraitedSpec,
isdefined,
traits,
)
from nipype.utils.filemanip import split_filename
from .utils import DENOISER_MAP, DenoiseParameters
from ..space_time.utils import estimate_noise
[docs]
class PatchDenoiseOutputSpec(TraitedSpec):
"""OutputSpec for Denoising Interface."""
denoised_file = File(desc="denoised file")
noise_std_map = File(desc="a map of the noise variance.")
rank_map = File(desc="a map of the rank of the patch.")
[docs]
class PatchDenoise(SimpleInterface):
"""Patch based denoising interface."""
input_spec = PatchDenoiseInputSpec
output_spec = PatchDenoiseOutputSpec
_denoise_attrs = [
"method",
"patch_shape",
"patch_overlap",
"mask_threshold",
"recombination",
]
[docs]
def _run_interface(self, runtime):
# INPUT
if isdefined(self.inputs.denoise_str):
d_par = DenoiseParameters.from_str(self.inputs.denoise_str)
else:
d_par = DenoiseParameters()
for attr in PatchDenoise._denoise_attrs:
setattr(d_par, attr, getattr(self.inputs, attr))
if isdefined(self.inputs.in_mag):
data_mag_nii = nib.load(self.inputs.in_mag)
data = data_mag_nii.get_fdata(dtype=np.float32)
basename = self.inputs.in_mag
self._affine = data_mag_nii.affine
else:
data_real_nii = nib.load(self.inputs.in_real)
self._affine = data_real_nii.affine
data_real = data_real_nii.get_fdata(dtype=np.float32)
data_imag = nib.load(self.inputs.in_imag).get_fdata(dtype=np.float32)
data = 1j * data_imag
data += data_real
basename = self.inputs.in_real
if isdefined(self.inputs.mask) and self.inputs.mask:
mask = np.abs(nib.load(self.inputs.mask).get_fdata()) > 0
else:
mask = None
try:
denoise_func = DENOISER_MAP[d_par.method]
except KeyError:
raise ValueError(
f"unknown denoising denoise_method '{self.inputs.denoise_method}', "
f"available are {list(DENOISER_MAP.keys())}"
) from None
if isdefined(self.inputs.extra_kwargs) and self.inputs.extra_kwargs:
extra_kwargs = self.inputs.extra_kwargs
else:
extra_kwargs = dict()
if d_par.method in [
"nordic",
"hybrid-pca",
"adaptive-qut",
"optimal-fro-noise",
]:
extra_kwargs["noise_std"] = nib.load(self.inputs.noise_std_map).get_fdata()
if denoise_func is not None:
# CORE CALL
denoised_data, _, noise_std_map, rank_map = denoise_func(
data,
patch_shape=d_par.patch_shape,
patch_overlap=d_par.patch_overlap,
mask=mask,
mask_threshold=d_par.mask_threshold,
recombination=d_par.recombination,
**extra_kwargs,
)
else:
denoised_data = data
noise_std_map = np.std(data, axis=-1, dtype=np.float32)
rank_map = np.zeros_like(noise_std_map)
# OUTPUT
if np.any(np.iscomplex(denoised_data)):
denoised_data = np.abs(denoised_data, dtype=np.float32)
_, base, _ = split_filename(basename)
base = base.replace("_mag", "")
base = base.replace("_real", "")
self._make_results_file("rank_map", f"{base}_rank_map.nii", rank_map)
self._make_results_file(
"denoised_file",
f"{base}_d_{d_par.method}.nii",
denoised_data,
)
self._make_results_file("noise_std_map", f"{base}_noise_map.nii", noise_std_map)
return runtime
[docs]
def _make_results_file(self, result_file, file_name, array):
"""Add a new results file."""
self._results[result_file] = os.path.abspath(file_name)
nib.save(nib.Nifti1Image(array, self._affine), file_name)
[docs]
class NoiseStdMapOutputSpec(TraitedSpec):
"""OutputSpec for Noise Map Estimation."""
noise_std_map = File(desc="Spatial variation of noise variance")
[docs]
class NoiseStdMap(SimpleInterface):
"""Noise std estimation."""
input_spec = NoiseStdMapInputSpec
output_spec = NoiseStdMapOutputSpec
[docs]
def _run_interface(self, runtime):
noise_map = nib.load(self.inputs.noise_map_file)
noise_std_map = estimate_noise(
noise_map.get_fdata() / self.inputs.fft_scale, self.inputs.block_size
)
noise_std_map_img = nib.Nifti1Image(noise_std_map, affine=noise_map.affine)
filename = os.path.abspath(
os.path.basename(self.inputs.noise_map_file).split(".")[0] + "_std.nii"
)
noise_std_map_img.to_filename(filename)
self._results["noise_std_map"] = filename
return runtime