from __future__ import annotations from dataclasses import dataclass, field from itertools import product from math import prod from typing import Any import numpy as np from .functions import pulse_attn from .distributions import BaseDistribution from .motions import BaseMotion __all__ = [ 'SimParameter', 'MoleculeParameter', 'StimEchoParameter', 'SpectrumParameter', 'DistParameter', 'MotionParameter', 'Parameter', 'make_filename' ] @dataclass class SimParameter: seed: int | None num_walker: int t_max: float def header(self) -> str: return f'num_traj = {self.num_walker}\nseed = {self.seed}' @dataclass class MoleculeParameter: delta: float eta: float @dataclass class StimEchoParameter: t_evo: 'ArrayLike' t_mix: 'ArrayLike' t_echo: float t_max: float = field(init=False) def __post_init__(self): self.t_max = np.max(self.t_mix) + 2 * np.max(self.t_evo) + 2*self.t_echo def header(self) -> str: return ( f't_evo = {self.t_evo}\n' f't_mix = {self.t_mix}\n' f't_echo={self.t_echo}\n' ) @dataclass class SpectrumParameter: dwell_time: float num_points: int t_echo: 'ArrayLike' lb: float t_pulse: float t_acq: 'ArrayLike' = field(init=False) freq: 'ArrayLike' = field(init=False) t_max: float = field(init=False) dampening: 'ArrayLike' = field(init=False) pulse_attn: 'ArrayLike' = field(init=False) def __post_init__(self): self.t_acq = np.arange(self.num_points) * self.dwell_time self.dampening = np.exp(-self.lb * self.t_acq) self.t_max = np.max(self.t_acq) + 2 * np.max(self.t_echo) self.freq = np.fft.fftshift(np.fft.fftfreq(self.num_points, self.dwell_time)) self.pulse_attn = pulse_attn(self.freq, self.t_pulse) def header(self) -> str: return ( f'dwell_time = {self.dwell_time}\n' f'num_points = {self.num_points}\n' f't_echo = {self.t_echo}\n' f'lb = {self.lb}\n' f't_pulse = {self.t_pulse}' ) @dataclass class DistParameter: name: str dist_type: BaseDistribution variables: field(default_factory=dict) num_variables: int = 0 iter: field(init=False) = None def __post_init__(self): self.num_variables = prod(map(len, self.variables.values())) def __iter__(self): return self def __next__(self) -> dict[str, Any]: if self.iter is None: self.iter = product(*self.variables.values()) try: return dict(zip(self.variables.keys(), next(self.iter))) except StopIteration: self.iter = None raise StopIteration @dataclass class MotionParameter: name: str model: BaseMotion variables: field(default_factory=dict) num_variables: int = 0 iter: field(init=False) = None def __post_init__(self): self.num_variables = prod(map(len, self.variables.values())) def __iter__(self): return self def __next__(self) -> dict[str, Any]: if self.iter is None: self.iter = product(*self.variables.values()) try: return dict(zip(self.variables.keys(), next(self.iter))) except StopIteration: self.iter = None raise StopIteration @dataclass class Parameter: ste: StimEchoParameter | None spec: SpectrumParameter | None sim: SimParameter dist: DistParameter motion: MotionParameter molecule: MoleculeParameter def header(self, sim: bool = True, spec: bool = False, ste: bool = False) -> str: text = [] if sim: text.append(self.sim.header()) if spec: text.append(self.spec.header()) if ste: text.append(self.ste.header()) return '\n'.join(text) def make_filename(dist: BaseDistribution, motion: BaseMotion) -> str: filename = f'{dist}_{motion}' filename = filename.replace(' ', '_') filename = filename.replace('.', 'p') return filename