diff --git a/src/nmreval/math/bootstrap.py b/src/nmreval/math/bootstrap.py index de1eddf..7c3bef7 100644 --- a/src/nmreval/math/bootstrap.py +++ b/src/nmreval/math/bootstrap.py @@ -1,4 +1,5 @@ import multiprocessing +from typing import Callable import numpy as np @@ -6,10 +7,18 @@ from numpy import arange from numpy.random import default_rng from scipy.optimize import least_squares +from nmreval.models.relaxation import TwoSatRecAbsolute +from nmreval.utils.text import convert + class Bootstrap: def __init__(self, func, x, y, p, bounds=None, n_sims=1000, seed=None): - self._func = func + if hasattr(func, 'func'): + self._func = func.func + self.model = func + else: + self._func = func + self.model = None self._x = x self._y = y self._bounds = bounds @@ -18,15 +27,15 @@ class Bootstrap: self.num = len(self._x) self._p_start = p - self.manager = multiprocessing.Manager() - self.rng = default_rng(seed=seed) def resid(self, pp, xx, yy): return self._func(xx, *pp) - yy def run(self): - shared_list = self.manager.list() + + manager = multiprocessing.Manager() + shared_list = manager.list() sims_to_do = self.n_sims while sims_to_do > 0: @@ -44,13 +53,22 @@ class Bootstrap: sims_to_do = self.n_sims - len(shared_list) - parameter = np.empty((self.n_sims, len(self._p_start))) - chi = np.empty(self.n_sims) - for i, (p, c) in enumerate(shared_list): - parameter[i] = p - chi[i] = c + return self.create_results(list(shared_list)) - return parameter, chi + def create_results(self, raw_results: list) -> dict: + + if self.model is not None: + keys = [convert(p, old='tex', new='str', brackets=False) for p in self.model.params] + ['chi2'] + else: + keys = ['p'+str(i) for i in range(len(self._p_start))] + ['chi2'] + + dic = {k: np.empty(self.n_sims) for k in keys} + + for i, p in enumerate(raw_results): + for k, p_k in zip(keys, p): + dic[k][i] = p_k + + return dic def fit(self, ind, ret_list): r = least_squares(self.resid, self._p_start, bounds=self._bounds, args=(self._x[ind], self._y[ind])) @@ -58,15 +76,10 @@ class Bootstrap: print('failure', r.status) return - res = [] - res.extend(r.x.tolist()) - - ret_list.append((res, sum(r.fun**2))) - - -def mag(xx, *p): - return p[0]*(1-np.exp(-(xx/p[1])**p[2])) + p[3]*(1-np.exp(-(xx/p[4])**p[5])) + p[6] + res = r.x.tolist() + res.append(np.sum(r.fun**2)) + ret_list.append(res) if __name__ == '__main__': @@ -76,13 +89,15 @@ if __name__ == '__main__': bounds = ([0] * 6 + [-np.inf], [np.inf, np.inf, 1, np.inf, 20, 1, np.inf]) # bounds = (-np.inf, np.inf) + mag = TwoSatRecAbsolute.func + y = mag(x, *p) + 10 * (2 * np.random.randn(len(x)) - 1) import matplotlib.pyplot as plt plt.semilogx(x, y) plt.show() - bootstrap3 = Bootstrap(mag, x, y, p, bounds=bounds, n_sims=10) - from pprint import pprint - pprint(bootstrap3.run()) + + bootstrap3 = Bootstrap(TwoSatRecAbsolute, x, y, p, bounds=bounds, n_sims=10) + print(bootstrap3.run())