"""K-space trajectory factories."""
from __future__ import annotations
import logging
from collections.abc import Callable, Generator, Mapping, Sequence
from typing import Any
import numpy as np
from numpy.typing import NDArray
from mrinufft.trajectories.maths import R2D
from mrinufft.trajectories.tools import rotate, stack
from mrinufft.trajectories.trajectory2D import (
initialize_2D_radial,
initialize_2D_spiral,
)
from mrinufft.trajectories.utils import (
check_hardware_constraints,
compute_gradients_and_slew_rates,
)
from scipy.stats import norm # type: ignore
from ..._meta import NoCaseEnum
logger = logging.getLogger(__name__)
SlicerType = list[slice | np.ndarray[Any, np.dtype[np.int64]] | int]
[docs]
def validate_rng(rng: int | None | np.random.Generator = None) -> np.random.Generator:
"""Validate Random Number Generator."""
if isinstance(rng, int | list):
return np.random.default_rng(rng)
elif rng is None:
return np.random.default_rng()
elif isinstance(rng, np.random.Generator):
return rng
else:
raise ValueError("rng should be a numpy Generator, None or an integer seed.")
[docs]
class VDSorder(NoCaseEnum):
"""Available ordering for variable density sampling."""
CENTER_OUT = "center-out"
RANDOM = "random"
TOP_DOWN = "top-down"
[docs]
class VDSpdf(NoCaseEnum):
"""Available law for variable density sampling."""
GAUSSIAN = "gaussian"
UNIFORM = "uniform"
EQUISPACED = "equispaced"
[docs]
def get_kspace_slice_loc(
dim_size: int,
center_prop: int | float,
accel: int = 4,
pdf: VDSpdf = VDSpdf.GAUSSIAN,
rng: int | None | np.random.Generator = None,
order: VDSorder = VDSorder.CENTER_OUT,
) -> np.ndarray:
"""Get slice index at a random position.
Parameters
----------
dim_size: int
Dimension size
center_prop: float or int
Proportion of center of kspace to continuouly sample
accel: float
Undersampling/Acceleration factor
pdf: str, optional
Probability density function for the remaining samples.
"gaussian" (default) or "uniform".
rng: random state
Returns
-------
np.ndarray: array of size dim_size/accel.
"""
order = VDSorder(order)
pdf = VDSpdf(pdf)
if accel == 0:
return np.arange(dim_size) # type: ignore
indexes = list(range(dim_size))
if not isinstance(center_prop, int):
center_prop = int(center_prop * dim_size)
center_start = (dim_size - center_prop) // 2
center_stop = (dim_size + center_prop) // 2
center_indexes = indexes[center_start:center_stop]
borders = np.asarray([*indexes[:center_start], *indexes[center_stop:]])
n_samples_borders = (dim_size - len(center_indexes)) // accel
if n_samples_borders < 1:
raise ValueError(
"acceleration factor, center_prop and dimension not compatible."
"Edges will not be sampled. "
)
rng = validate_rng(rng)
def _get_samples(p: NDArray) -> list[int]:
p /= np.sum(p)
return list(rng.choice(borders, size=n_samples_borders, replace=False, p=p))
if pdf is VDSpdf.GAUSSIAN:
p = norm.pdf(np.linspace(norm.ppf(0.001), norm.ppf(0.999), len(borders)))
sampled_in_border = _get_samples(p)
elif pdf is VDSpdf.UNIFORM:
p = np.ones(len(borders))
sampled_in_border = _get_samples(p)
elif pdf is VDSpdf.EQUISPACED:
sampled_in_border = list(borders[::accel])
pass
else:
raise ValueError("Unsupported value for pdf.")
# TODO: allow custom pdf as argument (vector or function.)
line_locs = np.array(sorted(center_indexes + sampled_in_border))
# apply order of lines
if order == VDSorder.CENTER_OUT:
line_locs = flip2center(sorted(line_locs), dim_size // 2)
elif order == VDSorder.RANDOM:
line_locs = rng.permutation(line_locs)
elif order == VDSorder.TOP_DOWN:
line_locs = np.array(sorted(line_locs))
else:
raise ValueError(f"Unknown direction '{order}'.")
return line_locs
[docs]
def get_cartesian_mask(
shape: tuple[int, ...],
n_frames: int,
rng: int | None | np.random.Generator = None,
constant: bool = False,
center_prop: float | int = 0.3,
accel: int = 4,
accel_axis: int = 0,
pdf: VDSpdf = VDSpdf.GAUSSIAN,
) -> np.ndarray:
"""
Get a cartesian mask for fMRI kspace data.
Parameters
----------
shape: tuple
shape of fMRI volume.
n_frames: int
number of frames.
rng: Generator or int or None (default)
Random number generator or seed.
constant: bool
If True, the mask is constant across time.
center_prop: float
Proportion of center of kspace to continuouly sample
accel: float
Undersampling/Acceleration factor
pdf: str, optional
Probability density function for the remaining samples.
"gaussian" (default) or "uniform".
rng: random state
Returns
-------
np.ndarray: random mask for an acquisition.
"""
rng = validate_rng(rng)
mask = np.zeros((n_frames, *shape))
slicer: SlicerType = [slice(None, None, None)] * (1 + len(shape))
if accel_axis < 0:
accel_axis = len(shape) + accel_axis
if not (0 < accel_axis < len(shape)):
raise ValueError(
"accel_axis should be lower than the number of spatial dimension."
)
if constant:
mask_loc = get_kspace_slice_loc(shape[accel_axis], center_prop, accel, pdf, rng)
slicer[accel_axis + 1] = mask_loc
mask[tuple(slicer)] = 1
return mask
for i in range(n_frames):
mask_loc = get_kspace_slice_loc(shape[accel_axis], center_prop, accel, pdf, rng)
slicer[0] = i
slicer[accel_axis + 1] = mask_loc
mask[tuple(slicer)] = 1
return mask
[docs]
def flip2center(mask_cols: Sequence[int], center_value: int) -> np.ndarray:
"""
Reorder a list by starting by a center_position and alternating left/right.
Parameters
----------
mask_cols: list or np.array
List of columns to reorder.
center_pos: int
Position of the center column.
Returns
-------
np.array: reordered columns.
"""
center_pos = np.argmin(np.abs(np.array(mask_cols) - center_value))
mask_cols = list(mask_cols)
left = mask_cols[center_pos::-1]
right = mask_cols[center_pos + 1 :]
new_cols = []
while left or right:
if left:
new_cols.append(left.pop(0))
if right:
new_cols.append(right.pop(0))
return np.array(new_cols)
[docs]
def check_trajectory(
trajectory: NDArray, osf: int, gmax: float, smax: float
) -> np.bool_:
"""Check if a trajectory is feasible or not."""
grads, slew = compute_gradients_and_slew_rates(trajectory[:, ::osf, :])
is_ok, max_grad, max_slew = check_hardware_constraints(grads, slew, gmax, smax)
return np.all(is_ok)
[docs]
def vds_factory(
shape: tuple[int, ...],
acs: float | int,
accel: int,
accel_axis: int,
order: VDSorder = VDSorder.CENTER_OUT,
pdf: VDSpdf = VDSpdf.GAUSSIAN,
rng: int | None | np.random.Generator = None,
) -> np.ndarray:
"""
Create a variable density sampling trajectory.
Parameters
----------
shape
Shape of the kspace.
acs
autocalibration line number (int) or proportion (float)
direction
Direction of the sampling.
TR
Time to acquire the k-space. Exclusive with base_TR.
base_TR
Time to acquire a full volume in the base trajectory. Exclusive with TR.
pdf
Probability density function of the sampling. "gaussian" or "uniform"
rng
Random number generator or seed.
Returns
-------
KspaceTrajectory
Variable density sampling trajectory.
"""
if accel_axis < 0:
accel_axis = len(shape) + accel_axis
if not (0 <= accel_axis < len(shape)):
raise ValueError(
"accel_axis should be lower than the number of spatial dimension."
)
line_locs = get_kspace_slice_loc(shape[accel_axis], acs, accel, pdf, rng, order)
# initialize the trajetory. -1 is the default value,
# and we put the line index in the correct axis (0-indexed)
shots = -np.ones((len(line_locs), 1, len(shape)), dtype=np.int32)
for shot_idx, line_loc in enumerate(line_locs):
shots[shot_idx, :, accel_axis] = line_loc
return shots
[docs]
def radial_factory(
shape: tuple[int, ...],
n_shots: int,
n_points: int,
expansion: str | None = None,
n_repeat: int = 0,
**kwargs: Mapping[str, Any],
) -> np.ndarray:
"""Create a radial sampling trajectory."""
traj_points = initialize_2D_radial(n_shots, n_points)
if len(shape) == 3:
if expansion is None:
raise ValueError("Expansion should be provided for 3D radial sampling.")
if n_repeat is None:
raise ValueError("n_repeat should be provided for 3D radial sampling.")
if expansion == "stacked":
traj_points = stack(
traj_points,
nb_stacks=n_repeat,
)
elif expansion == "rotated":
traj_points = rotate(
traj_points,
nb_rotations=n_repeat,
)
else:
raise ValueError("Only 2D and 3D trajectories are supported.")
return traj_points
[docs]
def stack_spiral_factory(
shape: tuple[int, ...],
accelz: int,
acsz: int | float,
n_samples: int,
nb_revolutions: int,
in_out: bool = True,
spiral: str = "archimedes",
n_shot_slices: int = 1,
orderz: VDSorder = VDSorder.CENTER_OUT,
pdfz: VDSpdf = VDSpdf.GAUSSIAN,
rng: int | None | np.random.Generator = None,
rotate_angle: AngleRotation | float = 0.0,
) -> np.ndarray:
"""Generate a trajectory of stack of spiral."""
sizeZ = shape[-1]
z_index = get_kspace_slice_loc(sizeZ, acsz, accelz, pdf=pdfz, rng=rng, order=orderz)
rotate_angle = float(rotate_angle)
spiral2D = initialize_2D_spiral(
Nc=n_shot_slices,
Ns=n_samples,
nb_revolutions=nb_revolutions,
spiral=spiral,
in_out=in_out,
).reshape(-1, 2)
z_kspace = (z_index - sizeZ // 2) / sizeZ
# create the equivalent 3d trajectory
nsamples = len(spiral2D) // n_shot_slices
spiral2D = spiral2D.reshape(n_shot_slices, nsamples, 2)
nz = len(z_kspace)
kspace_locs3d = np.zeros((nz * n_shot_slices, nsamples, 3), dtype=np.float32)
# TODO use numpy api for this ?
for i in range(nz):
if rotate_angle != 0:
rotated_spiral = spiral2D @ R2D(rotate_angle * i)
else:
rotated_spiral = spiral2D
kspace_locs3d[i * n_shot_slices : (i + 1) * n_shot_slices, :, :2] = (
rotated_spiral
)
kspace_locs3d[i * n_shot_slices : (i + 1) * n_shot_slices, :, 2] = z_kspace[i]
return kspace_locs3d.astype(np.float32)
#########################################
# Generators #
#########################################
[docs]
class AngleRotation(float, NoCaseEnum):
"""Available rotation angle for density sampling."""
ZERO = 0
GOLDEN = 2.39996322972865332 # 2pi(2-phi)
GOLDEN_MRI = 1.941678793 # 115.15 deg
[docs]
def rotate_trajectory(
trajectories: Generator[np.ndarray, None, None], theta: AngleRotation | float = 0
) -> Generator[np.ndarray, None, None]:
"""Incrementally rotate a trajectory.
Parameters
----------
trajectories:
Trajectory to rotate.
"""
if not isinstance(theta, float):
theta = theta.value
for traj in trajectories:
if traj.ndim == 2:
rot = np.array(
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
)
else:
rot = np.array(
[
[np.cos(theta), -np.sin(theta), 0],
[np.sin(theta), np.cos(theta), 0],
[0, 0, 1],
]
)
theta += theta
yield np.einsum("ij,klj->kli", rot, traj)
[docs]
def stacked_epi2d(
shape: tuple[int, int, int],
freq_locs: NDArray,
phase_locs: NDArray,
slice_locs: NDArray,
) -> NDArray:
"""Generate a list of 2D epi plane, stacked."""
Ns, Np, Nf = map(len, (slice_locs, phase_locs, freq_locs))
epi_3d_coord = np.zeros((Nf * Np * Ns, 3), dtype=np.uint32)
coord = 0
for s in range(Ns):
for p in range(0, Np, 2):
for f in range(Nf):
epi_3d_coord[coord] = (slice_locs[s], phase_locs[p], freq_locs[f])
coord += 1
if p + 1 >= Np:
continue # no room for an extra trip
for f in range(Nf - 1, -1, -1):
epi_3d_coord[coord] = (slice_locs[s], phase_locs[p + 1], freq_locs[f])
coord += 1
return epi_3d_coord
[docs]
def stacked_epi_factory(
shape: tuple[int, int, int],
acsz: int | float,
accelz: int,
orderz: VDSorder = VDSorder.CENTER_OUT,
pdfz: VDSpdf = VDSpdf.GAUSSIAN,
rng: int | None | np.random.Generator = None,
) -> np.ndarray:
"""Generate a VDS stack of fully sampled EPI trajectory."""
sizeZ = shape[0]
z_index = get_kspace_slice_loc(sizeZ, acsz, accelz, pdf=pdfz, rng=rng, order=orderz)
epi_3d_coord = stacked_epi2d(
shape,
slice_locs=z_index,
phase_locs=np.arange(shape[1]),
freq_locs=np.arange(shape[2]),
)
epi3d_stacked = epi_3d_coord.reshape(len(z_index), shape[1] * shape[2], 3)
return epi3d_stacked
[docs]
def evi_factory(
shape: tuple[int, int, int],
) -> np.ndarray:
"""Generate a Echo Volume Imaging trajectory."""
epi_3d_coord = stacked_epi2d(
shape,
slice_locs=np.arange(shape[0]),
phase_locs=np.arange(shape[1]),
freq_locs=np.arange(shape[2]),
)
evi = epi_3d_coord.reshape(-1, 3)
return evi
[docs]
def trajectory_generator(
traj_factory: Callable[..., np.ndarray],
shape: tuple[int, ...],
**kwargs: Any,
) -> Generator[np.ndarray, None, None]:
"""Generate a trajectory.
Parameters
----------
traj_factory
Trajectory factory function.
n_batch
Number of shot to deliver at once.
kwargs
Trajectory factory kwargs.
Yields
------
np.ndarray
Kspace trajectory.
"""
if kwargs.pop("constant", False):
logger.debug("constant trajectory")
traj = traj_factory(shape, **kwargs)
while True:
yield traj
while True:
yield traj_factory(shape, **kwargs)