Compare commits
2 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
64f6697573 | ||
|
b20d7e61b2 |
103
src/nmreval/math/bootstrap.py
Normal file
103
src/nmreval/math/bootstrap.py
Normal file
@ -0,0 +1,103 @@
|
||||
import multiprocessing
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from numpy import arange
|
||||
from numpy.random import default_rng
|
||||
from scipy.optimize import least_squares
|
||||
|
||||
from nmreval.models.relaxation import TwoSatRecAbsolute
|
||||
from nmreval.utils.text import convert
|
||||
|
||||
|
||||
class Bootstrap:
|
||||
def __init__(self, func, x, y, p, bounds=None, n_sims=1000, seed=None):
|
||||
if hasattr(func, 'func'):
|
||||
self._func = func.func
|
||||
self.model = func
|
||||
else:
|
||||
self._func = func
|
||||
self.model = None
|
||||
self._x = x
|
||||
self._y = y
|
||||
self._bounds = bounds
|
||||
self.n_sims = n_sims
|
||||
self.idx = arange(len(self._x))
|
||||
self.num = len(self._x)
|
||||
self._p_start = p
|
||||
|
||||
self.rng = default_rng(seed=seed)
|
||||
|
||||
def resid(self, pp, xx, yy):
|
||||
return self._func(xx, *pp) - yy
|
||||
|
||||
def run(self):
|
||||
|
||||
manager = multiprocessing.Manager()
|
||||
shared_list = manager.list()
|
||||
|
||||
sims_to_do = self.n_sims
|
||||
while sims_to_do > 0:
|
||||
# print('next_round', sims_to_do)
|
||||
jobs = []
|
||||
for i in range(sims_to_do):
|
||||
# drawing inside fit gives same ind for all
|
||||
ind = self.rng.choice(self.idx, self.num, replace=True)
|
||||
p = multiprocessing.Process(target=self.fit, args=(ind, shared_list))
|
||||
jobs.append(p)
|
||||
p.start()
|
||||
|
||||
for p in jobs:
|
||||
p.join()
|
||||
|
||||
sims_to_do = self.n_sims - len(shared_list)
|
||||
|
||||
return self.create_results(list(shared_list))
|
||||
|
||||
def create_results(self, raw_results: list) -> dict:
|
||||
|
||||
if self.model is not None:
|
||||
keys = [convert(p, old='tex', new='str', brackets=False) for p in self.model.params] + ['chi2']
|
||||
else:
|
||||
keys = ['p'+str(i) for i in range(len(self._p_start))] + ['chi2']
|
||||
|
||||
dic = {k: np.empty(self.n_sims) for k in keys}
|
||||
|
||||
for i, p in enumerate(raw_results):
|
||||
for k, p_k in zip(keys, p):
|
||||
dic[k][i] = p_k
|
||||
|
||||
return dic
|
||||
|
||||
def fit(self, ind, ret_list):
|
||||
r = least_squares(self.resid, self._p_start, bounds=self._bounds, args=(self._x[ind], self._y[ind]))
|
||||
if not r.success: # r.status == 0:
|
||||
print('failure', r.status)
|
||||
return
|
||||
|
||||
res = r.x.tolist()
|
||||
res.append(np.sum(r.fun**2))
|
||||
|
||||
ret_list.append(res)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
x = np.logspace(-4, 2, num=31)
|
||||
|
||||
p = [1000, 0.03, 1, 100, 0.9, 0.5, 0]
|
||||
bounds = ([0] * 6 + [-np.inf], [np.inf, np.inf, 1, np.inf, 20, 1, np.inf])
|
||||
# bounds = (-np.inf, np.inf)
|
||||
|
||||
mag = TwoSatRecAbsolute.func
|
||||
|
||||
y = mag(x, *p) + 10 * (2 * np.random.randn(len(x)) - 1)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
plt.semilogx(x, y)
|
||||
plt.show()
|
||||
|
||||
|
||||
bootstrap3 = Bootstrap(TwoSatRecAbsolute, x, y, p, bounds=bounds, n_sims=10)
|
||||
print(bootstrap3.run())
|
||||
|
Loading…
x
Reference in New Issue
Block a user