1
0
forked from IPKM/nmreval

use Parameter when collecting fit values

This commit is contained in:
Dominik Demuth 2023-09-18 13:52:10 +02:00
parent 03d172bade
commit bd1a227e4c
5 changed files with 94 additions and 84 deletions

View File

@ -208,11 +208,10 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit):
if sid not in self.data_values: if sid not in self.data_values:
self.data_values[sid] = [None] * len(self.data_parameter) self.data_values[sid] = [None] * len(self.data_parameter)
def get_parameter(self, use_func=None) -> tuple[dict[str, list[Parameter]], list[Optional[Parameter]]]: def get_parameter(self, use_func=None) -> tuple[dict, list]:
bds = [] bds = []
is_global = [] is_global = []
is_fixed = [] is_fixed = []
param_general = [] param_general = []
for g in self.global_parameter: for g in self.global_parameter:
@ -262,16 +261,16 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit):
else: else:
kw_p[g.argname] = p_i kw_p[g.argname] = p_i
data_parameter[sid] = (p, kw_p)
global_parameter = [] global_parameter = []
for param, global_flag in zip(param_general, is_global): for param, global_flag in zip(param_general, is_global):
if global_flag: if global_flag:
param.is_global = True
global_parameter.append(param) global_parameter.append(param)
else: else:
global_parameter.append(None) global_parameter.append(None)
data_parameter[sid] = (p, kw_p)
return data_parameter, global_parameter return data_parameter, global_parameter
def set_parameter(self, set_id: str | None, parameter: list[float]) -> int: def set_parameter(self, set_id: str | None, parameter: list[float]) -> int:

View File

@ -9,6 +9,7 @@ import numpy as np
from pyqtgraph import mkPen from pyqtgraph import mkPen
from nmreval.fit._meta import MultiModel, ModelFactory from nmreval.fit._meta import MultiModel, ModelFactory
from nmreval.fit.model import Model
from nmreval.fit.result import FitResult from nmreval.fit.result import FitResult
from .fit_forms import FitTableWidget from .fit_forms import FitTableWidget
@ -219,16 +220,16 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
def _prepare(self, model: list, function_use: list = None, def _prepare(self, model: list, function_use: list = None,
parameter: dict = None, add_idx: bool = False, cnt: int = 0) -> tuple[dict, int]: parameter: dict = None, add_idx: bool = False, cnt: int = 0) -> tuple[dict, int]:
if parameter is None: if parameter is None:
parameter = { parameter = {
'parameter': {}, 'data_parameter': {},
'glob': [], 'global_parameter': [],
'links': [], 'links': [],
'color': [], 'color': [],
} }
for i, f in enumerate(model): for i, f in enumerate(model):
print(i, f)
if not f['active']: if not f['active']:
continue continue
@ -239,33 +240,22 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
QtWidgets.QMessageBox.Ok) QtWidgets.QMessageBox.Ok)
return {}, -1 return {}, -1
print(p)
print(glob)
p_len = len(p)
parameter['color'].append(f['color']) parameter['color'].append(f['color'])
parameter['global_parameter'].extend(glob)
print(parameter)
cnt = f['cnt'] cnt = f['cnt']
for p_k, v_k in p.items(): for p_k, v_k in p.items():
if add_idx: if add_idx:
kw_k = {f'{k}_{cnt}': v for k, v in v_k[1].items()} kw_k = {f'{k}_{cnt}': v for k, v in v_k[1].items()}
else: else:
kw_k = v_k[1] kw_k = v_k[1]
if p_k in parameter['parameter']: if p_k in parameter['data_parameter']:
params, kw = parameter['parameter'][p_k] params, kw = parameter['data_parameter'][p_k]
params += v_k[0] params += v_k[0]
kw.update(kw_k) kw.update(kw_k)
else: else:
parameter['parameter'][p_k] = (v_k[0], kw_k) parameter['data_parameter'][p_k] = (v_k[0], kw_k)
for g_k, g_v in glob.items():
if g_k != 'idx':
parameter['glob'][g_k] += g_v
else:
parameter['glob']['idx'] += [idx_i + p_len for idx_i in g_v]
if add_idx: if add_idx:
cnt += 1 cnt += 1
@ -283,37 +273,43 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
data = self.data_table.collect_data(default=self.default_combobox.currentData()) data = self.data_table.collect_data(default=self.default_combobox.currentData())
func_dict = {} func_dict = {}
for k, mod in self.models.items(): for model_name, model_parameter in self.models.items():
func, order, param_len = ModelFactory.create_from_list(mod) func, order, param_len = ModelFactory.create_from_list(model_parameter)
if func is None: if func is None:
continue continue
if k in data: func = Model(func)
parameter, _ = self._prepare(mod, function_use=data[k], add_idx=isinstance(func, MultiModel))
# convert positions of global parameter to corresponding names if model_name in data:
global_parameter: dict = parameter['glob'] parameter, _ = self._prepare(model_parameter, function_use=data[model_name], add_idx=isinstance(func, MultiModel))
positions = global_parameter.pop('idx')
global_parameter['key'] = [pname for i, pname in enumerate(func.params) if i in positions]
# print(global_parameter)
if parameter is None: if parameter is None:
return return
for (data_parameter, _) in parameter['data_parameter'].values():
for pname, param in zip(func.params, data_parameter):
param.name = pname
if self._complex[model_name] is not None:
for p_k, p_v in parameter['data_parameter'].items():
p_v[1].update({'complex_mode': self._complex[model_name]})
parameter['data_parameter'][p_k] = p_v[0], p_v[1]
for pname, param_value in zip(func.params, parameter['global_parameter']):
if param_value is not None:
param_value.name = pname
func.set_global_parameter(param_value)
parameter['func'] = func parameter['func'] = func
parameter['order'] = order parameter['order'] = order
parameter['len'] = param_len parameter['len'] = param_len
parameter['complex'] = self._complex[k] parameter['complex'] = self._complex[model_name]
if self._complex[k] is not None:
for p_k, p_v in parameter['parameter'].items():
p_v[1].update({'complex_mode': self._complex[k]})
parameter['parameter'][p_k] = p_v[0], p_v[1]
func_dict[k] = parameter func_dict[model_name] = parameter
replaceable = [] replaceable = []
for k, v in func_dict.items(): for model_name, v in func_dict.items():
for i, link_i in enumerate(v['links']): for i, link_i in enumerate(v['links']):
if link_i is None: if link_i is None:
continue continue
@ -344,7 +340,7 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
QtWidgets.QMessageBox.Ok) QtWidgets.QMessageBox.Ok)
return return
replaceable.append((k, i, rep_model, repl_idx)) replaceable.append((model_name, i, rep_model, repl_idx))
replace_value = None replace_value = None
for p_k in f['parameter'].values(): for p_k in f['parameter'].values():

View File

@ -441,21 +441,22 @@ class UpperManagement(QtCore.QObject):
# all-encompassing error catch # all-encompassing error catch
try: try:
for model_id, model_p in parameter.items(): for model_id, model_p in parameter.items():
m = Model(model_p['func']) m = model_p['func']
models[model_id] = m models[model_id] = m
m_complex = model_p['complex'] m_complex = model_p['complex']
print(model_p)
# sets are not in active order but in order they first appeared in fit dialog # sets are not in active order but in order they first appeared in fit dialog
# iterate over order of set id in active order and access parameter inside loop # iterate over order of set id in active order and access parameter inside loop
# instead of directly looping # instead of directly looping
list_ids = list(model_p['parameter'].keys()) list_ids = list(model_p['data_parameter'].keys())
set_order = [self.active_id.index(i) for i in list_ids] set_order = [self.active_id.index(i) for i in list_ids]
for pos in set_order: for pos in set_order:
set_id = list_ids[pos] set_id = list_ids[pos]
data_i = self.data[set_id] data_i = self.data[set_id]
set_params = model_p['parameter'][set_id] set_params = model_p['data_parameter'][set_id]
if we_option.lower() == 'deltay': if we_option.lower() == 'deltay':
we = data_i.y_err**2 we = data_i.y_err**2
@ -485,18 +486,13 @@ class UpperManagement(QtCore.QObject):
d = fit_d.Data(_x[inside], _y[inside], we=we[inside], idx=set_id) d = fit_d.Data(_x[inside], _y[inside], we=we[inside], idx=set_id)
d.set_model(m) d.set_model(m)
d.set_parameter(set_params[0], var=model_p['var'], d.set_parameter(set_params[0], fun_kwargs=set_params[1])
lb=model_p['lb'], ub=model_p['ub'], # d.set_parameter(set_params[0], var=model_p['var'],
fun_kwargs=set_params[1]) # lb=model_p['lb'], ub=model_p['ub'],
# fun_kwargs=set_params[1])
self.fitter.add_data(d) self.fitter.add_data(d)
model_globs = model_p['glob']
if model_globs:
for parameter_args in zip(*model_globs.values()):
m.set_global_parameter(**{k: v for k, v in zip(model_globs.keys(), parameter_args)})
# m.set_global_parameter(**model_p['glob'])
for links_i in links: for links_i in links:
self.fitter.set_link_parameter((models[links_i[0]], links_i[1]), self.fitter.set_link_parameter((models[links_i[0]], links_i[1]),
(models[links_i[2]], links_i[3])) (models[links_i[2]], links_i[3]))

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import numpy as np import numpy as np
from .model import Model from .model import Model
@ -69,7 +71,7 @@ class Data(object):
return self.model return self.model
def set_parameter(self, def set_parameter(self,
values: list[float], values: list[float | Parameter],
*, *,
var: list[bool] = None, var: list[bool] = None,
ub: list[float] = None, ub: list[float] = None,
@ -103,6 +105,15 @@ class Data(object):
if len(values) != len(model.params): if len(values) != len(model.params):
raise ValueError('Number of given parameter does not match number of model parameters') raise ValueError('Number of given parameter does not match number of model parameters')
is_parameter = [isinstance(v, Parameter) for v in values]
if all(is_parameter):
for p_i in values:
key = f"p{next(Parameters.parameter_counter)}"
self.parameter.add_parameter(key, p_i)
elif any(is_parameter):
raise ValueError('list of parameter are not all float of Parameter')
else:
if var is None: if var is None:
var = [True] * len(values) var = [True] * len(values)

View File

@ -80,13 +80,21 @@ class Model(object):
if v.default is not inspect.Parameter.empty} if v.default is not inspect.Parameter.empty}
def set_global_parameter(self, def set_global_parameter(self,
key: str, key: str | Parameter,
value: float | str, value: float | str = None,
*,
var: bool = None, var: bool = None,
lb: float = None, lb: float = None,
ub: float = None, ub: float = None,
default_bounds: bool = False default_bounds: bool = False,
) -> Parameter: ) -> Parameter:
if isinstance(key, Parameter):
p = key
key = f'p{next(Parameters.parameter_counter)}'
self.parameter.add_parameter(key, p)
else:
idx = [self.params.index(key)] idx = [self.params.index(key)]
if default_bounds: if default_bounds:
if lb is None: if lb is None: