"""Module to create phantom for simulation."""
from __future__ import annotations
from copy import deepcopy
import base64
import contextlib
import logging
import os
from collections.abc import Generator
from dataclasses import dataclass
from enum import IntEnum
from importlib.resources import files
from multiprocessing.managers import SharedMemoryManager
from multiprocessing.shared_memory import SharedMemory
from typing import TypeVar, Any
import ismrmrd as mrd
import numpy as np
from numpy.typing import NDArray
from ..parallel import ArrayProps, array_from_shm, array_to_shm, run_parallel
from snake._meta import NoCaseEnum
from ..simulation import SimConfig
from snake.core.smaps import get_smaps
log = logging.getLogger(__name__)
[docs]
class PropTissueEnum(IntEnum):
"""Enum for the tissue properties."""
T1 = 0
T2 = 1
T2s = 2
rho = 3
chi = 4
[docs]
class TissueFile(str, NoCaseEnum):
"""Enum for the tissue properties file."""
tissue_1T5 = str(files("snake.core.phantom.data") / "tissues_properties_1T5.csv")
tissue_7T = str(files("snake.core.phantom.data") / "tissues_properties_7T.csv")
[docs]
@dataclass
class Phantom:
"""A Phantom consist of a list of tissue mask and parameters for those tissues."""
name: str
masks: NDArray[np.float32]
labels: NDArray[np.string_]
props: NDArray[np.float32]
smaps: NDArray[np.complex64] | None = None
[docs]
def add_tissue(
self,
tissue_name: str,
mask: NDArray[np.float32],
props: NDArray[np.float32],
phantom_name: str | None = None,
) -> Phantom:
"""Add a tissue to the phantom. Creates a new Phantom object."""
masks = np.concatenate((self.masks, mask[None, ...]), axis=0)
labels = np.concatenate((self.labels, np.array([tissue_name])))
props = np.concatenate((self.props, props), axis=0)
return Phantom(
phantom_name or self.name, masks, labels, props, smaps=self.smaps
)
@property
def labels_idx(self) -> dict[str, int]:
"""Get the index of the labels."""
return {label: i for i, label in enumerate(self.labels)}
[docs]
def make_smaps(
self, n_coils: int = None, sim_conf: SimConfig = None, antenna: str = "birdcage"
) -> None:
"""Get coil sensitivity maps for the phantom."""
if n_coils is None and sim_conf is not None:
n_coils = sim_conf.hardware.n_coils
elif sim_conf is None and n_coils is None:
raise ValueError("Either n_coils or sim_conf must be provided.")
if n_coils == 1:
log.warning("Only one coil, no need for smaps.")
elif n_coils > 1 and self.smaps is None:
self.smaps = get_smaps(self.anat_shape, n_coils=n_coils, antenna=antenna)
log.debug(f"Created smaps for {n_coils} coils.")
elif self.smaps is not None:
log.warning("Smaps already exists.")
[docs]
@classmethod
def from_brainweb(
cls,
sub_id: int,
sim_conf: SimConfig,
tissue_file: str | TissueFile = TissueFile.tissue_1T5,
tissue_select: list[str] | None = None,
tissue_ignore: list[str] | None = None,
) -> Phantom:
"""Get the Brainweb Phantom."""
from brainweb_dl import BrainWebTissuesV2, get_mri
from .utils import resize_tissues
if tissue_ignore and tissue_select:
raise ValueError("Only one of tissue_select or tissue_ignore can be used.")
# TODO: Use the sim shape properly.
tissues_mask = get_mri(sub_id, contrast="fuzzy").astype(np.float32)
z = (
np.array([0.5, 0.5, 0.5])
* np.array(sim_conf.shape)
/ np.array(sim_conf.fov_mm)
)
tissues_mask = np.ascontiguousarray(tissues_mask.T)
tissues_list = []
try:
if isinstance(tissue_file, TissueFile):
tissue_file = tissue_file.value
else:
tissue_file = TissueFile[tissue_file].value
except ValueError as exc:
if not os.path.exists(tissue_file):
raise FileNotFoundError(f"File {tissue_file} does not exist.") from exc
finally:
tissue_file = str(tissue_file)
log.info(f"Using tissue file:{tissue_file} ")
with open(tissue_file) as f:
lines = f.readlines()
select = []
for line in lines[1:]:
vals = line.split(",")
t1, t2, t2s, rho, chi = map(np.float32, vals[1:])
name = vals[0]
t = (name, t1, t2, t2s, rho, chi)
if (
(tissue_select and name in tissue_select)
or (tissue_ignore and name not in tissue_ignore)
or (not tissue_select and not tissue_ignore)
):
tissues_list.append(t)
select.append(BrainWebTissuesV2[name.upper()])
log.info(
f"Selected tissues: {select}, {[t[0] for t in tissues_list]}",
)
tissues_mask = tissues_mask[select]
shape = tissues_mask.shape
new_shape = (shape[0], *np.round(np.array(shape[1:]) * z).astype(int))
tissue_resized = np.zeros(new_shape, dtype=np.float32)
run_parallel(
resize_tissues,
tissues_mask,
tissue_resized,
parallel_axis=0,
z=tuple(z),
)
tissues_mask = tissue_resized
smaps = None
if sim_conf.hardware.n_coils > 1:
smaps = get_smaps(
tissues_mask.shape[1:],
n_coils=sim_conf.hardware.n_coils,
)
return cls(
"brainweb",
tissues_mask,
labels=np.array([t[0] for t in tissues_list]),
props=np.array([t[1:] for t in tissues_list]),
smaps=smaps,
)
[docs]
@classmethod
def from_shepp_logan(cls, resolution: tuple[int]) -> Phantom:
"""Get the Shepp-Logan Phantom."""
raise NotImplementedError
[docs]
@classmethod
def from_guerin_kern(cls, resolution: tuple[int]) -> Phantom:
"""Get the Guerin-Kern Phantom."""
raise NotImplementedError
[docs]
@classmethod
def from_mrd_dataset(
cls, dataset: mrd.Dataset | os.PathLike, imnum: int = 0
) -> Phantom:
"""Load the phantom from a mrd dataset."""
if not isinstance(dataset, mrd.Dataset):
dataset = mrd.Dataset(dataset, create_if_needed=False)
image = dataset.read_image("phantom", imnum)
name = image.meta.pop("name")
labels = np.array(image.meta["labels"].split(","))
props = unserialize_array(image.meta["props"])
return cls(
masks=image.data,
labels=labels,
props=props,
name=name,
)
[docs]
def to_mrd_dataset(
self, dataset: mrd.Dataset, sim_conf: SimConfig, imnum: int = 0
) -> mrd.Dataset:
"""Add the phantom as an image to the dataset."""
# Create the image
if not isinstance(dataset, mrd.Dataset):
dataset = mrd.Dataset(dataset, create_if_needed=True)
meta_sr = mrd.Meta(
{
"name": self.name,
"labels": f"{','.join(self.labels)}",
"props": serialize_array(self.props),
}
).serialize()
# Add the phantom data
dataset.append_image(
"phantom",
mrd.image.Image(
head=mrd.image.ImageHeader(
matrixSize=mrd.xsd.matrixSizeType(*self.anat_shape),
fieldOfView_mm=mrd.xsd.fieldOfViewMm(*sim_conf.fov_mm),
channels=self.n_tissues,
acquisition_time_stamp=0,
attribute_string_len=len(meta_sr),
),
data=self.masks,
attribute_string=meta_sr,
),
)
# Add the smaps
if self.smaps is not None:
dataset.append_image(
"smaps",
mrd.image.Image(
head=mrd.image.ImageHeader(
matrixSize=mrd.xsd.matrixSizeType(*self.smaps.shape[1:]),
fieldOfView_mm=mrd.xsd.fieldOfViewMm(*sim_conf.fov_mm),
channels=len(self.smaps),
acquisition_time_stamp=0,
),
data=self.smaps,
),
)
return dataset
[docs]
@classmethod
@contextlib.contextmanager
def from_shared_memory(
cls,
name: str,
mask_prop: ArrayProps,
properties_prop: ArrayProps,
label_prop: ArrayProps,
smaps_prop: ArrayProps,
) -> Generator[Phantom, None, None]:
"""Give access the tissue masks and properties in shared memory."""
with array_from_shm(mask_prop, label_prop, properties_prop, smaps_prop) as arrs:
yield cls(name, *arrs)
[docs]
def in_shared_memory(
self, manager: SharedMemoryManager
) -> tuple[
tuple[str, ArrayProps, ArrayProps, ArrayProps, ArrayProps | None],
tuple[SharedMemory, SharedMemory, SharedMemory, SharedMemory | None],
]:
"""Add a copy of the phantom in shared memory."""
tissue_mask, _, tisue_mask_smm = array_to_shm(self.masks, manager)
tissue_props, _, tissue_prop_smm = array_to_shm(self.props, manager)
labels, _, labels_sm = array_to_shm(self.labels, manager)
if self.smaps is not None:
smaps, _, smaps_sm = array_to_shm(self.smaps, manager)
else:
smaps, smaps_sm = None, None
return (
(self.name, tissue_mask, tissue_props, labels, smaps),
(tisue_mask_smm, tissue_prop_smm, labels_sm, smaps_sm),
)
@property
def anat_shape(self) -> tuple[int, ...]:
"""Get the shape of the base volume."""
return self.masks.shape[1:]
@property
def n_tissues(self) -> int:
"""Get the number of tissues."""
return len(self.masks)
[docs]
def __repr__(self):
ret = f"Phantom[{self.name}]: {self.props.shape}\n"
ret += f"{'tissue name':14s}" + "".join(
f"{str(prop):4s}" for prop in PropTissueEnum
)
ret += "\n"
for i, tissue_name in enumerate(self.labels):
props = self.props[i]
ret += (
f"{tissue_name:14s}"
+ "".join(f"{props[p]:4}" for p in PropTissueEnum.__members__.values())
+ "\n"
)
return ret
[docs]
def __deepcopy__(self, memo: Any) -> Phantom:
"""Create a copy of the phantom."""
return Phantom(
name=self.name,
masks=deepcopy(self.masks, memo),
labels=deepcopy(self.labels, memo),
props=deepcopy(self.props, memo),
smaps=deepcopy(self.smaps, memo),
)
[docs]
def copy(self) -> Phantom:
"""Return deep copy of the Phantom."""
return deepcopy(self)
T = TypeVar("T")
[docs]
def serialize_array(arr: NDArray) -> str:
"""Serialize the array for mrd compatible format."""
return "__".join(
[
base64.b64encode(arr.tobytes()).decode(),
str(arr.shape),
str(arr.dtype),
]
)
[docs]
def unserialize_array(s: str) -> NDArray:
"""Unserialize the array for mrd compatible format."""
data, shape, dtype = s.split("__")
shape = eval(shape) # FIXME
return np.frombuffer(base64.b64decode(data.encode()), dtype=dtype).reshape(*shape)