"""Utility functions for motion generation."""
import numpy as np
from numpy.typing import NDArray, DTypeLike
from functools import partial
from scipy.ndimage import affine_transform, shift
[docs]
def motion_generator(
n_frames: int,
t_std: tuple[float, float, float],
r_std: tuple[float, float, float],
time_res: float,
rng: np.random.Generator,
) -> np.ndarray:
"""Generate a motion trajectory.
Parameters
----------
n_frames
Number of frames.
t_std
Translation standard deviation, in mm/s
r_std
Rotation standard deviation, in radians/s.
time_res
Time resolution, in seconds.
rng
Random number generator.
Returns
-------
np.ndarray
Cumulative Motion trajectory.
Notes
-----
The motion is generated by drawing from a normal distribution with standard
deviation for the 6 motion parameters: 3 rotation (x,y,z) and 3 translations.
Then the cumulative motion is computed by summing the motion at each frame, to
derive the absolute displacement compared to the base reference.
"""
t_speeds = rng.normal(0, t_std, (n_frames, 3))
r_speeds = rng.normal(0, r_std, (n_frames, 3))
t_motion = t_speeds * time_res
r_motion = r_speeds * time_res
motions = np.concatenate([t_motion, r_motion], axis=1)
cum_motion = np.cumsum(motions, axis=0)
return cum_motion
[docs]
def rotation(
x: float = 0.0,
y: float = 0.0,
z: float = 0.0,
dtype: DTypeLike = "float32",
) -> NDArray:
"""Create an array with a 4 dimensional rotation matrix.
Parameters
----------
x, y, z : scalar
Rotation around the origin (in rad).
Returns
-------
r : array, shape = (4, 4)
The rotation matrix.
"""
r = np.eye(4)
r[:3, :3] = rotation3d(x=x, y=y, z=z, dtype=dtype)
return r
[docs]
def rotation2D(
angle: float,
dtype: DTypeLike = "float32",
) -> NDArray:
"""Create an array with a 2D rotation matrix."""
r = np.array(
[
[np.cos(angle), -np.sin(angle)],
[np.sin(angle), np.cos(angle)],
],
dtype=dtype,
)
return r
[docs]
def rotation3d(
x: float = 0.0,
y: float = 0.0,
z: float = 0.0,
dtype: DTypeLike = "float32",
) -> NDArray:
"""Create an array with a 3 dimensional rotation matrix.
Parameters
----------
x, y, z : scalar
Rotation around each axis (in rad).
Returns
-------
r : array, shape = (3, 3)
The rotation matrix.
"""
cos_x = np.cos(x)
cos_y = np.cos(y)
cos_z = np.cos(z)
sin_x = np.sin(x)
sin_y = np.sin(y)
sin_z = np.sin(z)
r = np.array(
[
[
cos_y * cos_z,
-cos_x * sin_z + sin_x * sin_y * cos_z,
sin_x * sin_z + cos_x * sin_y * cos_z,
],
[
cos_y * sin_z,
cos_x * cos_z + sin_x * sin_y * sin_z,
-sin_x * cos_z + cos_x * sin_y * sin_z,
],
[-sin_y, sin_x * cos_y, cos_x * cos_y],
],
dtype=float,
)
return r
[docs]
def translation(
x: float = 0.0,
y: float = 0.0,
z: float = 0.0,
dtype: DTypeLike = "float32",
) -> np.ndarray:
"""Create an array with a translation matrix.
Parameters
----------
x, y, z : scalar
Translation parameters.
Returns
-------
m : array, shape = (4, 4)
The translation matrix.
"""
m = np.array(
[
[1, 0, 0, x],
[0, 1, 0, y],
[0, 0, 1, z],
[0, 0, 0, 1],
],
dtype=dtype,
)
return m
[docs]
def apply_rotation_at_center(
data: NDArray,
angles: tuple[float, float, float],
) -> NDArray:
"""Apply the rotation at the center of the array."""
c = tuple(data.shape[i] / 2 for i in range(3))
# We build the pull affine matrix (From moved to origin)
T = translation(c[0], c[1], c[2], dtype=np.float32)
Tinv = translation(-c[0], -c[1], -c[2], dtype=np.float32)
rad_angles = np.deg2rad(angles)
R = rotation(
rad_angles[0],
rad_angles[1],
rad_angles[2],
dtype=np.float32,
)
M = T @ np.linalg.inv(R) @ Tinv
return affine_transform(data, M)
apply_shift = partial(shift, mode="nearest")
[docs]
def add_motion(
data: NDArray[np.complexfloating] | NDArray[np.floating],
motion_params: NDArray[np.floating],
idx: int = 0,
) -> np.ndarray:
"""Add motion to a base array.
Parameters
----------
data: np.ndarray
The data to which motion is added.
motion: np.ndarray
The N_frames x 6 motion trajectory.
frame_idx: int
The frame index used to compute the motion at that frame.
Returns
-------
np.ndarray
The data with motion added.
"""
rotated = apply_rotation_at_center(data, tuple(motion_params[3:]))
rotated_and_translated = apply_shift(rotated, tuple(motion_params[:3]))
return rotated_and_translated