Source code for snake.toolkit.cli.config
"""Configuration of SNAKE using Hydra."""
from pathlib import Path
from typing import Any
from dataclasses import dataclass, field
import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import OmegaConf, DictConfig
from snake.core.simulation import SimConfig
from snake.core.phantom.static import TissueFile
from snake.core.handlers import AbstractHandler
from snake.core.sampling import BaseSampler
from snake.toolkit.reconstructors import BaseReconstructor
[docs]
@dataclass
class EngineConfig:
"""Engine configuration for SNAKE."""
n_jobs: int = 1
chunk_size: int = 1
model: str = "simple"
snr: float = 0
nufft_backend: str = "finufft"
[docs]
@dataclass
class PhantomConfig:
"""PhantomConfig."""
name: str = "brainweb"
sub_id: int = 4
tissue_select: list[str] = field(default_factory=list)
tissue_ignore: list[str] = field(default_factory=list)
tissue_file: str | TissueFile = TissueFile.tissue_1T5
[docs]
@dataclass
class StatConfig:
"""Statistical configuration for SNAKE."""
roi_tissue_name: str = "ROI"
roi_threshold: float = 0.5
event_name: str = "block_on"
[docs]
@dataclass
class ConfigSNAKE:
"""Configuration for SNAKE."""
handlers: Any
sampler: Any
reconstructors: Any
sim_conf: SimConfig = SimConfig()
engine: EngineConfig = EngineConfig()
phantom: PhantomConfig = PhantomConfig()
stats: StatConfig = StatConfig()
cache_dir: Path = "${oc.env:PWD}/cache" # type: ignore
result_dir: Path = "${oc.env:PWD}/results" # type: ignore
filename: Path = "test.mrd" # type: ignore
[docs]
def conf_validator(cfg: DictConfig) -> ConfigSNAKE:
"""Validate the simulation configuration."""
cfg_obj: ConfigSNAKE = OmegaConf.to_object(cfg)
cfg_obj.sim_conf.fov_mm = tuple(cfg_obj.sim_conf.fov_mm)
cfg_obj.sim_conf.shape = tuple(cfg_obj.sim_conf.shape)
return cfg_obj
# Custom Resolver for OmegaConf
# allows to do:
# _target_: {$snake.handler:motion-image}
# instead of hardcoding the path to the class
[docs]
def snake_handler_resolver(name: str) -> str:
"""Get Custom resolver for OmegaConf to get handler name."""
from snake.core.handlers import H
cls = H[name]
return cls.__module__ + "." + cls.__name__
[docs]
def snake_sampler_resolver(name: str) -> str:
"""Get Custom resolver for OmegaConf to get handler name."""
from snake.core.sampling import BaseSampler
cls = BaseSampler.__registry__[name]
return cls.__module__ + "." + cls.__name__
OmegaConf.register_new_resolver("snake.handler", snake_handler_resolver)
OmegaConf.register_new_resolver("snake.sampler", snake_sampler_resolver)
cs = ConfigStore.instance()
cs.store(name="base_config", node=ConfigSNAKE)
for handler_name, h_cls in AbstractHandler.__registry__.items():
cs.store(group="handlers", name=handler_name, node={handler_name: h_cls})
for sampler, s_cls in BaseSampler.__registry__.items():
cs.store(group="sampler", name=sampler, node={sampler: s_cls})
for reconstructor, r_cls in BaseReconstructor.__registry__.items():
cs.store(group="reconstructors", name=reconstructor, node={reconstructor: r_cls})
[docs]
def cleanup_cuda() -> None:
"""Cleanup CUDA."""
import cupy as cp
cp.get_default_memory_pool().free_all_blocks()
cp.get_default_pinned_memory_pool().free_all_blocks()
cp._default_memory_pool = cp.cuda.MemoryPool()
cp._default_pinned_memory_pool = cp.cuda.PinnedMemoryPool()
[docs]
def make_hydra_cli(fun: callable) -> callable:
"""Create a Hydra CLI for the function."""
return hydra.main(
version_base=None, config_path="../../../cli-conf", config_name="config"
)(fun)