2024-06-19 17:10:49 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
from itertools import product
|
2024-06-20 17:19:55 +00:00
|
|
|
from math import prod
|
2024-06-19 17:10:49 +00:00
|
|
|
from typing import Any
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
2024-06-20 17:19:55 +00:00
|
|
|
from .distributions import BaseDistribution
|
|
|
|
from .motions import BaseMotion
|
2024-06-19 17:10:49 +00:00
|
|
|
|
2024-06-20 17:19:55 +00:00
|
|
|
__all__ = [
|
|
|
|
'SimParameter',
|
|
|
|
'MoleculeParameter',
|
|
|
|
'StimEchoParameter',
|
|
|
|
'SpectrumParameter',
|
|
|
|
'DistParameter',
|
|
|
|
'MotionParameter',
|
|
|
|
'Parameter',
|
|
|
|
]
|
2024-06-19 17:10:49 +00:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class SimParameter:
|
|
|
|
seed: int | None
|
|
|
|
num_walker: int
|
|
|
|
t_max: float
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class MoleculeParameter:
|
|
|
|
delta: float
|
|
|
|
eta: float
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class StimEchoParameter:
|
|
|
|
t_evo: np.ndarray
|
|
|
|
t_mix: np.ndarray
|
2024-06-20 17:19:55 +00:00
|
|
|
t_echo: float
|
2024-06-19 17:10:49 +00:00
|
|
|
t_max: float = field(init=False)
|
|
|
|
|
|
|
|
def __post_init__(self):
|
2024-06-20 17:19:55 +00:00
|
|
|
self.t_max = np.max(self.t_mix) + 2 * np.max(self.t_evo) + 2*self.t_echo
|
2024-06-19 17:10:49 +00:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class SpectrumParameter:
|
|
|
|
dwell_time: float
|
|
|
|
num_points: int
|
|
|
|
t_echo: np.ndarray
|
|
|
|
t_acq: np.ndarray = field(init=False)
|
|
|
|
t_max: float = field(init=False)
|
|
|
|
lb: float
|
|
|
|
dampening: np.ndarray = field(init=False)
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
self.t_acq = np.arange(self.num_points) * self.dwell_time
|
|
|
|
self.dampening = np.exp(-self.lb * self.t_acq)
|
|
|
|
self.t_max = np.max(self.t_acq) + 2 * np.max(self.t_echo)
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class DistParameter:
|
2024-06-20 17:19:55 +00:00
|
|
|
dist_type: BaseDistribution
|
2024-06-19 17:10:49 +00:00
|
|
|
variables: field(default_factory=dict)
|
|
|
|
num_variables: int = 0
|
|
|
|
iter: field(init=False) = None
|
|
|
|
|
|
|
|
def __post_init__(self):
|
2024-06-20 17:19:55 +00:00
|
|
|
self.num_variables = prod(map(len, self.variables.values()))
|
2024-06-19 17:10:49 +00:00
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
return self
|
|
|
|
|
|
|
|
def __next__(self) -> dict[str, Any]:
|
|
|
|
if self.iter is None:
|
|
|
|
self.iter = product(*self.variables.values())
|
|
|
|
try:
|
|
|
|
return dict(zip(self.variables.keys(), next(self.iter)))
|
|
|
|
except StopIteration:
|
|
|
|
self.iter = None
|
|
|
|
raise StopIteration
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class MotionParameter:
|
2024-06-20 17:19:55 +00:00
|
|
|
model: BaseMotion
|
2024-06-19 17:10:49 +00:00
|
|
|
variables: field(default_factory=dict)
|
|
|
|
num_variables: int = 0
|
|
|
|
iter: field(init=False) = None
|
|
|
|
|
|
|
|
def __post_init__(self):
|
2024-06-20 17:19:55 +00:00
|
|
|
self.num_variables = prod(map(len, self.variables.values()))
|
2024-06-19 17:10:49 +00:00
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
return self
|
|
|
|
|
|
|
|
def __next__(self) -> dict[str, Any]:
|
|
|
|
if self.iter is None:
|
|
|
|
self.iter = product(*self.variables.values())
|
|
|
|
try:
|
|
|
|
return dict(zip(self.variables.keys(), next(self.iter)))
|
|
|
|
except StopIteration:
|
|
|
|
self.iter = None
|
|
|
|
raise StopIteration
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class Parameter:
|
|
|
|
ste: StimEchoParameter | None
|
|
|
|
spec: SpectrumParameter | None
|
|
|
|
sim: SimParameter
|
|
|
|
dist: DistParameter
|
|
|
|
motion: MotionParameter
|
|
|
|
molecule: MoleculeParameter
|