168 lines
4.0 KiB
Python
168 lines
4.0 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from itertools import product
|
|
from math import prod
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
|
|
from .functions import pulse_attn
|
|
from .distributions import BaseDistribution
|
|
from .motions import BaseMotion
|
|
|
|
|
|
__all__ = [
|
|
'SimParameter',
|
|
'MoleculeParameter',
|
|
'StimEchoParameter',
|
|
'SpectrumParameter',
|
|
'DistParameter',
|
|
'MotionParameter',
|
|
'Parameter',
|
|
'make_filename'
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class SimParameter:
|
|
seed: int | None
|
|
num_walker: int
|
|
t_max: float
|
|
|
|
def header(self) -> str:
|
|
return f'num_traj = {self.num_walker}\nseed = {self.seed}'
|
|
|
|
|
|
@dataclass
|
|
class MoleculeParameter:
|
|
delta: float
|
|
eta: float
|
|
|
|
|
|
@dataclass
|
|
class StimEchoParameter:
|
|
t_evo: 'ArrayLike'
|
|
t_mix: 'ArrayLike'
|
|
t_echo: float
|
|
t_max: float = field(init=False)
|
|
|
|
def __post_init__(self):
|
|
self.t_max = np.max(self.t_mix) + 2 * np.max(self.t_evo) + 2*self.t_echo
|
|
|
|
def header(self) -> str:
|
|
return (
|
|
f't_evo = {self.t_evo}\n'
|
|
f't_mix = {self.t_mix}\n'
|
|
f't_echo={self.t_echo}\n'
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class SpectrumParameter:
|
|
dwell_time: float
|
|
num_points: int
|
|
t_echo: 'ArrayLike'
|
|
lb: float
|
|
t_pulse: float
|
|
t_acq: 'ArrayLike' = field(init=False)
|
|
freq: 'ArrayLike' = field(init=False)
|
|
t_max: float = field(init=False)
|
|
dampening: 'ArrayLike' = field(init=False)
|
|
pulse_attn: 'ArrayLike' = field(init=False)
|
|
|
|
def __post_init__(self):
|
|
self.t_acq = np.arange(self.num_points) * self.dwell_time
|
|
self.dampening = np.exp(-self.lb * self.t_acq)
|
|
self.t_max = np.max(self.t_acq) + 2 * np.max(self.t_echo)
|
|
self.freq = np.fft.fftshift(np.fft.fftfreq(self.num_points, self.dwell_time))
|
|
self.pulse_attn = pulse_attn(self.freq, self.t_pulse)
|
|
|
|
def header(self) -> str:
|
|
return (
|
|
f'dwell_time = {self.dwell_time}\n'
|
|
f'num_points = {self.num_points}\n'
|
|
f't_echo = {self.t_echo}\n'
|
|
f'lb = {self.lb}\n'
|
|
f't_pulse = {self.t_pulse}'
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class DistParameter:
|
|
name: str
|
|
dist_type: BaseDistribution
|
|
variables: field(default_factory=dict)
|
|
num_variables: int = 0
|
|
iter: field(init=False) = None
|
|
|
|
def __post_init__(self):
|
|
self.num_variables = prod(map(len, self.variables.values()))
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self) -> dict[str, Any]:
|
|
if self.iter is None:
|
|
self.iter = product(*self.variables.values())
|
|
try:
|
|
return dict(zip(self.variables.keys(), next(self.iter)))
|
|
except StopIteration:
|
|
self.iter = None
|
|
raise StopIteration
|
|
|
|
|
|
@dataclass
|
|
class MotionParameter:
|
|
name: str
|
|
model: BaseMotion
|
|
variables: field(default_factory=dict)
|
|
num_variables: int = 0
|
|
iter: field(init=False) = None
|
|
|
|
def __post_init__(self):
|
|
self.num_variables = prod(map(len, self.variables.values()))
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self) -> dict[str, Any]:
|
|
if self.iter is None:
|
|
self.iter = product(*self.variables.values())
|
|
try:
|
|
return dict(zip(self.variables.keys(), next(self.iter)))
|
|
except StopIteration:
|
|
self.iter = None
|
|
raise StopIteration
|
|
|
|
|
|
@dataclass
|
|
class Parameter:
|
|
ste: StimEchoParameter | None
|
|
spec: SpectrumParameter | None
|
|
sim: SimParameter
|
|
dist: DistParameter
|
|
motion: MotionParameter
|
|
molecule: MoleculeParameter
|
|
|
|
def header(self, sim: bool = True, spec: bool = False, ste: bool = False) -> str:
|
|
text = []
|
|
if sim:
|
|
text.append(self.sim.header())
|
|
|
|
if spec:
|
|
text.append(self.spec.header())
|
|
|
|
if ste:
|
|
text.append(self.ste.header())
|
|
|
|
return '\n'.join(text)
|
|
|
|
|
|
def make_filename(dist: BaseDistribution, motion: BaseMotion) -> str:
|
|
filename = f'{dist}_{motion}'
|
|
filename = filename.replace(' ', '_')
|
|
filename = filename.replace('.', 'p')
|
|
|
|
return filename
|