109 lines
2.6 KiB
Python
109 lines
2.6 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
from dataclasses import dataclass, field
|
||
|
from itertools import product
|
||
|
from typing import Any
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from src.rwsims.distributions import DeltaDistribution
|
||
|
from src.rwsims.motions import RandomJump
|
||
|
|
||
|
|
||
|
__all__ = ['SimParameter', 'MoleculeParameter', 'StimEchoParameter', 'SpectrumParameter', 'DistParameter', 'MotionParameter', 'Parameter']
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class SimParameter:
|
||
|
seed: int | None
|
||
|
num_walker: int
|
||
|
t_max: float
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class MoleculeParameter:
|
||
|
delta: float
|
||
|
eta: float
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class StimEchoParameter:
|
||
|
t_evo: np.ndarray
|
||
|
t_mix: np.ndarray
|
||
|
t_max: float = field(init=False)
|
||
|
|
||
|
def __post_init__(self):
|
||
|
self.t_max = np.max(self.t_mix) + 2 * np.max(self.t_evo)
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class SpectrumParameter:
|
||
|
dwell_time: float
|
||
|
num_points: int
|
||
|
t_echo: np.ndarray
|
||
|
t_acq: np.ndarray = field(init=False)
|
||
|
t_max: float = field(init=False)
|
||
|
lb: float
|
||
|
dampening: np.ndarray = field(init=False)
|
||
|
|
||
|
def __post_init__(self):
|
||
|
self.t_acq = np.arange(self.num_points) * self.dwell_time
|
||
|
self.dampening = np.exp(-self.lb * self.t_acq)
|
||
|
self.t_max = np.max(self.t_acq) + 2 * np.max(self.t_echo)
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class DistParameter:
|
||
|
dist_type: DeltaDistribution
|
||
|
variables: field(default_factory=dict)
|
||
|
num_variables: int = 0
|
||
|
iter: field(init=False) = None
|
||
|
|
||
|
def __post_init__(self):
|
||
|
self.num_variables = sum(map(len, self.variables.values()))
|
||
|
|
||
|
def __iter__(self):
|
||
|
return self
|
||
|
|
||
|
def __next__(self) -> dict[str, Any]:
|
||
|
if self.iter is None:
|
||
|
self.iter = product(*self.variables.values())
|
||
|
try:
|
||
|
return dict(zip(self.variables.keys(), next(self.iter)))
|
||
|
except StopIteration:
|
||
|
self.iter = None
|
||
|
raise StopIteration
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class MotionParameter:
|
||
|
model: RandomJump
|
||
|
variables: field(default_factory=dict)
|
||
|
num_variables: int = 0
|
||
|
iter: field(init=False) = None
|
||
|
|
||
|
def __post_init__(self):
|
||
|
self.num_variables = sum(map(len, self.variables.values()))
|
||
|
|
||
|
def __iter__(self):
|
||
|
return self
|
||
|
|
||
|
def __next__(self) -> dict[str, Any]:
|
||
|
if self.iter is None:
|
||
|
self.iter = product(*self.variables.values())
|
||
|
try:
|
||
|
return dict(zip(self.variables.keys(), next(self.iter)))
|
||
|
except StopIteration:
|
||
|
self.iter = None
|
||
|
raise StopIteration
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class Parameter:
|
||
|
ste: StimEchoParameter | None
|
||
|
spec: SpectrumParameter | None
|
||
|
sim: SimParameter
|
||
|
dist: DistParameter
|
||
|
motion: MotionParameter
|
||
|
molecule: MoleculeParameter
|