Source code for snake.core.sampling.samplers

"""Samplers generate kspace trajectories."""

from __future__ import annotations
import ismrmrd as mrd
import numpy as np
from numpy.typing import NDArray
from tqdm.auto import tqdm
from ..simulation import SimConfig
from .base import BaseSampler
from .factories import (
    AngleRotation,
    VDSorder,
    VDSpdf,
    stack_spiral_factory,
    stacked_epi_factory,
    evi_factory,
    rotate_trajectory,
)
from snake.mrd_utils.utils import ACQ
from snake._meta import batched, EnvConfig
from mrinufft.io import read_trajectory
from collections.abc import Generator


[docs] class NonCartesianAcquisitionSampler(BaseSampler): """ Base class for non-cartesian acquisition samplers. Parameters ---------- constant: bool If True, the trajectory is constant. obs_time_ms: int Time spent to acquire a single shot in_out: bool If true, the trajectory is acquired with a double join pattern from/to the periphery ndim: int Number of dimensions of the trajectory (2 or 3) """ __engine__ = "NUFFT" in_out: bool = True obs_time_ms: int = 30
[docs] def add_all_acq_mrd( self, dataset: mrd.Dataset, sim_conf: SimConfig, ) -> mrd.Dataset: """Generate all mrd_acquisitions.""" single_frame = self.get_next_frame(sim_conf) n_shots_frame = single_frame.shape[0] n_samples = single_frame.shape[1] TR_vol_ms = sim_conf.seq.TR * single_frame.shape[0] n_ksp_frames_true = sim_conf.max_sim_time * 1000 / TR_vol_ms n_ksp_frames = int(n_ksp_frames_true) trajectory_dimension = single_frame.shape[-1] self.log.info("Generating %d frames", n_ksp_frames) self.log.info("Frame have %d shots", n_shots_frame) self.log.info("Shot have %d samples", n_samples) self.log.info("Tobs %.3f ms", n_samples * sim_conf.hardware.dwell_time_ms) self.log.info("volume TR: %.3f ms", TR_vol_ms) if self.constant: self.log.info("Constant Trajectory") if n_ksp_frames == 0: raise ValueError( "No frame can be generated with the current configuration" " (TR/shot too long or max_sim_time too short)" ) if n_ksp_frames != n_ksp_frames_true: self.log.warning( "Volumic TR does not align with max simulation time, " "last incomplete frame will be discarded." ) self.log.warning("Updating the max_sim_time to match.") sim_conf.max_sim_time = TR_vol_ms * n_ksp_frames / 1000 self.log.info("Start Sampling pattern generation") kspace_data_vol = np.zeros( (n_shots_frame, sim_conf.hardware.n_coils, n_samples), dtype=np.complex64, ) hdr = mrd.xsd.CreateFromDocument(dataset.read_xml_header()) hdr.encoding[0].encodingLimits = mrd.xsd.encodingLimitsType( kspace_encoding_step_0=mrd.xsd.limitType(0, n_samples, n_samples // 2), kspace_encoding_step_1=mrd.xsd.limitType( 0, n_shots_frame, n_shots_frame // 2 ), repetition=mrd.xsd.limitType(0, n_ksp_frames, 0), ) dataset.write_xml_header(mrd.xsd.ToXML(hdr)) # write the updated header back # Write the acquisition. # We create the dataset manually with custom dtype. # Compared to using mrd.Dataset.append_acquisition # - this is faster (20-50%) # - uses fixed sized array (All shot have the same size !) # - allow for smart chunking (useful for reading/writing efficiently) acq_dtype = np.dtype( [ ("head", mrd.hdf5.acquisition_header_dtype), ("data", np.float32, (sim_conf.hardware.n_coils * n_samples * 2,)), ("traj", np.float32, (n_samples * trajectory_dimension,)), ] ) acq_size = np.empty((1,), dtype=acq_dtype).nbytes chunk = int( np.ceil((n_shots_frame * acq_size) / EnvConfig["SNAKE_HDF5_CHUNK_SIZE"]) ) chunk = min(chunk, n_shots_frame) chunk_write_sizes = [ len(c) for c in batched( range(n_shots_frame * n_ksp_frames), int( np.ceil( EnvConfig["SNAKE_HDF5_CHUNK_SIZE"] / (acq_size * n_shots_frame * n_ksp_frames) ) ), ) ] self.log.debug("chunk size for hdf5 %s, elem %s Bytes", chunk, acq_size) pbar = tqdm(total=n_ksp_frames * n_shots_frame) dataset._dataset.create_dataset( "data", shape=(n_ksp_frames * n_shots_frame,), dtype=acq_dtype, chunks=(chunk,), ) write_start = 0 counter = 0 for i in range(n_ksp_frames): kspace_traj_vol = self.get_next_frame(sim_conf) for j in range(n_shots_frame): flags = 0 if j == 0: flags |= ACQ.FIRST_IN_ENCODE_STEP1 flags |= ACQ.FIRST_IN_REPETITION if j == n_shots_frame - 1: flags |= ACQ.LAST_IN_ENCODE_STEP1 flags |= ACQ.LAST_IN_REPETITION if counter == 0: current_chunk_size = chunk_write_sizes.pop() acq_chunk = np.empty((current_chunk_size,), dtype=acq_dtype) acq_chunk[counter]["head"] = np.frombuffer( mrd.AcquisitionHeader( version=1, flags=flags, scan_counter=counter, sample_time_us=self.obs_time_ms * 1000 / n_samples, center_sample=n_samples // 2 if self.in_out else 0, idx=mrd.EncodingCounters( repetition=i, kspace_encode_step_1=j, kspace_encode_step_2=1, ), active_channels=sim_conf.hardware.n_coils, available_channels=sim_conf.hardware.n_coils, number_of_samples=n_samples, trajectory_dimensions=trajectory_dimension, ), dtype=mrd.hdf5.acquisition_header_dtype, ) acq_chunk[counter]["data"] = ( kspace_data_vol[j, :, :].view(np.float32).ravel() ) acq_chunk[counter]["traj"] = np.float32(kspace_traj_vol[j, :]).ravel() counter += 1 if counter == current_chunk_size: counter = 0 # write to hdf5 mrd dataset._dataset["data"][ write_start : write_start + current_chunk_size ] = acq_chunk write_start += current_chunk_size pbar.update(1) pbar.close() return dataset
[docs] class LoadTrajectorySampler(NonCartesianAcquisitionSampler): """Load a trajectory from a file. Parameters ---------- constant: bool If True, the trajectory is constant. obs_time_ms: int Time spent to acquire a single shot in_out: bool If true, the trajectory is acquired with a double join pattern from/to the periphery """ __sampler_name__ = "load-trajectory" __engine__ = "NUFFT" path: str constant: bool = True obs_time_ms: int = 25 raster_time: float = 0.05 in_out: bool = True
[docs] def _single_frame(self, sim_conf: SimConfig) -> NDArray: """Load the trajectory.""" data = read_trajectory(self.path, raster_time=self.raster_time)[0] data = np.minimum(data, 0.5) data = np.maximum(data, -0.5) return data
[docs] class StackOfSpiralSampler(NonCartesianAcquisitionSampler): """ Spiral 2D Acquisition Handler to generate k-space data. Parameters ---------- acsz: float | int Number/ proportion of lines to be acquired in the center of k-space. accelz: int Acceleration factor for the rest of the lines. directionz: Literal["center-out", "random"] Direction of the acquisition. Either "center-out" or "random". pdfz: Literal["gaussian", "uniform"] Probability density function of the sampling. Either "gaussian" or "uniform". obs_ms: int Time spent to acquire a single shot nb_revolutions: int Number of revolutions of the spiral. in_out: bool If true, the spiral is acquired with a double join pattern from/to the periphery **kwargs: Extra arguments (smaps, n_jobs, backend etc...) """ __sampler_name__ = "stack-of-spiral" acsz: float | int accelz: int orderz: str | VDSorder = VDSorder.TOP_DOWN nb_revolutions: int = 10 spiral_name: str = "archimedes" pdfz: str | VDSpdf = VDSpdf.GAUSSIAN constant: bool = False in_out: bool = True rotate_angle: AngleRotation = AngleRotation.ZERO obs_time_ms: int = 30 n_shot_slices: int = 1
[docs] def _single_frame(self, sim_conf: SimConfig) -> NDArray: """Generate the sampling pattern.""" n_samples = int(self.obs_time_ms / sim_conf.hardware.dwell_time_ms) return stack_spiral_factory( shape=sim_conf.shape, accelz=self.accelz, acsz=self.acsz, n_samples=n_samples, nb_revolutions=self.nb_revolutions, pdfz=self.pdfz, orderz=self.orderz, spiral=self.spiral_name, rotate_angle=self.rotate_angle, in_out=self.in_out, n_shot_slices=self.n_shot_slices, rng=sim_conf.rng, )
[docs] class RotatedStackOfSpiralSampler(StackOfSpiralSampler): """ Spiral 2D Acquisition Handler to generate k-space data. Parameters ---------- rotate_frame_angle: AngleRotation | int Angle of rotation of the frame. frame_index: int Index of the frame. **kwargs: Extra arguments (smaps, n_jobs, backend etc...) """ __sampler_name__ = "rotated-stack-of-spiral" rotate_frame_angle: AngleRotation | int = 0 frame_index: int = 0
[docs] def fix_angle_rotation( self, frame: Generator[np.ndarray, None, None], angle: AngleRotation | float = 0 ) -> Generator[np.ndarray, None, None]: """Rotate the trajectory by a given angle.""" for traj in frame: yield from rotate_trajectory((x for x in [traj]), angle)
[docs] def get_next_frame(self, sim_conf: SimConfig) -> NDArray: """Generate the next rotated frame.""" base_frame = self._single_frame(sim_conf) if self.constant or self.rotate_frame_angle == 0: return base_frame else: self.frame_index += 1 rotate_frame_angle = np.pi * (self.rotate_frame_angle / 180) base_frame_gen = (traj[None, ...] for traj in base_frame) rotated_frame = self.fix_angle_rotation( base_frame_gen, float(rotate_frame_angle * self.frame_index) ) return np.concatenate( [traj.astype(np.float32) for traj in rotated_frame], axis=0 )
[docs] class EPI3dAcquisitionSampler(BaseSampler): """Sampling pattern for EPI-3D.""" __sampler_name__ = "epi-3d" __engine__ = "EPI" in_out = True acsz: float | int accelz: int orderz: VDSorder = VDSorder.CENTER_OUT pdfz: VDSpdf = VDSpdf.GAUSSIAN
[docs] def _single_frame(self, sim_conf: SimConfig) -> NDArray: """Generate the sampling pattern.""" return stacked_epi_factory( shape=sim_conf.shape, accelz=self.accelz, acsz=self.acsz, orderz=self.orderz, pdfz=self.pdfz, rng=sim_conf.rng, )
[docs] def add_all_acq_mrd( self, dataset: mrd.Dataset, sim_conf: SimConfig, ) -> mrd.Dataset: """Create the acquisitions associated with this sampler.""" single_frame = self._single_frame(sim_conf) n_shots_frame = single_frame.shape[0] n_lines = sim_conf.shape[1] n_samples = single_frame.shape[1] TR_vol_ms = sim_conf.seq.TR * single_frame.shape[0] n_ksp_frames_true = sim_conf.max_sim_time * 1000 / TR_vol_ms n_ksp_frames = int(n_ksp_frames_true) self.log.info("Generating %d frames", n_ksp_frames) self.log.info("Frame have %d shots", n_shots_frame) self.log.info("Tobs %.3f ms", n_samples * sim_conf.hardware.dwell_time_ms) self.log.info("Shot have %d samples", n_samples) self.log.info("volume TR: %f ms", TR_vol_ms) if n_ksp_frames == 0: raise ValueError( "No frame can be generated with the current configuration" " (TR/shot too long or max_sim_time too short)" ) if n_ksp_frames != n_ksp_frames_true: self.log.warning( "Volumic TR does not align with max simulation time, " "last incomplete frame will be discarded." ) self.log.warning("Updating the max_sim_time to match.") sim_conf.max_sim_time = TR_vol_ms * n_ksp_frames / 1000 self.log.info("Start Sampling pattern generation") counter = 0 zero_data = np.zeros( (sim_conf.hardware.n_coils, sim_conf.shape[2]), dtype=np.complex64 ) # Update the encoding limits. # step 0 : frequency (readout directionz) # step 1 : phase encoding (blip epi) # hdr = mrd.xsd.CreateFromDocument(dataset.read_xml_header()) hdr.encoding[0].encodingLimits = mrd.xsd.encodingLimitsType( kspace_encoding_step_0=mrd.xsd.limitType( 0, sim_conf.shape[2], sim_conf.shape[2] // 2 ), kspace_encoding_step_1=mrd.xsd.limitType( 0, sim_conf.shape[1], sim_conf.shape[1] // 2 ), slice=mrd.xsd.limitType(0, sim_conf.shape[0], sim_conf.shape[0] // 2), repetition=mrd.xsd.limitType(0, n_ksp_frames, 0), ) dataset.write_xml_header(mrd.xsd.ToXML(hdr)) # write the updated header back acq_dtype = np.dtype( [ ("head", mrd.hdf5.acquisition_header_dtype), ( "data", np.float32, (sim_conf.hardware.n_coils * sim_conf.shape[2] * 2,), ), ("traj", np.uint32, (sim_conf.shape[2] * 3,)), ] ) acq_size = np.empty((1,), dtype=acq_dtype).nbytes chunk = int( np.ceil( (n_shots_frame * acq_size * n_lines) / EnvConfig["SNAKE_HDF5_CHUNK_SIZE"] ) ) chunk = min(chunk, n_shots_frame * n_lines) # write at least a chunk per frmae. chunk_write_sizes = [ len(c) for c in batched( range(n_lines * n_shots_frame * n_ksp_frames), int( np.ceil( EnvConfig["SNAKE_HDF5_CHUNK_SIZE"] / (acq_size * n_lines * n_shots_frame * n_ksp_frames) ) ), ) ] self.log.debug("chunk size for hdf5 %s, elem %s Bytes", chunk, acq_size) pbar = tqdm(total=n_ksp_frames * n_shots_frame) dataset._dataset.create_dataset( "data", shape=(n_ksp_frames * n_shots_frame * n_lines), dtype=acq_dtype, chunks=(chunk,), ) write_start = 0 counter = 0 for i in range(n_ksp_frames): stack_epi3d = self.get_next_frame(sim_conf) # of shape N_stack, N, 3 for j, epi2d in enumerate(stack_epi3d): epi2d_r = epi2d.reshape( sim_conf.shape[1], sim_conf.shape[2], 3 ) # reorder to have for k, readout in enumerate(epi2d_r): flags = 0 if k == 0: flags |= ACQ.FIRST_IN_ENCODE_STEP1 flags |= ACQ.FIRST_IN_SLICE if j == 0: flags |= ACQ.FIRST_IN_REPETITION if k == len(epi2d_r) - 1: flags |= ACQ.LAST_IN_ENCODE_STEP1 flags |= ACQ.LAST_IN_SLICE if j == len(stack_epi3d) - 1: flags |= ACQ.LAST_IN_REPETITION if i == n_ksp_frames - 1: flags |= ACQ.LAST_IN_MEASUREMENT if counter == 0: current_chunk_size = chunk_write_sizes.pop() acq_chunk = np.empty((current_chunk_size,), dtype=acq_dtype) acq_chunk[counter]["head"] = np.frombuffer( mrd.AcquisitionHeader( version=1, flags=flags, scan_counter=counter, sample_time_us=sim_conf.hardware.dwell_time_ms * 1000 / n_samples, center_sample=n_samples // 2 if self.in_out else 0, idx=mrd.EncodingCounters( repetition=i, kspace_encode_step_1=readout[0, 1], slice=readout[0, 0], ), read_dir=dir_cos(readout[0], readout[1]), active_channels=sim_conf.hardware.n_coils, available_channels=sim_conf.hardware.n_coils, number_of_samples=len(readout), trajectory_dimensions=3, ), dtype=mrd.hdf5.acquisition_header_dtype, ).copy() acq_chunk[counter]["data"] = zero_data.view(np.float32).ravel() acq_chunk[counter]["traj"] = readout.astype( np.uint32, copy=False ).ravel() counter += 1 if counter == current_chunk_size: counter = 0 # write to hdf5 mrd dataset._dataset["data"][ write_start : write_start + current_chunk_size ] = acq_chunk write_start += current_chunk_size pbar.update(1) pbar.close() dataset._file.flush() # Empty all buffers to disk return dataset
[docs] class EVI3dAcquisitionSampler(BaseSampler): """SAmpler for EVI acquisition.""" __sampler_name__ = "evi" __engine__ = "EVI" in_out = True
[docs] def _single_frame(self, sim_conf: SimConfig) -> NDArray: """Generate the sampling pattern.""" epi_coords = evi_factory( shape=sim_conf.shape, ).reshape(*sim_conf.shape, 3) return epi_coords
[docs] def add_all_acq_mrd( self, dataset: mrd.Dataset, sim_conf: SimConfig, ) -> mrd.Dataset: """Create the acquisitions associated with this sampler.""" single_frame = self._single_frame(sim_conf) n_samples = ( single_frame.shape[1] * single_frame.shape[2] * single_frame.shape[0] ) TR_vol_ms = sim_conf.seq.TR n_ksp_frames_true = sim_conf.max_sim_time * 1000 / TR_vol_ms n_ksp_frames = int(n_ksp_frames_true) self.log.info("Generating %d frames", n_ksp_frames) self.log.info("Frame have %d shots", 1) self.log.info("Tobs %.3f ms", n_samples * sim_conf.hardware.dwell_time_ms) self.log.info("Shot have %d samples", n_samples) self.log.info("volume TR: %f ms", TR_vol_ms) if n_ksp_frames == 0: raise ValueError( "No frame can be generated with the current configuration" " (TR/shot too long or max_sim_time too short)" ) if n_ksp_frames != n_ksp_frames_true: self.log.warning( "Volumic TR does not align with max simulation time, " "last incomplete frame will be discarded." ) self.log.warning("Updating the max_sim_time to match.") sim_conf.max_sim_time = TR_vol_ms * n_ksp_frames / 1000 self.log.info("Start Sampling pattern generation") counter = 0 zero_data = np.zeros( (sim_conf.hardware.n_coils, sim_conf.shape[2]), dtype=np.complex64 ) # Update the encoding limits. # step 0 : frequency (readout directionz) # step 1 : phase encoding (blip epi) # hdr = mrd.xsd.CreateFromDocument(dataset.read_xml_header()) hdr.encoding[0].encodingLimits = mrd.xsd.encodingLimitsType( kspace_encoding_step_0=mrd.xsd.limitType( 0, sim_conf.shape[2], sim_conf.shape[2] // 2 ), kspace_encoding_step_1=mrd.xsd.limitType( 0, sim_conf.shape[1], sim_conf.shape[1] // 2 ), slice=mrd.xsd.limitType(0, sim_conf.shape[0], sim_conf.shape[0] // 2), repetition=mrd.xsd.limitType(0, n_ksp_frames, 0), ) dataset.write_xml_header(mrd.xsd.ToXML(hdr)) # write the updated header back acq_dtype = np.dtype( [ ("head", mrd.hdf5.acquisition_header_dtype), ( "data", np.float32, (sim_conf.hardware.n_coils * sim_conf.shape[2] * 2,), ), ("traj", np.uint32, (sim_conf.shape[2] * 3,)), ] ) # Write the acquisition. # We create the dataset manually with custom dtype. # Compared to using mrd.Dataset.append_acquisition # - this is faster ! # - uses fixed sized array (All shot have the same size !) # - allow for smart chunking (useful for reading/writing efficiently) acq = np.empty( (n_ksp_frames * sim_conf.shape[1] * sim_conf.shape[0],), dtype=acq_dtype ) for i in range(n_ksp_frames): stack_epi3d = self._single_frame(sim_conf) # of shape N_stack, N, 3 for j, epi2d in enumerate(stack_epi3d): epi2d_r = epi2d.reshape( sim_conf.shape[1], sim_conf.shape[2], 3 ) # reorder to have for k, readout in enumerate(epi2d_r): flags = 0 if k == 0: flags |= ACQ.FIRST_IN_ENCODE_STEP1 flags |= ACQ.FIRST_IN_SLICE if j == 0: flags |= ACQ.FIRST_IN_REPETITION if k == len(epi2d_r) - 1: flags |= ACQ.LAST_IN_ENCODE_STEP1 flags |= ACQ.LAST_IN_SLICE if j == len(stack_epi3d) - 1: flags |= ACQ.LAST_IN_REPETITION if i == n_ksp_frames - 1: flags |= ACQ.LAST_IN_MEASUREMENT acq[counter]["head"] = np.frombuffer( mrd.AcquisitionHeader( version=1, flags=flags, scan_counter=counter, sample_time_us=sim_conf.hardware.dwell_time_ms * 1000 / n_samples, center_sample=n_samples // 2 if self.in_out else 0, idx=mrd.EncodingCounters( repetition=i, kspace_encode_step_1=readout[0, 1], slice=readout[0, 0], ), read_dir=dir_cos(readout[0], readout[1]), active_channels=sim_conf.hardware.n_coils, available_channels=sim_conf.hardware.n_coils, number_of_samples=len(readout), trajectory_dimensions=3, ), dtype=mrd.hdf5.acquisition_header_dtype, ).copy() acq[counter]["data"] = zero_data.view(np.float32).ravel() acq[counter]["traj"] = np.float32(readout).view(np.float32).ravel() counter += 1 dataset._dataset.create_dataset( "data", data=acq, chunks=min(sim_conf.shape[1] * sim_conf.shape[0], len(acq)), ) return dataset
[docs] def dir_cos(start: NDArray, end: NDArray) -> tuple[np.float32]: """Compute the directional cosine of the vector from beg to end point.""" diff = np.float32(end) - np.float32(start) cos = diff / np.sqrt(np.sum(diff**2)) return tuple(cos)