from __future__ import annotations from dataclasses import dataclass, field from itertools import product from math import prod from typing import Any import numpy as np from numpy._typing import ArrayLike from functions import pulse_attn from .distributions import BaseDistribution from .motions import BaseMotion __all__ = [ 'SimParameter', 'MoleculeParameter', 'StimEchoParameter', 'SpectrumParameter', 'DistParameter', 'MotionParameter', 'Parameter', ] @dataclass class SimParameter: seed: int | None num_walker: int t_max: float def totext(self) -> str: return f'num_traj={self.num_walker}\nseed={self.seed}' @dataclass class MoleculeParameter: delta: float eta: float @dataclass class StimEchoParameter: t_evo: ArrayLike t_mix: ArrayLike t_echo: float t_max: float = field(init=False) def __post_init__(self): self.t_max = np.max(self.t_mix) + 2 * np.max(self.t_evo) + 2*self.t_echo @dataclass class SpectrumParameter: dwell_time: float num_points: int t_echo: ArrayLike lb: float t_pulse: float t_acq: ArrayLike = field(init=False) freq: ArrayLike = field(init=False) t_max: float = field(init=False) dampening: ArrayLike = field(init=False) pulse_attn: ArrayLike = field(init=False) def __post_init__(self): self.t_acq = np.arange(self.num_points) * self.dwell_time self.dampening = np.exp(-self.lb * self.t_acq) self.t_max = np.max(self.t_acq) + 2 * np.max(self.t_echo) self.freq = np.fft.fftshift(np.fft.fftfreq(self.num_points, self.dwell_time)) self.pulse_attn = pulse_attn(self.freq, self.t_pulse) def totext(self) -> str: return (f'dwell_time{self.dwell_time}\n' f'num_points={self.num_points}\n' f't_echo={self.t_echo}\n' f'lb={self.lb}\n' f't_pulse={self.t_pulse}') @dataclass class DistParameter: dist_type: BaseDistribution variables: field(default_factory=dict) num_variables: int = 0 iter: field(init=False) = None def __post_init__(self): self.num_variables = prod(map(len, self.variables.values())) def __iter__(self): return self def __next__(self) -> dict[str, Any]: if self.iter is None: self.iter = product(*self.variables.values()) try: return dict(zip(self.variables.keys(), next(self.iter))) except StopIteration: self.iter = None raise StopIteration @dataclass class MotionParameter: model: BaseMotion variables: field(default_factory=dict) num_variables: int = 0 iter: field(init=False) = None def __post_init__(self): self.num_variables = prod(map(len, self.variables.values())) def __iter__(self): return self def __next__(self) -> dict[str, Any]: if self.iter is None: self.iter = product(*self.variables.values()) try: return dict(zip(self.variables.keys(), next(self.iter))) except StopIteration: self.iter = None raise StopIteration @dataclass class Parameter: ste: StimEchoParameter | None spec: SpectrumParameter | None sim: SimParameter dist: DistParameter motion: MotionParameter molecule: MoleculeParameter def totext(self, sim: bool = True, spec: bool = True) -> str: text = [] if sim: text.append(self.sim.totext()) if spec: text.append(self.spec.totext()) return '\n'.join(text)