from __future__ import annotations from dataclasses import dataclass, field from itertools import product from math import prod from typing import Any import numpy as np from .distributions import BaseDistribution from .motions import BaseMotion __all__ = [ 'SimParameter', 'MoleculeParameter', 'StimEchoParameter', 'SpectrumParameter', 'DistParameter', 'MotionParameter', 'Parameter', ] @dataclass class SimParameter: seed: int | None num_walker: int t_max: float @dataclass class MoleculeParameter: delta: float eta: float @dataclass class StimEchoParameter: t_evo: np.ndarray t_mix: np.ndarray t_echo: float t_max: float = field(init=False) def __post_init__(self): self.t_max = np.max(self.t_mix) + 2 * np.max(self.t_evo) + 2*self.t_echo @dataclass class SpectrumParameter: dwell_time: float num_points: int t_echo: np.ndarray t_acq: np.ndarray = field(init=False) t_max: float = field(init=False) lb: float dampening: np.ndarray = field(init=False) def __post_init__(self): self.t_acq = np.arange(self.num_points) * self.dwell_time self.dampening = np.exp(-self.lb * self.t_acq) self.t_max = np.max(self.t_acq) + 2 * np.max(self.t_echo) @dataclass class DistParameter: dist_type: BaseDistribution variables: field(default_factory=dict) num_variables: int = 0 iter: field(init=False) = None def __post_init__(self): self.num_variables = prod(map(len, self.variables.values())) def __iter__(self): return self def __next__(self) -> dict[str, Any]: if self.iter is None: self.iter = product(*self.variables.values()) try: return dict(zip(self.variables.keys(), next(self.iter))) except StopIteration: self.iter = None raise StopIteration @dataclass class MotionParameter: model: BaseMotion variables: field(default_factory=dict) num_variables: int = 0 iter: field(init=False) = None def __post_init__(self): self.num_variables = prod(map(len, self.variables.values())) def __iter__(self): return self def __next__(self) -> dict[str, Any]: if self.iter is None: self.iter = product(*self.variables.values()) try: return dict(zip(self.variables.keys(), next(self.iter))) except StopIteration: self.iter = None raise StopIteration @dataclass class Parameter: ste: StimEchoParameter | None spec: SpectrumParameter | None sim: SimParameter dist: DistParameter motion: MotionParameter molecule: MoleculeParameter