Source code for snake.core.simulation
"""SImulation base objects."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
import numpy as np
def _repr_html_(obj: Any, vertical: bool = True) -> str:
"""
Recursive HTML representation for dataclasses.
This function generates an HTML table representation of a dataclass,
including nested dataclasses.
Parameters
----------
obj: The dataclass instance.
Returns
-------
str: An HTML table string representing the dataclass.
"""
class_name = obj.__class__.__name__
table_rows = [
'<table style="border:1px solid lightgray;">'
'<caption style="border:1px solid lightgray;">'
f"<strong>{class_name}</strong></caption>"
]
from typing import get_type_hints
from dataclasses import fields
resolved_hints = get_type_hints(obj)
field_names = [f.name for f in fields(obj)]
field_values = {name: getattr(obj, name) for name in field_names}
resolved_field_types = {name: resolved_hints[name] for name in field_names}
if vertical: # switch between vertical and horizontal mode
for field_name in field_names:
# Recursively call _repr_html_ for nested dataclasses
field_value = field_values[field_name]
field_type = resolved_field_types[field_name].__name__
try:
field_value_str = field_value._repr_html_(vertical=not vertical)
except AttributeError:
field_value_str = repr(field_value)
table_rows.append(
f"<tr><td>{field_name}(<i>{field_type}</i>)</td>"
f"<td>{field_value_str}</td></tr>"
)
else:
table_rows.append(
"<tr>"
+ "".join(
[
f"<td>{field_name} (<i>{field_type}</i>)</td>"
for field_name, field_type in resolved_field_types.items()
]
)
+ "</tr>"
)
values = []
for field_value in field_values.values():
# Recursively call _repr_html_ for nested dataclasses
try:
field_value_str = field_value._repr_html_(
vertical=not vertical
) # alternates orientation
except AttributeError:
field_value_str = repr(field_value)
values.append(f"<td>{field_value_str}</td>")
table_rows.append("<tr>" + "".join(values) + "</tr>")
table_rows.append("</table>")
return "\n".join(table_rows)
[docs]
@dataclass
class GreConfig:
"""Gradient Recall Echo Sequence parameters."""
TR: float
TE: float
FA: float
_repr_html_ = _repr_html_
[docs]
@dataclass
class HardwareConfig:
"""Scanner Hardware parameters."""
gmax: float = 40
smax: float = 200
n_coils: int = 8
dwell_time_ms: float = 1e-3
raster_time_ms: float = 5e-3
field: float = 3.0
_repr_html_ = _repr_html_
default_hardware = HardwareConfig()
default_gre = GreConfig(TR=50, TE=30, FA=15)
[docs]
@dataclass
class SimConfig:
"""All base configuration of a simulation."""
max_sim_time: float = 300
seq: GreConfig = field(default_factory=lambda: GreConfig(TR=50, TE=30, FA=15))
hardware: HardwareConfig = field(default_factory=lambda: HardwareConfig())
fov_mm: tuple[float, float, float] = (192.0, 192.0, 128.0)
shape: tuple[int, int, int] = (192, 192, 128) # Target reconstruction shape
rng_seed: int = 19290506
_repr_html_ = _repr_html_
def __post_init__(self) -> None:
# To be compatible with frozen dataclass
self.rng: np.random.Generator = np.random.default_rng(self.rng_seed)
@property
def max_n_shots(self) -> int:
"""Maximum number of frames."""
return int(self.max_sim_time * 1000 / self.sim_tr_ms)
@property
def res_mm(self) -> tuple[float, ...]:
"""Voxel resolution in mm."""
return tuple(self.fov_mm[i] / self.shape[i] for i in range(3))
@property
def sim_tr_ms(self) -> float:
"""Simulation resolution in ms."""
return self.seq.TR