refactor odr
This commit is contained in:
parent
5a153585ee
commit
d2e63a5ee3
@ -49,6 +49,37 @@ def _cost_scipy(p, data, varpars, used_pars):
|
||||
return data.cost(actual_parameters)
|
||||
|
||||
|
||||
def _cost_odr(p: list[float], data: Data, varpars: list[str], used_pars: list[str], fitmode: int=0):
|
||||
for keys, values in zip(varpars, p):
|
||||
data.parameter[keys].scaled_value = values
|
||||
data.parameter[keys].namespace[keys] = data.parameter[keys].value
|
||||
|
||||
actual_parameters = [data.parameter[keys].value for keys in used_pars]
|
||||
|
||||
return data.func(actual_parameters, data.x)
|
||||
|
||||
|
||||
def _cost_odr_glob(p: list[float], data: list[Data], var_pars: list[str], used_pars: list[str]):
|
||||
# replace values
|
||||
for data_i in data:
|
||||
_update_parameter(data_i, var_pars, p)
|
||||
|
||||
r = []
|
||||
# unpack parameter and calculate y values and concatenate all
|
||||
for values, p_idx in zip(data, used_pars):
|
||||
actual_parameters = [values.parameter[keys].value for keys in p_idx]
|
||||
r = np.r_[r, values.func(actual_parameters, values.x)]
|
||||
|
||||
return r
|
||||
|
||||
|
||||
def _update_parameter(data: Data, varied_keys: list[str], parameter: list[float]):
|
||||
for keys, values in zip(varied_keys, parameter):
|
||||
if keys in data.parameter.keys():
|
||||
data.parameter[keys].scaled_value = values
|
||||
data.parameter[keys].namespace[keys] = data.parameter[keys].value
|
||||
|
||||
|
||||
class FitRoutine(object):
|
||||
def __init__(self, mode='lsq'):
|
||||
self._fitmethod = mode
|
||||
@ -189,7 +220,7 @@ class FitRoutine(object):
|
||||
logger.info('Fit aborted by user')
|
||||
self._abort = True
|
||||
|
||||
def run(self, mode='lsq'):
|
||||
def run(self, mode: str = 'lsq'):
|
||||
self._abort = False
|
||||
|
||||
fit_groups, linked_parameter = self.prepare_links()
|
||||
@ -246,7 +277,6 @@ class FitRoutine(object):
|
||||
return pp, lb, ub, var_pars
|
||||
|
||||
def _prep_global(self, data_group, linked):
|
||||
|
||||
p0 = []
|
||||
lb = []
|
||||
ub = []
|
||||
@ -264,16 +294,6 @@ class FitRoutine(object):
|
||||
p_k_used = p_k
|
||||
v_k_used = data.parameter[p_k]
|
||||
|
||||
# if i in data.model.parameter:
|
||||
# p_k_used = data.model.parameter[i]
|
||||
# v_k_used = self.parameter[p_k_used]
|
||||
# data.parameter.add_parameter(i, data.model.parameter[i])
|
||||
|
||||
# links trump global parameter
|
||||
# if p_k_used in linked:
|
||||
# p_k_used = linked[p_k_used]
|
||||
# v_k_used = self.parameter[p_k_used]
|
||||
|
||||
actual_pars.append(p_k_used)
|
||||
# parameter is variable and was not found before as shared parameter
|
||||
if v_k_used.var and p_k_used not in var:
|
||||
@ -292,27 +312,6 @@ class FitRoutine(object):
|
||||
|
||||
self._no_own_model = []
|
||||
|
||||
def __cost_odr(self, p: list[float], data: Data, varpars: list[str], used_pars: list[str]):
|
||||
for keys, values in zip(varpars, p):
|
||||
self.parameter[keys].scaled_value = values
|
||||
|
||||
actual_parameters = [self.parameter[keys].value for keys in used_pars]
|
||||
|
||||
return data.func(actual_parameters, data.x)
|
||||
|
||||
def __cost_odr_glob(self, p, data, varpars, used_pars):
|
||||
# replace values
|
||||
for keys, values in zip(varpars, p):
|
||||
self.parameter[keys].scaled_value = values
|
||||
|
||||
r = []
|
||||
# unpack parameter and calculate y values and concatenate all
|
||||
for values, p_idx in zip(data, used_pars):
|
||||
actual_parameters = [self.parameter[keys].value for keys in p_idx]
|
||||
r = np.r_[r, values.func(actual_parameters, values.x)]
|
||||
|
||||
return r
|
||||
|
||||
def _least_squares_single(self, data, p0, lb, ub, var):
|
||||
self.step = 0
|
||||
|
||||
@ -380,13 +379,18 @@ class FitRoutine(object):
|
||||
self.step += 1
|
||||
if self._abort:
|
||||
raise FitAbortException(f'Fit aborted by user')
|
||||
return self.__cost_odr(p, data, var_pars, data.para_keys)
|
||||
return _cost_odr(p, data, var_pars, data.para_keys)
|
||||
|
||||
odr_model = odr.Model(func)
|
||||
|
||||
corr, partial_corr, res = self._odr_fit(odr_data, odr_model, p0)
|
||||
|
||||
self.make_results(data, res.beta, var_pars, data.para_keys, (len(data), len(p0)),
|
||||
err=res.sd_beta, corr=corr, partial_corr=partial_corr)
|
||||
|
||||
def _odr_fit(self, odr_data, odr_model, p0):
|
||||
o = odr.ODR(odr_data, odr_model, beta0=p0)
|
||||
res = o.run()
|
||||
|
||||
corr = res.cov_beta / (res.sd_beta[:, None] * res.sd_beta[None, :]) * res.res_var
|
||||
try:
|
||||
corr_inv = np.linalg.inv(corr)
|
||||
@ -395,16 +399,14 @@ class FitRoutine(object):
|
||||
partial_corr[np.diag_indices_from(partial_corr)] = 1.
|
||||
except np.linalg.LinAlgError:
|
||||
partial_corr = corr
|
||||
|
||||
self.make_results(data, res.beta, var_pars, data.para_keys, (len(data), len(p0)),
|
||||
err=res.sd_beta, corr=corr, partial_corr=partial_corr)
|
||||
return corr, partial_corr, res
|
||||
|
||||
def _odr_global(self, data, p0, var, data_pars):
|
||||
def func(p, _):
|
||||
self.step += 1
|
||||
if self._abort:
|
||||
raise FitAbortException(f'Fit aborted by user')
|
||||
return self.__cost_odr_glob(p, data, var, data_pars)
|
||||
return _cost_odr_glob(p, data, var, data_pars)
|
||||
|
||||
x = []
|
||||
y = []
|
||||
@ -415,17 +417,7 @@ class FitRoutine(object):
|
||||
odr_data = odr.Data(x, y)
|
||||
odr_model = odr.Model(func)
|
||||
|
||||
o = odr.ODR(odr_data, odr_model, beta0=p0, ifixb=var)
|
||||
res = o.run()
|
||||
|
||||
corr = res.cov_beta / (res.sd_beta[:, None] * res.sd_beta[None, :]) * res.res_var
|
||||
try:
|
||||
corr_inv = np.linalg.inv(corr)
|
||||
corr_inv_diag = np.diag(np.sqrt(1 / np.diag(corr_inv)))
|
||||
partial_corr = -1. * np.dot(np.dot(corr_inv_diag, corr_inv), corr_inv_diag) # Partial correlation matrix
|
||||
partial_corr[np.diag_indices_from(partial_corr)] = 1.
|
||||
except np.linalg.LinAlgError:
|
||||
partial_corr = corr
|
||||
corr, partial_corr, res = self._odr_fit(odr_data, odr_model, p0)
|
||||
|
||||
for v, var_pars_k in zip(data, data_pars):
|
||||
self.make_results(v, res.beta, var, var_pars_k, (sum(len(d) for d in data), len(p0)),
|
||||
|
Loading…
Reference in New Issue
Block a user