python/rwsims/ste.py
2024-08-03 19:04:13 +02:00

127 lines
4.0 KiB
Python

from __future__ import annotations
import numpy as np
from matplotlib import pyplot as plt
from scipy.optimize import curve_fit
from .distributions import BaseDistribution
from .functions import ste
from .motions import BaseMotion
from .parameter import Parameter, make_filename
def save_ste_data(
cc: np.ndarray,
ss: np.ndarray,
param: Parameter,
dist: BaseDistribution,
motion: BaseMotion,
) -> None:
filename = make_filename(dist, motion)
header = param.header(sim=True, ste=True)
header += '\n' + dist.header()
header += '\n' + motion.header()
t_evo_string = list(map(lambda x: f'{x:.3e}', param.ste.t_evo))
header += '\nx\t' + '\t'.join(t_evo_string)
for ste_data, ste_label in ((cc, 'cc'), (ss, 'ss')):
np.savetxt(filename + f'_{ste_label}.dat', np.c_[param.ste.t_mix, ste_data], header=header)
fig, ax = plt.subplots()
lines = ax.semilogx(param.ste.t_mix, ste_data/ste_data[0, :])
ax.set_title(f'{dist}, {motion}')
ax.set_xlabel('t_mix / s')
ax.set_ylabel(f'F_{ste_label}(t) / F_{ste_label}(0)')
ax.legend(lines, t_evo_string)
plt.savefig(filename + f'_{ste_label}.png')
def fit_ste(
cc: np.ndarray,
ss: np.ndarray,
t_evo: np.ndarray,
t_mix: np.ndarray,
dist_values: dict,
num_variables: int
) -> tuple[np.ndarray, np.ndarray]:
for k in range(num_variables):
p_cc = []
p_ss = []
# fit ste decay for every evolution time
for k, t_evo_k in enumerate(t_evo):
for ste_data, ste_fits in ((cc, p_cc), (ss, p_ss)):
# [amplitude, f_infty, tau, beta]
p0 = [ste_data[0, k], 0.1, dist_values.get('tau', 1), 1]
try:
res = curve_fit(ste, t_mix, ste_data[:, k], p0=p0, bounds=([0, 0, 0, 0], [np.inf, 1, np.inf, 1]))
ste_fits.append([t_evo_k] + res[0].tolist())
except RuntimeError:
ste_fits.append([t_evo_k, np.nan, np.nan, np.nan, np.nan])
p_cc = np.array(p_cc)
p_ss = np.array(p_ss)
return p_cc, p_ss
def save_ste_fit(cc: np.ndarray, ss: np.ndarray, param: Parameter, dist: BaseDistribution, motion: BaseMotion):
filename = make_filename(dist, motion)
header = param.header(sim=True, ste=True)
header += '\n' + dist.header()
header += '\n' + motion.header()
header += '\nt_echo\tamp\tf_infty\ttau\tbeta'
np.savetxt(filename + '_cc_fit.dat', cc, header=header)
np.savetxt(filename + '_ss_fit.dat', ss, header=header)
def plot_ste_fits(fits_cc, fits_ss, dist, motion):
fits_cc = np.array(fits_cc)
fits_ss = np.array(fits_ss)
fig, ax = plt.subplots(2)
fig2, ax2 = plt.subplots(2)
fig3, ax3 = plt.subplots(2)
fig4, ax4 = plt.subplots(2)
num_motion = motion.num_variables
filename = f'{dist.name}_{motion.name}'
for (i, dist_values) in enumerate(dist):
for (j, motion_values) in enumerate(motion):
row = i*num_motion + j
label = ([f'{key}={val}' for key, val in dist_values.items()] +
[f'{key}={val}' for key, val in motion_values.items()])
for k, ax_k in enumerate((ax, ax2, ax3, ax4)):
ax_k[0].plot(fits_cc[row, :, 0], fits_cc[row, :, k+1], 'o--', label=', '.join(label))
ax_k[1].plot(fits_ss[row, :, 0], fits_ss[row, :, k+1], 'o--')
ax[0].legend()
ax[0].set_title('Amplitude (top: CC, bottom: SS)')
ax[0].set_yscale('log')
ax[1].set_yscale('log')
plt.savefig(filename + '_amp.png')
ax2[0].legend()
ax2[0].set_title('F_infty (top: CC, bottom: SS)')
ax2[0].set_yscale('log')
ax2[1].set_yscale('log')
plt.savefig(filename + '_finfty.png')
ax3[0].legend()
ax3[0].set_title('tau (top: CC, bottom: SS)')
ax3[0].set_yscale('log')
ax3[1].set_yscale('log')
plt.savefig(filename + '_tau.png')
ax4[0].legend()
ax4[0].set_title('beta (top: CC, bottom: SS)')
plt.savefig(filename + '_beta.png')