from __future__ import annotations import numpy as np from matplotlib import pyplot as plt from scipy.optimize import curve_fit from .distributions import BaseDistribution from .functions import ste from .motions import BaseMotion from .parameter import Parameter, make_filename def save_ste_data( cc: np.ndarray, ss: np.ndarray, param: Parameter, dist: BaseDistribution, motion: BaseMotion, ) -> None: filename = make_filename(dist, motion) header = param.header(sim=True, ste=True) header += '\n' + dist.header() header += '\n' + motion.header() t_evo_string = list(map(lambda x: f'{x:.3e}', param.ste.t_evo)) header += '\nx\t' + '\t'.join(t_evo_string) for ste_data, ste_label in ((cc, 'cc'), (ss, 'ss')): np.savetxt(filename + f'_{ste_label}.dat', np.c_[param.ste.t_mix, ste_data], header=header) fig, ax = plt.subplots() lines = ax.semilogx(param.ste.t_mix, ste_data/ste_data[0, :]) ax.set_title(f'{dist}, {motion}') ax.set_xlabel('t_mix / s') ax.set_ylabel(f'F_{ste_label}(t) / F_{ste_label}(0)') ax.legend(lines, t_evo_string) plt.savefig(filename + f'_{ste_label}.png') def fit_ste( cc: np.ndarray, ss: np.ndarray, t_evo: np.ndarray, t_mix: np.ndarray, dist_values: dict, num_variables: int ) -> tuple[np.ndarray, np.ndarray]: for k in range(num_variables): p_cc = [] p_ss = [] # fit ste decay for every evolution time for k, t_evo_k in enumerate(t_evo): for ste_data, ste_fits in ((cc, p_cc), (ss, p_ss)): # [amplitude, f_infty, tau, beta] p0 = [ste_data[0, k], 0.1, dist_values.get('tau', 1), 1] try: res = curve_fit(ste, t_mix, ste_data[:, k], p0=p0, bounds=([0, 0, 0, 0], [np.inf, 1, np.inf, 1])) ste_fits.append([t_evo_k] + res[0].tolist()) except RuntimeError: ste_fits.append([t_evo_k, np.nan, np.nan, np.nan, np.nan]) p_cc = np.array(p_cc) p_ss = np.array(p_ss) return p_cc, p_ss def save_ste_fit(cc: np.ndarray, ss: np.ndarray, param: Parameter, dist: BaseDistribution, motion: BaseMotion): filename = make_filename(dist, motion) header = param.header(sim=True, ste=True) header += '\n' + dist.header() header += '\n' + motion.header() header += '\nt_echo\tamp\tf_infty\ttau\tbeta' np.savetxt(filename + '_cc_fit.dat', cc, header=header) np.savetxt(filename + '_ss_fit.dat', ss, header=header) def plot_ste_fits(fits_cc, fits_ss, dist, motion): fits_cc = np.array(fits_cc) fits_ss = np.array(fits_ss) fig, ax = plt.subplots(2) fig2, ax2 = plt.subplots(2) fig3, ax3 = plt.subplots(2) fig4, ax4 = plt.subplots(2) num_motion = motion.num_variables filename = f'{dist.name}_{motion.name}' for (i, dist_values) in enumerate(dist): for (j, motion_values) in enumerate(motion): row = i*num_motion + j label = ([f'{key}={val}' for key, val in dist_values.items()] + [f'{key}={val}' for key, val in motion_values.items()]) for k, ax_k in enumerate((ax, ax2, ax3, ax4)): ax_k[0].plot(fits_cc[row, :, 0], fits_cc[row, :, k+1], 'o--', label=', '.join(label)) ax_k[1].plot(fits_ss[row, :, 0], fits_ss[row, :, k+1], 'o--') ax[0].legend() ax[0].set_title('Amplitude (top: CC, bottom: SS)') ax[0].set_yscale('log') ax[1].set_yscale('log') plt.savefig(filename + '_amp.png') ax2[0].legend() ax2[0].set_title('F_infty (top: CC, bottom: SS)') ax2[0].set_yscale('log') ax2[1].set_yscale('log') plt.savefig(filename + '_finfty.png') ax3[0].legend() ax3[0].set_title('tau (top: CC, bottom: SS)') ax3[0].set_yscale('log') ax3[1].set_yscale('log') plt.savefig(filename + '_tau.png') ax4[0].legend() ax4[0].set_title('beta (top: CC, bottom: SS)') plt.savefig(filename + '_beta.png')