Source code for snake.toolkit.reconstructors.fourier
"""FFT operators for MRI reconstruction."""
from numpy.typing import NDArray
import scipy as sp
from snake.mrd_utils.loader import CartesianFrameDataLoader, NonCartesianFrameDataLoader
[docs]
def fft(image: NDArray, axis: int | tuple[int] = -1) -> NDArray:
"""Apply the FFT operator.
Parameters
----------
image : array
Image in space.
axis : int
Axis to apply the FFT.
Returns
-------
kspace_data : array
kspace data.
"""
return sp.fft.ifftshift(
sp.fft.fftn(sp.fft.fftshift(image, axes=axis), norm="ortho", axes=axis),
axes=axis,
)
[docs]
def ifft(kspace_data: NDArray, axis: int | tuple[int] = -1) -> NDArray:
"""Apply the inverse FFT operator.
Parameters
----------
kspace_data : array
Image in space.
axis : int
Axis to apply the FFT.
Returns
-------
image_data : array
image data.
"""
return sp.fft.fftshift(
sp.fft.ifftn(sp.fft.ifftshift(kspace_data, axes=axis), norm="ortho", axes=axis),
axes=axis,
)
[docs]
def init_nufft(
data_loader: NonCartesianFrameDataLoader,
nufft_backend: str,
density_compensation: bool = False,
):
from mrinufft import get_operator
smaps = data_loader.get_smaps()
shape = data_loader.shape
traj, _ = data_loader.get_kspace_frame(0)
if data_loader.slice_2d:
shape = data_loader.shape[:2]
traj = traj.reshape(data_loader.n_shots, -1, traj.shape[-1])[0, :, :2]
kwargs = dict(
shape=shape,
n_coils=data_loader.n_coils,
smaps=smaps,
)
print(density_compensation, type(density_compensation))
if density_compensation is False:
kwargs["density"] = None
else:
kwargs["density"] = density_compensation
if "stacked" in nufft_backend:
kwargs["z_index"] = "auto"
return get_operator(
nufft_backend,
samples=traj,
**kwargs,
)