forked from IPKM/nmreval
use Parameter when collecting fit values
This commit is contained in:
parent
03d172bade
commit
bd1a227e4c
@ -208,11 +208,10 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit):
|
|||||||
if sid not in self.data_values:
|
if sid not in self.data_values:
|
||||||
self.data_values[sid] = [None] * len(self.data_parameter)
|
self.data_values[sid] = [None] * len(self.data_parameter)
|
||||||
|
|
||||||
def get_parameter(self, use_func=None) -> tuple[dict[str, list[Parameter]], list[Optional[Parameter]]]:
|
def get_parameter(self, use_func=None) -> tuple[dict, list]:
|
||||||
bds = []
|
bds = []
|
||||||
is_global = []
|
is_global = []
|
||||||
is_fixed = []
|
is_fixed = []
|
||||||
|
|
||||||
param_general = []
|
param_general = []
|
||||||
|
|
||||||
for g in self.global_parameter:
|
for g in self.global_parameter:
|
||||||
@ -262,16 +261,16 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit):
|
|||||||
else:
|
else:
|
||||||
kw_p[g.argname] = p_i
|
kw_p[g.argname] = p_i
|
||||||
|
|
||||||
|
data_parameter[sid] = (p, kw_p)
|
||||||
|
|
||||||
global_parameter = []
|
global_parameter = []
|
||||||
for param, global_flag in zip(param_general, is_global):
|
for param, global_flag in zip(param_general, is_global):
|
||||||
if global_flag:
|
if global_flag:
|
||||||
|
param.is_global = True
|
||||||
global_parameter.append(param)
|
global_parameter.append(param)
|
||||||
else:
|
else:
|
||||||
global_parameter.append(None)
|
global_parameter.append(None)
|
||||||
|
|
||||||
|
|
||||||
data_parameter[sid] = (p, kw_p)
|
|
||||||
|
|
||||||
return data_parameter, global_parameter
|
return data_parameter, global_parameter
|
||||||
|
|
||||||
def set_parameter(self, set_id: str | None, parameter: list[float]) -> int:
|
def set_parameter(self, set_id: str | None, parameter: list[float]) -> int:
|
||||||
|
@ -9,6 +9,7 @@ import numpy as np
|
|||||||
from pyqtgraph import mkPen
|
from pyqtgraph import mkPen
|
||||||
|
|
||||||
from nmreval.fit._meta import MultiModel, ModelFactory
|
from nmreval.fit._meta import MultiModel, ModelFactory
|
||||||
|
from nmreval.fit.model import Model
|
||||||
from nmreval.fit.result import FitResult
|
from nmreval.fit.result import FitResult
|
||||||
|
|
||||||
from .fit_forms import FitTableWidget
|
from .fit_forms import FitTableWidget
|
||||||
@ -219,16 +220,16 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
|
|||||||
|
|
||||||
def _prepare(self, model: list, function_use: list = None,
|
def _prepare(self, model: list, function_use: list = None,
|
||||||
parameter: dict = None, add_idx: bool = False, cnt: int = 0) -> tuple[dict, int]:
|
parameter: dict = None, add_idx: bool = False, cnt: int = 0) -> tuple[dict, int]:
|
||||||
|
|
||||||
if parameter is None:
|
if parameter is None:
|
||||||
parameter = {
|
parameter = {
|
||||||
'parameter': {},
|
'data_parameter': {},
|
||||||
'glob': [],
|
'global_parameter': [],
|
||||||
'links': [],
|
'links': [],
|
||||||
'color': [],
|
'color': [],
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, f in enumerate(model):
|
for i, f in enumerate(model):
|
||||||
print(i, f)
|
|
||||||
if not f['active']:
|
if not f['active']:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -239,33 +240,22 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
|
|||||||
QtWidgets.QMessageBox.Ok)
|
QtWidgets.QMessageBox.Ok)
|
||||||
return {}, -1
|
return {}, -1
|
||||||
|
|
||||||
print(p)
|
|
||||||
print(glob)
|
|
||||||
p_len = len(p)
|
|
||||||
parameter['color'].append(f['color'])
|
parameter['color'].append(f['color'])
|
||||||
|
parameter['global_parameter'].extend(glob)
|
||||||
print(parameter)
|
|
||||||
|
|
||||||
cnt = f['cnt']
|
cnt = f['cnt']
|
||||||
|
|
||||||
for p_k, v_k in p.items():
|
for p_k, v_k in p.items():
|
||||||
if add_idx:
|
if add_idx:
|
||||||
kw_k = {f'{k}_{cnt}': v for k, v in v_k[1].items()}
|
kw_k = {f'{k}_{cnt}': v for k, v in v_k[1].items()}
|
||||||
else:
|
else:
|
||||||
kw_k = v_k[1]
|
kw_k = v_k[1]
|
||||||
|
|
||||||
if p_k in parameter['parameter']:
|
if p_k in parameter['data_parameter']:
|
||||||
params, kw = parameter['parameter'][p_k]
|
params, kw = parameter['data_parameter'][p_k]
|
||||||
params += v_k[0]
|
params += v_k[0]
|
||||||
kw.update(kw_k)
|
kw.update(kw_k)
|
||||||
else:
|
else:
|
||||||
parameter['parameter'][p_k] = (v_k[0], kw_k)
|
parameter['data_parameter'][p_k] = (v_k[0], kw_k)
|
||||||
|
|
||||||
for g_k, g_v in glob.items():
|
|
||||||
if g_k != 'idx':
|
|
||||||
parameter['glob'][g_k] += g_v
|
|
||||||
else:
|
|
||||||
parameter['glob']['idx'] += [idx_i + p_len for idx_i in g_v]
|
|
||||||
|
|
||||||
if add_idx:
|
if add_idx:
|
||||||
cnt += 1
|
cnt += 1
|
||||||
@ -283,37 +273,43 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
|
|||||||
data = self.data_table.collect_data(default=self.default_combobox.currentData())
|
data = self.data_table.collect_data(default=self.default_combobox.currentData())
|
||||||
|
|
||||||
func_dict = {}
|
func_dict = {}
|
||||||
for k, mod in self.models.items():
|
for model_name, model_parameter in self.models.items():
|
||||||
func, order, param_len = ModelFactory.create_from_list(mod)
|
func, order, param_len = ModelFactory.create_from_list(model_parameter)
|
||||||
|
|
||||||
if func is None:
|
if func is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if k in data:
|
func = Model(func)
|
||||||
parameter, _ = self._prepare(mod, function_use=data[k], add_idx=isinstance(func, MultiModel))
|
|
||||||
|
|
||||||
# convert positions of global parameter to corresponding names
|
if model_name in data:
|
||||||
global_parameter: dict = parameter['glob']
|
parameter, _ = self._prepare(model_parameter, function_use=data[model_name], add_idx=isinstance(func, MultiModel))
|
||||||
positions = global_parameter.pop('idx')
|
|
||||||
global_parameter['key'] = [pname for i, pname in enumerate(func.params) if i in positions]
|
|
||||||
# print(global_parameter)
|
|
||||||
|
|
||||||
if parameter is None:
|
if parameter is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
for (data_parameter, _) in parameter['data_parameter'].values():
|
||||||
|
for pname, param in zip(func.params, data_parameter):
|
||||||
|
param.name = pname
|
||||||
|
|
||||||
|
if self._complex[model_name] is not None:
|
||||||
|
for p_k, p_v in parameter['data_parameter'].items():
|
||||||
|
p_v[1].update({'complex_mode': self._complex[model_name]})
|
||||||
|
parameter['data_parameter'][p_k] = p_v[0], p_v[1]
|
||||||
|
|
||||||
|
for pname, param_value in zip(func.params, parameter['global_parameter']):
|
||||||
|
if param_value is not None:
|
||||||
|
param_value.name = pname
|
||||||
|
func.set_global_parameter(param_value)
|
||||||
|
|
||||||
parameter['func'] = func
|
parameter['func'] = func
|
||||||
parameter['order'] = order
|
parameter['order'] = order
|
||||||
parameter['len'] = param_len
|
parameter['len'] = param_len
|
||||||
parameter['complex'] = self._complex[k]
|
parameter['complex'] = self._complex[model_name]
|
||||||
if self._complex[k] is not None:
|
|
||||||
for p_k, p_v in parameter['parameter'].items():
|
|
||||||
p_v[1].update({'complex_mode': self._complex[k]})
|
|
||||||
parameter['parameter'][p_k] = p_v[0], p_v[1]
|
|
||||||
|
|
||||||
func_dict[k] = parameter
|
func_dict[model_name] = parameter
|
||||||
|
|
||||||
replaceable = []
|
replaceable = []
|
||||||
for k, v in func_dict.items():
|
for model_name, v in func_dict.items():
|
||||||
for i, link_i in enumerate(v['links']):
|
for i, link_i in enumerate(v['links']):
|
||||||
if link_i is None:
|
if link_i is None:
|
||||||
continue
|
continue
|
||||||
@ -344,7 +340,7 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
|
|||||||
QtWidgets.QMessageBox.Ok)
|
QtWidgets.QMessageBox.Ok)
|
||||||
return
|
return
|
||||||
|
|
||||||
replaceable.append((k, i, rep_model, repl_idx))
|
replaceable.append((model_name, i, rep_model, repl_idx))
|
||||||
|
|
||||||
replace_value = None
|
replace_value = None
|
||||||
for p_k in f['parameter'].values():
|
for p_k in f['parameter'].values():
|
||||||
|
@ -441,21 +441,22 @@ class UpperManagement(QtCore.QObject):
|
|||||||
# all-encompassing error catch
|
# all-encompassing error catch
|
||||||
try:
|
try:
|
||||||
for model_id, model_p in parameter.items():
|
for model_id, model_p in parameter.items():
|
||||||
m = Model(model_p['func'])
|
m = model_p['func']
|
||||||
models[model_id] = m
|
models[model_id] = m
|
||||||
|
|
||||||
m_complex = model_p['complex']
|
m_complex = model_p['complex']
|
||||||
|
print(model_p)
|
||||||
|
|
||||||
# sets are not in active order but in order they first appeared in fit dialog
|
# sets are not in active order but in order they first appeared in fit dialog
|
||||||
# iterate over order of set id in active order and access parameter inside loop
|
# iterate over order of set id in active order and access parameter inside loop
|
||||||
# instead of directly looping
|
# instead of directly looping
|
||||||
list_ids = list(model_p['parameter'].keys())
|
list_ids = list(model_p['data_parameter'].keys())
|
||||||
set_order = [self.active_id.index(i) for i in list_ids]
|
set_order = [self.active_id.index(i) for i in list_ids]
|
||||||
for pos in set_order:
|
for pos in set_order:
|
||||||
set_id = list_ids[pos]
|
set_id = list_ids[pos]
|
||||||
|
|
||||||
data_i = self.data[set_id]
|
data_i = self.data[set_id]
|
||||||
set_params = model_p['parameter'][set_id]
|
set_params = model_p['data_parameter'][set_id]
|
||||||
|
|
||||||
if we_option.lower() == 'deltay':
|
if we_option.lower() == 'deltay':
|
||||||
we = data_i.y_err**2
|
we = data_i.y_err**2
|
||||||
@ -485,18 +486,13 @@ class UpperManagement(QtCore.QObject):
|
|||||||
d = fit_d.Data(_x[inside], _y[inside], we=we[inside], idx=set_id)
|
d = fit_d.Data(_x[inside], _y[inside], we=we[inside], idx=set_id)
|
||||||
|
|
||||||
d.set_model(m)
|
d.set_model(m)
|
||||||
d.set_parameter(set_params[0], var=model_p['var'],
|
d.set_parameter(set_params[0], fun_kwargs=set_params[1])
|
||||||
lb=model_p['lb'], ub=model_p['ub'],
|
# d.set_parameter(set_params[0], var=model_p['var'],
|
||||||
fun_kwargs=set_params[1])
|
# lb=model_p['lb'], ub=model_p['ub'],
|
||||||
|
# fun_kwargs=set_params[1])
|
||||||
|
|
||||||
self.fitter.add_data(d)
|
self.fitter.add_data(d)
|
||||||
|
|
||||||
model_globs = model_p['glob']
|
|
||||||
if model_globs:
|
|
||||||
for parameter_args in zip(*model_globs.values()):
|
|
||||||
m.set_global_parameter(**{k: v for k, v in zip(model_globs.keys(), parameter_args)})
|
|
||||||
# m.set_global_parameter(**model_p['glob'])
|
|
||||||
|
|
||||||
for links_i in links:
|
for links_i in links:
|
||||||
self.fitter.set_link_parameter((models[links_i[0]], links_i[1]),
|
self.fitter.set_link_parameter((models[links_i[0]], links_i[1]),
|
||||||
(models[links_i[2]], links_i[3]))
|
(models[links_i[2]], links_i[3]))
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .model import Model
|
from .model import Model
|
||||||
@ -69,7 +71,7 @@ class Data(object):
|
|||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def set_parameter(self,
|
def set_parameter(self,
|
||||||
values: list[float],
|
values: list[float | Parameter],
|
||||||
*,
|
*,
|
||||||
var: list[bool] = None,
|
var: list[bool] = None,
|
||||||
ub: list[float] = None,
|
ub: list[float] = None,
|
||||||
@ -103,6 +105,15 @@ class Data(object):
|
|||||||
if len(values) != len(model.params):
|
if len(values) != len(model.params):
|
||||||
raise ValueError('Number of given parameter does not match number of model parameters')
|
raise ValueError('Number of given parameter does not match number of model parameters')
|
||||||
|
|
||||||
|
is_parameter = [isinstance(v, Parameter) for v in values]
|
||||||
|
if all(is_parameter):
|
||||||
|
for p_i in values:
|
||||||
|
key = f"p{next(Parameters.parameter_counter)}"
|
||||||
|
self.parameter.add_parameter(key, p_i)
|
||||||
|
elif any(is_parameter):
|
||||||
|
raise ValueError('list of parameter are not all float of Parameter')
|
||||||
|
|
||||||
|
else:
|
||||||
if var is None:
|
if var is None:
|
||||||
var = [True] * len(values)
|
var = [True] * len(values)
|
||||||
|
|
||||||
|
@ -80,13 +80,21 @@ class Model(object):
|
|||||||
if v.default is not inspect.Parameter.empty}
|
if v.default is not inspect.Parameter.empty}
|
||||||
|
|
||||||
def set_global_parameter(self,
|
def set_global_parameter(self,
|
||||||
key: str,
|
key: str | Parameter,
|
||||||
value: float | str,
|
value: float | str = None,
|
||||||
|
*,
|
||||||
var: bool = None,
|
var: bool = None,
|
||||||
lb: float = None,
|
lb: float = None,
|
||||||
ub: float = None,
|
ub: float = None,
|
||||||
default_bounds: bool = False
|
default_bounds: bool = False,
|
||||||
) -> Parameter:
|
) -> Parameter:
|
||||||
|
|
||||||
|
if isinstance(key, Parameter):
|
||||||
|
p = key
|
||||||
|
key = f'p{next(Parameters.parameter_counter)}'
|
||||||
|
self.parameter.add_parameter(key, p)
|
||||||
|
|
||||||
|
else:
|
||||||
idx = [self.params.index(key)]
|
idx = [self.params.index(key)]
|
||||||
if default_bounds:
|
if default_bounds:
|
||||||
if lb is None:
|
if lb is None:
|
||||||
|
Loading…
Reference in New Issue
Block a user