work on linked models

This commit is contained in:
Dominik Demuth 2023-08-26 20:08:13 +02:00
parent 0b8f4932b2
commit d17d0f251e
3 changed files with 9 additions and 18 deletions

View File

@ -153,10 +153,8 @@ class FitRoutine(object):
if isinstance(replace_model, Data): if isinstance(replace_model, Data):
linked_parameter[dismiss_key] = replace_key linked_parameter[dismiss_key] = replace_key
else: else:
# print('dismiss model', dismiss_model.parameter) p = dismiss_model.set_global_parameter(dismiss_param, replace_key)
# print('replace model', replace_model.parameter) p._expr_disp = replace_param
dismiss_model.parameter.replace_parameter(dismiss_key, replace_key, replace_model.parameter[replace_key])
# print('after replacement', dismiss_model.parameter)
for mm, m_data in _found_models.items(): for mm, m_data in _found_models.items():
if mm.parameter: if mm.parameter:
@ -212,7 +210,6 @@ class FitRoutine(object):
self._odr_single(data, p0_k, var_pars_k) self._odr_single(data, p0_k, var_pars_k)
else: else:
# print('INSIDE RUN')
data_pars, p0, lb, ub, var_pars = self._prep_global(data_groups, linked_parameter) data_pars, p0, lb, ub, var_pars = self._prep_global(data_groups, linked_parameter)
if mode == 'lsq': if mode == 'lsq':
@ -249,8 +246,6 @@ class FitRoutine(object):
return pp, lb, ub, var_pars return pp, lb, ub, var_pars
def _prep_global(self, data_group, linked): def _prep_global(self, data_group, linked):
# print('PREP GLOBAL')
# print(data_group, linked)
p0 = [] p0 = []
lb = [] lb = []
@ -261,14 +256,11 @@ class FitRoutine(object):
# loopy-loop over data that belong to one fit (linked or global) # loopy-loop over data that belong to one fit (linked or global)
for data in data_group: for data in data_group:
# is parameter replaced by global parameter? # is parameter replaced by global parameter?
# print('SET GLOBAL')
for k, v in data.model.parameter.items(): for k, v in data.model.parameter.items():
# print(k, v)
data.replace_parameter(k, v) data.replace_parameter(k, v)
actual_pars = [] actual_pars = []
for i, p_k in enumerate(data.para_keys): for i, p_k in enumerate(data.para_keys):
# print(i, p_k)
p_k_used = p_k p_k_used = p_k
v_k_used = data.parameter[p_k] v_k_used = data.parameter[p_k]
@ -290,8 +282,6 @@ class FitRoutine(object):
ub.append(v_k_used.ub / v_k_used.scale) ub.append(v_k_used.ub / v_k_used.scale)
var.append(p_k_used) var.append(p_k_used)
# print('aloha, ', actual_pars)
data_pars.append(actual_pars) data_pars.append(actual_pars)
return data_pars, p0, lb, ub, var return data_pars, p0, lb, ub, var

View File

@ -6,7 +6,7 @@ from typing import Sized
from numpy import inf from numpy import inf
from ._meta import MultiModel from ._meta import MultiModel
from .parameter import Parameters from .parameter import Parameters, Parameter
class Model(object): class Model(object):
@ -86,7 +86,7 @@ class Model(object):
lb: float = None, lb: float = None,
ub: float = None, ub: float = None,
default_bounds: bool = False default_bounds: bool = False
): ) -> Parameter:
idx = [self.params.index(key)] idx = [self.params.index(key)]
if default_bounds: if default_bounds:
if lb is None: if lb is None:
@ -97,6 +97,8 @@ class Model(object):
p = self.parameter.add(key, value, var=var, lb=lb, ub=ub) p = self.parameter.add(key, value, var=var, lb=lb, ub=ub)
p.is_global = True p.is_global = True
return p
@staticmethod @staticmethod
def _prep(param_len, val): def _prep(param_len, val):
if isinstance(val, Sized): if isinstance(val, Sized):

View File

@ -223,6 +223,7 @@ class FitResult(Points):
return self.nobs-self.nvar return self.nobs-self.nvar
def pprint(self, statistics=True, correlations=True): def pprint(self, statistics=True, correlations=True):
sstream = StringIO()
print('Fit result:') print('Fit result:')
print(' model :', self.name) print(' model :', self.name)
print(' #data :', self.nobs) print(' #data :', self.nobs)
@ -243,7 +244,7 @@ class FitResult(Points):
def parameter_string(self): def parameter_string(self):
ret_val = '' ret_val = ''
for pval in self.parameter.values(): for pkey, pval in self.parameter.items():
ret_val += convert(str(pval), old='tex', new='str') + '\n' ret_val += convert(str(pval), old='tex', new='str') + '\n'
if self.fun_kwargs: if self.fun_kwargs:
@ -255,9 +256,7 @@ class FitResult(Points):
def _correlation_string(self): def _correlation_string(self):
ret_val = '' ret_val = ''
for p_i, p_j, corr_ij, pcorr_ij in self.correlation_list(): for p_i, p_j, corr_ij, pcorr_ij in self.correlation_list():
ret_val += ' {} / {} : {:.4f} ({:.4f})\n'.format(convert(p_i, old='tex', new='str'), ret_val += f" {convert(p_i, old='tex', new='str')} / {convert(p_j, old='tex', new='str')} : {corr_ij:.4f} ({pcorr_ij:.4f})\n"
convert(p_j, old='tex', new='str'),
corr_ij, pcorr_ij)
return ret_val return ret_val
def correlation_list(self, limit=0.1): def correlation_list(self, limit=0.1):