diff --git a/src/gui_qt/fit/fit_parameter.py b/src/gui_qt/fit/fit_parameter.py index 47b6ae0..7426c1b 100644 --- a/src/gui_qt/fit/fit_parameter.py +++ b/src/gui_qt/fit/fit_parameter.py @@ -208,11 +208,10 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit): if sid not in self.data_values: self.data_values[sid] = [None] * len(self.data_parameter) - def get_parameter(self, use_func=None) -> tuple[dict[str, list[Parameter]], list[Optional[Parameter]]]: + def get_parameter(self, use_func=None) -> tuple[dict, list]: bds = [] is_global = [] is_fixed = [] - param_general = [] for g in self.global_parameter: @@ -262,16 +261,16 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit): else: kw_p[g.argname] = p_i - global_parameter = [] - for param, global_flag in zip(param_general, is_global): - if global_flag: - global_parameter.append(param) - else: - global_parameter.append(None) - - data_parameter[sid] = (p, kw_p) + global_parameter = [] + for param, global_flag in zip(param_general, is_global): + if global_flag: + param.is_global = True + global_parameter.append(param) + else: + global_parameter.append(None) + return data_parameter, global_parameter def set_parameter(self, set_id: str | None, parameter: list[float]) -> int: diff --git a/src/gui_qt/fit/fitwindow.py b/src/gui_qt/fit/fitwindow.py index e3b708f..3f9ce48 100644 --- a/src/gui_qt/fit/fitwindow.py +++ b/src/gui_qt/fit/fitwindow.py @@ -9,6 +9,7 @@ import numpy as np from pyqtgraph import mkPen from nmreval.fit._meta import MultiModel, ModelFactory +from nmreval.fit.model import Model from nmreval.fit.result import FitResult from .fit_forms import FitTableWidget @@ -219,16 +220,16 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog): def _prepare(self, model: list, function_use: list = None, parameter: dict = None, add_idx: bool = False, cnt: int = 0) -> tuple[dict, int]: + if parameter is None: parameter = { - 'parameter': {}, - 'glob': [], + 'data_parameter': {}, + 'global_parameter': [], 'links': [], 'color': [], } for i, f in enumerate(model): - print(i, f) if not f['active']: continue @@ -239,33 +240,22 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog): QtWidgets.QMessageBox.Ok) return {}, -1 - print(p) - print(glob) - p_len = len(p) parameter['color'].append(f['color']) - - print(parameter) + parameter['global_parameter'].extend(glob) cnt = f['cnt'] - for p_k, v_k in p.items(): if add_idx: kw_k = {f'{k}_{cnt}': v for k, v in v_k[1].items()} else: kw_k = v_k[1] - if p_k in parameter['parameter']: - params, kw = parameter['parameter'][p_k] + if p_k in parameter['data_parameter']: + params, kw = parameter['data_parameter'][p_k] params += v_k[0] kw.update(kw_k) else: - parameter['parameter'][p_k] = (v_k[0], kw_k) - - for g_k, g_v in glob.items(): - if g_k != 'idx': - parameter['glob'][g_k] += g_v - else: - parameter['glob']['idx'] += [idx_i + p_len for idx_i in g_v] + parameter['data_parameter'][p_k] = (v_k[0], kw_k) if add_idx: cnt += 1 @@ -283,37 +273,43 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog): data = self.data_table.collect_data(default=self.default_combobox.currentData()) func_dict = {} - for k, mod in self.models.items(): - func, order, param_len = ModelFactory.create_from_list(mod) + for model_name, model_parameter in self.models.items(): + func, order, param_len = ModelFactory.create_from_list(model_parameter) if func is None: continue - if k in data: - parameter, _ = self._prepare(mod, function_use=data[k], add_idx=isinstance(func, MultiModel)) + func = Model(func) - # convert positions of global parameter to corresponding names - global_parameter: dict = parameter['glob'] - positions = global_parameter.pop('idx') - global_parameter['key'] = [pname for i, pname in enumerate(func.params) if i in positions] - # print(global_parameter) + if model_name in data: + parameter, _ = self._prepare(model_parameter, function_use=data[model_name], add_idx=isinstance(func, MultiModel)) if parameter is None: return + for (data_parameter, _) in parameter['data_parameter'].values(): + for pname, param in zip(func.params, data_parameter): + param.name = pname + + if self._complex[model_name] is not None: + for p_k, p_v in parameter['data_parameter'].items(): + p_v[1].update({'complex_mode': self._complex[model_name]}) + parameter['data_parameter'][p_k] = p_v[0], p_v[1] + + for pname, param_value in zip(func.params, parameter['global_parameter']): + if param_value is not None: + param_value.name = pname + func.set_global_parameter(param_value) + parameter['func'] = func parameter['order'] = order parameter['len'] = param_len - parameter['complex'] = self._complex[k] - if self._complex[k] is not None: - for p_k, p_v in parameter['parameter'].items(): - p_v[1].update({'complex_mode': self._complex[k]}) - parameter['parameter'][p_k] = p_v[0], p_v[1] + parameter['complex'] = self._complex[model_name] - func_dict[k] = parameter + func_dict[model_name] = parameter replaceable = [] - for k, v in func_dict.items(): + for model_name, v in func_dict.items(): for i, link_i in enumerate(v['links']): if link_i is None: continue @@ -344,7 +340,7 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog): QtWidgets.QMessageBox.Ok) return - replaceable.append((k, i, rep_model, repl_idx)) + replaceable.append((model_name, i, rep_model, repl_idx)) replace_value = None for p_k in f['parameter'].values(): diff --git a/src/gui_qt/main/management.py b/src/gui_qt/main/management.py index e4a0dc4..8d4c4ee 100644 --- a/src/gui_qt/main/management.py +++ b/src/gui_qt/main/management.py @@ -441,21 +441,22 @@ class UpperManagement(QtCore.QObject): # all-encompassing error catch try: for model_id, model_p in parameter.items(): - m = Model(model_p['func']) + m = model_p['func'] models[model_id] = m m_complex = model_p['complex'] + print(model_p) # sets are not in active order but in order they first appeared in fit dialog # iterate over order of set id in active order and access parameter inside loop # instead of directly looping - list_ids = list(model_p['parameter'].keys()) + list_ids = list(model_p['data_parameter'].keys()) set_order = [self.active_id.index(i) for i in list_ids] for pos in set_order: set_id = list_ids[pos] data_i = self.data[set_id] - set_params = model_p['parameter'][set_id] + set_params = model_p['data_parameter'][set_id] if we_option.lower() == 'deltay': we = data_i.y_err**2 @@ -485,18 +486,13 @@ class UpperManagement(QtCore.QObject): d = fit_d.Data(_x[inside], _y[inside], we=we[inside], idx=set_id) d.set_model(m) - d.set_parameter(set_params[0], var=model_p['var'], - lb=model_p['lb'], ub=model_p['ub'], - fun_kwargs=set_params[1]) + d.set_parameter(set_params[0], fun_kwargs=set_params[1]) + # d.set_parameter(set_params[0], var=model_p['var'], + # lb=model_p['lb'], ub=model_p['ub'], + # fun_kwargs=set_params[1]) self.fitter.add_data(d) - model_globs = model_p['glob'] - if model_globs: - for parameter_args in zip(*model_globs.values()): - m.set_global_parameter(**{k: v for k, v in zip(model_globs.keys(), parameter_args)}) - # m.set_global_parameter(**model_p['glob']) - for links_i in links: self.fitter.set_link_parameter((models[links_i[0]], links_i[1]), (models[links_i[2]], links_i[3])) diff --git a/src/nmreval/fit/data.py b/src/nmreval/fit/data.py index 0d7ed86..4a34409 100644 --- a/src/nmreval/fit/data.py +++ b/src/nmreval/fit/data.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from .model import Model @@ -69,7 +71,7 @@ class Data(object): return self.model def set_parameter(self, - values: list[float], + values: list[float | Parameter], *, var: list[bool] = None, ub: list[float] = None, @@ -103,24 +105,33 @@ class Data(object): if len(values) != len(model.params): raise ValueError('Number of given parameter does not match number of model parameters') - if var is None: - var = [True] * len(values) + is_parameter = [isinstance(v, Parameter) for v in values] + if all(is_parameter): + for p_i in values: + key = f"p{next(Parameters.parameter_counter)}" + self.parameter.add_parameter(key, p_i) + elif any(is_parameter): + raise ValueError('list of parameter are not all float of Parameter') - if lb is None: - if default_bounds: - lb = model.lb - else: - lb = [None] * len(values) + else: + if var is None: + var = [True] * len(values) - if ub is None: - if default_bounds: - ub = model.ub - else: - ub = [None] * len(values) + if lb is None: + if default_bounds: + lb = model.lb + else: + lb = [None] * len(values) - arg_names = ['name', 'value', 'var', 'lb', 'ub'] - for parameter_arg in zip(model.params, values, var, lb, ub): - self.parameter.add(**{arg_name: arg_value for arg_name, arg_value in zip(arg_names, parameter_arg)}) + if ub is None: + if default_bounds: + ub = model.ub + else: + ub = [None] * len(values) + + arg_names = ['name', 'value', 'var', 'lb', 'ub'] + for parameter_arg in zip(model.params, values, var, lb, ub): + self.parameter.add(**{arg_name: arg_value for arg_name, arg_value in zip(arg_names, parameter_arg)}) self.para_keys = list(self.parameter.keys()) diff --git a/src/nmreval/fit/model.py b/src/nmreval/fit/model.py index faca53d..c80a2a3 100644 --- a/src/nmreval/fit/model.py +++ b/src/nmreval/fit/model.py @@ -80,22 +80,30 @@ class Model(object): if v.default is not inspect.Parameter.empty} def set_global_parameter(self, - key: str, - value: float | str, + key: str | Parameter, + value: float | str = None, + *, var: bool = None, lb: float = None, ub: float = None, - default_bounds: bool = False + default_bounds: bool = False, ) -> Parameter: - idx = [self.params.index(key)] - if default_bounds: - if lb is None: - lb = [self.lb[i] for i in idx] - if ub is None: - ub = [self.lb[i] for i in idx] - p = self.parameter.add(key, value, var=var, lb=lb, ub=ub) - p.is_global = True + if isinstance(key, Parameter): + p = key + key = f'p{next(Parameters.parameter_counter)}' + self.parameter.add_parameter(key, p) + + else: + idx = [self.params.index(key)] + if default_bounds: + if lb is None: + lb = [self.lb[i] for i in idx] + if ub is None: + ub = [self.lb[i] for i in idx] + + p = self.parameter.add(key, value, var=var, lb=lb, ub=ub) + p.is_global = True return p