Source code for snake.core.engine.base

"""Engines are responsible for the acquisition of Kspace."""

from __future__ import annotations

import gc
import logging
import multiprocessing as mp
import os
from collections.abc import Mapping, Sequence
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from multiprocessing.managers import SharedMemoryManager
from tempfile import TemporaryDirectory
from typing import Any, ClassVar

import ismrmrd as mrd
import numpy as np
from numpy.typing import NDArray
from tqdm.auto import tqdm

from snake._meta import EnvConfig, MetaDCRegister, batched

from ...mrd_utils import MRDLoader, make_base_mrd
from ..parallel import ArrayProps
from ..phantom import DynamicData, Phantom, PropTissueEnum
from ..simulation import SimConfig
from ..sampling import BaseSampler
from ..handlers import AbstractHandler, HandlerList
from .utils import get_ideal_phantom, get_noise


[docs] class MetaEngine(MetaDCRegister): """MetaClass for engines.""" dunder_name = "engine"
[docs] class BaseAcquisitionEngine(metaclass=MetaEngine): """Base acquisition engine. Specific step can be overwritten in subclasses. """ __engine_name__: ClassVar[str] __registry__: ClassVar[dict[str, type[BaseAcquisitionEngine]]] log: ClassVar[logging.Logger] model: str = "simple" snr: float = np.inf def _get_chunk_list( self, data_loader: MRDLoader, ) -> Sequence[int]: return range(data_loader.n_acquisition) def _job_trajectories( self, dataset: mrd.Dataset, hdr: mrd.xsd.ismrmrdHeader, sim_conf: SimConfig, chunk: Sequence[int], ) -> NDArray: raise NotImplementedError @staticmethod def _job_get_T2s_decay( dwell_time_ms: float, echo_idx: int, n_samples: int, phantom: Phantom, ) -> NDArray: t = dwell_time_ms * (np.arange(n_samples, dtype=np.float32) - echo_idx) return np.exp(-t[None, :] / phantom.props[:, PropTissueEnum.T2s, None]) @staticmethod def _job_model_T2s( phantom: Phantom, dyn_datas: list[DynamicData], sim_conf: SimConfig, trajectories: NDArray, # (Chunksize, N, 3) smaps: NDArray, *args: Any, **kwargs: Any, ) -> NDArray: raise NotImplementedError @staticmethod def _job_model_simple( phantom: Phantom, dyn_datas: list[DynamicData], sim_conf: SimConfig, trajectories: NDArray, # (Chunksize, N, 3) smaps: NDArray, *args: Any, **kwargs: Any, ) -> NDArray: raise NotImplementedError def _write_chunk_data( self, dataset: mrd.Dataset, chunk: Sequence[int], chunk_data: NDArray ) -> None: raise NotImplementedError
[docs] def _acquire_ksp_job( self, filename: os.PathLike, chunk: Sequence[int], tmp_dir: str, shared_phantom_props: ( tuple[str, ArrayProps, ArrayProps, ArrayProps] | None ) = None, model: str = "T2s", **kwargs: Mapping[str, Any], ) -> str: """Entry point for worker. This handles the io part (Read dataset, write partial k-space), and dispatch to specialized functions for getting the k-space. """ # https://github.com/h5py/h5py/issues/712#issuecomment-562980532 # We know that we are going to read the dataset in read-only mode in # this function and use the main process to write the data. # This is an alternative to using swmr mode, that I could not get to work. os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" with MRDLoader(filename, swmr=True) as data_loader: hdr = data_loader.header # Get the Phantom, SimConfig, and all ... sim_conf = data_loader.get_sim_conf() ddatas = data_loader.get_all_dynamic() # sim_conf = SimConfig.from_mrd_dataset(dataset) for d in ddatas: # only keep the dynamic data that are in the chunk d.data = d.data[:, chunk] trajs = self._job_trajectories(data_loader, hdr, sim_conf, chunk) _job_model = getattr(self, f"_job_model_{model}") smaps = None if sim_conf.hardware.n_coils > 1: smaps = data_loader.get_smaps() if shared_phantom_props is None: phantom = data_loader.get_phantom() ksp = _job_model(phantom, ddatas, sim_conf, trajs, smaps, **kwargs) else: with Phantom.from_shared_memory(*shared_phantom_props) as phantom: ksp = _job_model(phantom, ddatas, sim_conf, trajs, smaps, **kwargs) chunk_file = os.path.join(tmp_dir, f"partial_{chunk[0]}-{chunk[-1]}.npy") np.save(chunk_file, ksp) return chunk_file
[docs] def __call__( self, filename: os.PathLike, sampler: BaseSampler, phantom: Phantom, sim_conf: SimConfig, handlers: list[AbstractHandler] | HandlerList | None = None, smaps: NDArray | None = None, coil_cov: NDArray | None = None, worker_chunk_size: int = 0, n_workers: int = 0, **kwargs: Any, ): """Perform the acquisition and fill the dataset.""" # Create the base dataset make_base_mrd(filename, sampler, phantom, sim_conf, handlers, smaps, coil_cov) # Guesstimate the workload if worker_chunk_size <= 0: # get the number of shot worker_chunk_size = sampler.get_next_frame(sim_conf).shape[0] if n_workers <= 0: n_workers = mp.cpu_count() // 2 with MRDLoader(filename) as data_loader: sim_conf = data_loader.get_sim_conf() phantom = data_loader.get_phantom() shot_idxs = self._get_chunk_list(data_loader) chunk_list = list(batched(shot_idxs, worker_chunk_size)) ideal_phantom = get_ideal_phantom(phantom, sim_conf) coil_cov = data_loader.get_coil_cov() or np.eye(sim_conf.hardware.n_coils) if self.snr > 0: energy = np.mean(ideal_phantom**2) coil_cov = coil_cov * energy / self.snr del ideal_phantom # https://github.com/h5py/h5py/issues/712#issuecomment-562980532 # We know that we are going to read the dataset in read-only mode # and use the main process (here) to write the data. # This is an alternative to using swmr mode, that I could not get to work. os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" if n_workers > 1: Executor = ProcessPoolExecutor( n_workers, mp_context=mp.get_context(self.__mp_mode__) ) else: Executor = ThreadPoolExecutor(max_workers=1) with ( SharedMemoryManager() as smm, Executor as executor, tqdm(total=len(shot_idxs)) as pbar, MRDLoader(filename, writeable=True) as data_loader, TemporaryDirectory( dir=EnvConfig["SNAKE_TMP_DIR"], prefix="snake-" ) as tmp_chunk_dir, ): # data_loader._file.swmr_mode = True phantom_props, shms = phantom.in_shared_memory(smm) # TODO: also put the smaps in shared memory futures = { executor.submit( self._acquire_ksp_job, filename, chunk_id, tmp_dir=tmp_chunk_dir, shared_phantom_props=phantom_props, model=self.model, **kwargs, ): chunk_id for chunk_id in chunk_list } for future in as_completed(futures): chunk = futures[future] try: f_chunk = str(future.result()) except Exception as exc: self.log.error(f"Error in chunk {min(chunk)}-{max(chunk)}") raise exc else: pbar.update(worker_chunk_size) chunk_ksp = np.load(f_chunk) # Add noise if self.snr > 0: noise = get_noise(chunk_ksp, coil_cov, sim_conf.rng) chunk_ksp += noise self._write_chunk_data( data_loader, chunk, chunk_ksp, ) os.remove(f_chunk) gc.collect()