"""Export data to mrd format."""
from __future__ import annotations
import logging
import os
from typing import TYPE_CHECKING
import ismrmrd as mrd
import numpy as np
from numpy.typing import NDArray
from hydra_callbacks import PerfLogger
from mrinufft.trajectories.utils import Gammas
from snake._version import __version__ as version
from .utils import get_waveform_id, obj2b64encode
if TYPE_CHECKING:
from snake.core.phantom import DynamicData, Phantom
from snake.core.handlers import AbstractHandler, HandlerList
from snake.core.sampling import BaseSampler
from snake.core.simulation import SimConfig
log = logging.getLogger(__name__)
[docs]
def add_phantom_mrd(
dataset: mrd.Dataset, phantom: Phantom, sim_conf: SimConfig
) -> mrd.Dataset:
"""Add the phantom to the dataset."""
return phantom.to_mrd_dataset(dataset, sim_conf)
[docs]
def add_smaps_mrd(
dataset: mrd.Dataset,
sim_conf: SimConfig,
smaps: NDArray | None = None,
) -> mrd.Dataset:
"""Add the Smaps to the dataset."""
if smaps is None:
return dataset
elif smaps.shape != (sim_conf.hardware.n_coils, *sim_conf.shape):
raise ValueError(
"Incompatible smaps shape"
f"{smaps.shape} != {(sim_conf.hardware.n_coils, *sim_conf.shape)} "
)
dataset.append_image(
"smaps",
mrd.image.Image(
head=mrd.image.ImageHeader(
matrixSize=mrd.xsd.matrixSizeType(*smaps.shape[1:]),
fieldOfView_mm=mrd.xsd.fieldOfViewMm(*sim_conf.fov_mm),
channels=len(smaps),
acquisition_time_stamp=0,
),
data=smaps,
),
)
return dataset
[docs]
def add_dynamic_mrd(
dataset: mrd.Dataset, dynamic: DynamicData, sim_conf: SimConfig
) -> mrd.Dataset:
"""Add the dynamic data to the dataset."""
waveform_id = get_waveform_id(dynamic.name)
# add the type to the header.
hdr = mrd.xsd.CreateFromDocument(dataset.read_xml_header())
hdr.waveformInformation.append(
mrd.xsd.waveformInformationType(
waveformName=dynamic.name,
waveformType=waveform_id,
userParameters=mrd.xsd.userParametersType(
userParameterBase64=[
mrd.xsd.userParameterBase64Type(
dynamic.name, obj2b64encode(dynamic.func)
)
],
userParameterString=[
mrd.xsd.userParameterStringType(
"domain", "kspace" if dynamic.in_kspace else "image"
)
],
),
)
)
dataset.write_xml_header(mrd.xsd.ToXML(hdr))
if dynamic.data.ndim == 1:
channels = 1
nsamples = dynamic.data.shape[0]
elif dynamic.data.ndim == 2:
channels, nsamples = dynamic.data.shape
else:
raise ValueError(f"Invalid data shape: {dynamic.data.shape}")
dataset.append_waveform(
mrd.Waveform(
mrd.WaveformHeader(
waveform_id=waveform_id,
number_of_samples=nsamples,
channels=channels,
sample_time_us=sim_conf.sim_tr_ms * 1000,
),
data=np.float32(dynamic.data).view(np.uint32),
)
)
return dataset
[docs]
def add_coil_cov_mrd(
dataset: mrd.Dataset,
sim_conf: SimConfig,
coil_cov: NDArray | None = None,
) -> mrd.Dataset:
"""Add the Smaps to the dataset."""
n_coils = sim_conf.hardware.n_coils
if coil_cov is None:
return dataset
elif coil_cov.shape != (n_coils, n_coils):
raise ValueError(
f"Incompatible coil_cov shape {coil_cov.shape} != {(n_coils, n_coils)} "
)
dataset.append_image(
"coil_cov",
mrd.image.Image(
head=mrd.image.ImageHeader(
matrixSize=mrd.xsd.matrixSizeType(*coil_cov.shape),
fieldOfView_mm=mrd.xsd.fieldOfViewMm(*coil_cov.shape),
channels=1,
acquisition_time_stamp=0,
),
data=coil_cov,
),
)
return dataset
[docs]
def make_base_mrd(
filename: os.PathLike,
sampler: BaseSampler,
phantom: Phantom,
sim_conf: SimConfig,
handlers: list[AbstractHandler] | HandlerList | None = None,
smaps: NDArray | None = None,
coil_cov: NDArray | None = None,
) -> mrd.Dataset:
"""
Create a base `.mrd` file from the simulation configurations.
Parameters
----------
filename : os.PathLike
The output filename.
sampler : BaseSampler
The sampling pattern generator.
phantom : Phantom
The phantom object.
sim_conf : SimConfig
The simulation configurations.
dynamic_data : list[DynamicData], optional
The dynamic data, by default None
smaps : NDArray, optional
The coil sensitivity maps, by default None
coil_covar : NDArray, optional
The coil covariance matrix, by default None
"""
try:
log.warning("Existing %s it will be overwritten", filename)
os.remove(filename)
except Exception as e:
log.error(e)
pass
dataset = mrd.Dataset(filename, "dataset", create_if_needed=True)
dataset.write_xml_header(
mrd.xsd.ToXML(get_mrd_header(sim_conf, sampler.__engine__))
)
with PerfLogger(logger=log, name="acq"):
sampler.add_all_acq_mrd(dataset, sim_conf)
# Apply the handlers and get the dynamic data
if handlers is None:
handlers = []
for h in handlers:
phantom = h.get_static(phantom, sim_conf)
dynamic_data = [h.get_dynamic(phantom, sim_conf) for h in handlers]
with PerfLogger(logger=log, name="phantom"):
add_phantom_mrd(dataset, phantom, sim_conf)
with PerfLogger(logger=log, name="dynamic"):
if dynamic_data is not None:
for dyn in dynamic_data:
if dyn is not None:
add_dynamic_mrd(dataset, dyn, sim_conf)
with PerfLogger(logger=log, name="smaps"):
if sim_conf.hardware.n_coils > 1 and smaps is not None:
add_smaps_mrd(dataset, sim_conf, smaps)
with PerfLogger(logger=log, name="coil_cov"):
if sim_conf.hardware.n_coils > 1 and coil_cov is not None:
add_coil_cov_mrd(dataset, sim_conf, coil_cov)
dataset._file.flush()
dataset.close()
return dataset