From d17d0f251ece24d712de8425e00fe2e78940b9b3 Mon Sep 17 00:00:00 2001 From: Dominik Demuth Date: Sat, 26 Aug 2023 20:08:13 +0200 Subject: [PATCH] work on linked models --- src/nmreval/fit/minimizer.py | 14 ++------------ src/nmreval/fit/model.py | 6 ++++-- src/nmreval/fit/result.py | 7 +++---- 3 files changed, 9 insertions(+), 18 deletions(-) diff --git a/src/nmreval/fit/minimizer.py b/src/nmreval/fit/minimizer.py index 2dbca81..c4b8fae 100644 --- a/src/nmreval/fit/minimizer.py +++ b/src/nmreval/fit/minimizer.py @@ -153,10 +153,8 @@ class FitRoutine(object): if isinstance(replace_model, Data): linked_parameter[dismiss_key] = replace_key else: - # print('dismiss model', dismiss_model.parameter) - # print('replace model', replace_model.parameter) - dismiss_model.parameter.replace_parameter(dismiss_key, replace_key, replace_model.parameter[replace_key]) - # print('after replacement', dismiss_model.parameter) + p = dismiss_model.set_global_parameter(dismiss_param, replace_key) + p._expr_disp = replace_param for mm, m_data in _found_models.items(): if mm.parameter: @@ -212,7 +210,6 @@ class FitRoutine(object): self._odr_single(data, p0_k, var_pars_k) else: - # print('INSIDE RUN') data_pars, p0, lb, ub, var_pars = self._prep_global(data_groups, linked_parameter) if mode == 'lsq': @@ -249,8 +246,6 @@ class FitRoutine(object): return pp, lb, ub, var_pars def _prep_global(self, data_group, linked): - # print('PREP GLOBAL') - # print(data_group, linked) p0 = [] lb = [] @@ -261,14 +256,11 @@ class FitRoutine(object): # loopy-loop over data that belong to one fit (linked or global) for data in data_group: # is parameter replaced by global parameter? - # print('SET GLOBAL') for k, v in data.model.parameter.items(): - # print(k, v) data.replace_parameter(k, v) actual_pars = [] for i, p_k in enumerate(data.para_keys): - # print(i, p_k) p_k_used = p_k v_k_used = data.parameter[p_k] @@ -290,8 +282,6 @@ class FitRoutine(object): ub.append(v_k_used.ub / v_k_used.scale) var.append(p_k_used) - # print('aloha, ', actual_pars) - data_pars.append(actual_pars) return data_pars, p0, lb, ub, var diff --git a/src/nmreval/fit/model.py b/src/nmreval/fit/model.py index a4db952..faca53d 100644 --- a/src/nmreval/fit/model.py +++ b/src/nmreval/fit/model.py @@ -6,7 +6,7 @@ from typing import Sized from numpy import inf from ._meta import MultiModel -from .parameter import Parameters +from .parameter import Parameters, Parameter class Model(object): @@ -86,7 +86,7 @@ class Model(object): lb: float = None, ub: float = None, default_bounds: bool = False - ): + ) -> Parameter: idx = [self.params.index(key)] if default_bounds: if lb is None: @@ -97,6 +97,8 @@ class Model(object): p = self.parameter.add(key, value, var=var, lb=lb, ub=ub) p.is_global = True + return p + @staticmethod def _prep(param_len, val): if isinstance(val, Sized): diff --git a/src/nmreval/fit/result.py b/src/nmreval/fit/result.py index a2a83d1..09c550d 100644 --- a/src/nmreval/fit/result.py +++ b/src/nmreval/fit/result.py @@ -223,6 +223,7 @@ class FitResult(Points): return self.nobs-self.nvar def pprint(self, statistics=True, correlations=True): + sstream = StringIO() print('Fit result:') print(' model :', self.name) print(' #data :', self.nobs) @@ -243,7 +244,7 @@ class FitResult(Points): def parameter_string(self): ret_val = '' - for pval in self.parameter.values(): + for pkey, pval in self.parameter.items(): ret_val += convert(str(pval), old='tex', new='str') + '\n' if self.fun_kwargs: @@ -255,9 +256,7 @@ class FitResult(Points): def _correlation_string(self): ret_val = '' for p_i, p_j, corr_ij, pcorr_ij in self.correlation_list(): - ret_val += ' {} / {} : {:.4f} ({:.4f})\n'.format(convert(p_i, old='tex', new='str'), - convert(p_j, old='tex', new='str'), - corr_ij, pcorr_ij) + ret_val += f" {convert(p_i, old='tex', new='str')} / {convert(p_j, old='tex', new='str')} : {corr_ij:.4f} ({pcorr_ij:.4f})\n" return ret_val def correlation_list(self, limit=0.1):