refactor odr

This commit is contained in:
Dominik Demuth 2023-08-29 19:44:09 +02:00
parent 5a153585ee
commit d2e63a5ee3

View File

@ -49,6 +49,37 @@ def _cost_scipy(p, data, varpars, used_pars):
return data.cost(actual_parameters)
def _cost_odr(p: list[float], data: Data, varpars: list[str], used_pars: list[str], fitmode: int=0):
for keys, values in zip(varpars, p):
data.parameter[keys].scaled_value = values
data.parameter[keys].namespace[keys] = data.parameter[keys].value
actual_parameters = [data.parameter[keys].value for keys in used_pars]
return data.func(actual_parameters, data.x)
def _cost_odr_glob(p: list[float], data: list[Data], var_pars: list[str], used_pars: list[str]):
# replace values
for data_i in data:
_update_parameter(data_i, var_pars, p)
r = []
# unpack parameter and calculate y values and concatenate all
for values, p_idx in zip(data, used_pars):
actual_parameters = [values.parameter[keys].value for keys in p_idx]
r = np.r_[r, values.func(actual_parameters, values.x)]
return r
def _update_parameter(data: Data, varied_keys: list[str], parameter: list[float]):
for keys, values in zip(varied_keys, parameter):
if keys in data.parameter.keys():
data.parameter[keys].scaled_value = values
data.parameter[keys].namespace[keys] = data.parameter[keys].value
class FitRoutine(object):
def __init__(self, mode='lsq'):
self._fitmethod = mode
@ -189,7 +220,7 @@ class FitRoutine(object):
logger.info('Fit aborted by user')
self._abort = True
def run(self, mode='lsq'):
def run(self, mode: str = 'lsq'):
self._abort = False
fit_groups, linked_parameter = self.prepare_links()
@ -246,7 +277,6 @@ class FitRoutine(object):
return pp, lb, ub, var_pars
def _prep_global(self, data_group, linked):
p0 = []
lb = []
ub = []
@ -264,16 +294,6 @@ class FitRoutine(object):
p_k_used = p_k
v_k_used = data.parameter[p_k]
# if i in data.model.parameter:
# p_k_used = data.model.parameter[i]
# v_k_used = self.parameter[p_k_used]
# data.parameter.add_parameter(i, data.model.parameter[i])
# links trump global parameter
# if p_k_used in linked:
# p_k_used = linked[p_k_used]
# v_k_used = self.parameter[p_k_used]
actual_pars.append(p_k_used)
# parameter is variable and was not found before as shared parameter
if v_k_used.var and p_k_used not in var:
@ -292,27 +312,6 @@ class FitRoutine(object):
self._no_own_model = []
def __cost_odr(self, p: list[float], data: Data, varpars: list[str], used_pars: list[str]):
for keys, values in zip(varpars, p):
self.parameter[keys].scaled_value = values
actual_parameters = [self.parameter[keys].value for keys in used_pars]
return data.func(actual_parameters, data.x)
def __cost_odr_glob(self, p, data, varpars, used_pars):
# replace values
for keys, values in zip(varpars, p):
self.parameter[keys].scaled_value = values
r = []
# unpack parameter and calculate y values and concatenate all
for values, p_idx in zip(data, used_pars):
actual_parameters = [self.parameter[keys].value for keys in p_idx]
r = np.r_[r, values.func(actual_parameters, values.x)]
return r
def _least_squares_single(self, data, p0, lb, ub, var):
self.step = 0
@ -380,13 +379,18 @@ class FitRoutine(object):
self.step += 1
if self._abort:
raise FitAbortException(f'Fit aborted by user')
return self.__cost_odr(p, data, var_pars, data.para_keys)
return _cost_odr(p, data, var_pars, data.para_keys)
odr_model = odr.Model(func)
corr, partial_corr, res = self._odr_fit(odr_data, odr_model, p0)
self.make_results(data, res.beta, var_pars, data.para_keys, (len(data), len(p0)),
err=res.sd_beta, corr=corr, partial_corr=partial_corr)
def _odr_fit(self, odr_data, odr_model, p0):
o = odr.ODR(odr_data, odr_model, beta0=p0)
res = o.run()
corr = res.cov_beta / (res.sd_beta[:, None] * res.sd_beta[None, :]) * res.res_var
try:
corr_inv = np.linalg.inv(corr)
@ -395,16 +399,14 @@ class FitRoutine(object):
partial_corr[np.diag_indices_from(partial_corr)] = 1.
except np.linalg.LinAlgError:
partial_corr = corr
self.make_results(data, res.beta, var_pars, data.para_keys, (len(data), len(p0)),
err=res.sd_beta, corr=corr, partial_corr=partial_corr)
return corr, partial_corr, res
def _odr_global(self, data, p0, var, data_pars):
def func(p, _):
self.step += 1
if self._abort:
raise FitAbortException(f'Fit aborted by user')
return self.__cost_odr_glob(p, data, var, data_pars)
return _cost_odr_glob(p, data, var, data_pars)
x = []
y = []
@ -415,17 +417,7 @@ class FitRoutine(object):
odr_data = odr.Data(x, y)
odr_model = odr.Model(func)
o = odr.ODR(odr_data, odr_model, beta0=p0, ifixb=var)
res = o.run()
corr = res.cov_beta / (res.sd_beta[:, None] * res.sd_beta[None, :]) * res.res_var
try:
corr_inv = np.linalg.inv(corr)
corr_inv_diag = np.diag(np.sqrt(1 / np.diag(corr_inv)))
partial_corr = -1. * np.dot(np.dot(corr_inv_diag, corr_inv), corr_inv_diag) # Partial correlation matrix
partial_corr[np.diag_indices_from(partial_corr)] = 1.
except np.linalg.LinAlgError:
partial_corr = corr
corr, partial_corr, res = self._odr_fit(odr_data, odr_model, p0)
for v, var_pars_k in zip(data, data_pars):
self.make_results(v, res.beta, var, var_pars_k, (sum(len(d) for d in data), len(p0)),