forked from IPKM/nmreval
		
	Compare commits
	
		
			2 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 64f6697573 | ||
|  | b20d7e61b2 | 
							
								
								
									
										103
									
								
								src/nmreval/math/bootstrap.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										103
									
								
								src/nmreval/math/bootstrap.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,103 @@ | ||||
| import multiprocessing | ||||
| from typing import Callable | ||||
|  | ||||
| import numpy as np | ||||
|  | ||||
| from numpy import arange | ||||
| from numpy.random import default_rng | ||||
| from scipy.optimize import least_squares | ||||
|  | ||||
| from nmreval.models.relaxation import TwoSatRecAbsolute | ||||
| from nmreval.utils.text import convert | ||||
|  | ||||
|  | ||||
| class Bootstrap: | ||||
|     def __init__(self, func, x, y, p, bounds=None, n_sims=1000, seed=None): | ||||
|         if hasattr(func, 'func'): | ||||
|             self._func = func.func | ||||
|             self.model = func | ||||
|         else: | ||||
|             self._func = func | ||||
|             self.model = None | ||||
|         self._x = x | ||||
|         self._y = y | ||||
|         self._bounds = bounds | ||||
|         self.n_sims = n_sims | ||||
|         self.idx = arange(len(self._x)) | ||||
|         self.num = len(self._x) | ||||
|         self._p_start = p | ||||
|  | ||||
|         self.rng = default_rng(seed=seed) | ||||
|  | ||||
|     def resid(self, pp, xx, yy): | ||||
|         return self._func(xx, *pp) - yy | ||||
|  | ||||
|     def run(self): | ||||
|  | ||||
|         manager = multiprocessing.Manager() | ||||
|         shared_list = manager.list() | ||||
|  | ||||
|         sims_to_do = self.n_sims | ||||
|         while sims_to_do > 0: | ||||
|             # print('next_round', sims_to_do) | ||||
|             jobs = [] | ||||
|             for i in range(sims_to_do): | ||||
|                 # drawing inside fit gives same ind for all | ||||
|                 ind = self.rng.choice(self.idx, self.num, replace=True) | ||||
|                 p = multiprocessing.Process(target=self.fit, args=(ind, shared_list)) | ||||
|                 jobs.append(p) | ||||
|                 p.start() | ||||
|  | ||||
|             for p in jobs: | ||||
|                 p.join() | ||||
|  | ||||
|             sims_to_do = self.n_sims - len(shared_list) | ||||
|  | ||||
|         return self.create_results(list(shared_list)) | ||||
|  | ||||
|     def create_results(self, raw_results: list) -> dict: | ||||
|  | ||||
|         if self.model is not None: | ||||
|             keys = [convert(p, old='tex', new='str', brackets=False) for p in self.model.params] + ['chi2'] | ||||
|         else: | ||||
|             keys = ['p'+str(i) for i in range(len(self._p_start))] + ['chi2'] | ||||
|  | ||||
|         dic = {k: np.empty(self.n_sims) for k in keys} | ||||
|  | ||||
|         for i, p in enumerate(raw_results): | ||||
|             for k, p_k in zip(keys, p): | ||||
|                 dic[k][i] = p_k | ||||
|  | ||||
|         return dic | ||||
|  | ||||
|     def fit(self, ind, ret_list): | ||||
|         r = least_squares(self.resid, self._p_start, bounds=self._bounds, args=(self._x[ind], self._y[ind])) | ||||
|         if not r.success:  # r.status == 0: | ||||
|             print('failure', r.status) | ||||
|             return | ||||
|  | ||||
|         res = r.x.tolist() | ||||
|         res.append(np.sum(r.fun**2)) | ||||
|  | ||||
|         ret_list.append(res) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     x = np.logspace(-4, 2, num=31) | ||||
|  | ||||
|     p = [1000, 0.03, 1, 100, 0.9, 0.5, 0] | ||||
|     bounds = ([0] * 6 + [-np.inf], [np.inf, np.inf, 1, np.inf, 20, 1, np.inf]) | ||||
|     # bounds = (-np.inf, np.inf) | ||||
|  | ||||
|     mag = TwoSatRecAbsolute.func | ||||
|  | ||||
|     y = mag(x, *p) + 10 * (2 * np.random.randn(len(x)) - 1) | ||||
|  | ||||
|     import matplotlib.pyplot as plt | ||||
|     plt.semilogx(x, y) | ||||
|     plt.show() | ||||
|  | ||||
|  | ||||
|     bootstrap3 = Bootstrap(TwoSatRecAbsolute, x, y, p, bounds=bounds, n_sims=10) | ||||
|     print(bootstrap3.run()) | ||||
|  | ||||
		Reference in New Issue
	
	Block a user