python/src/rwsims/parameter.py

109 lines
2.6 KiB
Python
Raw Normal View History

2024-06-19 17:10:49 +00:00
from __future__ import annotations
from dataclasses import dataclass, field
from itertools import product
from typing import Any
import numpy as np
from src.rwsims.distributions import DeltaDistribution
from src.rwsims.motions import RandomJump
__all__ = ['SimParameter', 'MoleculeParameter', 'StimEchoParameter', 'SpectrumParameter', 'DistParameter', 'MotionParameter', 'Parameter']
@dataclass
class SimParameter:
seed: int | None
num_walker: int
t_max: float
@dataclass
class MoleculeParameter:
delta: float
eta: float
@dataclass
class StimEchoParameter:
t_evo: np.ndarray
t_mix: np.ndarray
t_max: float = field(init=False)
def __post_init__(self):
self.t_max = np.max(self.t_mix) + 2 * np.max(self.t_evo)
@dataclass
class SpectrumParameter:
dwell_time: float
num_points: int
t_echo: np.ndarray
t_acq: np.ndarray = field(init=False)
t_max: float = field(init=False)
lb: float
dampening: np.ndarray = field(init=False)
def __post_init__(self):
self.t_acq = np.arange(self.num_points) * self.dwell_time
self.dampening = np.exp(-self.lb * self.t_acq)
self.t_max = np.max(self.t_acq) + 2 * np.max(self.t_echo)
@dataclass
class DistParameter:
dist_type: DeltaDistribution
variables: field(default_factory=dict)
num_variables: int = 0
iter: field(init=False) = None
def __post_init__(self):
self.num_variables = sum(map(len, self.variables.values()))
def __iter__(self):
return self
def __next__(self) -> dict[str, Any]:
if self.iter is None:
self.iter = product(*self.variables.values())
try:
return dict(zip(self.variables.keys(), next(self.iter)))
except StopIteration:
self.iter = None
raise StopIteration
@dataclass
class MotionParameter:
model: RandomJump
variables: field(default_factory=dict)
num_variables: int = 0
iter: field(init=False) = None
def __post_init__(self):
self.num_variables = sum(map(len, self.variables.values()))
def __iter__(self):
return self
def __next__(self) -> dict[str, Any]:
if self.iter is None:
self.iter = product(*self.variables.values())
try:
return dict(zip(self.variables.keys(), next(self.iter)))
except StopIteration:
self.iter = None
raise StopIteration
@dataclass
class Parameter:
ste: StimEchoParameter | None
spec: SpectrumParameter | None
sim: SimParameter
dist: DistParameter
motion: MotionParameter
molecule: MoleculeParameter