python/rwsims/parameter.py

168 lines
4.0 KiB
Python
Raw Normal View History

2024-06-19 17:10:49 +00:00
from __future__ import annotations
from dataclasses import dataclass, field
from itertools import product
2024-06-20 17:19:55 +00:00
from math import prod
2024-06-19 17:10:49 +00:00
from typing import Any
import numpy as np
2024-08-01 16:46:28 +00:00
from .functions import pulse_attn
2024-06-20 17:19:55 +00:00
from .distributions import BaseDistribution
from .motions import BaseMotion
2024-06-19 17:10:49 +00:00
2024-08-03 17:04:13 +00:00
2024-06-20 17:19:55 +00:00
__all__ = [
'SimParameter',
'MoleculeParameter',
'StimEchoParameter',
'SpectrumParameter',
'DistParameter',
'MotionParameter',
'Parameter',
2024-08-03 17:04:13 +00:00
'make_filename'
2024-06-20 17:19:55 +00:00
]
2024-06-19 17:10:49 +00:00
@dataclass
class SimParameter:
seed: int | None
num_walker: int
t_max: float
2024-08-03 17:04:13 +00:00
def header(self) -> str:
return f'num_traj = {self.num_walker}\nseed = {self.seed}'
2024-06-30 10:06:44 +00:00
2024-06-19 17:10:49 +00:00
@dataclass
class MoleculeParameter:
delta: float
eta: float
@dataclass
class StimEchoParameter:
2024-08-01 16:46:28 +00:00
t_evo: 'ArrayLike'
t_mix: 'ArrayLike'
2024-06-20 17:19:55 +00:00
t_echo: float
2024-06-19 17:10:49 +00:00
t_max: float = field(init=False)
def __post_init__(self):
2024-06-20 17:19:55 +00:00
self.t_max = np.max(self.t_mix) + 2 * np.max(self.t_evo) + 2*self.t_echo
2024-06-19 17:10:49 +00:00
2024-08-03 17:04:13 +00:00
def header(self) -> str:
return (
f't_evo = {self.t_evo}\n'
f't_mix = {self.t_mix}\n'
f't_echo={self.t_echo}\n'
)
2024-06-19 17:10:49 +00:00
@dataclass
class SpectrumParameter:
dwell_time: float
num_points: int
2024-08-01 16:46:28 +00:00
t_echo: 'ArrayLike'
2024-06-19 17:10:49 +00:00
lb: float
2024-06-30 10:06:44 +00:00
t_pulse: float
2024-08-01 16:46:28 +00:00
t_acq: 'ArrayLike' = field(init=False)
freq: 'ArrayLike' = field(init=False)
2024-06-30 10:06:44 +00:00
t_max: float = field(init=False)
2024-08-01 16:46:28 +00:00
dampening: 'ArrayLike' = field(init=False)
pulse_attn: 'ArrayLike' = field(init=False)
2024-06-19 17:10:49 +00:00
def __post_init__(self):
self.t_acq = np.arange(self.num_points) * self.dwell_time
self.dampening = np.exp(-self.lb * self.t_acq)
self.t_max = np.max(self.t_acq) + 2 * np.max(self.t_echo)
2024-06-30 10:06:44 +00:00
self.freq = np.fft.fftshift(np.fft.fftfreq(self.num_points, self.dwell_time))
self.pulse_attn = pulse_attn(self.freq, self.t_pulse)
2024-08-03 17:04:13 +00:00
def header(self) -> str:
return (
f'dwell_time = {self.dwell_time}\n'
f'num_points = {self.num_points}\n'
f't_echo = {self.t_echo}\n'
f'lb = {self.lb}\n'
f't_pulse = {self.t_pulse}'
)
2024-06-19 17:10:49 +00:00
@dataclass
class DistParameter:
2024-08-03 17:04:13 +00:00
name: str
2024-06-20 17:19:55 +00:00
dist_type: BaseDistribution
2024-06-19 17:10:49 +00:00
variables: field(default_factory=dict)
num_variables: int = 0
iter: field(init=False) = None
def __post_init__(self):
2024-06-20 17:19:55 +00:00
self.num_variables = prod(map(len, self.variables.values()))
2024-06-19 17:10:49 +00:00
def __iter__(self):
return self
def __next__(self) -> dict[str, Any]:
if self.iter is None:
self.iter = product(*self.variables.values())
try:
return dict(zip(self.variables.keys(), next(self.iter)))
except StopIteration:
self.iter = None
raise StopIteration
@dataclass
class MotionParameter:
2024-08-03 17:04:13 +00:00
name: str
2024-06-20 17:19:55 +00:00
model: BaseMotion
2024-06-19 17:10:49 +00:00
variables: field(default_factory=dict)
num_variables: int = 0
iter: field(init=False) = None
def __post_init__(self):
2024-06-20 17:19:55 +00:00
self.num_variables = prod(map(len, self.variables.values()))
2024-06-19 17:10:49 +00:00
def __iter__(self):
return self
def __next__(self) -> dict[str, Any]:
if self.iter is None:
self.iter = product(*self.variables.values())
try:
return dict(zip(self.variables.keys(), next(self.iter)))
except StopIteration:
self.iter = None
raise StopIteration
@dataclass
class Parameter:
ste: StimEchoParameter | None
spec: SpectrumParameter | None
sim: SimParameter
dist: DistParameter
motion: MotionParameter
molecule: MoleculeParameter
2024-06-30 10:06:44 +00:00
2024-08-03 17:04:13 +00:00
def header(self, sim: bool = True, spec: bool = False, ste: bool = False) -> str:
2024-06-30 10:06:44 +00:00
text = []
if sim:
2024-08-03 17:04:13 +00:00
text.append(self.sim.header())
2024-06-30 10:06:44 +00:00
if spec:
2024-08-03 17:04:13 +00:00
text.append(self.spec.header())
if ste:
text.append(self.ste.header())
2024-06-30 10:06:44 +00:00
return '\n'.join(text)
2024-08-03 17:04:13 +00:00
def make_filename(dist: BaseDistribution, motion: BaseMotion) -> str:
filename = f'{dist}_{motion}'
filename = filename.replace(' ', '_')
filename = filename.replace('.', 'p')
return filename