"""Reconstructors using PySAP-fMRI toolbox."""
import copy
import os
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Any
import numpy as np
from numpy.typing import NDArray
# Local imports
from snake.mrd_utils import (
CartesianFrameDataLoader,
MRDLoader,
NonCartesianFrameDataLoader,
)
from snake.core.parallel import (
ArrayProps,
SharedMemoryManager,
array_from_shm,
array_to_shm,
)
from snake._meta import NoCaseEnum
from snake.core.simulation import SimConfig
from tqdm.auto import tqdm
from .base import BaseReconstructor
from .fourier import ifft
def _reconstruct_cartesian_frame(
filename: os.PathLike,
idx: int,
smaps_props: ArrayProps | None,
final_props: ArrayProps,
) -> int:
"""Reconstruct a single frame."""
with (
array_from_shm(final_props) as final_images,
CartesianFrameDataLoader(filename) as data_loader,
):
mask, kspace = data_loader.get_kspace_frame(idx)
sim_conf = data_loader.get_sim_conf()
adj_data = ifft(kspace, axis=tuple(range(len(sim_conf.shape), 0, -1)))
if smaps_props is not None and data_loader.n_coils > 1:
with array_from_shm(smaps_props) as smaps_info:
smaps = smaps_info[0]
adj_data_smaps_comb = abs(
np.sum(adj_data * smaps.conj(), axis=0)
/ np.sum(smaps * smaps.conj(), axis=0)
).astype(np.float32, copy=False)
elif data_loader.n_coils > 1:
adj_data_smaps_comb = np.sqrt(np.sum(abs(adj_data) ** 2, axis=0)).astype(
np.float32, copy=False
)
else:
adj_data_smaps_comb = abs(adj_data).astype(np.float32, copy=False)
final_images[0][idx] = adj_data_smaps_comb
return idx
[docs]
class ZeroFilledReconstructor(BaseReconstructor):
"""Zero Filled Reconstructor."""
__reconstructor_name__ = "adjoint"
n_jobs: int = 10
nufft_backend: str = "gpunufft"
density_compensation: str | bool = "pipe"
[docs]
def setup(self, sim_conf: SimConfig) -> None:
"""Initialize Reconstructor."""
pass
[docs]
def reconstruct(self, data_loader: MRDLoader, sim_conf: SimConfig) -> NDArray:
"""Reconstruct data with zero-filled method."""
with data_loader:
if isinstance(data_loader, CartesianFrameDataLoader):
return self._reconstruct_cartesian(data_loader, sim_conf)
elif isinstance(data_loader, NonCartesianFrameDataLoader):
return self._reconstruct_nufft(data_loader, sim_conf)
else:
raise ValueError("Unknown dataloader")
def _reconstruct_cartesian(
self, data_loader: CartesianFrameDataLoader, sim_conf: SimConfig
) -> NDArray:
smaps = data_loader.get_smaps()
if smaps is None and data_loader.n_coils > 1:
raise NotImplementedError("Missing coil combine code.")
final_images = np.ones(
(data_loader.n_frames, *data_loader.shape), dtype=np.float32
)
with (
SharedMemoryManager() as smm,
ProcessPoolExecutor(self.n_jobs) as executor,
tqdm(total=data_loader.n_frames) as pbar,
):
smaps_props = None
if smaps is not None:
smaps_props, smaps_shared, smaps_sm = array_to_shm(smaps, smm)
final_props, final_shared, final_sm = array_to_shm(final_images, smm)
futures = {
executor.submit(
_reconstruct_cartesian_frame,
data_loader._filename,
idx,
smaps_props,
final_props,
): idx
for idx in range(data_loader.n_frames)
}
for future in as_completed(futures):
future.result()
pbar.update(1)
final_images[:] = final_shared.copy()
final_sm.close()
if smaps_props is not None:
smaps_sm.close()
smm.shutdown()
return final_images
[docs]
def _reconstruct_nufft(
self, data_loader: NonCartesianFrameDataLoader, sim_conf: SimConfig
) -> NDArray:
"""Reconstruct data with nufft method."""
from mrinufft import get_operator
smaps = data_loader.get_smaps()
traj, kspace_data = data_loader.get_kspace_frame(0)
kwargs = dict(
shape=data_loader.shape,
n_coils=data_loader.n_coils,
smaps=smaps,
)
print(self.density_compensation, type(self.density_compensation))
if self.density_compensation is False:
kwargs["density"] = None
else:
kwargs["density"] = self.density_compensation
if "stacked" in self.nufft_backend:
kwargs["z_index"] = "auto"
nufft_operator = get_operator(
self.nufft_backend,
samples=traj,
**kwargs,
)
final_images = np.empty(
(data_loader.n_frames, *data_loader.shape), dtype=np.float32
)
for i in tqdm(range(data_loader.n_frames)):
traj, data = data_loader.get_kspace_frame(i)
nufft_operator.samples = traj
final_images[i] = abs(nufft_operator.adj_op(data))
return final_images
[docs]
class RestartStrategy(NoCaseEnum):
"""Restart strategies for the reconstruction."""
WARM = "warm"
COLD = "cold"
REFINE = "refine"
[docs]
class SequentialReconstructor(BaseReconstructor):
"""Use a sequential Reconstruction.
Parameters
----------
max_iter_frame
Number of iteration to allow per frame.
optimizer
Optimizer name, available are pogm and fista.
threshold
Threshold value for the wavelet regularisation.
"""
__reconstructor_name__ = "sequential"
max_iter_per_frame: int = 15
optimizer: str = "pogm"
wavelet: str = "db4"
threshold: float | str = "sure"
nufft_backend: str = "gpunufft"
density_compensation: str | bool = "pipe"
restart_strategy: RestartStrategy = RestartStrategy.WARM
compute_backend: str = "cupy"
def __str__(self) -> str:
"""Return a string representation of the reconstructor."""
return f"{self.__reconstructor_name__}-{self.restart_strategy}"
[docs]
def setup(self, sim_conf: SimConfig) -> None:
"""Set up the reconstructor."""
from fmri.operators.weighted import AutoWeightedSparseThreshold
from modopt.opt.linear import Identity
from modopt.opt.linear.wavelet import WaveletTransform
from modopt.opt.proximity import SparseThreshold
from modopt.base.backend import get_backend
self.space_linear_op = WaveletTransform(
self.wavelet,
shape=sim_conf.shape,
level=3,
mode="zero",
compute_backend=self.compute_backend,
)
xp, _ = get_backend(self.compute_backend)
_ = self.space_linear_op.op(xp.zeros(sim_conf.shape, dtype=np.complex64))
if self.threshold == "sure":
self.space_prox_op = AutoWeightedSparseThreshold(
self.space_linear_op.coeffs_shape,
linear=None,
threshold_estimation="hybrid-sure",
threshold_scaler=0.6,
)
else:
self.threshold = float(self.threshold)
self.space_prox_op = SparseThreshold(
linear=Identity(), weights=self.threshold
)
[docs]
def reconstruct(self, data_loader: MRDLoader, sim_conf: SimConfig) -> np.ndarray:
"""Reconstruct with Sequential."""
self.setup(sim_conf)
from fmri.operators.gradient import (
GradAnalysis,
GradSynthesis,
)
from modopt.base.backend import get_backend
from mrinufft import get_operator
from mrinufft.density import pipe
xp, _ = get_backend(self.compute_backend)
traj, data = data_loader.get_kspace_frame(0)
smaps = data_loader.get_smaps()
density_compensation = self.density_compensation
if (
isinstance(self.density_compensation, str)
and "first" in self.density_compensation
):
density_compensation = False
kwargs = {}
if "stacked" in self.nufft_backend:
kwargs["z_index"] = "auto"
if self.nufft_backend == "cufinufft":
kwargs["smaps_cached"] = True
fourier_op = get_operator(
self.nufft_backend,
samples=traj,
shape=data_loader.shape,
n_coils=data_loader.n_coils,
smaps=smaps,
# smaps=xp.array(smaps) if smaps is not None else None,
density=density_compensation,
**kwargs,
)
final_estimate = np.zeros(
(data_loader.n_frames, *data_loader.shape), dtype=np.float32
)
grad_kwargs = dict(
fourier_op=fourier_op,
input_data_writeable=True,
dtype=np.complex64,
compute_backend=self.compute_backend,
num_check_lips=0,
verbose=0,
)
if self.optimizer in ["fista"]:
grad_op = GradAnalysis(**grad_kwargs)
if self.optimizer in ["pogm"]:
grad_op = GradSynthesis(linear_op=self.space_linear_op, **grad_kwargs)
x_init = xp.zeros(sim_conf.shape, dtype=np.complex64)
if (
isinstance(self.density_compensation, str)
and "first" in self.density_compensation
):
density_comp_vector = pipe(traj, sim_conf.shape, self.nufft_backend)
x_init = fourier_op.adj_op(xp.array(data * density_comp_vector, copy=False))
else:
x_init = fourier_op.adj_op(xp.array(data, copy=False))
pbar_frames = tqdm(total=data_loader.n_frames, position=0)
pbar_iter = tqdm(total=self.max_iter_per_frame, position=1)
for i, traj, data in data_loader.iter_frames():
grad_op.fourier_op.samples = traj
spec_rad = grad_op.fourier_op.get_lipschitz_cst(20)
grad_op._obs_data = xp.array(data)
grad_op.spec_rad = spec_rad
grad_op.inv_spec_rad = 1 / spec_rad
x_iter = self._reconstruct_frame(
grad_op,
x_init,
n_iter=self.max_iter_per_frame,
progbar=pbar_iter,
)
# Prepare for next iteration and save results
x_init = (
x_iter.copy()
if self.restart_strategy != RestartStrategy.COLD
else x_init.copy()
)
if self.compute_backend == "cupy":
final_estimate[i, ...] = abs(x_iter).get() # type: ignore
else:
final_estimate[i, ...] = abs(x_iter)
pbar_frames.update(1)
if self.restart_strategy != RestartStrategy.REFINE:
return final_estimate
# else, we do a second pass on the data using the last iteration as a slotion.
pbar_frames.reset()
pbar_iter.reset()
x_init = x_iter.copy() # last iteration results.
for i, traj, data in data_loader.iter_frames():
grad_op.fourier_op.samples = traj
spec_rad = grad_op.fourier_op.get_lipschitz_cst()
grad_op._obs_data = xp.array(data)
grad_op.spec_rad = spec_rad
grad_op.inv_spec_rad = 1 / spec_rad
x_iter = self._reconstruct_frame(
grad_op,
x_init,
n_iter=self.max_iter_per_frame,
progbar=pbar_iter,
)
if self.compute_backend == "cupy":
final_estimate[i, ...] = abs(x_iter).get() # type: ignore
else:
final_estimate[i, ...] = abs(x_iter)
pbar_frames.update(1)
return final_estimate
def _reconstruct_frame(
self,
grad_op: Any,
x_init: NDArray,
n_iter: int = 15,
progbar: tqdm | None = None,
) -> NDArray:
from fmri.reconstructors.utils import initialize_opt
from modopt.base.backend import get_backend
xp, _ = get_backend(self.compute_backend)
# only recreate gradient if the trajectory change.
# reset Smaps and optimizer if required.
opt = initialize_opt(
opt_name=self.optimizer,
grad_op=grad_op,
linear_op=copy.deepcopy(self.space_linear_op),
prox_op=copy.deepcopy(self.space_prox_op),
x_init=x_init,
synthesis_init=False,
metric_kwargs={},
compute_backend=self.compute_backend,
opt_kwargs=dict(
verbose=0,
cost="auto",
),
)
# if no reset, the internal state is kept.
if progbar is not None:
progbar.reset(total=n_iter)
opt.iterate(max_iter=n_iter, progbar=progbar)
if hasattr(grad_op, "linear_op"):
img = grad_op.linear_op.adj_op(opt.x_final)
else:
img = opt.x_final
return img