"""Acquisition engine for Cartesian trajectories."""
from collections.abc import Sequence
from copy import deepcopy
import ismrmrd as mrd
import numpy as np
from numpy.typing import NDArray
from snake.core.phantom import DynamicData, Phantom
from snake.core.simulation import SimConfig
from snake.mrd_utils import MRDLoader
from .base import BaseAcquisitionEngine
from .utils import fft, get_contrast_gre, get_phantom_state
[docs]
class EPIAcquisitionEngine(BaseAcquisitionEngine):
"""Acquisition engine for EPI base trajectories."""
__engine_name__ = "EPI"
__mp_mode__ = "forkserver"
model: str = "simple"
snr: float = np.inf
def _get_chunk_list(self, data_loader: MRDLoader) -> Sequence[int]:
limits = data_loader.header.encoding[0].encodingLimits
self.n_lines_epi = limits.kspace_encoding_step_1.maximum
n_epi = data_loader.n_acquisition // self.n_lines_epi
return range(n_epi)
[docs]
def _job_trajectories(
self,
data_loader: MRDLoader,
hdr: mrd.xsd.ismrmrdHeader,
sim_conf: SimConfig,
chunk: Sequence[int],
) -> np.ndarray:
"""Generate the fourier operator by iterating the dataset."""
if not isinstance(chunk, Sequence):
chunk = [chunk]
limits = hdr.encoding[0].encodingLimits
n_lines_epi = limits.kspace_encoding_step_1.maximum
readout_length = limits.kspace_encoding_step_0.maximum
# Read all the chunk data from file.
raw_traj = data_loader._dataset["data"][
chunk[0] * n_lines_epi : (chunk[-1] + 1) * n_lines_epi
]["traj"].copy()
traj = raw_traj.view(np.uint32).reshape(
len(chunk), n_lines_epi, readout_length, 3
)
return traj
[docs]
@staticmethod
def _job_model_T2s(
phantom: Phantom,
dyn_datas: list[DynamicData],
sim_conf: SimConfig,
trajectories: NDArray, # (Chunksize, N, 3)
smaps: NDArray,
) -> np.ndarray:
"""Acquire k-space data. With T2s decay."""
readout_length = trajectories.shape[-2]
n_lines_epi = trajectories.shape[-3]
n_samples = int(readout_length * n_lines_epi)
shape = sim_conf.shape
echo_idx = int(
np.argmin(
np.sum(
abs(
trajectories[0].reshape(-1, 3)
- (shape[0] // 2, shape[1] // 2, shape[2] // 2)
)
** 2,
axis=-1,
)
)
)
t2s_decay = BaseAcquisitionEngine._job_get_T2s_decay(
sim_conf.hardware.dwell_time_ms, echo_idx, n_samples, phantom
)
final_ksp = np.zeros(
(
len(trajectories),
sim_conf.hardware.n_coils,
n_lines_epi,
readout_length,
),
dtype=np.complex64,
)
for i, epi_2d in enumerate(trajectories):
phantom_state = get_phantom_state(phantom, dyn_datas, i, sim_conf)
if smaps is None:
ksp = fft(phantom_state[:, None, ...], axis=(-3, -2, -1))
else:
ksp = fft(phantom_state[:, None, ...] * smaps, axis=(-3, -2, -1))
flat_epi = epi_2d.reshape(-1, 3)
for c in range(sim_conf.hardware.n_coils):
ksp_coil_sum = np.zeros(
(n_lines_epi * readout_length), dtype=np.complex64
)
for b in range(phantom_state.shape[0]):
ksp_coil_sum += ksp[b, c][tuple(flat_epi.T)] * t2s_decay[b]
final_ksp[i, c] = ksp_coil_sum.reshape((n_lines_epi, readout_length))
return final_ksp
[docs]
@staticmethod
def _job_model_simple(
phantom: Phantom,
dyn_datas: list[DynamicData],
sim_conf: SimConfig,
trajectories: NDArray, # (Chunksize, N, 3)
smaps: NDArray,
) -> np.ndarray:
"""Acquire k-space data. No T2s decay."""
final_ksp = np.zeros(
(
len(trajectories),
sim_conf.hardware.n_coils,
trajectories.shape[-3],
trajectories.shape[-2],
),
dtype=np.complex64,
)
for i, epi_2d in enumerate(trajectories):
phantom_state = get_phantom_state(phantom, dyn_datas, i, sim_conf)
phantom_state = np.sum(phantom_state, axis=0)
if smaps is None:
ksp = fft(phantom_state[None, ...], axis=(-3, -2, -1))
else:
ksp = fft(phantom_state[None, ...] * smaps, axis=(-3, -2, -1))
flat_epi = epi_2d.reshape(-1, 3)
for c in range(sim_conf.hardware.n_coils):
ksp_coil = ksp[c]
a = ksp_coil[tuple(flat_epi.T)]
final_ksp[i, c] = a.reshape(
trajectories.shape[-3],
trajectories.shape[-2],
)
return final_ksp
def _write_chunk_data(
self, data_loader: MRDLoader, chunk: Sequence[int], chunk_data: NDArray
) -> None:
shots = np.concatenate(
[
np.arange(
shot * self.n_lines_epi,
(shot + 1) * self.n_lines_epi,
dtype=np.int32,
)
for shot in chunk
]
)
chunk_data = chunk_data.view(np.float32)
chunk_data = np.moveaxis(
chunk_data, 1, 2
) # put the coil axis after the readout axis
acq_chunk = data_loader._dataset["data"][shots]
acq_chunk["data"] = chunk_data.reshape(acq_chunk["data"].shape)
data_loader._dataset["data"][shots] = acq_chunk
[docs]
class EVIAcquisition(EPIAcquisitionEngine):
"""EVI Acquisition engine. Same as EPI, but the shots are longer."""
__engine_name__ = "EVI"
__mp_mode__ = "forkserver"
model: str = "simple"
snr: float = np.inf
def _get_chunk_list(self, data_loader: MRDLoader) -> Sequence[int]:
limits = data_loader.header.encoding[0].encodingLimits
self.n_lines_epi = limits.kspace_encoding_step_1.maximum
self.n_slice_epi = limits.slice.maximum
n_evi = data_loader.n_acquisition // (self.n_lines_epi * self.n_slice_epi)
return range(n_evi)
[docs]
def _job_trajectories(
self,
data_loader: MRDLoader,
hdr: mrd.xsd.ismrmrdHeader,
sim_conf: SimConfig,
chunk: Sequence[int],
) -> np.ndarray:
"""Generate the fourier operator by iterating the dataset."""
if not isinstance(chunk, Sequence):
chunk = [chunk]
limits = hdr.encoding[0].encodingLimits
n_lines_epi = limits.kspace_encoding_step_1.maximum
readout_length = limits.kspace_encoding_step_0.maximum
slice = limits.slice.maximum
# Read all the chunk data from file.
traj = (
data_loader._dataset["data"][
chunk[0] * n_lines_epi * slice : (chunk[-1] + 1) * n_lines_epi * slice
]["traj"]
.view(np.uint32)
.reshape(len(chunk), slice, n_lines_epi, readout_length, 3)
)
return traj
[docs]
@staticmethod
def _job_model_T2s(
phantom: Phantom,
dyn_datas: list[DynamicData],
sim_conf: SimConfig,
trajectories: NDArray, # (Chunksize, N, 3)
smaps: NDArray,
) -> np.ndarray:
"""Acquire k-space data. With T2s decay."""
readout_length = trajectories.shape[-2]
n_lines_epi = trajectories.shape[-3]
n_slice = trajectories.shape[-4]
final_ksp = np.zeros(
(
len(trajectories),
sim_conf.hardware.n_coils,
n_slice,
n_lines_epi,
readout_length,
),
dtype=np.complex64,
)
n_samples = int(readout_length * n_lines_epi * n_slice)
shape = sim_conf.shape
echo_idx = int(
np.argmin(
np.sum(
abs(
trajectories[0].reshape(-1, 3)
- (shape[0] // 2, shape[1] // 2, shape[2] // 2)
)
** 2,
axis=-1,
)
)
)
t2s_decay = BaseAcquisitionEngine._job_get_T2s_decay(
sim_conf.hardware.dwell_time_ms, echo_idx, n_samples, phantom
)
for i, evi in enumerate(trajectories):
frame_phantom = deepcopy(phantom)
for dyn_data in dyn_datas:
frame_phantom = dyn_data.func(frame_phantom, dyn_data.data, i)
contrast = get_contrast_gre(
frame_phantom,
sim_conf.seq.FA,
sim_conf.seq.TE,
sim_conf.seq.TR,
)
phantom_state = (
contrast[(..., *([None] * len(frame_phantom.anat_shape)))]
* frame_phantom.masks
)
if smaps is None:
ksp = fft(phantom_state[:, None, ...], axis=(-3, -2, -1))
else:
ksp = fft(phantom_state[:, None, ...] * smaps, axis=(-3, -2, -1))
flat_evi = evi.reshape(-1, 3)
for c in range(sim_conf.hardware.n_coils):
ksp_coil_sum = np.zeros(
(n_lines_epi * readout_length * n_slice), dtype=np.complex64
)
for b in range(phantom_state.shape[0]):
ksp_coil_sum += ksp[b, c][tuple(flat_evi.T)] * t2s_decay[b]
final_ksp[i, c] = ksp_coil_sum.reshape(
(n_slice, n_lines_epi, readout_length)
)
return final_ksp
[docs]
@staticmethod
def _job_model_simple(
phantom: Phantom,
dyn_datas: list[DynamicData],
sim_conf: SimConfig,
trajectories: NDArray, # (Chunksize, N, 3)
smaps: NDArray,
) -> np.ndarray:
"""Acquire k-space data. No T2s decay."""
final_ksp = np.zeros(
(
len(trajectories),
sim_conf.hardware.n_coils,
trajectories.shape[-3],
trajectories.shape[-2],
),
dtype=np.complex64,
)
for i, epi_2d in enumerate(trajectories):
frame_phantom = deepcopy(phantom)
for dyn_data in dyn_datas:
frame_phantom = dyn_data.func(frame_phantom, dyn_data.data, i)
# Reduce the array, we dont have batch tissues !
contrast = get_contrast_gre(
frame_phantom,
sim_conf.seq.FA,
sim_conf.seq.TE,
sim_conf.seq.TR,
)
phantom_state = np.sum(
contrast[(..., *([None] * len(phantom.anat_shape)))]
* frame_phantom.masks,
axis=0,
)
if smaps is None:
ksp = fft(phantom_state[None, ...], axis=(-3, -2, -1))
else:
ksp = fft(phantom_state[None, ...] * smaps, axis=(-3, -2, -1))
flat_epi = epi_2d.reshape(-1, 3)
for c in range(sim_conf.hardware.n_coils):
ksp_coil = ksp[c]
a = ksp_coil[tuple(flat_epi.T)]
final_ksp[i, c] = a.reshape(
trajectories.shape[-3],
trajectories.shape[-2],
)
return final_ksp
def _write_chunk_data(
self, data_loader: MRDLoader, chunk: Sequence[int], chunk_data: NDArray
) -> None:
shots = np.concatenate(
[
np.arange(
shot * self.n_lines_epi * self.n_slice_epi,
(shot + 1) * self.n_lines_epi * self.n_slice_epi,
dtype=np.int32,
)
for shot in chunk
]
)
chunk_data = np.moveaxis(
chunk_data.view(np.float32), 1, 2
) # put the coil axis after the readout axis
acq_chunk = data_loader._dataset["data"][shots]
acq_chunk["data"] = chunk_data.reshape(*acq_chunk.shape, -1)
data_loader._dataset["data"][shots] = acq_chunk