first test

This commit is contained in:
Dominik Demuth 2023-02-05 18:05:14 +01:00
parent df8a5e5699
commit b20d7e61b2

View File

@ -0,0 +1,88 @@
import multiprocessing
import numpy as np
from numpy import arange
from numpy.random import default_rng
from scipy.optimize import least_squares
class Bootstrap:
def __init__(self, func, x, y, p, bounds=None, n_sims=1000, seed=None):
self._func = func
self._x = x
self._y = y
self._bounds = bounds
self.n_sims = n_sims
self.idx = arange(len(self._x))
self.num = len(self._x)
self._p_start = p
self.manager = multiprocessing.Manager()
self.rng = default_rng(seed=seed)
def resid(self, pp, xx, yy):
return self._func(xx, *pp) - yy
def run(self):
shared_list = self.manager.list()
sims_to_do = self.n_sims
while sims_to_do > 0:
# print('next_round', sims_to_do)
jobs = []
for i in range(sims_to_do):
# drawing inside fit gives same ind for all
ind = self.rng.choice(self.idx, self.num, replace=True)
p = multiprocessing.Process(target=self.fit, args=(ind, shared_list))
jobs.append(p)
p.start()
for p in jobs:
p.join()
sims_to_do = self.n_sims - len(shared_list)
parameter = np.empty((self.n_sims, len(self._p_start)))
chi = np.empty(self.n_sims)
for i, (p, c) in enumerate(shared_list):
parameter[i] = p
chi[i] = c
return parameter, chi
def fit(self, ind, ret_list):
r = least_squares(self.resid, self._p_start, bounds=self._bounds, args=(self._x[ind], self._y[ind]))
if not r.success: # r.status == 0:
print('failure', r.status)
return
res = []
res.extend(r.x.tolist())
ret_list.append((res, sum(r.fun**2)))
def mag(xx, *p):
return p[0]*(1-np.exp(-(xx/p[1])**p[2])) + p[3]*(1-np.exp(-(xx/p[4])**p[5])) + p[6]
if __name__ == '__main__':
x = np.logspace(-4, 2, num=31)
p = [1000, 0.03, 1, 100, 0.9, 0.5, 0]
bounds = ([0] * 6 + [-np.inf], [np.inf, np.inf, 1, np.inf, 20, 1, np.inf])
# bounds = (-np.inf, np.inf)
y = mag(x, *p) + 10 * (2 * np.random.randn(len(x)) - 1)
import matplotlib.pyplot as plt
plt.semilogx(x, y)
plt.show()
bootstrap3 = Bootstrap(mag, x, y, p, bounds=bounds, n_sims=10)
from pprint import pprint
pprint(bootstrap3.run())