From 0b8f4932b25f4ad16b7309575044ee568a6f50ac Mon Sep 17 00:00:00 2001 From: Dominik Demuth Date: Fri, 25 Aug 2023 18:46:36 +0200 Subject: [PATCH] seems mostly to be working --- src/gui_qt/fit/fit_parameter.py | 4 +- src/gui_qt/fit/fitwindow.py | 11 ++++- src/gui_qt/main/management.py | 4 +- src/nmreval/fit/data.py | 43 ++++++++++++++----- src/nmreval/fit/minimizer.py | 75 ++++++++++++++------------------- src/nmreval/fit/model.py | 12 +++++- src/nmreval/fit/parameter.py | 45 +++++++++++++++----- src/nmreval/models/diffusion.py | 2 +- 8 files changed, 123 insertions(+), 73 deletions(-) diff --git a/src/gui_qt/fit/fit_parameter.py b/src/gui_qt/fit/fit_parameter.py index 682893e..292c220 100644 --- a/src/gui_qt/fit/fit_parameter.py +++ b/src/gui_qt/fit/fit_parameter.py @@ -227,7 +227,7 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit): kw_p = {} p = [] if global_p is None: - global_p = {'p': [], 'idx': [], 'var': [], 'ub': [], 'lb': []} + global_p = {'value': [], 'idx': [], 'var': [], 'ub': [], 'lb': []} for i, (p_i, g) in enumerate(zip(parameter, self.global_parameter)): if isinstance(g, FitModelWidget): @@ -235,7 +235,7 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit): p.append(globs[i]) if is_global[i]: if i not in global_p['idx']: - global_p['p'].append(globs[i]) + global_p['value'].append(globs[i]) global_p['idx'].append(i) global_p['var'].append(is_fixed[i]) global_p['ub'].append(ub[i]) diff --git a/src/gui_qt/fit/fitwindow.py b/src/gui_qt/fit/fitwindow.py index 519216e..ea66eef 100644 --- a/src/gui_qt/fit/fitwindow.py +++ b/src/gui_qt/fit/fitwindow.py @@ -221,7 +221,7 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog): parameter: dict = None, add_idx: bool = False, cnt: int = 0) -> tuple[dict, int]: if parameter is None: parameter = {'parameter': {}, 'lb': (), 'ub': (), 'var': [], - 'glob': {'idx': [], 'p': [], 'var': [], 'lb': [], 'ub': []}, + 'glob': {'idx': [], 'value': [], 'var': [], 'lb': [], 'ub': []}, 'links': [], 'color': []} for i, f in enumerate(model): @@ -269,7 +269,7 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog): if f['children']: # recurse for children - child_parameter, cnt = self._prepare(f['children'], parameter=parameter, add_idx=add_idx, cnt=cnt) + _, cnt = self._prepare(f['children'], parameter=parameter, add_idx=add_idx, cnt=cnt) return parameter, cnt @@ -288,6 +288,13 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog): if k in data: parameter, _ = self._prepare(mod, function_use=data[k], add_idx=isinstance(func, MultiModel)) + + # convert positions of global parameter to corresponding names + global_parameter: dict = parameter['glob'] + positions = global_parameter.pop('idx') + global_parameter['key'] = [pname for i, pname in enumerate(func.params) if i in positions] + # print(global_parameter) + if parameter is None: return diff --git a/src/gui_qt/main/management.py b/src/gui_qt/main/management.py index 191fe7c..dfc98d7 100644 --- a/src/gui_qt/main/management.py +++ b/src/gui_qt/main/management.py @@ -467,7 +467,9 @@ class UpperManagement(QtCore.QObject): model_globs = model_p['glob'] if model_globs: - m.set_global_parameter(**model_p['glob']) + for parameter_args in zip(*model_globs.values()): + m.set_global_parameter(**{k: v for k, v in zip(model_globs.keys(), parameter_args)}) + # m.set_global_parameter(**model_p['glob']) for links_i in links: self.fitter.set_link_parameter((models[links_i[0]], links_i[1]), diff --git a/src/nmreval/fit/data.py b/src/nmreval/fit/data.py index e33bfcf..0d7ed86 100644 --- a/src/nmreval/fit/data.py +++ b/src/nmreval/fit/data.py @@ -68,12 +68,19 @@ class Data(object): def get_model(self): return self.model - def set_parameter(self, parameter, var=None, ub=None, lb=None, - default_bounds=False, fun_kwargs=None): + def set_parameter(self, + values: list[float], + *, + var: list[bool] = None, + ub: list[float] = None, + lb: list[float] = None, + default_bounds: bool = False, + fun_kwargs: dict = None + ): """ Creates parameter for this data. If no Model is available, it falls back to the model - :param parameter: list of parameters + :param values: list of parameters :param var: list of boolean or boolean; False fixes parameter at given list index. Single value is broadcast to all parameter :param ub: list of upper boundaries or float; Single value is broadcast to all parameter. @@ -87,23 +94,37 @@ class Data(object): model = self.model if model is None: # Data has no unique - if self.minimizer is None: - model = None - else: + if self.minimizer is not None: model = self.minimizer.fit_model - self.fun_kwargs.update(model.fun_kwargs) if model is None: raise ValueError('No model found, please set model before parameters') - if default_bounds: - if lb is None: + if len(values) != len(model.params): + raise ValueError('Number of given parameter does not match number of model parameters') + + if var is None: + var = [True] * len(values) + + if lb is None: + if default_bounds: lb = model.lb - if ub is None: + else: + lb = [None] * len(values) + + if ub is None: + if default_bounds: ub = model.ub + else: + ub = [None] * len(values) - self.para_keys = self.parameter.add_parameter(parameter, var=var, lb=lb, ub=ub, names=model.params) + arg_names = ['name', 'value', 'var', 'lb', 'ub'] + for parameter_arg in zip(model.params, values, var, lb, ub): + self.parameter.add(**{arg_name: arg_value for arg_name, arg_value in zip(arg_names, parameter_arg)}) + self.para_keys = list(self.parameter.keys()) + + self.fun_kwargs.update(model.fun_kwargs) if fun_kwargs is not None: self.fun_kwargs.update(fun_kwargs) diff --git a/src/nmreval/fit/minimizer.py b/src/nmreval/fit/minimizer.py index 71ab2c6..2dbca81 100644 --- a/src/nmreval/fit/minimizer.py +++ b/src/nmreval/fit/minimizer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import warnings from itertools import product @@ -26,7 +28,7 @@ def _cost_scipy_glob(p: list[float], data: list[Data], varpars: list[str], used_ # replace values for keys, values in zip(varpars, p): for data_i in data: - if keys in data_i.parameter: + if keys in data_i.parameter.keys(): data_i.parameter[keys].scaled_value = values data_i.parameter[keys].namespace[keys] = data_i.parameter[keys].value r = [] @@ -53,7 +55,6 @@ class FitRoutine(object): self.data = [] self.fit_model = None self._no_own_model = [] - self.parameter = Parameters() self.result = [] self.linked = [] self._abort = False @@ -107,28 +108,25 @@ class FitRoutine(object): return self.fit_model - def set_link_parameter(self, parameter: tuple, replacement: tuple): + def set_link_parameter(self, dismissed_param: tuple[Model | Data, str], replacement: tuple[Model, str]): if isinstance(replacement[0], Model): - if replacement[1] not in replacement[0].global_parameter: - raise KeyError(f'Parameter at pos {replacement[1]} of ' - f'model {str(replacement[0])} is not global') + if replacement[1] not in replacement[0].parameter: + raise KeyError(f'Parameter {replacement[1]} of ' + f'model {replacement[0]} is not global') - if isinstance(parameter[0], Model): - warnings.warn(f'Replaced parameter at pos {parameter[1]} in {str(parameter[0])} ' + if isinstance(dismissed_param[0], Model): + warnings.warn(f'Replaced parameter {dismissed_param[1]} in {dismissed_param[0]} ' f'becomes global with linkage.') - self.linked.append((*parameter, *replacement)) + self.linked.append((*dismissed_param, *replacement)) def prepare_links(self): self._no_own_model = [] - self.parameter = Parameters() _found_models = {} linked_sender = {} for v in self.data: linked_sender[v] = set() - for k, p_i in v.parameter.items(): - self.parameter.add_parameter(k, p_i.copy()) # set temporary model if v.model is None: @@ -138,35 +136,29 @@ class FitRoutine(object): # register model if v.model not in _found_models: _found_models[v.model] = [] - for k, p in v.model.parameter.items(): - self.parameter.add_parameter(k, p) - # m_param = v.model.parameter.copy() - # self.parameter.update(m_param) - # + _found_models[v.model].append(v) if v.model not in linked_sender: linked_sender[v.model] = set() linked_parameter = {} - # for par, par_parm, repl, repl_par in self.linked: - # if isinstance(par, Data): - # if isinstance(repl, Data): - # linked_parameter[par.para_keys[par_parm]] = repl.para_keys[repl_par] - # else: - # linked_parameter[par.para_keys[par_parm]] = repl.parameter[repl_par] - # - # else: - # if isinstance(repl, Data): - # par.global_parameter[par_parm] = repl.para_keys[repl_par] - # else: - # par.global_parameter[par_parm] = repl.global_parameter[repl_par] - # - # linked_sender[repl].add(par) - # linked_sender[par].add(repl) + for dismiss_model, dismiss_param, replace_model, replace_param in self.linked: + linked_sender[replace_model].add(dismiss_model) + linked_sender[replace_model].add(replace_model) + + replace_key = replace_model.parameter.get_key(replace_param) + dismiss_key = dismiss_model.parameter.get_key(dismiss_param) + + if isinstance(replace_model, Data): + linked_parameter[dismiss_key] = replace_key + else: + # print('dismiss model', dismiss_model.parameter) + # print('replace model', replace_model.parameter) + dismiss_model.parameter.replace_parameter(dismiss_key, replace_key, replace_model.parameter[replace_key]) + # print('after replacement', dismiss_model.parameter) for mm, m_data in _found_models.items(): - # print('has global', mm.parameter) if mm.parameter: for dd in m_data: linked_sender[mm].add(dd) @@ -174,14 +166,12 @@ class FitRoutine(object): coupled_data = [] visited_data = [] - # print('linked', linked_sender) for s in linked_sender.keys(): if s in visited_data: continue sub_graph = [] self.find_paths(s, linked_sender, sub_graph, visited_data) if sub_graph: - # print('sub', sub_graph) coupled_data.append(sub_graph) return coupled_data, linked_parameter @@ -203,12 +193,8 @@ class FitRoutine(object): def run(self, mode='lsq'): self._abort = False - self.parameter = Parameters() fit_groups, linked_parameter = self.prepare_links() - - # print(fit_groups, self.linked) - for data_groups in fit_groups: if len(data_groups) == 1 and not self.linked: data = data_groups[0] @@ -226,6 +212,7 @@ class FitRoutine(object): self._odr_single(data, p0_k, var_pars_k) else: + # print('INSIDE RUN') data_pars, p0, lb, ub, var_pars = self._prep_global(data_groups, linked_parameter) if mode == 'lsq': @@ -262,18 +249,21 @@ class FitRoutine(object): return pp, lb, ub, var_pars def _prep_global(self, data_group, linked): + # print('PREP GLOBAL') + # print(data_group, linked) + p0 = [] lb = [] ub = [] var = [] data_pars = [] - # print(data_group) - # loopy-loop over data that belong to one fit (linked or global) for data in data_group: # is parameter replaced by global parameter? + # print('SET GLOBAL') for k, v in data.model.parameter.items(): + # print(k, v) data.replace_parameter(k, v) actual_pars = [] @@ -312,7 +302,6 @@ class FitRoutine(object): self._no_own_model = [] - def __cost_odr(self, p: list[float], data: Data, varpars: list[str], used_pars: list[str]): for keys, values in zip(varpars, p): self.parameter[keys].scaled_value = values @@ -361,6 +350,7 @@ class FitRoutine(object): with np.errstate(all='ignore'): res = optimize.least_squares(cost, p0, bounds=(lb, ub), max_nfev=500 * len(p0)) + err, corr, partial_corr = self._calc_error(res.jac, np.sum(res.fun**2), *res.jac.shape) for v, var_pars_k in zip(data, data_pars): self.make_results(v, res.x, var, var_pars_k, res.jac.shape, @@ -457,7 +447,6 @@ class FitRoutine(object): if err is None: err = [0] * len(p) - print(p, var_pars, used_pars) # update parameter values for keys, p_value, err_value in zip(var_pars, p, err): if keys in data.parameter: diff --git a/src/nmreval/fit/model.py b/src/nmreval/fit/model.py index 76a9da2..a4db952 100644 --- a/src/nmreval/fit/model.py +++ b/src/nmreval/fit/model.py @@ -79,7 +79,14 @@ class Model(object): self.fun_kwargs = {k: v.default for k, v in inspect.signature(model.func).parameters.items() if v.default is not inspect.Parameter.empty} - def set_global_parameter(self, key, value, var=None, lb=None, ub=None, default_bounds=False): + def set_global_parameter(self, + key: str, + value: float | str, + var: bool = None, + lb: float = None, + ub: float = None, + default_bounds: bool = False + ): idx = [self.params.index(key)] if default_bounds: if lb is None: @@ -87,7 +94,8 @@ class Model(object): if ub is None: ub = [self.lb[i] for i in idx] - self.parameter.add(key, value, var=var, lb=lb, ub=ub) + p = self.parameter.add(key, value, var=var, lb=lb, ub=ub) + p.is_global = True @staticmethod def _prep(param_len, val): diff --git a/src/nmreval/fit/parameter.py b/src/nmreval/fit/parameter.py index 0ccb3c6..e9bf2c1 100644 --- a/src/nmreval/fit/parameter.py +++ b/src/nmreval/fit/parameter.py @@ -8,6 +8,7 @@ import numpy as np class Parameters(dict): parameter_counter = count() + # is one global namespace a good idea? namespace: dict = {} def __init__(self): @@ -24,7 +25,14 @@ class Parameters(dict): return super().__getitem__(item) def __setitem__(self, key, value): - super().__setitem__(key, value) + self.add_parameter(key, value) + + def __contains__(self, item): + for v in self.values(): + if item == v.name: + return True + + return False def add(self, name: str, @@ -34,7 +42,6 @@ class Parameters(dict): lb: float = -np.inf, ub: float = np.inf) -> Parameter: par = Parameter(name=name, value=value, var=var, lb=lb, ub=ub) - key = f'p{next(Parameters.parameter_counter)}' self.add_parameter(key, par) @@ -43,11 +50,8 @@ class Parameters(dict): def add_parameter(self, key: str, parameter: Parameter): self._mapping[parameter.name] = key - self[key] = parameter + super().__setitem__(key, parameter) - self._mapping[parameter.name] = key - - self[key] = parameter parameter.eval_allowed = False self.namespace[key] = parameter.value parameter.namespace = self.namespace @@ -63,13 +67,17 @@ class Parameters(dict): p._expr = expression def replace_parameter(self, key_out: str, key_in: str, parameter: Parameter): + # print('replace par', key_out, key_in, parameter) + # print('name', parameter.name) + + self.add_parameter(key_in, parameter) for k, v in self._mapping.items(): if v == key_out: self._mapping[k] = key_in break - self.add_parameter(key_in, parameter) - del self.namespace[key_out] + if key_out in self.namespace: + del self.namespace[key_out] for p in self.values(): try: @@ -80,6 +88,13 @@ class Parameters(dict): expression = re.sub(re.escape(n), k, expression) p._expr = expression + def get_key(self, name: str) -> str | None: + for k, v in self.items(): + if name == v.name: + return k + + return + def get_state(self): return {k: v.get_state() for k, v in self.items()} @@ -102,6 +117,7 @@ class Parameter: self.eval_allowed: bool = True self._expr: None | str = None self._expr_disp: None | str = None + self.is_global = False if isinstance(value, str): self._expr_disp = value @@ -126,15 +142,19 @@ class Parameter: start = StringIO() if self.name: if self.function: - start.write(f"{self.name} ({self.function}): ") + start.write(f"{self.name} ({self.function})") else: start.write(self.name) - start.write(": ") + + if self.is_global: + start.write("*") + + start.write(": ") if self.var: start.write(f"{self.value:.4g} +/- {self.error:.4g}, init={self.init_val}") else: - start.write(f"{self.value:}") + start.write(f"{self.value:.4g}") if self._expr is None: start.write(" (fixed)") else: @@ -204,6 +224,9 @@ class Parameter: para_copy = Parameter(name=self.name, value=val, var=self.var, lb=self.lb, ub=self.ub) para_copy._expr = self._expr para_copy.namespace = self.namespace + para_copy.is_global = self.is_global + para_copy.error = self.error + para_copy.function = self.function return para_copy diff --git a/src/nmreval/models/diffusion.py b/src/nmreval/models/diffusion.py index 4e03f92..32dba92 100644 --- a/src/nmreval/models/diffusion.py +++ b/src/nmreval/models/diffusion.py @@ -125,7 +125,7 @@ class Peschier: q = nucleus*g*tp r1s, r1f = 1 / t1s, 1 / t1f - kf, ks = pf*k, (1-pf)*k + kf, ks = (1-pf)*k, pf*k a_plus = 0.5 * (d*q*q + kf + ks + r1f + r1s + np.sqrt((d*q*q + kf + r1f - ks - r1s)**2 + 4*kf*ks)) a_minu = 0.5 * (d*q*q + kf + ks + r1f + r1s - np.sqrt((d*q*q + kf + r1f - ks - r1s)**2 + 4*kf*ks))