Source code for snake.core.engine.nufft

"""Acquisition engine using nufft."""

from collections.abc import Sequence

import ismrmrd as mrd
import numpy as np
from mrinufft.operators import FourierOperatorBase, get_operator
from numpy.typing import NDArray

from snake.core.phantom import DynamicData, Phantom
from snake.core.simulation import SimConfig

from .base import BaseAcquisitionEngine
from .utils import get_phantom_state


[docs] class NufftAcquisitionEngine(BaseAcquisitionEngine): """Acquisition engine using nufft.""" __engine_name__ = "NUFFT" __mp_mode__ = "spawn" model: str = "simple" snr: float = np.inf
[docs] def _job_trajectories( self, dataset: mrd.Dataset, hdr: mrd.xsd.ismrmrdHeader, sim_conf: SimConfig, shot_idx: Sequence[int], ) -> NDArray: """Get Non Cartesian trajectories from the dataset. Returns ------- NDArray The trajectories. """ if not isinstance(shot_idx, Sequence): shot_idx = [shot_idx] head = dataset._dataset["data"][0]["head"] n_samples = head["number_of_samples"] ndim = head["trajectory_dimensions"] trajectories = np.zeros((len(shot_idx), n_samples, ndim), dtype=np.float32) for i, s in enumerate(shot_idx): trajectories[i] = dataset._dataset["data"][s]["traj"].reshape( n_samples, ndim ) return trajectories
[docs] @staticmethod def _init_model_nufft( samples: NDArray, sim_conf: SimConfig, smaps: NDArray, backend: str ) -> FourierOperatorBase: """Initialize the nufft operator.""" n_coils = len(smaps) if smaps is not None else 1 kwargs = {} if "stacked" in backend: kwargs["z_index"] = "auto" nufft = get_operator(backend)( samples, # dummy samples locs shape=sim_conf.shape, n_coils=n_coils, smaps=smaps, density=False, **kwargs, ) return nufft
[docs] @staticmethod def _job_model_T2s( phantom: Phantom, dyn_datas: list[DynamicData], sim_conf: SimConfig, trajectories: NDArray, smaps: NDArray, nufft_backend: str, ) -> np.ndarray: """Acquire k-space data with T2s relaxation effect.""" chunk_size, n_samples, ndim = trajectories.shape final_ksp = np.zeros( (chunk_size, sim_conf.hardware.n_coils, n_samples), dtype=np.complex64 ) # (n_tissues_true, n_samples) Filter the tissues that have NaN Values. nufft = NufftAcquisitionEngine._init_model_nufft( trajectories[0], sim_conf, smaps, backend=nufft_backend ) echo_idx = np.argmin(np.sum(np.abs(trajectories[0]) ** 2), axis=-1) t2s_decay = BaseAcquisitionEngine._job_get_T2s_decay( sim_conf.hardware.dwell_time_ms, echo_idx, n_samples, phantom ) for i, traj in enumerate(trajectories): phantom_state = get_phantom_state(phantom, dyn_datas, i, sim_conf) phantom_state = phantom_state[:, None, ...] nufft.samples = traj nufft.n_batchs = len(phantom_state) ksp = nufft.op(phantom_state) # apply the T2s and sum over tissues # final_ksp[i] = np.sum(ksp * t2s_decay[:, None, :], axis=0) final_ksp[i] = np.einsum("kij, kj-> ij", ksp, t2s_decay) return final_ksp
[docs] @staticmethod def _job_model_simple( phantom: Phantom, dyn_datas: list[DynamicData], sim_conf: SimConfig, trajectories: NDArray, smaps: NDArray, nufft_backend: str, ) -> np.ndarray: """Acquire k-space data. No T2s decay.""" chunk_size, n_samples, ndim = trajectories.shape final_ksp = np.zeros( (chunk_size, sim_conf.hardware.n_coils, n_samples), dtype=np.complex64 ) nufft = NufftAcquisitionEngine._init_model_nufft( trajectories[0], sim_conf, smaps, backend=nufft_backend ) # (n_tissues_true, n_samples) Filter the tissues that have NaN Values for i, traj in enumerate(trajectories): phantom_state = get_phantom_state(phantom, dyn_datas, i, sim_conf) phantom_state = np.sum(phantom_state, axis=0) phantom_state = phantom_state[None, ...] nufft.samples = traj nufft.n_batchs = len(phantom_state) final_ksp[i] = nufft.op(phantom_state) return final_ksp
def _write_chunk_data( self, dataset: mrd.Dataset, chunk: Sequence[int], chunk_data: NDArray ) -> None: shot_idx = np.asarray(chunk) acq_chunk = dataset._dataset["data"][shot_idx] chunk_data = chunk_data.view(np.float32) acq_chunk["data"] = chunk_data.reshape(acq_chunk["data"].shape) dataset._dataset["data"][shot_idx] = acq_chunk